Skip to content

Commit 95df865

Browse files
Mike Ruberryfacebook-github-bot
Mike Ruberry
authored andcommittedSep 25, 2020
Enables test linalg (pytorch#45278)
Summary: Fixes pytorch#45271. Pull Request resolved: pytorch#45278 Reviewed By: ngimel Differential Revision: D23926124 Pulled By: mruberry fbshipit-source-id: 26692597f9a1988e5fa846f97b8430c3689cac27
1 parent bdf329e commit 95df865

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed
 

‎test/run_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
'test_foreach',
4242
'test_indexing',
4343
'test_jit',
44+
'test_linalg',
4445
'test_logging',
4546
'test_mkldnn',
4647
'test_multiprocessing',

‎test/test_linalg.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from math import inf, nan, isnan
66

77
from torch.testing._internal.common_utils import \
8-
(TestCase, run_tests, TEST_NUMPY)
8+
(TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN)
99
from torch.testing._internal.common_device_type import \
1010
(instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack)
1111
from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args
@@ -56,11 +56,12 @@ def test_det(self, device, dtype):
5656

5757
# NOTE: det requires a 2D+ tensor
5858
t = torch.randn(1, device=device, dtype=dtype)
59-
with self.assertRaises(IndexError):
59+
with self.assertRaises(RuntimeError):
6060
op(t)
6161

6262
# This test confirms that torch.linalg.norm's dtype argument works
6363
# as expected, according to the function's documentation
64+
@skipCUDAIfNoMagma
6465
def test_norm_dtype(self, device):
6566
def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype):
6667
msg = (
@@ -154,6 +155,7 @@ def run_test_case(input, p, dim, keepdim):
154155

155156
# This test compares torch.linalg.norm and numpy.linalg.norm to ensure that
156157
# their matrix norm results match
158+
@skipCUDAIfNoMagma
157159
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
158160
@dtypes(torch.float, torch.double)
159161
def test_norm_matrix(self, device, dtype):
@@ -400,6 +402,8 @@ def gen_error_message(input_size, ord, keepdim, dim=None):
400402

401403
# Test that linal.norm gives the same result as numpy when inputs
402404
# contain extreme values (inf, -inf, nan)
405+
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
406+
@unittest.skipIf(IS_MACOS, "Skipped on MacOS!")
403407
@skipCUDAIfNoMagma
404408
@skipCPUIfNoLapack
405409
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
@@ -440,14 +444,14 @@ def is_broken_matrix_norm_case(ord, x):
440444
result_n = np.linalg.norm(x_n, ord=ord)
441445

442446
if is_broken_matrix_norm_case(ord, x):
443-
self.assertNotEqual(result, result_n, msg=msg)
447+
continue
444448
else:
445449
self.assertEqual(result, result_n, msg=msg)
446450

447451
# Test degenerate shape results match numpy for linalg.norm vector norms
448452
@skipCUDAIfNoMagma
449453
@skipCPUIfNoLapack
450-
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
454+
@unittest.skipIf(TEST_WITH_ASAN, "Skipped on ASAN since it checks for undefined behavior.")
451455
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
452456
def test_norm_vector_degenerate_shapes(self, device, dtype):
453457
def run_test_case(input, ord, dim, keepdim, should_error):

0 commit comments

Comments
 (0)
Please sign in to comment.