Skip to content

Commit 0638940

Browse files
zasdfgbnmfacebook-github-bot
authored andcommittedSep 18, 2020
CUDA BFloat activations 1 (pytorch#44834)
Summary: Pull Request resolved: pytorch#44834 Reviewed By: mruberry Differential Revision: D23752660 Pulled By: ngimel fbshipit-source-id: 209a937e8a9afe12b7dd86ecfa493c9417fd22fb
1 parent 76a109c commit 0638940

File tree

2 files changed

+35
-56
lines changed

2 files changed

+35
-56
lines changed
 

‎aten/src/ATen/native/cuda/Activation.cu

+35-55
Original file line numberDiff line numberDiff line change
@@ -246,33 +246,27 @@ std::tuple<Tensor, Tensor> prelu_backward_cuda(const Tensor& grad_out_, const Te
246246
// -----------------------------------
247247
void hardshrink_kernel(TensorIterator& iter, Scalar value) {
248248
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardshrink_cuda", [&]() {
249-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "hardshrink_cuda", [&] {
250-
auto lambd = value.to<scalar_t>();
251-
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
252-
return (a >= -lambd && a <= lambd) ? scalar_t(0) : a;
253-
});
249+
auto lambd = value.to<scalar_t>();
250+
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
251+
return (a >= -lambd && a <= lambd) ? scalar_t(0) : a;
254252
});
255253
});
256254
}
257255

258256
void softshrink_kernel(TensorIterator& iter, Scalar value) {
259257
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softshrink_cuda", [&]() {
260-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "softshrink_cuda", [&] {
261-
auto lambd = value.to<scalar_t>();
262-
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
263-
return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0));
264-
});
258+
auto lambd = value.to<scalar_t>();
259+
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t a) -> scalar_t {
260+
return a > lambd ? a - lambd : (a < -lambd ? a + lambd : scalar_t(0));
265261
});
266262
});
267263
}
268264

269265
void shrink_backward_kernel(TensorIterator& iter, Scalar value) {
270266
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "shrink_backward_cuda", [&]() {
271-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "shrink_backward_cuda", [&] {
272-
auto lambd = value.to<scalar_t>();
273-
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t {
274-
return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) : grad_val;
275-
});
267+
auto lambd = value.to<scalar_t>();
268+
gpu_kernel(iter, [lambd]GPU_LAMBDA(scalar_t grad_val, scalar_t self_val) -> scalar_t {
269+
return (self_val >= -lambd && self_val <= lambd) ? scalar_t(0) : grad_val;
276270
});
277271
});
278272
}
@@ -289,25 +283,21 @@ void hardtanh_backward_kernel(TensorIterator& iter, Scalar min, Scalar max) {
289283

290284
void softplus_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
291285
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softplus_cuda", [&]() {
292-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "softplus_cuda", [&] {
293-
auto beta = beta_.to<scalar_t>();
294-
auto threshold = threshold_.to<scalar_t>();
295-
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a) -> scalar_t {
296-
return (a * beta) > threshold ? a : static_cast<scalar_t>(::log1p(std::exp(a * beta))) / beta;
297-
});
286+
auto beta = beta_.to<scalar_t>();
287+
auto threshold = threshold_.to<scalar_t>();
288+
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a) -> scalar_t {
289+
return (a * beta) > threshold ? a : static_cast<scalar_t>(::log1p(std::exp(a * beta))) / beta;
298290
});
299291
});
300292
}
301293

302294
void softplus_backward_kernel(TensorIterator& iter, Scalar beta_, Scalar threshold_) {
303295
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "softplus_backward_cuda", [&]() {
304-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "softplus_backward_cuda", [&] {
305-
auto beta = beta_.to<scalar_t>();
306-
auto threshold = threshold_.to<scalar_t>();
307-
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
308-
scalar_t z = std::exp(b * beta);
309-
return (b * beta) > threshold ? a : a * (z - scalar_t(1.)) / z;
310-
});
296+
auto beta = beta_.to<scalar_t>();
297+
auto threshold = threshold_.to<scalar_t>();
298+
gpu_kernel(iter, [beta, threshold]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
299+
scalar_t z = std::exp(b * beta);
300+
return (b * beta) > threshold ? a : a * (z - scalar_t(1.)) / z;
311301
});
312302
});
313303
}
@@ -321,34 +311,28 @@ void threshold_kernel_impl(TensorIterator& iter, scalar_t threshold, scalar_t va
321311

322312
static void threshold_kernel(TensorIterator& iter, Scalar threshold, Scalar value) {
323313
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "threshold_cuda", [&] {
324-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "threshold_cuda", [&] {
325-
threshold_kernel_impl<scalar_t>(iter, threshold.to<scalar_t>(), value.to<scalar_t>());
326-
});
314+
threshold_kernel_impl<scalar_t>(iter, threshold.to<scalar_t>(), value.to<scalar_t>());
327315
});
328316
}
329317

