Skip to content

Commit 5b702ab

Browse files
Krovatkinfacebook-github-bot
authored andcommittedNov 11, 2019
switching to a simple/full executor
Summary: Pull Request resolved: pytorch#29230 Differential Revision: D18402229 Pulled By: Krovatkin fbshipit-source-id: 62f4bc9bc89c0c7369359bba1359c22a2fa80f46
1 parent cedca37 commit 5b702ab

19 files changed

+415
-265
lines changed
 

‎.circleci/config.yml

+16
Original file line numberDiff line numberDiff line change
@@ -1895,6 +1895,22 @@ workflows:
18951895
ios_platform: "OS"
18961896
requires:
18971897
- setup
1898+
- pytorch_linux_test:
1899+
name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_legacy_test
1900+
requires:
1901+
- setup
1902+
- pytorch_linux_xenial_py3_6_gcc5_4_build
1903+
build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_legacy-test"
1904+
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:347"
1905+
resource_class: large
1906+
- pytorch_linux_test:
1907+
name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test
1908+
requires:
1909+
- setup
1910+
- pytorch_linux_xenial_py3_6_gcc5_4_build
1911+
build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_simple-test"
1912+
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:347"
1913+
resource_class: large
18981914
- caffe2_linux_build:
18991915
name: caffe2_onnx_py2_gcc5_ubuntu16_04_build
19001916
requires:

‎.circleci/generate_config_yml.py

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def write(self, output_filehandle):
9393
File("workflows-pytorch-macos-builds.yml"),
9494
File("workflows-pytorch-android-gradle-build.yml"),
9595
File("workflows-pytorch-ios-builds.yml"),
96+
File("workflows-pytorch-ge-config-tests.yml"),
9697
Listgen(caffe2_build_definitions.get_workflow_jobs, 3),
9798
File("workflows-binary-builds-smoke-subset.yml"),
9899
Listgen(binary_build_definitions.get_binary_smoke_test_jobs, 3),

‎.circleci/scripts/should_run_job.py

+4
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@
6565
# XLA
6666
'pytorch-xla-linux-xenial-py3.6-clang7',
6767

68+
# GraphExecutor config jobs
69+
'pytorch-linux-xenial-py3.6-gcc5.4-ge_config_simple-test',
70+
'pytorch-linux-xenial-py3.6-gcc5.4-ge_config_legacy-test',
71+
6872
# Other checks
6973
'pytorch-short-perf-test-gpu',
7074
'pytorch-python-doc-push',
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
- pytorch_linux_test:
2+
name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_legacy_test
3+
requires:
4+
- setup
5+
- pytorch_linux_xenial_py3_6_gcc5_4_build
6+
build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_legacy-test"
7+
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:347"
8+
resource_class: large
9+
- pytorch_linux_test:
10+
name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test
11+
requires:
12+
- setup
13+
- pytorch_linux_xenial_py3_6_gcc5_4_build
14+
build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_simple-test"
15+
docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:347"
16+
resource_class: large

‎.jenkins/pytorch/test.sh

+15-1
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,18 @@ test_python_nn() {
109109
assert_git_not_dirty
110110
}
111111

112+
test_python_ge_config_simple() {
113+
time python test/run_test.py --include jit_simple --verbose
114+
assert_git_not_dirty
115+
}
116+
117+
test_python_ge_config_legacy() {
118+
time python test/run_test.py --include jit_legacy jit_fuser_legacy --verbose
119+
assert_git_not_dirty
120+
}
121+
112122
test_python_all_except_nn() {
113-
time python test/run_test.py --exclude nn --verbose --bring-to-front quantization quantized quantized_tensor quantized_nn_mods
123+
time python test/run_test.py --exclude nn jit_simple jit_legacy jit_fuser_legacy --verbose --bring-to-front quantization quantized quantized_tensor quantized_nn_mods
114124
assert_git_not_dirty
115125
}
116126

