9
9
import torch .nn .functional as F
10
10
from torch .testing import FileCheck
11
11
12
- from common_utils import run_tests , IS_SANDCASTLE
12
+ from common_utils import run_tests , IS_SANDCASTLE , ProfilingMode , GRAPH_EXECUTOR , \
13
+ enable_profiling_mode
13
14
from textwrap import dedent
14
15
from itertools import product , permutations
15
16
16
17
from test_jit import JitTestCase , enable_cpu_fuser , RUN_CUDA , RUN_CUDA_HALF , RUN_CUDA_MULTI_GPU , \
17
18
backward_graph , all_backward_graphs , get_lstm_inputs , get_milstm_inputs , \
18
19
LSTMCellC , LSTMCellF , LSTMCellS , MiLSTMCell , _inline_everything
19
- from jit_utils import enable_profiling_mode , ProfilingMode , IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR
20
20
21
- if IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR :
21
+ if GRAPH_EXECUTOR == ProfilingMode . PROFILING :
22
22
torch ._C ._jit_set_profiling_executor (True )
23
23
torch ._C ._jit_set_profiling_mode (True )
24
24
@@ -123,7 +123,7 @@ def scaleshift(x, scale, shift):
123
123
124
124
@unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
125
125
@unittest .skipIf (not RUN_CUDA_HALF , "no half support" )
126
- @unittest .skipIf (IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR , "no half support with profiling on" )
126
+ @unittest .skipIf (GRAPH_EXECUTOR != ProfilingMode . LEGACY , "no half support with profiling on" )
127
127
def test_cuda_half (self ):
128
128
x = torch .randn (4 , 4 , dtype = torch .half , device = 'cuda' )
129
129
y = torch .randn (4 , 4 , dtype = torch .half , device = 'cuda' )
@@ -303,15 +303,16 @@ def funcOptMax(a, b):
303
303
funcs = (func2 , funcInf , funcOptMin , funcOptMax )
304
304
for f , inputs in product (funcs , [[a , b ], [a , nan ]]):
305
305
inp1 , inp2 = inputs
306
- s = self .checkScript (f , (inp1 , inp2 ), profiling = ProfilingMode .FULL )
306
+ s = self .checkScript (f , (inp1 , inp2 ), profiling = ProfilingMode .PROFILING )
307
307
self .assertAllFused (s .graph_for (inp1 , inp2 ), except_for = {'aten::size' , 'aten::_size_if_not_equal' })
308
308
c = s (inp1 , inp2 )
309
- with enable_profiling_mode (ProfilingMode . FULL ):
309
+ with enable_profiling_mode ():
310
310
warmup_backward (c .sum ())
311
311
graph = backward_graph (s )
312
312
self .assertAllFused (graph , except_for = {'aten::Float' })
313
313
314
314
@unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
315
+ @unittest .skipIf (GRAPH_EXECUTOR != ProfilingMode .LEGACY , "no half support with profiling on" )
315
316
def test_dropout (self ):
316
317
def func (x ):
317
318
x = torch .nn .functional .dropout (x )
@@ -461,7 +462,7 @@ def test_exp_cuda(self):
461
462
self .assertAllFused (ge .graph_for (x , y ))
462
463
463
464
@unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
464
- @unittest .skipIf (IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR , "broken with profiling on" )
465
+ @unittest .skipIf (GRAPH_EXECUTOR != ProfilingMode . LEGACY , "broken with profiling on" )
465
466
@_inline_everything
466
467
def test_fuse_decompose_normalization (self ):
467
468
class ResLike (torch .jit .ScriptModule ):
@@ -552,7 +553,7 @@ def fn_test_scalar_arg_requires_grad(x, p):
552
553
"aten::_size_if_not_equal" ))
553
554
554
555
@unittest .skipIf (IS_SANDCASTLE , "NYI: fuser CPU support for Sandcastle" )
555
- @unittest .skipIf (IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR , "broken with profiling on" )
556
+ @unittest .skipIf (GRAPH_EXECUTOR != ProfilingMode . LEGACY , "broken with profiling on" )
556
557
@enable_cpu_fuser
557
558
def test_fuser_deduplication (self ):
558
559
# See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
@@ -905,6 +906,7 @@ def f(x, y):
905
906
self .assertAllFused (script_f .graph_for (x , y ), except_for = {'prim::TupleConstruct' })
906
907
907
908
@unittest .skipIf (not RUN_CUDA , "fuser requires CUDA" )
909
+ @unittest .skipIf (GRAPH_EXECUTOR != ProfilingMode .LEGACY , "no half support with profiling on" )
908
910
def test_grad_sum_to_size_elimination (self ):
909
911
910
912
def my_broadcasted_cell (a , b , c ):
@@ -913,7 +915,7 @@ def my_broadcasted_cell(a, b, c):
913
915
s1 = torch .randn (5 , 1 , requires_grad = True , device = 'cuda' )
914
916
s2 = torch .randn (5 , 5 , requires_grad = True , device = 'cuda' )
915
917
916
- module = self .checkScript (my_broadcasted_cell , (s1 , s1 , s1 ), profiling = ProfilingMode .FULL )
918
+ module = self .checkScript (my_broadcasted_cell , (s1 , s1 , s1 ), profiling = ProfilingMode .PROFILING )
917
919
forward_graph = module .graph_for (s1 , s1 , s1 )
918
920
self .assertAllFused (forward_graph , except_for = ("aten::size" , "prim::BroadcastSizes" ,
919
921
"aten::_size_if_not_equal" ))
@@ -925,7 +927,7 @@ def my_broadcasted_cell(a, b, c):
925
927
args = s2 if i < 1 else s1 , s2 if i < 2 else s1 , s2
926
928
args = [a .detach_ ().requires_grad_ () for a in args ]
927
929
# recompile, so we don't trigger bailouts
928
- module = self .checkScript (my_broadcasted_cell , args , profiling = ProfilingMode .FULL )
930
+ module = self .checkScript (my_broadcasted_cell , args , profiling = ProfilingMode .PROFILING )
929
931
res = module (s2 if i < 1 else s1 , s2 if i < 2 else s1 , s2 )
930
932
warmup_backward (res .sum (), args )
931
933
grads = torch .autograd .grad (res .sum (), args )
0 commit comments