41
41
from typing import Dict, List, Tuple, Union
42
42
import torch.backends.quantized
43
43
import torch.testing._internal.data
44
- from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, \
44
+ from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, with_tf32_off, \
45
45
_get_torch_cuda_version, TEST_MAGMA
46
46
47
47
48
-
49
48
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
50
49
# sharding on sandcastle. This line silences flake warnings
51
50
load_tests = load_tests
@@ -7003,6 +7002,9 @@ def test_matrix_exp_boundary_cases(self, device, dtype):
7003
7002
@skipCUDAIfNoMagma
7004
7003
@skipCPUIfNoLapack
7005
7004
@dtypes(torch.float, torch.double)
7005
+ # Although tf32 is always disabled on matrix_exp, this test uses matmul,
7006
+ # which has tf32 on by default
7007
+ @with_tf32_off
7006
7008
def test_matrix_exp_analytic(self, device, dtype):
7007
7009
# check zero matrix
7008
7010
x = torch.zeros(20, 20, dtype=dtype, device=device)
@@ -7144,6 +7146,9 @@ def run_test(*n):
7144
7146
@skipCUDAIfNoMagma
7145
7147
@skipCPUIfNoLapack
7146
7148
@dtypes(torch.float, torch.double)
7149
+ # Although tf32 is always disabled on matrix_exp, this test uses matmul,
7150
+ # which has tf32 on by default
7151
+ @with_tf32_off
7147
7152
def test_matrix_exp_compare_with_taylor(self, device, dtype):
7148
7153
7149
7154
def normalize_to_1_operator_norm(sample, desired_norm):
@@ -16471,6 +16476,7 @@ def _test(row_major, incx, incy, lda_tail):
16471
16476
@dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
16472
16477
@dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes())
16473
16478
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
16479
+ @tf32_on_and_off(0.05)
16474
16480
def test_addmm(self, device, dtype):
16475
16481
M = torch.randn(10, 25, device=device).to(dtype)
16476
16482
m1 = torch.randn(10, 50, device=device).to(dtype)
@@ -19832,13 +19838,13 @@ def inner(self, device, dtype):
19832
19838
1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
19833
19839
('addbmm', '', _small_2d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
19834
19840
1e-1, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types,
19835
- _cpu_types, True, [tf32_on_and_off(0.005 )]),
19841
+ _cpu_types, True, [tf32_on_and_off(0.01 )]),
19836
19842
('addbmm', 'scalar', _small_2d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
19837
19843
1e-1, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types, _cpu_types, True,
19838
- [tf32_on_and_off(0.005 ), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
19844
+ [tf32_on_and_off(0.01 ), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
19839
19845
('addbmm', 'two_scalars', _small_2d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
19840
19846
1e-1, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types, _cpu_types, True,
19841
- [tf32_on_and_off(0.005 ), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
19847
+ [tf32_on_and_off(0.01 ), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
19842
19848
('baddbmm', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
19843
19849
1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)),
19844
19850
('baddbmm', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
@@ -19865,26 +19871,26 @@ def inner(self, device, dtype):
19865
19871
[_wrap_maybe_warns("This overload of addcmul_? is deprecated")]),
19866
19872
('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)], 1e-1, 1e-1, 1e-4,
19867
19873
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM),
19868
- _cpu_types, True, [tf32_on_and_off(0.005 )], 0, True),
19874
+ _cpu_types, True, [tf32_on_and_off(0.01 )], 0, True),
19869
19875
('addmm', 'scalar', _medium_2d,
19870
19876
lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)], 1e-1, 1e-1, 1e-4,
19871
19877
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM), _cpu_types, True,
19872
- [tf32_on_and_off(0.005 ), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
19878
+ [tf32_on_and_off(0.01 ), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
19873
19879
('addmm', 'two_scalars', _medium_2d,
19874
19880
lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)], 1e-1, 1e-1, 1e-4,
19875
19881
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM), _cpu_types, True,
19876
- [tf32_on_and_off(0.005 ), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
19882
+ [tf32_on_and_off(0.01 ), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
19877
19883
('addmv', '', _medium_1d, lambda t, d: [_medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4,
19878
19884
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types,
19879
- True, [tf32_on_and_off(0.005) ], 0, True),
19885
+ True, [], 0, True),
19880
19886
('addmv', 'scalar', _medium_1d,
19881
19887
lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4,
19882
19888
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types, True,
19883
- [tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmv_? is deprecated")]),
19889
+ [_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
19884
19890
('addmv', 'two_scalars', _medium_1d,
19885
19891
lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4,
19886
19892
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types, True,
19887
- [tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmv_? is deprecated")]),
19893
+ [_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
19888
19894
('addr', '', _medium_2d, lambda t, d: [_medium_1d(t, d), _medium_1d(t, d)],
19889
19895
1e-2, 1e-1, 1e-4, _float_types2),
19890
19896
('addr', 'scalar', _medium_2d,
0 commit comments