@@ -219,6 +229,10 @@ if [[ "${BUILD_ENVIRONMENT}" == *backward* ]]; then
219229
elif [[ "${BUILD_ENVIRONMENT}" == *xla* || "${JOB_BASE_NAME}" == *xla* ]]; then
220230
test_torchvision
221231
test_xla
232+
elif [[ "${BUILD_ENVIRONMENT}" == *ge_config_legacy* || "${JOB_BASE_NAME}" == *ge_config_legacy* ]]; then
233+
test_python_ge_config_legacy
234+
elif [[ "${BUILD_ENVIRONMENT}" == *ge_config_simple* || "${JOB_BASE_NAME}" == *ge_config_simple* ]]; then
235+
test_python_ge_config_simple
222236
elif [[ "${BUILD_ENVIRONMENT}" == *libtorch* ]]; then
223237
# TODO: run some C++ tests
224238
echo "no-op at the moment"

‎test/common_utils.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,65 @@
3939
from torch._six import string_classes, inf
4040
import torch.backends.cudnn
4141
import torch.backends.mkl
42-
42+
from enum import Enum
4343

4444
torch.backends.disable_global_flags()
4545

46+
IS_SANDCASTLE = os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 'sandcastle'
47+
48+
class ProfilingMode(Enum):
49+
LEGACY = 1
50+
SIMPLE = 2
51+
PROFILING = 3
52+
53+
@contextmanager
54+
def enable_profiling_mode():
55+
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
56+
old_prof_exec_state = torch._C._jit_set_profiling_executor(True)
57+
old_prof_mode_state = torch._C._jit_set_profiling_mode(True)
58+
try:
59+
yield
60+
finally:
61+
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
62+
torch._C._jit_set_profiling_executor(old_prof_exec_state)
63+
torch._C._jit_set_profiling_mode(old_prof_mode_state)
64+
65+
func_call = torch._C.ScriptFunction.__call__
66+
meth_call = torch._C.ScriptMethod.__call__
67+
68+
def prof_callable(callable, *args, **kwargs):
69+
if 'profile_and_replay' in kwargs:
70+
del kwargs['profile_and_replay']
71+
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
72+
with enable_profiling_mode():
73+
callable(*args, **kwargs)
74+
return callable(*args, **kwargs)
75+
76+
return callable(*args, **kwargs)
77+
78+
def prof_func_call(*args, **kwargs):
79+
return prof_callable(func_call, *args, **kwargs)
80+
81+
def prof_meth_call(*args, **kwargs):
82+
return prof_callable(meth_call, *args, **kwargs)
83+
84+
torch._C.ScriptFunction.__call__ = prof_func_call
85+
torch._C.ScriptMethod.__call__ = prof_meth_call
4686

4787
parser = argparse.ArgumentParser(add_help=False)
4888
parser.add_argument('--subprocess', action='store_true',
4989
help='whether to run each test in a subprocess')
5090
parser.add_argument('--seed', type=int, default=1234)
5191
parser.add_argument('--accept', action='store_true')
92+
parser.add_argument('--ge_config', type=str)
93+
94+
GRAPH_EXECUTOR = ProfilingMode.SIMPLE if IS_SANDCASTLE else ProfilingMode.PROFILING
5295
args, remaining = parser.parse_known_args()
96+
if args.ge_config == 'legacy':
97+
GRAPH_EXECUTOR = ProfilingMode.LEGACY
98+
elif args.ge_config == 'simple':
99+
GRAPH_EXECUTOR = ProfilingMode.SIMPLE
100+
53101
TEST_IN_SUBPROCESS = args.subprocess
54102
SEED = args.seed
55103
if not expecttest.ACCEPT:
@@ -1229,7 +1277,7 @@ def get_int64_dtype(dtype):
12291277
int64_dtype, layout, device, fv + 5, False)
12301278

12311279

1232-
IS_SANDCASTLE = os.getenv('SANDCASTLE') == '1' or os.getenv('TW_JOB_USER') == 'sandcastle'
1280+
12331281

12341282
THESE_TAKE_WAY_TOO_LONG = {
12351283
'test_Conv3d_groups',

‎test/cpp/jit/test_misc.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,10 @@ graph(%a):
10011001
return stack;
10021002
};
10031003
run(graph, stack);
1004-
AT_ASSERT(testPassValue);
1004+
// we will not run fusion in simple mode
1005+
if (!getExecutorMode()) {
1006+
AT_ASSERT(testPassValue);
1007+
}
10051008
}
10061009

