Skip to content

Commit 439930c

Browse files
bdhirshfacebook-github-bot
authored andcommittedSep 25, 2020
adding a beta parameter to the smooth_l1 loss fn (pytorch#44433)
Summary: Pull Request resolved: pytorch#44433 Not entirely sure why, but changing the type of beta from `float` to `double in autocast_mode.cpp and FunctionsManual.h fixes my compiler errors, failing instead at link time fixing some type errors, updated fn signature in a few more files removing my usage of Scalar, making beta a double everywhere instead Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D23636720 Pulled By: bdhirsh fbshipit-source-id: caea2a1f8dd72b3b5fd1d72dd886b2fcd690af6d
1 parent 37513a1 commit 439930c

18 files changed

+170
-86
lines changed
 

‎aten/src/ATen/autocast_mode.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
357357
KERNEL(ADD_NS(hinge_embedding_loss), "hinge_embedding_loss", Tensor (const Tensor &, const Tensor &, double, int64_t), fp32)
358358
KERNEL(ADD_NS(kl_div), "kl_div", Tensor (const Tensor &, const Tensor &, int64_t, bool), fp32)
359359
KERNEL(ADD_NS(l1_loss), "l1_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32)
360-
KERNEL(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32)
360+
KERNEL(ADD_NS(smooth_l1_loss), "smooth_l1_loss", Tensor (const Tensor &, const Tensor &, int64_t, double), fp32)
361361
KERNEL(ADD_NS(mse_loss), "mse_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32)
362362
KERNEL(ADD_NS(margin_ranking_loss), "margin_ranking_loss", Tensor (const Tensor &, const Tensor &, const Tensor &, double, int64_t), fp32)
363363
KERNEL(ADD_NS(multilabel_margin_loss), "multilabel_margin_loss", Tensor (const Tensor &, const Tensor &, int64_t), fp32)

‎aten/src/ATen/native/BinaryOps.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ inline void sub_check(const Tensor& self, const Tensor& other) {
2525
}
2626

2727
using binary_fn_alpha = void(*)(TensorIterator&, Scalar alpha);
28+
using binary_fn_beta = void(*)(TensorIterator&, double beta);
2829
using binary_fn = void(*)(TensorIterator&);
2930
using binary_clamp_fn_alpha =
3031
void(*)(TensorIterator&, Scalar alpha, Scalar min_val, Scalar max_val);
@@ -54,7 +55,7 @@ DECLARE_DISPATCH(binary_fn, max_elementwise_stub);
5455
DECLARE_DISPATCH(binary_fn, min_elementwise_stub);
5556
DECLARE_DISPATCH(binary_fn, maximum_stub);
5657
DECLARE_DISPATCH(binary_fn, minimum_stub);
57-
DECLARE_DISPATCH(binary_fn, smooth_l1_stub);
58+
DECLARE_DISPATCH(binary_fn_beta, smooth_l1_stub);
5859
DECLARE_DISPATCH(binary_fn, sigmoid_backward_stub);
5960
DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub);
6061
DECLARE_DISPATCH(binary_fn, tanh_backward_stub);

‎aten/src/ATen/native/Loss.cpp

+24-9
Original file line numberDiff line numberDiff line change
@@ -295,38 +295,53 @@ Tensor soft_margin_loss(
295295
return output;
296296
}
297297

298-
Tensor smooth_l1_loss(const Tensor& input, const Tensor& target, const int64_t reduction) {
298+
Tensor smooth_l1_loss(const Tensor& input, const Tensor& target, const int64_t reduction, double beta) {
299+
if (beta <= 0)
300+
return at::native::l1_loss(input, target, reduction);
299301
Tensor loss;
300302
auto iter = TensorIterator::binary_op(loss, input, target);
301-
smooth_l1_stub(iter.device_type(), iter);
303+
smooth_l1_stub(iter.device_type(), iter, beta);
302304
return apply_loss_reduction(iter.output(), reduction);
303305
}
304306

305-
Tensor& smooth_l1_loss_out(Tensor& result, const Tensor& input, const Tensor& target, int64_t reduction) {
307+
Tensor& smooth_l1_loss_out(Tensor& result, const Tensor& input, const Tensor& target, int64_t reduction, double beta) {
308+
if (beta <= 0)
309+
return at::native::l1_loss_out(result, input, target, reduction);
306310
if (reduction != Reduction::None) {
307-
result = at::smooth_l1_loss(input, target, reduction);
311+
Tensor loss;
312+
auto iter = TensorIterator::binary_op(loss, input, target);
313+
smooth_l1_stub(iter.device_type(), iter, beta);
314+
if (reduction == Reduction::Mean) {
315+
at::mean_out(result, iter.output(), 0);
316+
} else {
317+
at::sum_out(result, iter.output(), 0);
318+
}
308319
} else {
309320
auto iter = TensorIterator::binary_op(result, input, target);
310-
smooth_l1_stub(iter.device_type(), iter);
321+
smooth_l1_stub(iter.device_type(), iter, beta);
311322
}
312323
return result;
313324
}
314325

315-
Tensor& smooth_l1_loss_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
326+
Tensor& smooth_l1_loss_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta) {
327+
if (beta <= 0)
328+
return at::native::l1_loss_backward_out(grad_input, grad_output, input, target, reduction);
316329
auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.;
317330
auto iter = at::TensorIteratorConfig()
318331
.add_output(grad_input)
319332
.add_input(input)
320333
.add_input(target)
321334
.add_input(grad_output)
322335
.build();
323-
smooth_l1_backward_stub(iter.device_type(), iter, norm);
336+
smooth_l1_backward_stub(iter.device_type(), iter, norm, beta);
324337
return grad_input;
325338
}
326339

327-
Tensor smooth_l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
340+
Tensor smooth_l1_loss_backward(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, double beta) {
341+
if (beta <= 0)
342+
return at::native::l1_loss_backward(grad_output, input, target, reduction);
328343
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
329-
return at::smooth_l1_loss_backward_out(grad_input, grad_output, input, target, reduction);
344+
return at::smooth_l1_loss_backward_out(grad_input, grad_output, input, target, reduction, beta);
330345
}
331346

332347
Tensor mse_loss(const Tensor& input, const Tensor& target, int64_t reduction) {

‎aten/src/ATen/native/PointwiseOps.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,11 @@ struct TensorIterator;
1111
namespace native {
1212

1313
using pointwise_fn = void (*)(TensorIterator&, Scalar scalar);
14+
using pointwise_fn_beta = void (*)(TensorIterator&, Scalar scalar, double beta);
1415

1516
DECLARE_DISPATCH(pointwise_fn, addcmul_stub);
1617
DECLARE_DISPATCH(pointwise_fn, addcdiv_stub);
17-
DECLARE_DISPATCH(pointwise_fn, smooth_l1_backward_stub);
18+
DECLARE_DISPATCH(pointwise_fn_beta, smooth_l1_backward_stub);
1819
DECLARE_DISPATCH(pointwise_fn, mse_backward_stub);
1920

2021
} // namespace native

‎aten/src/ATen/native/cpu/BinaryOpsKernel.cpp

+9-8
Original file line numberDiff line numberDiff line change
@@ -502,24 +502,25 @@ void minimum_kernel(TensorIterator& iter) {
502502
}
503503
}
504504

505-
void smooth_l1_kernel(TensorIterator& iter) {
505+
void smooth_l1_kernel(TensorIterator& iter, double beta) {
506506
AT_DISPATCH_FLOATING_TYPES_AND2(
507507
kBFloat16, kHalf, iter.dtype(), "smooth_l1_cpu", [&]() {
508508
using Vec = Vec256<scalar_t>;
509-
const Vec one_vec(static_cast<scalar_t>(1));
509+
const scalar_t beta_val(beta);
510+
const Vec beta_val_vec(beta_val);
510511
const Vec point_five_vec(static_cast<scalar_t>(0.5));
511512
cpu_kernel_vec(
512513
iter,
513-
[](scalar_t a, scalar_t b) -> scalar_t {
514+
[&beta_val](scalar_t a, scalar_t b) -> scalar_t {
514515
auto z = std::abs(a - b);
515-
return z < static_cast<scalar_t>(1)
516-
? static_cast<scalar_t>(0.5) * z * z
517-
: z - static_cast<scalar_t>(0.5);
516+
return z < beta_val
517+
? static_cast<scalar_t>(0.5) * z * z / beta_val
518+
: z - static_cast<scalar_t>(0.5) * beta_val;
518519
},
519-
[&one_vec, &point_five_vec](Vec a, Vec b) {
520+
[&beta_val_vec, &point_five_vec](Vec a, Vec b) {
520521
auto z = (a - b).abs();
521522
return Vec::blendv(
522-
point_five_vec * z * z, z - point_five_vec, z >= one_vec);
523+
point_five_vec * z * z / beta_val_vec, z - point_five_vec * beta_val_vec, z >= beta_val_vec);
523524
});
524525
});
525526
}

‎aten/src/ATen/native/cpu/PointwiseOpsKernel.cpp

+19-8
Original file line numberDiff line numberDiff line change
@@ -46,28 +46,39 @@ static void addcdiv_cpu_kernel(TensorIterator& iter, Scalar value) {
4646
});
4747
}
4848

49-
static void smooth_l1_backward_cpu_kernel(TensorIterator& iter, Scalar norm) {
49+
static void smooth_l1_backward_cpu_kernel(TensorIterator& iter, Scalar norm, double beta) {
5050
ScalarType dtype = iter.dtype(0);
5151
AT_DISPATCH_ALL_TYPES(dtype, "smooth_l1_backward_cpu_out", [&] {
5252
auto norm_val = norm.to<scalar_t>();
53+
scalar_t beta_val(beta);
5354
auto norm_val_vec = Vec256<scalar_t>(norm_val);
55+
auto beta_val_vec = Vec256<scalar_t>(beta_val);
5456
const auto neg_1_vec = Vec256<scalar_t>(-1);
57+
const auto zero_vec = Vec256<scalar_t>(0);
5558
const auto pos_1_vec = Vec256<scalar_t>(1);
5659
cpu_kernel_vec(iter,
5760
[=](scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
5861
const auto x = input - target;
59-
if (x < -1.)
62+
if (x <= -beta)
6063
return -norm_val * grad_output;
61-
else if (x > 1.)
64+
else if (x >= beta)
6265
return norm_val * grad_output;
6366
else
64-
return norm_val * x * grad_output;
67+
return norm_val * x * grad_output / beta;
6568
},
66-
[norm_val_vec, neg_1_vec, pos_1_vec](
69+
[norm_val_vec, beta_val_vec, neg_1_vec, zero_vec, pos_1_vec](
6770
Vec256<scalar_t> input, Vec256<scalar_t> target, Vec256<scalar_t> grad_output) -> Vec256<scalar_t> {
68-
auto x = input - target;
69-
x = clamp(x, neg_1_vec, pos_1_vec);
70-
return norm_val_vec * x * grad_output;
71+
// using two blendv calls to simulate the 3 cases
72+
// 1 if x >= beta
73+
// -1 if x <= -beta
74+
// x / beta if |x| < beta
75+
const auto x = input - target;
76+
const auto pos_or_neg_1_vec = Vec256<scalar_t>::blendv(
77+
neg_1_vec, pos_1_vec, x > zero_vec);
78+
const auto x_abs = x.abs();
79+
const auto output = Vec256<scalar_t>::blendv(
80+
x / beta_val_vec, pos_or_neg_1_vec, x_abs >= beta_val_vec);
81+
return norm_val_vec * output * grad_output;
7182
}
7283
);
7384
});

‎aten/src/ATen/native/cuda/BinaryMiscOpsKernels.cu

+5-4
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ void atan2_kernel_cuda(TensorIterator& iter) {
1919
});
2020
}
2121

22-
void smooth_l1_kernel_cuda(TensorIterator& iter) {
23-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "smooth_l1_cuda", [&]() {
24-
gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
22+
void smooth_l1_kernel_cuda(TensorIterator& iter, double beta) {
23+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "smooth_l1_cuda", [&iter, beta]() {
24+
scalar_t beta_val(beta);
25+
gpu_kernel(iter, [beta_val] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
2526
auto z = ::abs(a - b);
26-
return z < scalar_t(1.) ? scalar_t(0.5) * z * z : z - scalar_t(0.5);
27+
return z < beta_val ? scalar_t(0.5) * z * z / beta_val : z - scalar_t(0.5) * beta_val;
2728
});
2829
});
2930
}

‎aten/src/ATen/native/cuda/PointwiseOpsKernel.cu

+7-6
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,18 @@ void addcdiv_cuda_kernel(TensorIterator& iter, Scalar value) {
2626
});
2727
}
2828

29-
void smooth_l1_backward_cuda_kernel(TensorIterator& iter, Scalar norm) {
30-
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "smooth_l1_backward_cuda", [&]() {
29+
void smooth_l1_backward_cuda_kernel(TensorIterator& iter, Scalar norm, double beta) {
30+
AT_DISPATCH_ALL_TYPES_AND(kHalf, iter.dtype(), "smooth_l1_backward_cuda", [&iter, &norm, beta] {
3131
auto norm_val = norm.to<scalar_t>();
32-
gpu_kernel(iter, [norm_val]GPU_LAMBDA(scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
32+
scalar_t beta_val(beta);
33+
gpu_kernel(iter, [norm_val, beta_val]GPU_LAMBDA(scalar_t input, scalar_t target, scalar_t grad_output) -> scalar_t {
3334
const auto x = input - target;
34-
if (x < scalar_t(-1))
35+
if (x < -beta_val)
3536
return -norm_val * grad_output;
36-
else if (x > scalar_t(1))
37+
else if (x > beta_val)
3738
return norm_val * grad_output;
3839
else
39-
return norm_val * x * grad_output;
40+
return norm_val * x * grad_output / beta_val;
4041
});
4142
});
4243
}

‎aten/src/ATen/native/native_functions.yaml

+4-4
Original file line numberDiff line numberDiff line change
@@ -6767,25 +6767,25 @@
67676767
CPU: nll_loss2d_backward_cpu
67686768
CUDA: legacy::cuda::_thnn_nll_loss2d_backward
67696769

6770-
- func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)
6770+
- func: smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)
67716771
python_module: nn
67726772
dispatch:
67736773
CPU: smooth_l1_loss_out
67746774
CUDA: smooth_l1_loss_out
67756775

6776-
- func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
6776+
- func: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor
67776777
use_c10_dispatcher: full
67786778
python_module: nn
67796779
dispatch:
67806780
CPU, CUDA: smooth_l1_loss
67816781

6782-
- func: smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)
6782+
- func: smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta=1.0, *, Tensor(a!) grad_input) -> Tensor(a!)
67836783
python_module: nn
67846784
dispatch:
67856785
CPU: smooth_l1_loss_backward_out
67866786
CUDA: smooth_l1_loss_backward_out
67876787

6788-
- func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
6788+
- func: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta=1.0) -> Tensor
67896789
use_c10_dispatcher: full
67906790
python_module: nn
67916791

‎test/cpp/api/functional.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,18 @@ TEST_F(FunctionalTest, SmoothL1LossDefaultOptions) {
246246
ASSERT_TRUE(input.sizes() == input.grad().sizes());
247247
}
248248

249+
TEST_F(FunctionalTest, SmoothL1LossBeta) {
250+
auto input = torch::tensor({0.1, 1.5, 10.0}, torch::dtype(torch::kFloat).requires_grad(true));
251+
auto target = torch::tensor({0., 1., 5.}, torch::kFloat);
252+
auto output =
253+
F::smooth_l1_loss(input, target, /*reduction=*/torch::kMean, /*beta=*/0.5);
254+
auto expected = torch::tensor(1.67, torch::kFloat);
255+
auto s = output.sum();
256+
s.backward();
257+
ASSERT_TRUE(output.allclose(expected));
258+
ASSERT_TRUE(input.sizes() == input.grad().sizes());
259+
}
260+
249261
TEST_F(FunctionalTest, SmoothL1LossNoReduction) {
250262
auto input = torch::tensor({0.1, 1.2, 4.7}, torch::dtype(torch::kFloat).requires_grad(true));
251263
auto target = torch::tensor({0., 1., 5.}, torch::kFloat);

‎tools/autograd/derivatives.yaml

+7-7
Original file line numberDiff line numberDiff line change
@@ -1221,9 +1221,9 @@
12211221
self: nll_loss2d_backward(grad, self, target, weight, reduction, ignore_index, total_weight)
12221222
target: non_differentiable
12231223

1224-
- name: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
1225-
self: smooth_l1_loss_backward(grad, self, target, reduction)
1226-
target: smooth_l1_loss_backward(grad, target, self, reduction)
1224+
- name: smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor
1225+
self: smooth_l1_loss_backward(grad, self, target, reduction, beta)
1226+
target: smooth_l1_loss_backward(grad, target, self, reduction, beta)
12271227

12281228
- name: soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
12291229
self: soft_margin_loss_backward(grad, self, target, reduction)
@@ -1589,10 +1589,10 @@
15891589
grad_output: replication_pad3d(grad, padding)
15901590
self: zeros_like(self)
15911591

1592-
- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
1593-
grad_output: smooth_l1_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
1594-
self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction)
1595-
target: -smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction)
1592+
- name: smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta=1.0) -> Tensor
1593+
grad_output: smooth_l1_loss_double_backward_grad_output(grad, grad_output, self, target, reduction, beta)
1594+
self: smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta)
1595+
target: -smooth_l1_loss_double_backward(grad * grad_output, self, target, reduction, beta)
15961596

