Skip to content

Commit 3f5eee6

Browse files
zasdfgbnmfacebook-github-bot
authored andcommittedSep 24, 2020
Adjust TF32 tests (pytorch#44240)
Summary: - The thresholds of some tests are bumped up. Depending on the random generator, sometimes these tests fail with things like 0.0059 is not smaller than 0.005. I ran `test_nn.py` and `test_torch.py` for 10+ times to check these are no longer flaky. - Add `tf32_on_and_off` to new `matrix_exp` tests. - Disable TF32 on test suites other than `test_nn.py` and `test_torch.py` cc: ptrblck Pull Request resolved: pytorch#44240 Reviewed By: mruberry Differential Revision: D23882498 Pulled By: ngimel fbshipit-source-id: 44a9ec08802c93a2efaf4e01d7487222478b6df8
1 parent b8eab8c commit 3f5eee6

File tree

11 files changed

+124
-28
lines changed

11 files changed

+124
-28
lines changed
 

‎aten/src/ATen/Context.cpp

+23
Original file line numberDiff line numberDiff line change
@@ -230,4 +230,27 @@ Allocator* getCPUAllocator() {
230230
return getTHDefaultAllocator();
231231
}
232232

233+
// override_allow_tf32_flag = true
234+
// means the allow_tf32 flags are overrided and tf32 is force disabled
235+
// override_allow_tf32_flag = false
236+
// means the original allow_tf32 flags are followed
237+
thread_local bool override_allow_tf32_flag = false;
238+
239+
NoTF32Guard::NoTF32Guard() {
240+
if (!override_allow_tf32_flag) {
241+
changed = true;
242+
override_allow_tf32_flag = true;
243+
}
244+
}
245+
246+
NoTF32Guard::~NoTF32Guard() {
247+
if (changed) {
248+
override_allow_tf32_flag = false;
249+
}
250+
}
251+
252+
bool NoTF32Guard::should_disable_tf32() {
253+
return override_allow_tf32_flag;
254+
}
255+
233256
} // namespace at

‎aten/src/ATen/Context.h

+16
Original file line numberDiff line numberDiff line change
@@ -327,4 +327,20 @@ static inline void manual_seed(uint64_t seed) {
327327
}
328328
}
329329

330+
// When the global flag `allow_tf32` is set to true, cuBLAS handles are
331+
// automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH.
332+
// For some operators, such as addmv, TF32 offers no performance improvement
333+
// but causes precision loss. To help this case, this class implements
334+
// a RAII guard that can be used to quickly disable TF32 within its scope.
335+
//
336+
// Usage:
337+
// NoTF32Guard disable_tf32;
338+
struct TORCH_API NoTF32Guard {
339+
NoTF32Guard();
340+
~NoTF32Guard();
341+
static bool should_disable_tf32();
342+
private:
343+
bool changed = false;
344+
};
345+
330346
} // namespace at

‎aten/src/ATen/cuda/CUDABlas.cpp

+19-13
Original file line numberDiff line numberDiff line change
@@ -407,19 +407,22 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
407407
#endif
408408

409409
#if !defined(__HIP_PLATFORM_HCC__) || (defined(__HIP_PLATFORM_HCC__) && HIP_VERSION >= 210)
410-
template <>
411-
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
412-
// See Note [Writing Nondeterministic Operations]
413-
globalContext().alertCuBLASConfigNotDeterministic();
414-
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
415-
cublasOperation_t op = _cublasOpFromChar(trans);
416-
_cublasAdjustLdLevel2(m, n, &lda);
417-
GEMV_CHECK_ARGVALUES(c10::complex<float>);
418-
TORCH_CUDABLAS_CHECK(
419-
cublasCgemv(handle, op, m, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
420-
lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
421-
reinterpret_cast<cuComplex*>(y), incy));
422-
}
410+
template <>
411+
void gemv<c10::complex<float>>(CUDABLAS_GEMV_ARGTYPES(c10::complex<float>)) {
412+
// gemv is bw bound, and does not benefit from TF32. But the precision
413+
// loss still happens on TF32. So we disable it here.
414+
NoTF32Guard disable_tf32;
415+
// See Note [Writing Nondeterministic Operations]
416+
globalContext().alertCuBLASConfigNotDeterministic();
417+
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
418+
cublasOperation_t op = _cublasOpFromChar(trans);
419+
_cublasAdjustLdLevel2(m, n, &lda);
420+
GEMV_CHECK_ARGVALUES(c10::complex<float>);
421+
TORCH_CUDABLAS_CHECK(
422+
cublasCgemv(handle, op, m, n, reinterpret_cast<const cuComplex*>(&alpha), reinterpret_cast<const cuComplex*>(a),
423+
lda, reinterpret_cast<const cuComplex*>(x), incx, reinterpret_cast<const cuComplex*>(&beta),
424+
reinterpret_cast<cuComplex*>(y), incy));
425+
}
423426
#endif
424427