10071010
static void checkShape(

‎test/jit/test_autodiff_subgraph_slicing.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
3-
3+
import unittest
4+
from common_utils import GRAPH_EXECUTOR, ProfilingMode, enable_profiling_mode
45
import torch
56

67
# Make the helper files in test/ importable
@@ -21,18 +22,21 @@
2122
def pyfn(a, b):
2223
return a * b
2324

25+
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple Executor doesn't support gradients")
2426
class TestAutodiffSubgraphSlicing(JitTestCase):
2527
# TODO: It is better if we can test directly on graphs instead of the current
2628
# end-to-end fashion.
2729
def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
2830
with disable_autodiff_subgraph_inlining():
29-
ge = torch.jit.script(fn)
30-
inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
31-
ge(*inputs)
32-
return ge.graph_for(*inputs)
31+
with enable_profiling_mode():
32+
ge = torch.jit.script(fn)
33+
inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
34+
ge(*inputs, profile_and_replay=True)
35+
return ge.graph_for(*inputs)
3336

3437
def assertGraphSize(self, graph, size):
35-
self.assertEqual(len(list(graph.nodes())), size)
38+
nodes = list(filter(lambda n : n.kind() != "prim::BailOut" and n.kind() != "prim::BailoutTemplate", graph.nodes()))
39+
self.assertEqual(len(list(nodes)), size)
3640

3741
def test_chunk_constant_script_ad(self):
3842
@torch.jit.script
@@ -42,8 +46,9 @@ def func(x):
4246

4347
input = torch.rand(6, 10).requires_grad_()
4448
with disable_autodiff_subgraph_inlining():
45-
output = func(input)
46-
self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
49+
with enable_profiling_mode():
50+
output = func(input, profile_and_replay=True)
51+
self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
4752

4853
def test_simple_merge(self):
4954
# o --> o
@@ -156,8 +161,13 @@ def fn(v, w, x, y):
156161

157162
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
158163

159-
self.assertGraphSize(graph, 3)
160-
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
164+
# GuardElimination can't get rid of a prim::BailOut on ^pyfn
165+
# which makes us create two `prim::DifferentiableGraph`s
166+
# instead of just one
167+
num_nodes = 4 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 3
168+
self.assertGraphSize(graph, num_nodes)
169+
num_diff_nodes = 2 if GRAPH_EXECUTOR == ProfilingMode.PROFILING else 1
170+
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', num_diff_nodes)
161171

162172
def test_respects_lexical_scoping(self):
163173
def fn(x, k):

‎test/jit/test_models.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import sys
33
import unittest
4-
4+
from common_utils import enable_profiling_mode
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
@@ -228,8 +228,9 @@ def test_neural_style_cuda(self):
228228
@staticmethod
229229
def _test_mnist(self, device, check_export_import=True):
230230
# eval() is present because dropout makes this nondeterministic
231-
self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
232-
export_import=check_export_import)
231+
with enable_profiling_mode():
232+
self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
233+
export_import=check_export_import)
233234

234235
def test_mnist(self):
235236
self._test_mnist(self, device='cpu')
@@ -277,8 +278,9 @@ def forward(self, x):
277278
action_scores = self.affine2(x)
278279
return F.softmax(action_scores, dim=1)
279280

280-
self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
281-
export_import=test_export_import)
281+
with enable_profiling_mode():
282+
self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
283+
export_import=test_export_import)
282284

283285
def test_reinforcement_learning(self):
284286
self._test_reinforcement_learning(self, device='cpu')
@@ -526,9 +528,10 @@ def forward(self, x):
526528
export_import=False, allow_unused=True,
527529
inputs_require_grads=False)
528530
else:
529-
# eval() is present because randn_like makes this nondeterministic
530-
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
531-
export_import=check_export_import)
531+
with enable_profiling_mode():
532+
# eval() is present because randn_like makes this nondeterministic
533+
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
534+
export_import=check_export_import)
532535

533536
def test_vae(self):
534537
self._test_vae(self, device='cpu')

‎test/jit_utils.py

+7-26
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
import torch.jit.quantized
1313
import zipfile
1414
import functools
15-
from enum import Enum
1615

1716
# Testing utils
1817
from common_utils import TestCase, IS_WINDOWS, \
19-
freeze_rng_state, TemporaryFileName
18+
freeze_rng_state, TemporaryFileName, enable_profiling_mode, ProfilingMode
2019

2120
# Standard library
2221
from contextlib import contextmanager
@@ -33,28 +32,9 @@
3332
import tempfile
3433
import textwrap
3534

36-
IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR = False
37-
3835
RUN_CUDA = torch.cuda.is_available()
3936
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
4037

