Skip to content

Commit a303fd2

Browse files
xuhdevfacebook-github-bot
authored andcommittedJun 30, 2020
Let exp support complex types on CUDA and enable device/dtype in complex tests (pytorch#39087)
Summary: Pull Request resolved: pytorch#39087 Differential Revision: D22169697 Pulled By: anjali411 fbshipit-source-id: 4866b7be6742508cc40540ed1ac811f005531d8b
1 parent ef5a314 commit a303fd2

File tree

3 files changed

+105
-24
lines changed

3 files changed

+105
-24
lines changed
 

‎aten/src/ATen/native/cuda/UnaryOpsKernel.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/native/DispatchStub.h>
88
#include <ATen/native/TensorIterator.h>
99
#include <ATen/native/cuda/Math.cuh>
10+
#include <c10/util/complex.h>
1011

1112
namespace at { namespace native {
1213

@@ -25,7 +26,7 @@ void bitwise_not_kernel_cuda(TensorIterator& iter) {
2526
}
2627

2728
void exp_kernel_cuda(TensorIterator& iter) {
28-
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exp_cuda", [&]() {
29+
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "exp_cuda", [&]() {
2930
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "exp_cuda", [&] {
3031
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
3132
return ::exp(a);

‎test/test_complex.py

+14-23
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,27 @@
1-
import math
21
import torch
3-
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY
4-
import unittest
5-
6-
if TEST_NUMPY:
7-
import numpy as np
2+
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
3+
from torch.testing._internal.common_utils import TestCase, run_tests
84

95
devices = (torch.device('cpu'), torch.device('cuda:0'))
106

117

128
class TestComplexTensor(TestCase):
13-
def test_to_list_with_complex_64(self):
9+
@dtypes(*torch.testing.get_all_complex_dtypes())
10+
def test_to_list(self, device, dtype):
1411
# test that the complex float tensor has expected values and
1512
# there's no garbage value in the resultant list
16-
self.assertEqual(torch.zeros((2, 2), dtype=torch.complex64).tolist(), [[0j, 0j], [0j, 0j]])
17-
18-
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
19-
def test_exp(self):
20-
def exp_fn(dtype):
21-
a = torch.tensor(1j, dtype=dtype) * torch.arange(18) / 3 * math.pi
22-
expected = np.exp(a.numpy())
23-
actual = torch.exp(a)
24-
self.assertEqual(actual, torch.from_numpy(expected))
13+
self.assertEqual(torch.zeros((2, 2), device=device, dtype=dtype).tolist(), [[0j, 0j], [0j, 0j]])
2514

26-
exp_fn(torch.complex64)
27-
exp_fn(torch.complex128)
28-
29-
def test_dtype_inference(self):
15+
@dtypes(torch.float32, torch.float64)
16+
def test_dtype_inference(self, device, dtype):
3017
# issue: https://github.com/pytorch/pytorch/issues/36834
31-
torch.set_default_dtype(torch.double)
32-
x = torch.tensor([3., 3. + 5.j])
33-
self.assertEqual(x.dtype, torch.cdouble)
18+
default_dtype = torch.get_default_dtype()
19+
torch.set_default_dtype(dtype)
20+
x = torch.tensor([3., 3. + 5.j], device=device)
21+
torch.set_default_dtype(default_dtype)
22+
self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat)
23+
24+
instantiate_device_type_tests(TestComplexTensor, globals())
3425

3526
if __name__ == '__main__':
3627
run_tests()

‎test/test_torch.py

+89
Original file line numberDiff line numberDiff line change
@@ -11063,6 +11063,95 @@ def test_exponential(self, device, dtype):
1106311063
with self.assertRaises(RuntimeError):
1106411064
torch.empty((1,), device=device, dtype=dtype).exponential_(-0.5)
1106511065

11066+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
11067+
@dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False) +
11068+
torch.testing.get_all_complex_dtypes()))
11069+
@dtypesIfCUDA(*(torch.testing.get_all_fp_dtypes(include_half=True) +
11070+
torch.testing.get_all_complex_dtypes()))
11071+
def test_exp(self, device, dtype):
11072+
for v in (2, -2) + ((1j, 1 + 1j) if dtype.is_complex else ()):
11073+
if dtype == torch.bfloat16:
11074+
# Currently multiply a bfloat16 type with floating-point causes error:
11075+
# RuntimeError: dtype != ScalarType::Undefined INTERNAL ASSERT FAILED at
11076+
# "/pytorch/aten/src/ATen/native/TensorIterator.cpp":125, please report a bug to PyTorch.
11077+
# We skip bfloat16 for now, but we should fix it. https://github.com/pytorch/pytorch/issues/40580
11078+
if self.device_type == 'cpu' or self.device_type == 'cuda':
11079+
with self.assertRaises(RuntimeError):
11080+
torch.tensor(v, dtype=dtype, device=device) * torch.arange(18, device=device)
11081+
return
11082+
elif self.device_type == 'xla':
11083+
# Error:
11084+
# Traceback (most recent call last):
11085+
# File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py",
11086+
# line 241, in instantiated_test
11087+
# result = test(self, device_arg, dtype)
11088+
# File "/var/lib/jenkins/workspace/xla/test/../../test/test_torch.py", line 11062, in test_exp
11089+
# self.compare_with_numpy(torch.exp, np.exp, a)
11090+
# File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 878,
11091+
# in compare_with_numpy
11092+
# a = tensor_like.detach().cpu().numpy()
11093+
# TypeError: Got unsupported ScalarType BFloat16
11094+
return
11095+
11096+
a = torch.tensor(v, dtype=dtype, device=device) * torch.arange(18, device=device) / 3 * math.pi
11097+
a = a.to(dtype)
11098+
self.compare_with_numpy(torch.exp, np.exp, a)
11099+
11100+
if dtype.is_complex:
11101+
inf_real_zero_imag_in = torch.tensor(complex(float('inf'), 0), device=device, dtype=dtype)
11102+
inf_real_zero_imag_out = torch.exp(inf_real_zero_imag_in).item()
11103+
self.assertTrue(math.isinf(inf_real_zero_imag_out.real))
11104+
if self.device_type == 'cpu':
11105+
if not IS_WINDOWS: # Windows tests don't show some bugs consistently
11106+
# This is incorrect. It should be zero. Need fix!
11107+
# https://github.com/pytorch/pytorch/issues/40590
11108+
self.assertNotEqual(inf_real_zero_imag_out.imag, 0)
11109+
# This is incorrect. They should equal. Need fix!
11110+
# https://github.com/pytorch/pytorch/issues/40590
11111+
with self.assertRaises(AssertionError):
11112+
self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in)
11113+
else:
11114+
self.assertEqual(inf_real_zero_imag_out.imag, 0, atol=0, rtol=0)
11115+
self.compare_with_numpy(torch.exp, np.exp, inf_real_zero_imag_in)
11116+
11117+
zero_real_inf_imag_in = torch.tensor(complex(0, float('inf')), device=device, dtype=dtype)
11118+
zero_real_inf_imag_out = torch.exp(zero_real_inf_imag_in).item()
11119+
self.assertTrue(math.isnan(zero_real_inf_imag_out.real))
11120+
self.assertTrue(math.isnan(zero_real_inf_imag_out.imag))
11121+
# Ensure we are notified when NumPy changes its behavior
11122+
self.compare_with_numpy(torch.exp, np.exp, zero_real_inf_imag_in)
11123+
11124+
inf_real_imag_in = torch.tensor(complex(float('inf'), float('inf')), device=device, dtype=dtype)
11125+
inf_real_imag_out = torch.exp(inf_real_imag_in).item()
11126+
if self.device_type == 'cpu':
11127+
if not IS_WINDOWS: # Windows tests don't show some bugs consistently
11128+
# This is incorrect. Need fix! https://github.com/pytorch/pytorch/issues/40590
11129+
with self.assertRaises(AssertionError):
11130+
self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in)
11131+
else:
11132+
self.assertTrue(math.isinf(inf_real_imag_out.real))
11133+
self.assertTrue(math.isnan(inf_real_imag_out.imag))
11134+
self.compare_with_numpy(torch.exp, np.exp, inf_real_imag_in)
11135+
11136+
inf_real_nan_imag_in = torch.tensor(complex(float('inf'), float('nan')), device=device, dtype=dtype)
11137+
inf_real_nan_imag_out = torch.exp(inf_real_nan_imag_in).item()
11138+
if self.device_type == 'cpu':
11139+
if not IS_WINDOWS: # Windows tests don't show some bugs consistently
11140+
# This is incorrect. It should be inf. Need fix! https://github.com/pytorch/pytorch/issues/40590
11141+
with self.assertRaises(AssertionError):
11142+
self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in)
11143+
else:
11144+
self.assertTrue(math.isinf(inf_real_nan_imag_out.real))
11145+
self.assertTrue(math.isnan(inf_real_nan_imag_out.imag))
11146+
self.compare_with_numpy(torch.exp, np.exp, inf_real_nan_imag_in)
11147+
11148+
nan_real_inf_imag_in = torch.tensor(complex(float('nan'), float('inf')), device=device, dtype=dtype)
11149+
nan_real_inf_imag_out = torch.exp(nan_real_inf_imag_in).item()
11150+
self.assertTrue(math.isnan(nan_real_inf_imag_out.real))
11151+
self.assertTrue(math.isnan(nan_real_inf_imag_out.imag))
11152+
# Ensure we are notified when NumPy changes its behavior
11153+
self.compare_with_numpy(torch.exp, np.exp, nan_real_inf_imag_in)
11154+
1106611155
@skipIfNoSciPy
1106711156
@dtypes(*torch.testing.get_all_fp_dtypes())
1106811157
def test_uniform_kstest(self, device, dtype):

0 commit comments

Comments
 (0)
Please sign in to comment.