15971597
- name: softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, Tensor output) -> Tensor
15981598
grad_output: softplus_backward(grad, self, beta, threshold, output)

‎torch/csrc/api/include/torch/nn/functional/loss.h

+8-6
Original file line numberDiff line numberDiff line change
@@ -307,25 +307,26 @@ inline Tensor cosine_embedding_loss(
307307

308308
// ============================================================================
309309

310-
inline Tensor _smooth_l1_loss(const Tensor& input, const Tensor& target) {
310+
inline Tensor _smooth_l1_loss(const Tensor& input, const Tensor& target, double beta = 1.) {
311311
auto t = torch::abs(input - target);
312-
return torch::where(t < 1, 0.5 * torch::pow(t, 2), t - 0.5);
312+
return torch::where(t < beta, 0.5 * torch::pow(t, 2) / beta, t - 0.5 * beta);
313313
}
314314

315315
#ifndef DOXYGEN_SHOULD_SKIP_THIS
316316
namespace detail {
317317
inline Tensor smooth_l1_loss(
318318
const Tensor& input,
319319
const Tensor& target,
320-
SmoothL1LossFuncOptions::reduction_t reduction) {
320+
SmoothL1LossFuncOptions::reduction_t reduction,
321+
double beta = 1.) {
321322
if (target.sizes() != input.sizes()) {
322323
TORCH_WARN("Using a target size (", target.sizes(), ") that is different to the input size (", input.sizes(), "). ",
323324
"This will likely lead to incorrect results due to broadcasting. ",
324325
"Please ensure they have the same size.");
325326
}
326327

327328
std::vector<Tensor> expanded_tensors = torch::broadcast_tensors({input, target});
328-
return torch::smooth_l1_loss(expanded_tensors[0], expanded_tensors[1], enumtype::reduction_get_enum(reduction));
329+
return torch::smooth_l1_loss(expanded_tensors[0], expanded_tensors[1], enumtype::reduction_get_enum(reduction), beta);
329330
}
330331
} // namespace detail
331332
#endif /* DOXYGEN_SHOULD_SKIP_THIS */
@@ -344,8 +345,9 @@ inline Tensor smooth_l1_loss(
344345
inline Tensor smooth_l1_loss(
345346
const Tensor& input,
346347
const Tensor& target,
347-
const SmoothL1LossFuncOptions& options = {}) {
348-
return detail::smooth_l1_loss(input, target, options.reduction());
348+
const SmoothL1LossFuncOptions& options = {},
349+
double beta = 1.) {
350+
return detail::smooth_l1_loss(input, target, options.reduction(), beta);
349351
}
350352

351353
// ============================================================================

‎torch/csrc/autograd/FunctionsManual.cpp

+9-5
Original file line numberDiff line numberDiff line change
@@ -957,20 +957,24 @@ Tensor l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & i
957957
return output;
958958
}
959959

960-
Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction) {
960+
Tensor smooth_l1_loss_double_backward(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, double beta) {
961+
// special case to protect against a divide-by-zero.
962+
if (beta == 0) {
963+
return at::zeros(grad.sizes(), grad.options());
964+
}
961965
auto d = (input - target).abs();
962-
auto grad_input = grad * (d < 1).type_as(grad);
966+
auto grad_input = grad * (d < beta).type_as(grad) / beta;
963967
if (reduction == at::Reduction::Mean) {
964968
grad_input /= input.numel();
965969
}
966970
return grad_input;
967971
}
968972

969-
Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction) {
973+
Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Tensor & grad_output, const Tensor & input, const Tensor & target, int64_t reduction, double beta) {
970974
if (reduction == at::Reduction::None) {
971-
return smooth_l1_loss_backward(grad, input, target, reduction);
975+
return smooth_l1_loss_backward(grad, input, target, reduction, beta);
972976
}
973-
auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction);
977+
auto r = smooth_l1_loss_backward(ones_like(grad_output), input, target, reduction, beta);
974978
return (r * grad).sum();
975979
}
976980

‎torch/csrc/autograd/FunctionsManual.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ at::Tensor log_softmax_double_backward(const at::Tensor & grad, const at::Tensor
104104
at::Tensor binary_cross_entropy_double_backward(const at::Tensor & grad_output, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, const c10::optional<at::Tensor>& weight, int64_t reduction);
105105
at::Tensor binary_cross_entropy_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, const c10::optional<at::Tensor>& weight, int64_t reduction);
106106
at::Tensor l1_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction);
107-
at::Tensor smooth_l1_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction);
108-
at::Tensor smooth_l1_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction);
107+
at::Tensor smooth_l1_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction, double beta);
108+
at::Tensor smooth_l1_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction, double beta);
109109
at::Tensor mse_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, int64_t reduction);
110110
at::Tensor mse_loss_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & target, int64_t reduction);
111111
at::Tensor soft_margin_loss_double_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction);

0 commit comments

Comments
 (0)
Please sign in to comment.