41-
class ProfilingMode(Enum):
42-
OFF = 1
43-
EXECUTOR = 2
44-
FULL = 3
45-
46-
@contextmanager
47-
def enable_profiling_mode(flag):
48-
if IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR:
49-
old_prof_exec_state = torch._C._jit_set_profiling_executor(flag != ProfilingMode.OFF)
50-
old_prof_mode_state = torch._C._jit_set_profiling_mode(flag == ProfilingMode.FULL)
51-
try:
52-
yield
53-
finally:
54-
if IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR:
55-
torch._C._jit_set_profiling_executor(old_prof_exec_state)
56-
torch._C._jit_set_profiling_mode(old_prof_mode_state)
57-
5838
def execWrapper(code, glob, loc):
5939
if PY2:
6040
exec(code) in glob, loc
@@ -325,13 +305,13 @@ def get_frame_vars(self, frames_up):
325305
return defined_vars
326306

327307
def checkScriptRaisesRegex(self, script, inputs, exception, regex,
328-
outputs=None, capture_output=False, profiling=ProfilingMode.FULL):
308+
outputs=None, capture_output=False, profiling=ProfilingMode.PROFILING):
329309
"""
330310
Checks that a given function will throw the correct exception,
331311
when executed with normal python, the string frontend, and the AST frontend
332312
"""
333313

334-
with enable_profiling_mode(profiling):
314+
with enable_profiling_mode():
335315
# normal python
336316
with self.assertRaisesRegex(exception, regex):
337317
script(*inputs)
@@ -362,12 +342,12 @@ def checkScript(self,
362342
inputs_requires_grad=False,
363343
capture_output=False,
364344
frames_up=1,
365-
profiling=ProfilingMode.FULL):
345+
profiling=ProfilingMode.PROFILING):
366346
with torch.jit.optimized_execution(optimize):
367-
with enable_profiling_mode(profiling):
347+
with enable_profiling_mode():
368348
if isinstance(script, str):
369349
# Compile the string to a Script function
370-
# with enable_profiling_mode(profiling):
350+
# with enable_profiling_mode():
371351
cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
372352

373353
# Execute the Python function so we can run it later and get its
@@ -473,6 +453,7 @@ def input_reduce(input, fn, acc):
473453
outputs_ge = ge(*nograd_inputs)
474454
self.assertEqual(outputs, outputs_ge)
475455

456+
# test gradients case
476457
outputs = func(*recording_inputs)
477458
if inputs_require_grads:
478459
grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs,