425428
template <>
@@ -436,6 +439,9 @@ void gemv<double>(CUDABLAS_GEMV_ARGTYPES(double)) {
436439

437440
template <>
438441
void gemv<float>(CUDABLAS_GEMV_ARGTYPES(float)) {
442+
// gemv is bw bound, and does not benefit from TF32. But the precision
443+
// loss still happens on TF32. So we disable it here.
444+
NoTF32Guard disable_tf32;
439445
// See Note [Writing Nondeterministic Operations]
440446
globalContext().alertCuBLASConfigNotDeterministic();
441447
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();

‎aten/src/ATen/cuda/CublasHandlePool.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
4545
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
4646
// FP32 data type calculations based on the value of the allow_tf32 flag.
4747
// To enable TF32, set the math mode of the handle to CUBLAS_TF32_TENSOR_OP_MATH.
48-
if (at::globalContext().allowTF32CuBLAS()) {
48+
if (!NoTF32Guard::should_disable_tf32() && at::globalContext().allowTF32CuBLAS()) {
4949
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH));
5050
} else {
5151
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));

‎aten/src/ATen/native/LinearAlgebra.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,8 @@ Tensor matrix_exp(const Tensor& a) {
12231223
"matrix_exp(", a.scalar_type(), "{", a.sizes(), "}): expected a tensor "
12241224
"of squared matrices");
12251225

1226+
NoTF32Guard disable_tf32;
1227+
12261228
if (a.size(-1) == 1) {
12271229
return a.exp();
12281230
}
@@ -1231,6 +1233,7 @@ Tensor matrix_exp(const Tensor& a) {
12311233
}
12321234

