@@ -19672,10 +19672,20 @@ def test_movedim_view(self, device):
19672
19672
torch.int8, torch.short, torch.int, torch.long
19673
19673
]
19674
19674
19675
+ _integer_types = [
19676
+ torch.uint8, torch.int8, torch.int16,
19677
+ torch.int32, torch.int64
19678
+ ]
19679
+
19675
19680
_cpu_types: List[torch.dtype] = []
19676
19681
19677
19682
_unsigned_types = [torch.uint8]
19678
19683
19684
+ # Binary Float Ops
19685
+ # Operators which use TensorIterator::binary_float_op
19686
+ # These Ops promote integer inputs to Float.
19687
+ binary_float_ops_inplace = ['atan2_', 'div_']
19688
+
19679
19689
# Helper values and functions for producing tensors and scalars to use in tensor op tests.
19680
19690
# Tensor dimension sizes (Small, Medium, Large, Giant)
19681
19691
_S = 5
@@ -19896,7 +19906,7 @@ def inner(self, device, dtype):
19896
19906
lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_1d(t, d), _medium_1d(t, d)],
19897
19907
1e-2, 1e-1, 1e-4, _float_types2, _cpu_types, True,
19898
19908
[_wrap_maybe_warns("This overload of addr_? is deprecated")]),
19899
- ('atan2', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-2, 1e-5, 1e-5, _float_types ),
19909
+ ('atan2', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-2, 1e-5, 1e-5, _types, _types_no_half ),
19900
19910
('angle', '', _small_3d, lambda t, d: [], 0, 0, 0, _types_no_half, [torch.bfloat16], False),
19901
19911
('fmod', 'value', _small_3d, lambda t, d: [3], 1e-3),
19902
19912
('fmod', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-3),
@@ -20188,6 +20198,15 @@ def fn(self, device, dtype) -> None:
20188
20198
(isinstance(arg, torch.Tensor) and arg.dtype == torch.float) else arg
20189
20199
for arg in device_args]
20190
20200
20201
+ # Special case for binary float ops (binary ops that promote int to float)
20202
+ if op_str in binary_float_ops_inplace and \
20203
+ 'inplace' in subtest_str and dtype in _integer_types:
20204
+ with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to "):
20205
+ cpu_result = getattr(cpu_tensor, op_str)(*cpu_args)
20206
+ with self.assertRaisesRegex(RuntimeError, "result type Float can't be cast to "):
20207
+ device_result = getattr(device_tensor, op_str)(*device_args)
20208
+ return # Nothing more to check
20209
+
20191
20210
# Runs the tensor op on CPU and device
20192
20211
cpu_result = getattr(cpu_tensor, op_str)(*cpu_args)
20193
20212
device_result = getattr(device_tensor, op_str)(*device_args)
0 commit comments