‎test/run_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
'utils',
5555
'namedtuple_return_api',
5656
'jit_fuser',
57+
'jit_simple',
58+
'jit_legacy',
59+
'jit_fuser_legacy',
5760
'tensorboard',
5861
'namedtensor',
5962
'type_promotion',
@@ -135,15 +138,13 @@ def run_test(executable, test_module, test_directory, options, *extra_unittest_a
135138
# Can't call `python -m unittest test_*` here because it doesn't run code
136139
# in `if __name__ == '__main__': `. So call `python test_*.py` instead.
137140
argv = [test_module + '.py'] + unittest_args + list(extra_unittest_args)
138-
139141
command = executable + argv
140142
return shell(command, test_directory)
141143

142144

143145
def test_cuda_primary_ctx(executable, test_module, test_directory, options):
144146
return run_test(executable, test_module, test_directory, options, '--subprocess')
145147

146-
147148
def test_cpp_extensions(executable, test_module, test_directory, options):
148149
try:
149150
cpp_extension.verify_ninja_availability()
@@ -444,7 +445,6 @@ def main():
444445
signal_name = SIGNALS_TO_NAMES_DICT[-return_code]
445446
message += ' Received signal: {}'.format(signal_name)
446447
raise RuntimeError(message)
447-
448448
if options.coverage:
449449
shell(['coverage', 'combine'])
450450
shell(['coverage', 'html'])

‎test/test_jit.py

+204-190
Large diffs are not rendered by default.

‎test/test_jit_fuser.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
import torch.nn.functional as F
1010
from torch.testing import FileCheck
1111

12-
from common_utils import run_tests, IS_SANDCASTLE
12+
from common_utils import run_tests, IS_SANDCASTLE, ProfilingMode, GRAPH_EXECUTOR, \
13+
enable_profiling_mode
1314
from textwrap import dedent
1415
from itertools import product, permutations
1516

1617
from test_jit import JitTestCase, enable_cpu_fuser, RUN_CUDA, RUN_CUDA_HALF, RUN_CUDA_MULTI_GPU, \
1718
backward_graph, all_backward_graphs, get_lstm_inputs, get_milstm_inputs, \
1819
LSTMCellC, LSTMCellF, LSTMCellS, MiLSTMCell, _inline_everything
19-
from jit_utils import enable_profiling_mode, ProfilingMode, IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR
2020

21-
if IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR:
21+
if GRAPH_EXECUTOR == ProfilingMode.PROFILING:
2222
torch._C._jit_set_profiling_executor(True)
2323
torch._C._jit_set_profiling_mode(True)
2424

@@ -123,7 +123,7 @@ def scaleshift(x, scale, shift):
123123

124124
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
125125
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
126-
@unittest.skipIf(IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR, "no half support with profiling on")
126+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
127127
def test_cuda_half(self):
128128
x = torch.randn(4, 4, dtype=torch.half, device='cuda')
129129
y = torch.randn(4, 4, dtype=torch.half, device='cuda')
@@ -303,15 +303,16 @@ def funcOptMax(a, b):
303303
funcs = (func2, funcInf, funcOptMin, funcOptMax)
304304
for f, inputs in product(funcs, [[a, b], [a, nan]]):
305305
inp1, inp2 = inputs
306-
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.FULL)
306+
s = self.checkScript(f, (inp1, inp2), profiling=ProfilingMode.PROFILING)
307307
self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size', 'aten::_size_if_not_equal'})
308308
c = s(inp1, inp2)
309-
with enable_profiling_mode(ProfilingMode.FULL):
309+
with enable_profiling_mode():
310310
warmup_backward(c.sum())
311311
graph = backward_graph(s)
312312
self.assertAllFused(graph, except_for={'aten::Float'})
313313

314314
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
315+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
315316
def test_dropout(self):
316317
def func(x):
317318
x = torch.nn.functional.dropout(x)
@@ -461,7 +462,7 @@ def test_exp_cuda(self):
461462
self.assertAllFused(ge.graph_for(x, y))
462463

463464
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
464-
@unittest.skipIf(IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR, "broken with profiling on")
465+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
465466
@_inline_everything
466467
def test_fuse_decompose_normalization(self):
467468
class ResLike(torch.jit.ScriptModule):
@@ -552,7 +553,7 @@ def fn_test_scalar_arg_requires_grad(x, p):
552553
"aten::_size_if_not_equal"))
553554

554555
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle")
555-
@unittest.skipIf(IN_TRANSITION_TO_PROFILING_GRAPH_EXECUTOR, "broken with profiling on")
556+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
556557
@enable_cpu_fuser
557558
def test_fuser_deduplication(self):
558559
# See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
@@ -905,6 +906,7 @@ def f(x, y):
905906
self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
906907

907908
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
909+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "no half support with profiling on")
908910
def test_grad_sum_to_size_elimination(self):
909911

910912
def my_broadcasted_cell(a, b, c):
@@ -913,7 +915,7 @@ def my_broadcasted_cell(a, b, c):
913915
s1 = torch.randn(5, 1, requires_grad=True, device='cuda')
914916
s2 = torch.randn(5, 5, requires_grad=True, device='cuda')
915917

916-
module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.FULL)
918+
module = self.checkScript(my_broadcasted_cell, (s1, s1, s1), profiling=ProfilingMode.PROFILING)
917919
forward_graph = module.graph_for(s1, s1, s1)
918920
self.assertAllFused(forward_graph, except_for=("aten::size", "prim::BroadcastSizes",
919921
"aten::_size_if_not_equal"))
@@ -925,7 +927,7 @@ def my_broadcasted_cell(a, b, c):
925927
args = s2 if i < 1 else s1, s2 if i < 2 else s1, s2
926928
args = [a.detach_().requires_grad_() for a in args]
927929
# recompile, so we don't trigger bailouts
928-
module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.FULL)
930+
module = self.checkScript(my_broadcasted_cell, args, profiling=ProfilingMode.PROFILING)
929931
res = module(s2 if i < 1 else s1, s2 if i < 2 else s1, s2)
930932
warmup_backward(res.sum(), args)
931933
grads = torch.autograd.grad(res.sum(), args)