330318
void elu_kernel(TensorIterator& iter, Scalar alpha, Scalar scale, Scalar input_scale) {
331319
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_cuda", [&]() {
332-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "elu_cuda", [&] {
333-
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
334-
auto poscoef = scale.to<scalar_t>();
335-
auto negiptcoef = input_scale.to<scalar_t>();
336-
gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a) -> scalar_t {
337-
return a > scalar_t(0) ? a * poscoef : (static_cast<scalar_t>(std::exp(a * negiptcoef)) - scalar_t(1.)) * negcoef;
338-
});
320+
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
321+
auto poscoef = scale.to<scalar_t>();
322+
auto negiptcoef = input_scale.to<scalar_t>();
323+
gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a) -> scalar_t {
324+
return a > scalar_t(0) ? a * poscoef : (static_cast<scalar_t>(std::exp(a * negiptcoef)) - scalar_t(1.)) * negcoef;
339325
});
340326
});
341327
}
342328

343329
void elu_backward_kernel(TensorIterator& iter, Scalar alpha, Scalar scale, Scalar input_scale) {
344330
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "elu_backward_cuda", [&]() {
345-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "elu_backward_cuda", [&] {
346-
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
347-
auto poscoef = scale.to<scalar_t>();
348-
auto negiptcoef = input_scale.to<scalar_t>();
349-
gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
350-
return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef;
351-
});
331+
auto negcoef = alpha.to<scalar_t>() * scale.to<scalar_t>();
332+
auto poscoef = scale.to<scalar_t>();
333+
auto negiptcoef = input_scale.to<scalar_t>();
334+
gpu_kernel(iter, [negcoef, poscoef, negiptcoef]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
335+
return b <= scalar_t(0) ? a * negiptcoef * (b + negcoef) : a * poscoef;
352336
});
353337
});
354338
}
@@ -387,22 +371,18 @@ void GeluBackwardCUDAKernelImpl(TensorIterator& it) {
387371

388372
void leaky_relu_kernel(TensorIterator& iter, Scalar negval_) {
389373
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "leaky_relu_cuda", [&]() {
390-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "leaky_relu_cuda", [&] {
391-
auto negval = negval_.to<scalar_t>();
392-
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a) -> scalar_t {
393-
return a > scalar_t(0) ? a : a * negval;
394-
});
374+
auto negval = negval_.to<scalar_t>();
375+
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a) -> scalar_t {
376+
return a > scalar_t(0) ? a : a * negval;
395377
});
396378
});
397379
}
398380

399381
void leaky_relu_backward_kernel(TensorIterator& iter, Scalar negval_) {
400382
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "leaky_relu_backward_cuda", [&]() {
401-
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "leaky_relu_backward_cuda", [&] {
402-
auto negval = negval_.to<scalar_t>();
403-
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
404-
return a > scalar_t(0) ? b : b * negval;
405-
});
383+
auto negval = negval_.to<scalar_t>();
384+
gpu_kernel(iter, [negval]GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
385+
return a > scalar_t(0) ? b : b * negval;
406386
});
407387
});
408388
}

‎test/test_nn.py

-1
Original file line numberDiff line numberDiff line change
@@ -11832,7 +11832,6 @@ def _test_bfloat16_ops(self, op, device, inp_dims=(), prec=1e-2):
1183211832
self.assertEqual(input1.grad.data, input2.grad.data, atol=prec, rtol=0, exact_dtype=False)
1183311833

1183411834
@onlyCUDA
11835-
@skipCUDAIfNotRocm
1183611835
def test_activations_bfloat16(self, device):
1183711836
self._test_bfloat16_ops(torch.nn.ReLU(), device, inp_dims=(5), prec=1e-2)
1183811837
self._test_bfloat16_ops(torch.nn.Threshold(0.1, 20), device, inp_dims=(5), prec=1e-2)

0 commit comments

Comments
 (0)