Skip to content

Commit 5973b44

Browse files
gchananfacebook-github-bot
authored andcommittedSep 3, 2020
Rename NewCriterionTest to CriterionTest. (pytorch#44056)
Summary: Pull Request resolved: pytorch#44056 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D23482573 Pulled By: gchanan fbshipit-source-id: dde0f1624330dc85f48e5a0b9d98fb55fdb72f68
1 parent 7d95eb8 commit 5973b44

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed
 

‎test/cpp_api_parity/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# Unique identifier for this module config (e.g. "BCELoss_weights_cuda")
2222
'module_variant_name',
2323

24-
# An instance of an NN test class (e.g. `NewCriterionTest`) which stores
24+
# An instance of an NN test class (e.g. `CriterionTest`) which stores
2525
# necessary information (e.g. input / target / extra_args) for running the Python test
2626
'test_instance',
2727

@@ -184,7 +184,7 @@ def move_cpp_tensors_to_device(cpp_tensor_stmts, device):
184184
return ['{}.to("{}")'.format(tensor_stmt, device) for tensor_stmt in cpp_tensor_stmts]
185185

186186
def is_criterion_test(test_instance):
187-
return isinstance(test_instance, common_nn.NewCriterionTest)
187+
return isinstance(test_instance, common_nn.CriterionTest)
188188

189189
# This function computes the following:
190190
# - What variable declaration statements should show up in the C++ parity test function

‎test/test_cpp_api_parity.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ class TestCppApiParity(common.TestCase):
3030
(sample_functional.functional_tests, common_nn.NewModuleTest),
3131
(common_nn.module_tests, common_nn.ModuleTest),
3232
(common_nn.new_module_tests, common_nn.NewModuleTest),
33-
(common_nn.criterion_tests, common_nn.NewCriterionTest),
34-
(common_nn.new_criterion_tests, common_nn.NewCriterionTest),
33+
(common_nn.criterion_tests, common_nn.CriterionTest),
34+
(common_nn.new_criterion_tests, common_nn.CriterionTest),
3535
]:
3636
for test_params_dict in test_params_dicts:
3737
if test_params_dict.get('test_cpp_api_parity', True):

‎test/test_nn.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \
3939
ALL_TENSORTYPES2, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC
4040
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
41-
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, NewCriterionTest, \
41+
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
4242
module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
4343
ctcloss_reference, new_module_tests
4444
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \
@@ -8742,7 +8742,7 @@ def reference_fn(i, p, m):
87428742
for test_params in criterion_tests + new_criterion_tests:
87438743
name = test_params.pop('module_name')
87448744
test_params['constructor'] = getattr(nn, name)
8745-
test = NewCriterionTest(**test_params)
8745+
test = CriterionTest(**test_params)
87468746
decorator = test_params.pop('decorator', None)
87478747
add_test(test, decorator)
87488748
if 'check_sum_reduction' in test_params:
@@ -8757,7 +8757,7 @@ def sum_reduction_constructor(*args, **kwargs):
87578757
return sum_reduction_constructor
87588758

87598759
test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor'])
8760-
test = NewCriterionTest(**test_params)
8760+
test = CriterionTest(**test_params)
87618761
add_test(test, decorator)
87628762

87638763

‎torch/testing/_internal/common_nn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5032,7 +5032,7 @@ def constructor_args(self):
50325032
return self._get_arg('constructor_args', False)
50335033

50345034

5035-
class NewCriterionTest(InputVariableMixin, TestBase):
5035+
class CriterionTest(InputVariableMixin, TestBase):
50365036
# TODO: check that criterions don't ignore grad_output
50375037

50385038
_required_arg_names = TestBase._required_arg_names.union({'target'})

0 commit comments

Comments
 (0)