‎test/test_jit_fuser_legacy.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import sys
2+
sys.argv.append("--ge_config=legacy")
3+
from test_jit_fuser import *
4+
5+
if __name__ == '__main__':
6+
run_tests()

‎test/test_jit_legacy.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import sys
2+
sys.argv.append("--ge_config=legacy")
3+
from test_jit import *
4+
5+
if __name__ == '__main__':
6+
run_tests()
7+
if not PY2:
8+
import test_jit_py3
9+
suite = unittest.findTestCases(test_jit_py3)
10+
unittest.TextTestRunner().run(suite)

‎test/test_jit_simple.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import sys
2+
sys.argv.append("--ge_config=simple")
3+
from test_jit import *
4+
5+
if __name__ == '__main__':
6+
run_tests()
7+
if not PY2:
8+
import test_jit_py3
9+
suite = unittest.findTestCases(test_jit_py3)
10+
unittest.TextTestRunner().run(suite)

‎torch/csrc/jit/graph_executor.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ struct GraphExecutorImpl : public GraphExecutorImplBase {
495495
}
496496

497497
ExecutionPlan getPlanFor(Stack& stack) override {
498-
return getGraphExecutorOptimize() ? getOrCompile(stack)
498+
return getGraphExecutorOptimize() ? getOrCompile(stack)
499499
: getOrCompileFallback();
500500
}
501501

‎torch/csrc/jit/passes/alias_analysis.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,10 @@ void AliasDb::analyzeImpl(Node* node) {
363363
// TODO: this can be improved with summarizes of what the function does
364364
// for now we assume the worst
365365
return analyzeConservative(node);
366-
case prim::Print:
367366
case prim::Uninitialized:
367+
giveFreshAlias(node->output());
368+
return;
369+
case prim::Print:
368370
case prim::isinstance:
369371
// These ops do nothing
370372
return;

‎torch/csrc/jit/profiling_graph_executor_impl.cpp

+22-12
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,14 @@
1919
namespace torch {
2020
namespace jit {
2121

22+
#ifdef FBCODE_CAFFE2
2223
static std::atomic<bool> profiling_mode{false};
2324
static std::atomic<bool> executor_mode{false};
25+
#else
26+
static std::atomic<bool> executor_mode{true};
27+
static std::atomic<bool> profiling_mode{true};
28+
#endif
29+
2430

2531
std::atomic<bool>& getProfilingMode() {
2632
return profiling_mode;
@@ -112,22 +118,26 @@ ExecutionPlan ProfilingGraphExecutorImpl::getPlanFor(Stack& stack) {
112118
// TODO: insert grad propagation
113119
bool needs_gradient = getProfilingMode()
114120
? needsGradientInProfilingMode(copy->block())
115-
: needsGradient(copy);
121+
: true;
116122
if (needs_gradient) {
117-
auto diff_nodes = CreateAutodiffSubgraphs(
123+
// for Simple Executor skip creating autodiff graphs
124+
// and let autograd handle backward for us
125+
if (getProfilingMode()) {
126+
auto diff_nodes = CreateAutodiffSubgraphs(
118127
copy,
119128
getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1);
120-
for (Node *dnode : diff_nodes) {
121-
auto diff_graph = std::move(dnode->g(attr::Subgraph));
122-
Gradient gradient = differentiate(diff_graph);
123-
runOptimization(gradient.f);
124-
// run non diff optimization on the forward graph
125-
runNondiffOptimization(gradient.f);
126-
packGradient(gradient, dnode);
129+
for (Node *dnode : diff_nodes) {
130+
auto diff_graph = std::move(dnode->g(attr::Subgraph));
131+
Gradient gradient = differentiate(diff_graph);
132+
runOptimization(gradient.f);
133+
// run non diff optimization on the forward graph
134+
runNondiffOptimization(gradient.f);
135+
packGradient(gradient, dnode);
136+
}
137+
InlineAutodiffSubgraphs(copy, getAutodiffSubgraphInlining()
138+
? autodiffSubgraphInlineThreshold
139+
: 1);
127140
}
128-
InlineAutodiffSubgraphs(copy, getAutodiffSubgraphInlining()
129-
? autodiffSubgraphInlineThreshold
130-
: 1);
131141
} else {
132142
runNondiffOptimization(copy);
133143
}

0 commit comments

Comments
 (0)
Please sign in to comment.