12331235
Tensor matrix_exp_backward(const Tensor& self, const Tensor& grad) {
1236+
NoTF32Guard disable_tf32;
12341237
return backward_analytic_function_of_a_matrix(
12351238
self, grad,
12361239
[](const Tensor& a) {

‎test/jit/test_tracer.py

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
IS_SANDCASTLE, IS_WINDOWS
1919
from torch.testing._internal.jit_utils import JitTestCase, enable_cpu_fuser, \
2020
_tmp_donotuse_dont_inline_everything, _trace, RUN_CUDA, RUN_CUDA_MULTI_GPU
21+
from torch.testing._internal.common_cuda import with_tf32_off
2122
from typing import List, Tuple
2223
from torch import Tensor
2324

@@ -900,6 +901,9 @@ def foo(a):
900901
self.assertEqual(foo(x), x + x + x)
901902

902903
@unittest.skipIf(not RUN_CUDA, "calls .cuda()")
904+
# By default, on Ampere or later GPUs, nn.Linear computes float tensors at TF32 precision.
905+
# We want float tensors to be computed at full precision in order to use the default precision
906+
@with_tf32_off
903907
def test_traced_module_cuda(self):
904908
class Model(nn.Module):
905909
def __init__(self, num_features, num_layers):

‎test/test_jit_fuser.py

+7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, warmup_backward
1111
from textwrap import dedent
1212
from itertools import product, permutations
13+
from torch.testing._internal.common_cuda import with_tf32_off
1314

1415
from test_jit import backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \
1516
LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell
@@ -710,6 +711,9 @@ def test_lstm_cuda(self):
710711
"aten::_grad_sum_to_size"))
711712

712713
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
714+
# By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision.
715+
# We want float tensors to be computed at full precision in order to use the default precision
716+
@with_tf32_off
713717
def test_lstm_concat_cuda(self):
714718
inputs = get_lstm_inputs('cuda')
715719
ge = self.checkTrace(LSTMCellC, inputs)
@@ -740,6 +744,9 @@ def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
740744

741745
# TODO: Fuser doesn't work at all when inputs require grad. Fix that
742746
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
747+
# By default, on Ampere or later GPUs, LSTM computes float tensors at TF32 precision.
748+
# We want float tensors to be computed at full precision in order to use the default precision
749+
@with_tf32_off
743750
def test_lstm_traced_cuda(self):
744751
inputs = get_lstm_inputs('cuda')
745752
ge = self.checkTrace(LSTMCellF, inputs)

‎test/test_nn.py

+1
Original file line numberDiff line numberDiff line change
@@ -12018,6 +12018,7 @@ def _test_conv_cudnn_nhwc_nchw(self, layer, n, c, h, w, k, filter_size, device):
1201812018
@onlyCUDA
1201912019
@skipCUDAIfRocm
1202012020
@skipCUDAIfCudnnVersionLessThan(7603)
12021+
@tf32_on_and_off(0.05)
1202112022
def test_conv_cudnn_mismatch_memory_format(self, device):
1202212023
configs = [
1202312024
[4, 2, 8, 8, 4, 2],

‎test/test_torch.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,10 @@
4141
from typing import Dict, List, Tuple, Union
4242
import torch.backends.quantized
4343
import torch.testing._internal.data
44-
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, \
44+
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, with_tf32_off, \
4545
_get_torch_cuda_version, TEST_MAGMA
4646

4747

48-
4948
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
5049
# sharding on sandcastle. This line silences flake warnings
5150
load_tests = load_tests
@@ -7003,6 +7002,9 @@ def test_matrix_exp_boundary_cases(self, device, dtype):
70037002
@skipCUDAIfNoMagma
70047003
@skipCPUIfNoLapack
70057004
@dtypes(torch.float, torch.double)
7005+
# Although tf32 is always disabled on matrix_exp, this test uses matmul,
7006+
# which has tf32 on by default
7007+
@with_tf32_off
70067008
def test_matrix_exp_analytic(self, device, dtype):
70077009
# check zero matrix
70087010
x = torch.zeros(20, 20, dtype=dtype, device=device)
@@ -7144,6 +7146,9 @@ def run_test(*n):
71447146
@skipCUDAIfNoMagma
71457147
@skipCPUIfNoLapack
71467148
@dtypes(torch.float, torch.double)
7149+
# Although tf32 is always disabled on matrix_exp, this test uses matmul,
7150+
# which has tf32 on by default
7151+
@with_tf32_off
71477152
def test_matrix_exp_compare_with_taylor(self, device, dtype):
71487153

71497154
def normalize_to_1_operator_norm(sample, desired_norm):
@@ -16471,6 +16476,7 @@ def _test(row_major, incx, incy, lda_tail):
1647116476
@dtypesIfCUDA(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM))
1647216477
@dtypes(*torch.testing.get_all_complex_dtypes(), *torch.testing.get_all_fp_dtypes())
1647316478
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
16479+
@tf32_on_and_off(0.05)
1647416480
def test_addmm(self, device, dtype):
1647516481
M = torch.randn(10, 25, device=device).to(dtype)
1647616482
m1 = torch.randn(10, 50, device=device).to(dtype)
@@ -19832,13 +19838,13 @@ def inner(self, device, dtype):
1983219838
1e-1, 1e-1, 1e-5, torch.testing.get_all_fp_dtypes()),
1983319839
('addbmm', '', _small_2d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
1983419840
1e-1, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types,
19835-
_cpu_types, True, [tf32_on_and_off(0.005)]),
19841+
_cpu_types, True, [tf32_on_and_off(0.01)]),
1983619842
('addbmm', 'scalar', _small_2d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
1983719843
1e-1, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types, _cpu_types, True,
19838-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
19844+
[tf32_on_and_off(0.01), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
1983919845
('addbmm', 'two_scalars', _small_2d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
1984019846
1e-1, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types, _cpu_types, True,
19841-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
19847+
[tf32_on_and_off(0.01), _wrap_maybe_warns("This overload of addbmm_? is deprecated")]),
1984219848
('baddbmm', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)],
1984319849
1e-2, 1e-1, 1e-4, torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM)),
1984419850
('baddbmm', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)],
@@ -19865,26 +19871,26 @@ def inner(self, device, dtype):
1986519871
[_wrap_maybe_warns("This overload of addcmul_? is deprecated")]),
1986619872
('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)], 1e-1, 1e-1, 1e-4,
1986719873
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM),
19868-
_cpu_types, True, [tf32_on_and_off(0.005)], 0, True),
19874+
_cpu_types, True, [tf32_on_and_off(0.01)], 0, True),
1986919875
('addmm', 'scalar', _medium_2d,
1987019876
lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)], 1e-1, 1e-1, 1e-4,
1987119877
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM), _cpu_types, True,
19872-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
19878+
[tf32_on_and_off(0.01), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
1987319879
('addmm', 'two_scalars', _medium_2d,
1987419880
lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)], 1e-1, 1e-1, 1e-4,
1987519881
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM), _cpu_types, True,
19876-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
19882+
[tf32_on_and_off(0.01), _wrap_maybe_warns("This overload of addmm_? is deprecated")]),
1987719883
('addmv', '', _medium_1d, lambda t, d: [_medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4,
1987819884
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types,
19879-
True, [tf32_on_and_off(0.005)], 0, True),
19885+
True, [], 0, True),
1988019886
('addmv', 'scalar', _medium_1d,
1988119887
lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4,
1988219888
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types, True,
19883-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmv_? is deprecated")]),
19889+
[_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
1988419890
('addmv', 'two_scalars', _medium_1d,
1988519891
lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4,
1988619892
torch.testing.get_all_fp_dtypes(include_bfloat16=AMPERE_OR_ROCM) + _complex_types_skip_rocm, _cpu_types, True,
19887-
[tf32_on_and_off(0.005), _wrap_maybe_warns("This overload of addmv_? is deprecated")]),
19893+
[_wrap_maybe_warns("This overload of addmv_? is deprecated")]),
1988819894
('addr', '', _medium_2d, lambda t, d: [_medium_1d(t, d), _medium_1d(t, d)],
1988919895
1e-2, 1e-1, 1e-4, _float_types2),
1989019896
('addr', 'scalar', _medium_2d,

‎torch/testing/_internal/common_cuda.py

+15
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,21 @@ def wrapped(self, device, dtype):
127127
return wrapped
128128
return wrapper
129129

130+
131+
# This is a wrapper that wraps a test to run it with TF32 turned off.
132+
# This wrapper is designed to be used when a test uses matmul or convolutions
133+
# but the purpose of that test is not testing matmul or convolutions.
134+
# Disabling TF32 will enforce torch.float tensors to be always computed
135+
# at full precision.
136+
def with_tf32_off(f):
137+
@functools.wraps(f)
138+
def wrapped(*args, **kwargs):
139+
with tf32_off():
140+
return f(*args, **kwargs)
141+
142+
return wrapped
143+
144+
130145
def _get_torch_cuda_version():
131146
if torch.version.cuda is None:
132147
return [0, 0]

‎torch/testing/_internal/common_nn.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,7 @@ def fractional_max_pool3d_test(test_case):
16011601
input_size=(2, 4, 10),
16021602
cudnn=True,
16031603
with_tf32=True,
1604+
tf32_precision=0.005,
16041605
),
16051606
dict(
16061607
module_name='Conv1d',
@@ -1620,6 +1621,7 @@ def fractional_max_pool3d_test(test_case):
16201621
cudnn=True,
16211622
desc='pad1',
16221623
with_tf32=True,
1624+
tf32_precision=0.005,
16231625
),
16241626
dict(
16251627
module_name='Conv1d',
@@ -1629,6 +1631,7 @@ def fractional_max_pool3d_test(test_case):
16291631
cudnn=True,
16301632
desc='pad2',
16311633
with_tf32=True,
1634+
tf32_precision=0.005,
16321635
),
16331636
dict(
16341637
module_name='Conv1d',
@@ -1638,6 +1641,7 @@ def fractional_max_pool3d_test(test_case):
16381641
cudnn=True,
16391642
desc='pad1size1',
16401643
with_tf32=True,
1644+
tf32_precision=0.005,
16411645
),
16421646
dict(
16431647
module_name='Conv1d',
@@ -1647,6 +1651,7 @@ def fractional_max_pool3d_test(test_case):
16471651
cudnn=True,
16481652
desc='pad2size1',
16491653
with_tf32=True,
1654+
tf32_precision=0.005,
16501655
),
16511656
dict(
16521657
module_name='Conv1d',
@@ -1657,13 +1662,15 @@ def fractional_max_pool3d_test(test_case):
16571662
desc='zero_batch',
16581663
test_cuda=(not TEST_WITH_ROCM),
16591664
with_tf32=True,
1665+
tf32_precision=0.005,
16601666
),
16611667
dict(
16621668
fullname='Conv1d_dilated',
16631669
constructor=lambda: nn.Conv1d(4, 5, kernel_size=3, dilation=2),
16641670
cpp_constructor_args='torch::nn::Conv1dOptions(4, 5, 3).dilation(2)',
16651671
input_size=(2, 4, 10),
16661672
with_tf32=True,
1673+
tf32_precision=0.005,
16671674
),
16681675
dict(
16691676
fullname='Conv1d_groups',
@@ -1672,6 +1679,7 @@ def fractional_max_pool3d_test(test_case):
16721679
input_size=(2, 4, 6),
16731680
cudnn=True,
16741681
with_tf32=True,
1682+
tf32_precision=0.005,
16751683
),
16761684
dict(
16771685
fullname='ConvTranspose1d',
@@ -1702,6 +1710,7 @@ def fractional_max_pool3d_test(test_case):
17021710
cudnn=True,
17031711
desc='dilated',
17041712
with_tf32=True,
1713+
tf32_precision=0.005,
17051714
),
17061715
dict(
17071716
fullname='ConvTranspose1d_groups',
@@ -2117,7 +2126,7 @@ def fractional_max_pool3d_test(test_case):
21172126
cudnn=True,
21182127
check_with_long_tensor=True,
21192128
with_tf32=True,
2120-
tf32_precision=0.005,
2129+
tf32_precision=0.05,
21212130
),
21222131
dict(
21232132
module_name='Conv3d',
@@ -2140,7 +2149,7 @@ def fractional_max_pool3d_test(test_case):
21402149
desc='stride',
21412150
check_with_long_tensor=True,
21422151
with_tf32=True,
2143-
tf32_precision=0.005,
2152+
tf32_precision=0.05,
21442153
),
21452154
dict(
21462155
module_name='Conv3d',
@@ -2151,7 +2160,7 @@ def fractional_max_pool3d_test(test_case):
21512160
desc='stride_padding',
21522161
check_with_long_tensor=True,
21532162
with_tf32=True,
2154-
tf32_precision=0.01,
2163+
tf32_precision=0.05,
21552164
),
21562165
dict(
21572166
module_name='Conv3d',
@@ -2180,13 +2189,15 @@ def fractional_max_pool3d_test(test_case):
21802189
cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2)',
21812190
input_size=(2, 3, 5, 5, 5),
21822191
with_tf32=True,
2192+
tf32_precision=0.05,
21832193
),
21842194
dict(
21852195
fullname='Conv3d_dilated_strided',
21862196
constructor=lambda: nn.Conv3d(3, 4, kernel_size=2, dilation=2, stride=2),
21872197
cpp_constructor_args='torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)',
21882198
input_size=(2, 3, 5, 5, 5),
21892199
with_tf32=True,
2200+
tf32_precision=0.05
21902201
),
21912202
dict(
21922203
module_name='ConvTranspose3d',
@@ -2195,6 +2206,7 @@ def fractional_max_pool3d_test(test_case):
21952206
cudnn=True,
21962207
input_size=(1, 2, 4, 5, 4),
21972208
with_tf32=True,
2209+
tf32_precision=0.05
21982210
),
21992211
dict(
22002212
module_name='ConvTranspose3d',
@@ -2205,6 +2217,7 @@ def fractional_max_pool3d_test(test_case):
22052217
input_size=(1, 2, 4, 5, 4),
22062218
desc='dilated',
22072219
with_tf32=True,
2220+
tf32_precision=0.05
22082221
),
22092222
dict(
22102223
module_name='MaxPool3d',
@@ -5005,6 +5018,8 @@ def __init__(self, *args, **kwargs):
50055018
self.check_bfloat16 = kwargs.get('check_bfloat16', False)
50065019
self.convert_target = kwargs.get('convert_target', True)
50075020
self.test_cpu = kwargs.get('test_cpu', True)
5021+
self.with_tf32 = kwargs.get('with_tf32', True)
5022+
self.tf32_precision = kwargs.get('tf32_precision', 0.001)
50085023

50095024
def __call__(self, test_case):
50105025
module = self.constructor(*self.constructor_args)

0 commit comments

Comments
 (0)
Please sign in to comment.