Skip to content

Commit c48f511

Browse files
Mike Ruberryfacebook-github-bot
Mike Ruberry
authored andcommittedSep 11, 2020
Moves some of TestTorchMathOps to OpInfos (pytorch#44277)
Summary: This PR fixes three OpInfo-related bugs and moves some functions from TestTorchMathOps to be tested using the OpInfo pattern. The bugs are: - A skip test path in test_ops.py incorrectly formatted its string argument - Decorating the tests in common_device_type.py was incorrectly always applying decorators to the original test, not the op-specific variant of the test. This could cause the same decorator to be applied multiple times, overriding past applications. - make_tensor was incorrectly constructing tensors in some cases The functions moved are: - asin - asinh - sinh - acosh - tan - atan - atanh - tanh - log - log10 - log1p - log2 In a follow-up PR more or all of the remaining functions in TestTorchMathOps will be refactored as OpInfo-based tests. Pull Request resolved: pytorch#44277 Reviewed By: mrshenli, ngimel Differential Revision: D23617361 Pulled By: mruberry fbshipit-source-id: edb292947769967de9383f6a84eb327f027509e0
1 parent 2e744b1 commit c48f511

File tree

6 files changed

+195
-70
lines changed

6 files changed

+195
-70
lines changed
 

‎test/test_ops.py

+4
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def test_method_grad(self, device, dtype, op):
102102
@dtypes(torch.double, torch.cdouble)
103103
@ops(op_db)
104104
def test_inplace_grad(self, device, dtype, op):
105+
if not op.test_inplace_grad:
106+
self.skipTest("Skipped! Inplace gradcheck marked to skip.")
105107
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
106108

107109
# Test that gradients of gradients are computed correctly
@@ -118,6 +120,8 @@ def test_method_gradgrad(self, device, dtype, op):
118120
@dtypes(torch.double, torch.cdouble)
119121
@ops(op_db)
120122
def test_inplace_gradgrad(self, device, dtype, op):
123+
if not op.test_inplace_grad:
124+
self.skipTest("Skipped! Inplace gradgradcheck marked to skip.")
121125
self._gradgrad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
122126

123127

‎test/test_torch.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -20279,19 +20279,7 @@ def __init__(self,
2027920279
self.dtypes = dtypes
2028020280
self.replace_inf_with_nan = replace_inf_with_nan
2028120281

20282-
torch_op_tests = [_TorchMathTestMeta('asin', reffn='arcsin'),
20283-
_TorchMathTestMeta('asinh', reffn='arcsinh'),
20284-
_TorchMathTestMeta('sinh'),
20285-
_TorchMathTestMeta('acosh', reffn='arccosh'),
20286-
_TorchMathTestMeta('tan'),
20287-
_TorchMathTestMeta('atan', reffn='arctan'),
20288-
_TorchMathTestMeta('atanh', reffn='arctanh'),
20289-
_TorchMathTestMeta('tanh'),
20290-
_TorchMathTestMeta('log'),
20291-
_TorchMathTestMeta('log10'),
20292-
_TorchMathTestMeta('log1p'),
20293-
_TorchMathTestMeta('log2'),
20294-
_TorchMathTestMeta('sqrt'),
20282+
torch_op_tests = [_TorchMathTestMeta('sqrt'),
2029520283
_TorchMathTestMeta('erf', ref_backend='scipy'),
2029620284
_TorchMathTestMeta('erfc', ref_backend='scipy'),
2029720285
_TorchMathTestMeta('exp'),

‎test/test_unary_ufuncs.py

+42-38
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@
1919
if TEST_NUMPY:
2020
import numpy as np
2121

22+
# Tests for unary "universal functions (ufuncs)" that accept a single
23+
# tensor and have common properties like:
24+
# - they are elementwise functions
25+
# - the input shape is the output shape
26+
# - they typically have method and inplace variants
27+
# - they typically support the out kwarg
28+
# - they typically have NumPy or SciPy references
29+
30+
# See NumPy's universal function documentation
31+
# (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
32+
# about the concept of ufuncs.
33+
34+
# Functions tested here:
35+
#
36+
2237
# Interesting values and extremal values for different dtypes
2338
_unsigned_int_vals = (0, 1, 55, 127)
2439
_int_vals = (0, -1, 1, -55, 55, -127, 127, -128, 128)
@@ -117,50 +132,13 @@ def generate_numeric_tensors(device, dtype, *,
117132

118133
return chain(empty_tensors, scalar_tensors, small_tensors, (medium_tensor,), (large_tensor,))
119134

120-
# Tests for unary "universal functions (ufuncs)" that accept a single
121-
# tensor and have common properties like:
122-
# - they are elementwise functions
123-
# - the input shape is the output shape
124-
# - they typically have method and inplace variants
125-
# - they typically support the out kwarg
126-
# - they typically have NumPy or SciPy references
127-
128-
# See NumPy's universal function documentation
129-
# (https://numpy.org/doc/1.18/reference/ufuncs.html) for more details
130-
# about the concept of ufuncs.
131-
132135
# TODO: port test_unary_out_op_mem_overlap
133136
# TODO: add out= tests (different devices, dtypes, mismatched sizes,
134137
# correct sizes, 0 size, broadcasted out)
135138
# TODO: add test for inplace variants erroring on broadcasted inputs
136139
class TestUnaryUfuncs(TestCase):
137140
exact_dtype = True
138141

139-
# Helper for comparing torch tensors and numpy arrays
140-
# TODO: should this or assertEqual also validate that strides are equal?
141-
def assertEqualHelper(self, actual, expected, *, dtype, exact_dtype=True, **kwargs):
142-
assert isinstance(actual, torch.Tensor)
143-
144-
# Some NumPy functions return scalars, not arrays
145-
if isinstance(expected, Number):
146-
self.assertEqual(actual.item(), expected)
147-
elif isinstance(expected, np.ndarray):
148-
# Handles exact dtype comparisons between arrays and tensors
149-
if exact_dtype:
150-
# Allows array dtype to be float32 when comparing with bfloat16 tensors
151-
# since NumPy doesn't support the bfloat16 dtype
152-
if expected.dtype == np.float32:
153-
assert actual.dtype in (torch.bfloat16, torch.float32)
154-
else:
155-
assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype]
156-
157-
self.assertEqual(actual,
158-
torch.from_numpy(expected).to(actual.dtype),
159-
exact_device=False,
160-
**kwargs)
161-
else:
162-
self.assertEqual(actual, expected, exact_device=False, **kwargs)
163-
164142
# Tests bool tensor negation raises the correct error
165143
def test_neg_error_message(self, device):
166144
msg = ("Negation, the `\\-` operator, on a bool tensor is not supported."
@@ -234,6 +212,32 @@ def _fn(t):
234212
actual = alt(t.clone())
235213
self.assertEqual(actual, expected, rtol=0, atol=0)
236214

215+
# Helper for comparing torch tensors and numpy arrays
216+
# TODO: should this or assertEqual also validate that strides are equal?
217+
def assertEqualHelper(self, actual, expected, msg, *, dtype, exact_dtype=True, **kwargs):
218+
assert isinstance(actual, torch.Tensor)
219+
220+
# Some NumPy functions return scalars, not arrays
221+
if isinstance(expected, Number):
222+
self.assertEqual(actual.item(), expected)
223+
elif isinstance(expected, np.ndarray):
224+
# Handles exact dtype comparisons between arrays and tensors
225+
if exact_dtype:
226+
# Allows array dtype to be float32 when comparing with bfloat16 tensors
227+
# since NumPy doesn't support the bfloat16 dtype
228+
if expected.dtype == np.float32:
229+
assert actual.dtype in (torch.bfloat16, torch.float32)
230+
else:
231+
assert expected.dtype == torch_to_numpy_dtype_dict[actual.dtype]
232+
233+
self.assertEqual(actual,
234+
torch.from_numpy(expected).to(actual.dtype),
235+
msg,
236+
exact_device=False,
237+
**kwargs)
238+
else:
239+
self.assertEqual(actual, expected, msg, exact_device=False, **kwargs)
240+
237241
# Tests that the function and its (array-accepting) reference produce the same
238242
# values on a range of tensors, including empty tensors, scalar tensors,
239243
# 1D tensors and a large 2D tensor with interesting and extremal values
@@ -266,7 +270,7 @@ def test_reference_numerics(self, device, dtype, op):
266270
else:
267271
msg = None
268272

269-
self.assertEqualHelper(actual, expected, dtype=dtype, msg=msg)
273+
self.assertEqualHelper(actual, expected, msg, dtype=dtype)
270274

271275
# Tests for testing (dis)contiguity consistency
272276

‎torch/testing/_internal/common_device_type.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -229,33 +229,45 @@ def instantiate_test(cls, name, test, *, generic_cls=None):
229229

230230
def instantiate_test_helper(cls, name, *, test, dtype, op):
231231

232-
# wraps test with op decorators
232+
# Constructs the test's name
233+
test_name = _construct_test_name(name, op, cls.device_type, dtype)
234+
235+
# wraps instantiated test with op decorators
236+
# NOTE: test_wrapper exists because we don't want to apply
237+
# op-specific decorators to the original test.
238+
# Test-sepcific decorators are applied to the original test,
239+
# however.
233240
if op is not None and op.decorators is not None:
241+
@wraps(test)
242+
def test_wrapper(*args, **kwargs):
243+
return test(*args, **kwargs)
244+
234245
for decorator in op.decorators:
235-
test = decorator(test)
246+
test_wrapper = decorator(test_wrapper)
236247

237-
# Constructs the test's name
238-
test_name = _construct_test_name(name, op, cls.device_type, dtype)
248+
test_fn = test_wrapper
249+
else:
250+
test_fn = test
239251

240252
# Constructs the test
241253
@wraps(test)
242-
def instantiated_test(self, name=name, test=test, dtype=dtype, op=op):
254+
def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op):
243255
if op is not None and op.should_skip(generic_cls.__name__, name,
244256
self.device_type, dtype):
245257
self.skipTest("Skipped!")
246258

247259
device_arg = cls.get_primary_device()
248-
if hasattr(test, 'num_required_devices'):
260+
if hasattr(test_fn, 'num_required_devices'):
249261
device_arg = cls.get_all_devices()
250262

251263
# Sets precision and runs test
252264
# Note: precision is reset after the test is run
253265
guard_precision = self.precision
254266
try:
255-
self.precision = self._get_precision_override(test, dtype)
267+
self.precision = self._get_precision_override(test_fn, dtype)
256268
args = (device_arg, dtype, op)
257269
args = (arg for arg in args if arg is not None)
258-
result = test(self, *args)
270+
result = test_fn(self, *args)
259271
finally:
260272
self.precision = guard_precision
261273

‎torch/testing/_internal/common_methods_invocations.py

+124-7
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.testing import \
1111
(make_non_contiguous,
1212
_dispatch_dtypes,
13-
floating_types, floating_types_and,
13+
floating_types, floating_types_and, floating_types_and_half,
1414
floating_and_complex_types, floating_and_complex_types_and,
1515
all_types_and_complex_and)
1616
from torch.testing._internal.common_device_type import \
@@ -62,6 +62,7 @@ def __init__(self,
6262
dtypesIfCPU=None, # dtypes this function is expected to work with on CPU
6363
dtypesIfCUDA=None, # dtypes this function is expected to work with on CUDA
6464
dtypesIfROCM=None, # dtypes this function is expected to work with on ROCM
65+
test_inplace_grad=True, # whether to gradcheck and gradgradcheck the inplace variant
6566
skips=tuple(), # information about which tests to skip
6667
decorators=None): # decorators to apply to generated tests
6768
# Validates the dtypes are generated from the dispatch-related functions
@@ -83,6 +84,8 @@ def __init__(self,
8384
inplace_name = name + "_"
8485
self.inplace_variant = getattr(torch.Tensor, inplace_name) if hasattr(torch.Tensor, name) else None
8586

87+
self.test_inplace_grad = test_inplace_grad
88+
8689
self.skips = skips
8790
self.decorators = decorators
8891

@@ -197,7 +200,7 @@ def sample_inputs(self, device, dtype, requires_grad=False):
197200

198201

199202

200-
# Operator database
203+
# Operator database (sorted alphabetically)
201204
op_db = [
202205
# NOTE: CPU complex acos produces incorrect outputs (https://github.com/pytorch/pytorch/issues/42952)
203206
UnaryUfuncInfo('acos',
@@ -212,13 +215,56 @@ def sample_inputs(self, device, dtype, requires_grad=False):
212215
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
213216
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
214217
dtypes=[torch.cfloat, torch.cdouble], active_if=IS_WINDOWS),
218+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
219+
device_type='cuda', dtypes=[torch.float16],
220+
active_if=TEST_WITH_ROCM),
215221
SkipInfo('TestGradients', 'test_fn_grad',
216222
dtypes=[torch.cdouble], active_if=IS_WINDOWS),
217223
SkipInfo('TestGradients', 'test_method_grad',
218224
dtypes=[torch.cdouble], active_if=IS_WINDOWS),
219225
SkipInfo('TestGradients', 'test_inplace_grad',
220226
dtypes=[torch.cdouble], active_if=IS_WINDOWS),
221227
)),
228+
# NOTE: the derivative for inplace acosh is not implemented
229+
UnaryUfuncInfo('acosh',
230+
ref=np.arccosh,
231+
domain=(1, float('inf')),
232+
dtypesIfCPU=floating_types(),
233+
dtypesIfCUDA=floating_types_and_half(),
234+
test_inplace_grad=False),
235+
UnaryUfuncInfo('asin',
236+
ref=np.arcsin,
237+
domain=(-1, 1),
238+
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
239+
skips=(
240+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
241+
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
242+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
243+
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
244+
active_if=IS_WINDOWS),
245+
)),
246+
# NOTE: derivative for inplace asinh is not implemented
247+
UnaryUfuncInfo('asinh',
248+
ref=np.arcsinh,
249+
dtypesIfCPU=floating_types(),
250+
dtypesIfCUDA=floating_types_and_half(),
251+
test_inplace_grad=False),
252+
UnaryUfuncInfo('atan',
253+
ref=np.arctan,
254+
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
255+
skips=(
256+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
257+
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble]),
258+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
259+
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
260+
active_if=IS_WINDOWS),
261+
)),
262+
UnaryUfuncInfo('atanh',
263+
ref=np.arctanh,
264+
domain=(-1, 1),
265+
dtypesIfCPU=floating_types(),
266+
dtypesIfCUDA=floating_types_and_half(),
267+
test_inplace_grad=False),
222268
UnaryUfuncInfo('cos',
223269
ref=np.cos,
224270
dtypesIfCUDA=floating_and_complex_types_and(torch.half, torch.bfloat16),
@@ -241,6 +287,49 @@ def sample_inputs(self, device, dtype, requires_grad=False):
241287
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics', device_type='cpu',
242288
dtypes=[torch.cfloat, torch.cdouble], active_if=IS_MACOS),
243289
)),
290+
UnaryUfuncInfo('log',
291+
ref=np.log,
292+
domain=(0, float('inf')),
293+
skips=(
294+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
295+
device_type='cpu', dtypes=[torch.bfloat16]),
296+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
297+
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
298+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
299+
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
300+
active_if=IS_WINDOWS),
301+
)),
302+
UnaryUfuncInfo('log10',
303+
ref=np.log10,
304+
domain=(0, float('inf')),
305+
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
306+
skips=(
307+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
308+
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
309+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
310+
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
311+
active_if=IS_WINDOWS),
312+
)),
313+
UnaryUfuncInfo('log1p',
314+
ref=np.log1p,
315+
domain=(-1, float('inf')),
316+
dtypesIfCPU=floating_types_and(torch.bfloat16),
317+
dtypesIfCUDA=floating_types_and_half(),
318+
decorators=(precisionOverride({torch.bfloat16: 1e-1}),)),
319+
UnaryUfuncInfo('log2',
320+
ref=np.log2,
321+
domain=(0, float('inf')),
322+
skips=(
323+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
324+
device_type='cpu', dtypes=[torch.bfloat16]),
325+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
326+
dtypes=[torch.cfloat, torch.cdouble]),
327+
)),
328+
UnaryUfuncInfo('neg',
329+
ref=np.negative,
330+
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
331+
dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
332+
dtypesIfCUDA=all_types_and_complex_and(torch.half)),
244333
UnaryUfuncInfo('sin',
245334
ref=np.sin,
246335
handles_large_floats=False,
@@ -252,11 +341,39 @@ def sample_inputs(self, device, dtype, requires_grad=False):
252341
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
253342
dtypes=[torch.float], active_if=TEST_WITH_ROCM),
254343
)),
255-
UnaryUfuncInfo('neg',
256-
ref=np.negative,
257-
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16),
258-
dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
259-
dtypesIfCUDA=all_types_and_complex_and(torch.half)),
344+
UnaryUfuncInfo('sinh',
345+
ref=np.sinh,
346+
dtypesIfCPU=floating_and_complex_types(),
347+
decorators=(precisionOverride({torch.float16: 1e-2}),),
348+
skips=(
349+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
350+
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
351+
active_if=(IS_MACOS or IS_WINDOWS)),
352+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
353+
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble],
354+
active_if=IS_WINDOWS),
355+
)),
356+
UnaryUfuncInfo('tan',
357+
ref=np.tan,
358+
skips=(
359+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
360+
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
361+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
362+
device_type='cpu', dtypes=[torch.bfloat16]),
363+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
364+
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
365+
active_if=(IS_MACOS or IS_WINDOWS)),
366+
)),
367+
UnaryUfuncInfo('tanh',
368+
ref=np.tanh,
369+
decorators=(precisionOverride({torch.bfloat16: 1e-2}),),
370+
skips=(
371+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
372+
device_type='cuda', dtypes=[torch.cfloat, torch.cdouble]),
373+
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
374+
device_type='cpu', dtypes=[torch.cfloat, torch.cdouble],
375+
active_if=(IS_MACOS or IS_WINDOWS)),
376+
)),
260377
]
261378

262379
# Common operator groupings

0 commit comments

Comments
 (0)
Please sign in to comment.