Skip to content

Commit 0b6b735

Browse files
kshitij12345facebook-github-bot
authored andcommittedSep 24, 2020
[fix] type promotion atan2 (pytorch#43466)
Summary: Fixes pytorch#43360 Pull Request resolved: pytorch#43466 Reviewed By: malfet Differential Revision: D23834928 Pulled By: mruberry fbshipit-source-id: 2e7e0b4fcf1a846efc171c275d65a6daffd3c631
1 parent 6a2e9eb commit 0b6b735

File tree

3 files changed

+57
-5
lines changed

3 files changed

+57
-5
lines changed
 

‎aten/src/ATen/native/BinaryOps.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -390,14 +390,16 @@ Tensor rsub(const Tensor& self, const Tensor& other, Scalar alpha) {
390390
}
391391

392392
Tensor& atan2_out(Tensor& result, const Tensor& self, const Tensor& other) {
393-
auto iter = TensorIterator::binary_op(result, self, other);
393+
auto iter = TensorIterator::binary_float_op(result, self, other);
394394
atan2_stub(iter.device_type(), iter);
395395
return result;
396396
}
397397

398398
Tensor atan2(const Tensor& self, const Tensor& other) {
399-
Tensor result = at::empty({0}, self.options());
400-
return native::atan2_out(result, self, other);
399+
Tensor result;
400+
auto iter = TensorIterator::binary_float_op(result, self, other);
401+
atan2_stub(iter.device_type(), iter);
402+
return iter.output();
401403
}
402404

403405
Tensor& atan2_(Tensor& self, const Tensor& other) {

‎test/test_torch.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -19672,10 +19672,20 @@ def test_movedim_view(self, device):
1967219672
torch.int8, torch.short, torch.int, torch.long
1967319673
]
1967419674

19675+
_integer_types = [
19676+
torch.uint8, torch.int8, torch.int16,
19677+
torch.int32, torch.int64
19678+
]
19679+
1967519680
_cpu_types: List[torch.dtype] = []
1967619681

1967719682
_unsigned_types = [torch.uint8]
1967819683

19684+
# Binary Float Ops
19685+
# Operators which use TensorIterator::binary_float_op
19686+
# These Ops promote integer inputs to Float.
19687+
binary_float_ops_inplace = ['atan2_', 'div_']
19688+
1967919689
# Helper values and functions for producing tensors and scalars to use in tensor op tests.
1968019690
# Tensor dimension sizes (Small, Medium, Large, Giant)
1968119691
_S = 5
@@ -19896,7 +19906,7 @@ def inner(self, device, dtype):
1989619906
lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_1d(t, d), _medium_1d(t, d)],
1989719907
1e-2, 1e-1, 1e-4, _float_types2, _cpu_types, True,
1989819908
[_wrap_maybe_warns("This overload of addr_? is deprecated")]),
19899-
('atan2', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-2, 1e-5, 1e-5, _float_types),
19909+
('atan2', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-2, 1e-5, 1e-5, _types, _types_no_half),
1990019910
('angle', '', _small_3d, lambda t, d: [], 0, 0, 0, _types_no_half, [torch.bfloat16], False),
1990119911
('fmod', 'value', _small_3d, lambda t, d: [3], 1e-3),
1990219912
('fmod', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-3),
@@ -20188,6 +20198,15 @@ def fn(self, device, dtype) -> None:
2018820198
(isinstance(arg, torch.Tensor) and arg.dtype == torch.float) else arg
2018920199
for arg in device_args]
2019020200

20201+
# Special case for binary float ops (binary ops that promote int to float)
20202+
if op_str in binary_float_ops_inplace and \
20203+
'inplace' in subtest_str and dtype in _integer_types:
20204+
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to "):
20205+
cpu_result = getattr(cpu_tensor, op_str)(*cpu_args)
20206+
with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to "):
20207+
device_result = getattr(device_tensor, op_str)(*device_args)
20208+
return # Nothing more to check
20209+
2019120210
# Runs the tensor op on CPU and device
2019220211
cpu_result = getattr(cpu_tensor, op_str)(*cpu_args)
2019320212
device_result = getattr(device_tensor, op_str)(*device_args)

‎test/test_type_promotion.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests,
88
TEST_NUMPY, torch_to_numpy_dtype_dict)
99
from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyOnCPUAndCUDA,
10-
dtypes, onlyCPU)
10+
dtypes, dtypesIfCUDA, onlyCPU)
1111

1212
if TEST_NUMPY:
1313
import numpy as np
@@ -958,6 +958,37 @@ def test_computation_ignores_out(self, device):
958958
self.assertEqual(result, a - b, exact_dtype=False)
959959
self.assertNotEqual(result, a.double() - b, exact_dtype=False)
960960

961+
@dtypesIfCUDA(*itertools.product(torch.testing.get_all_dtypes(include_bfloat16=False, include_complex=False),
962+
torch.testing.get_all_dtypes(include_bfloat16=False, include_complex=False)))
963+
@dtypes(*itertools.product(torch.testing.get_all_dtypes(include_half=False, include_bfloat16=False,
964+
include_complex=False),
965+
torch.testing.get_all_dtypes(include_half=False, include_bfloat16=False,
966+
include_complex=False)))
967+
def test_atan2_type_promotion(self, device, dtypes):
968+
dtype1, dtype2 = dtypes
969+
default_float = torch.get_default_dtype()
970+
971+
def is_int(dtype):
972+
return dtype in torch.testing.get_all_int_dtypes() + [torch.bool]
973+
974+
def is_float(dtype):
975+
return dtype in torch.testing.get_all_fp_dtypes(include_half=True, include_bfloat16=False)
976+
977+
def get_binary_float_result_type(x, y):
978+
dtype1 = x.dtype
979+
dtype2 = y.dtype
980+
if is_float(dtype1) and is_float(dtype2):
981+
return torch.result_type(x, y)
982+
elif is_float(dtype1) and is_int(dtype2):
983+
return dtype1
984+
elif is_int(dtype1) and is_float(dtype2):
985+
return dtype2
986+
elif is_int(dtype1) and is_int(dtype2):
987+
return default_float
988+
989+
x = torch.tensor(1, dtype=dtype1, device=device)
990+
y = torch.tensor(2, dtype=dtype2, device=device)
991+
self.assertEqual(get_binary_float_result_type(x, y), torch.atan2(x, y).dtype)
961992

962993
instantiate_device_type_tests(TestTypePromotion, globals())
963994

0 commit comments

Comments
 (0)
Please sign in to comment.