|
5 | 5 | from math import inf, nan, isnan
|
6 | 6 |
|
7 | 7 | from torch.testing._internal.common_utils import \
|
8 |
| - (TestCase, run_tests, TEST_NUMPY) |
| 8 | + (TestCase, run_tests, TEST_NUMPY, IS_MACOS, IS_WINDOWS, TEST_WITH_ASAN) |
9 | 9 | from torch.testing._internal.common_device_type import \
|
10 | 10 | (instantiate_device_type_tests, dtypes, skipCUDAIfNoMagma, skipCPUIfNoLapack)
|
11 | 11 | from torch.testing._internal.jit_metaprogramming_utils import gen_script_fn_and_args
|
@@ -56,11 +56,12 @@ def test_det(self, device, dtype):
|
56 | 56 |
|
57 | 57 | # NOTE: det requires a 2D+ tensor
|
58 | 58 | t = torch.randn(1, device=device, dtype=dtype)
|
59 |
| - with self.assertRaises(IndexError): |
| 59 | + with self.assertRaises(RuntimeError): |
60 | 60 | op(t)
|
61 | 61 |
|
62 | 62 | # This test confirms that torch.linalg.norm's dtype argument works
|
63 | 63 | # as expected, according to the function's documentation
|
| 64 | + @skipCUDAIfNoMagma |
64 | 65 | def test_norm_dtype(self, device):
|
65 | 66 | def run_test_case(input_size, ord, keepdim, from_dtype, to_dtype, compare_dtype):
|
66 | 67 | msg = (
|
@@ -154,6 +155,7 @@ def run_test_case(input, p, dim, keepdim):
|
154 | 155 |
|
155 | 156 | # This test compares torch.linalg.norm and numpy.linalg.norm to ensure that
|
156 | 157 | # their matrix norm results match
|
| 158 | + @skipCUDAIfNoMagma |
157 | 159 | @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
158 | 160 | @dtypes(torch.float, torch.double)
|
159 | 161 | def test_norm_matrix(self, device, dtype):
|
@@ -400,6 +402,8 @@ def gen_error_message(input_size, ord, keepdim, dim=None):
|
400 | 402 |
|
401 | 403 | # Test that linal.norm gives the same result as numpy when inputs
|
402 | 404 | # contain extreme values (inf, -inf, nan)
|
| 405 | + @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") |
| 406 | + @unittest.skipIf(IS_MACOS, "Skipped on MacOS!") |
403 | 407 | @skipCUDAIfNoMagma
|
404 | 408 | @skipCPUIfNoLapack
|
405 | 409 | @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
|
@@ -440,14 +444,14 @@ def is_broken_matrix_norm_case(ord, x):
|
440 | 444 | result_n = np.linalg.norm(x_n, ord=ord)
|
441 | 445 |
|
442 | 446 | if is_broken_matrix_norm_case(ord, x):
|
443 |
| - self.assertNotEqual(result, result_n, msg=msg) |
| 447 | + continue |
444 | 448 | else:
|
445 | 449 | self.assertEqual(result, result_n, msg=msg)
|
446 | 450 |
|
447 | 451 | # Test degenerate shape results match numpy for linalg.norm vector norms
|
448 | 452 | @skipCUDAIfNoMagma
|
449 | 453 | @skipCPUIfNoLapack
|
450 |
| - @unittest.skipIf(not TEST_NUMPY, "Numpy not found") |
| 454 | + @unittest.skipIf(TEST_WITH_ASAN, "Skipped on ASAN since it checks for undefined behavior.") |
451 | 455 | @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
|
452 | 456 | def test_norm_vector_degenerate_shapes(self, device, dtype):
|
453 | 457 | def run_test_case(input, ord, dim, keepdim, should_error):
|
|
0 commit comments