Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ca1b8eb

Browse files
suofacebook-github-bot
authored andcommittedJul 13, 2020
move misc implementation out of jit/__init__.py (pytorch#41154)
Summary: Pull Request resolved: pytorch#41154 Test Plan: Imported from OSS Reviewed By: ailzhang Differential Revision: D22445213 Pulled By: suo fbshipit-source-id: 200545715c5ef13beb1437f49e01efb21498ddb7
1 parent 6392713 commit ca1b8eb

21 files changed

+403
-357
lines changed
 

‎test/jit/test_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def forward(self, input):
487487
return self.seq.forward(input)
488488

489489
# disabled due to a jitter issues that will be fixed by using load/store in the compiler
490-
with torch.jit._disable_emit_hooks():
490+
with torch._jit_internal._disable_emit_hooks():
491491
# TODO: toggle export_import once above issues are fixed
492492
self.checkTrace(Traced(), (torch.rand(3, 4),),
493493
export_import=False)

‎test/quantization/test_quantize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1817,7 +1817,7 @@ def weight(self):
18171817
def weight(self, w):
18181818
self._packed_weight = torch.ops.quantized.linear_prepack(w)
18191819

1820-
with torch.jit._disable_emit_hooks():
1820+
with torch._jit_internal._disable_emit_hooks():
18211821
x = torch.jit.script(Linear(10, 10))
18221822
torch._C._jit_pass_erase_shape_information(x.graph)
18231823

‎test/test_jit.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -2159,7 +2159,7 @@ def foo(a):
21592159
self.assertExpected(cu.foo.code)
21602160

21612161
def test_import_method(self):
2162-
with torch.jit._disable_emit_hooks():
2162+
with torch._jit_internal._disable_emit_hooks():
21632163
class Foo(torch.jit.ScriptModule):
21642164
def __init__(self):
21652165
super(Foo, self).__init__()
@@ -3596,7 +3596,7 @@ def test_annoying_doubles(self):
35963596
mod.ninf = float("-inf")
35973597
mod.nan = float("nan")
35983598

3599-
with torch.jit._disable_emit_hooks():
3599+
with torch._jit_internal._disable_emit_hooks():
36003600
class Foo(torch.jit.ScriptModule):
36013601
def __init__(self):
36023602
super(Foo, self).__init__()
@@ -9122,7 +9122,7 @@ def pack_padded_pad_packed_script(x, seq_lens):
91229122
x[seq_lens[b]:, b, :] = 0
91239123

91249124
eager_seq, eager_lengths = pack_padded_pad_packed_script(x, seq_lens)
9125-
with torch.jit._disable_emit_hooks():
9125+
with torch._jit_internal._disable_emit_hooks():
91269126
scripted_pack_padded_seq = torch.jit.script(pack_padded_pad_packed_script)
91279127
script_seq, script_lengths = scripted_pack_padded_seq(x, seq_lens)
91289128
self.assertEqual(eager_seq, script_seq)
@@ -9145,7 +9145,7 @@ def forward(self, input):
91459145

91469146
lstm = ExperimentalLSTM(input_dim=2, hidden_dim=2)
91479147

9148-
with torch.jit._disable_emit_hooks():
9148+
with torch._jit_internal._disable_emit_hooks():
91499149
self.checkModule(lstm, [torch.ones(2, 2)])
91509150

91519151
def test_script_pad_sequence_pack_sequence(self):
@@ -9165,7 +9165,7 @@ def pack_sequence_func(tensor_list, enforce_sorted=True):
91659165
tensor1 = torch.tensor([1, 2, 3])
91669166
tensor2 = torch.tensor([4, 5])
91679167
tensor3 = torch.tensor([6])
9168-
with torch.jit._disable_emit_hooks():
9168+
with torch._jit_internal._disable_emit_hooks():
91699169
self.checkScript(pad_sequence_func,
91709170
([ones3, ones4, ones5],))
91719171
self.checkScript(pad_sequence_func,
@@ -9361,7 +9361,7 @@ def bar():
93619361

93629362
def test_tuples(self):
93639363
# TODO: jitter issue.
9364-
with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list
9364+
with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list
93659365
def foo(i):
93669366
a = (i + 4, i * 2)
93679367
c = a
@@ -12613,7 +12613,7 @@ def foo(a, b):
1261312613
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
1261412614

1261512615
def test_bool_dispatch(self):
12616-
with torch.jit._disable_emit_hooks(): # TODO: Python print broadcasting list
12616+
with torch._jit_internal._disable_emit_hooks(): # TODO: Python print broadcasting list
1261712617
def kwarg_false(x):
1261812618
# type: (Tensor) -> Tensor
1261912619
return F.max_pool1d(x, 1, 1, return_indices=False)
@@ -14237,7 +14237,7 @@ def forward(self, key):
1423714237
# type: (str) -> Tensor
1423814238
return self.table[key] + self.x
1423914239

14240-
with torch.jit._disable_emit_hooks():
14240+
with torch._jit_internal._disable_emit_hooks():
1424114241
# TODO: re-enable module hook when Python printing of attributes is
1424214242
# supported
1424314243
m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
@@ -15393,7 +15393,7 @@ def run_test():
1539315393
self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
1539415394

1539515395
if test_name in EXCLUDE_PYTHON_PRINT:
15396-
with torch.jit._disable_emit_hooks():
15396+
with torch._jit_internal._disable_emit_hooks():
1539715397
run_test()
1539815398
else:
1539915399
run_test()

‎test/test_jit_fuser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def test_exp_cuda(self):
474474

475475
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
476476
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
477-
@torch.jit._disable_emit_hooks_decorator
477+
@torch._jit_internal._disable_emit_hooks_decorator
478478
@_inline_everything
479479
def test_fuse_decompose_normalization(self):
480480
class ResLike(torch.jit.ScriptModule):

‎test/test_jit_fuser_te.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def test_exp_cuda(self):
507507

508508
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
509509
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on")
510-
@torch.jit._disable_emit_hooks_decorator
510+
@torch._jit_internal._disable_emit_hooks_decorator
511511
@_inline_everything
512512
def test_fuse_decompose_normalization(self):
513513
class ResLike(torch.jit.ScriptModule):

‎torch/_jit_internal.py

+44
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
circular dependency problems
55
"""
66

7+
import contextlib
8+
import collections
79
import inspect
810
import weakref
911
import warnings
@@ -767,3 +769,45 @@ def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_t
767769

768770
def fake_range():
769771
return SourceContext('', None, 0, 0).make_raw_range(0, 1)
772+
773+
774+
def _try_get_dispatched_fn(fn):
775+
if not callable(fn):
776+
return None
777+
return boolean_dispatched.get(fn)
778+
779+
780+
def _get_named_tuple_properties(obj):
781+
assert issubclass(obj, tuple) and hasattr(obj, '_fields')
782+
fields = list(obj._fields)
783+
annotations = []
784+
has_annotations = hasattr(obj, '__annotations__')
785+
for field in fields:
786+
if has_annotations and field in obj.__annotations__:
787+
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], fake_range())
788+
annotations.append(the_type)
789+
else:
790+
annotations.append(torch._C.TensorType.get())
791+
return type(obj).__name__, fields, annotations
792+
793+
794+
def _create_named_tuple(t, unqual_name, field_names):
795+
TupleType = collections.namedtuple(unqual_name, field_names)
796+
return TupleType(*t)
797+
798+
799+
@contextlib.contextmanager
800+
def _disable_emit_hooks():
801+
hooks = torch._C._jit_get_emit_hooks()
802+
torch._C._jit_set_emit_hooks(None, None)
803+
yield
804+
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
805+
806+
807+
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
808+
def __enter__(self):
809+
self.hooks = torch._C._jit_get_emit_hooks()
810+
torch._C._jit_set_emit_hooks(None, None)
811+
812+
def __exit__(self, *args):
813+
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])

‎torch/_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __getattr__(self, op_name):
6161
op = torch._C._jit_get_operation(qualified_op_name)
6262
# let the script frontend know that op is identical to the builtin op
6363
# with qualified_op_name
64-
torch.jit._register_builtin(op, qualified_op_name)
64+
torch.jit._builtins._register_builtin(op, qualified_op_name)
6565
setattr(self, op_name, op)
6666
op.__module__ = self.__module__ + "." + self.name
6767
return op

‎torch/csrc/jit/python/pybind_utils.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -839,7 +839,7 @@ inline py::object toPyObject(IValue ivalue) {
839839
auto fieldNames = fmap(
840840
tuple->type()->schema()->arguments(),
841841
[](const Argument& arg) { return arg.name(); });
842-
return py::module::import("torch.jit")
842+
return py::module::import("torch._jit_internal")
843843
.attr("_create_named_tuple")(t, unqualName, fieldNames);
844844
} else {
845845
return std::move(t);

‎torch/csrc/jit/python/python_sugared_value.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -686,8 +686,8 @@ TypePtr registerNamedTuple(const py::object& obj, const SourceRange& loc) {
686686
}
687687
}
688688

689-
py::object props =
690-
py::module::import("torch.jit").attr("_get_named_tuple_properties")(obj);
689+
py::object props = py::module::import("torch._jit_internal")
690+
.attr("_get_named_tuple_properties")(obj);
691691
std::string unqualName;
692692
std::vector<std::string> fields;
693693
std::vector<TypePtr> annotations;
@@ -788,7 +788,7 @@ std::shared_ptr<SugaredValue> toSugaredValue(
788788
}
789789

790790
py::object builtin_name =
791-
py::module::import("torch.jit").attr("_find_builtin")(obj);
791+
py::module::import("torch.jit._builtins").attr("_find_builtin")(obj);
792792
if (!builtin_name.is_none()) {
793793
return std::make_shared<BuiltinFunction>(
794794
Symbol::fromQualString(py::str(builtin_name)), c10::nullopt);
@@ -801,8 +801,8 @@ std::shared_ptr<SugaredValue> toSugaredValue(
801801
}
802802
}
803803

804-
py::object dispatched_fn =
805-
py::module::import("torch.jit").attr("_try_get_dispatched_fn")(obj);
804+
py::object dispatched_fn = py::module::import("torch._jit_internal")
805+
.attr("_try_get_dispatched_fn")(obj);
806806
if (!dispatched_fn.is_none()) {
807807
return std::make_shared<BooleanDispatchValue>(std::move(dispatched_fn));
808808
}

‎torch/distributed/rpc/api.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def _wait_all_workers():
160160

161161
is_leader_worker = leader_worker_name == self_worker_name
162162
# Set a long enough timeout for all shutdown messages to be processed.
163-
timeout = 5 # seconds
163+
timeout = 5 # second
164164

165165
# Phase 1: Followers send intents.
166166
# All followers report intents to the leader.
@@ -522,7 +522,7 @@ def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
522522
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
523523
>>> rpc.shutdown()
524524
"""
525-
qualified_name = torch.jit._find_builtin(func)
525+
qualified_name = torch.jit._builtins._find_builtin(func)
526526
dst_worker_info = _to_worker_info(to)
527527
should_profile = torch.autograd._profiler_enabled()
528528

@@ -594,7 +594,7 @@ def _invoke_rpc(to, func, rpc_type, args=None, kwargs=None, rpc_timeout=UNSET_RP
594594
if not callable(func):
595595
raise TypeError("function should be callable.")
596596

597-
qualified_name = torch.jit._find_builtin(func)
597+
qualified_name = torch.jit._builtins._find_builtin(func)
598598
dst_worker_info = _to_worker_info(to)
599599

600600
# TODO: profiling logic does not really belong in invoke_rpc, it should be

‎torch/jit/__init__.py

+37-326
Original file line numberDiff line numberDiff line change
@@ -1,306 +1,59 @@
11
import torch._C
2-
import torch._jit_internal as _jit_internal
32

4-
from torch.jit._builtins import _find_builtin, _get_builtin_table, _register_builtin # noqa
5-
from torch._jit_internal import Future
6-
from torch.nn import Module
73
from torch.utils import set_module
8-
from torch.autograd.grad_mode import _DecoratorContextManager
9-
from typing import Optional, List
10-
11-
import collections
12-
import contextlib
13-
import functools
14-
import os
15-
import pathlib
164

175
# These are imported so users can access them from the `torch.jit` module
18-
from torch._jit_internal import Final, _overload, _overload_method
19-
from torch._jit_internal import ignore, export, unused
20-
from torch.jit._script import script, Attribute, ScriptModule, is_scripting, script_method, \
21-
RecursiveScriptModule, ScriptWarning, interface
22-
from torch.jit._trace import trace, trace_module, TracedModule, TracerWarning, TracingCheckError, \
23-
is_tracing, ONNXTracedModule, _unique_state_dict, _flatten, TopLevelTracedModule
6+
from torch._jit_internal import (
7+
Final,
8+
Future,
9+
_overload,
10+
_overload_method,
11+
ignore,
12+
export,
13+
unused,
14+
)
15+
from torch.jit._script import (
16+
script,
17+
Attribute,
18+
ScriptModule,
19+
is_scripting,
20+
script_method,
21+
RecursiveScriptModule,
22+
ScriptWarning,
23+
interface,
24+
CompilationUnit,
25+
ScriptFunction,
26+
_unwrap_optional,
27+
)
28+
from torch.jit._trace import (
29+
trace,
30+
trace_module,
31+
TracedModule,
32+
TracerWarning,
33+
TracingCheckError,
34+
is_tracing,
35+
ONNXTracedModule,
36+
TopLevelTracedModule,
37+
_unique_state_dict,
38+
_flatten,
39+
_script_if_tracing,
40+
_get_trace_graph,
41+
)
2442
from torch.jit._async import fork, wait
2543
from torch.jit._serialization import save, load
26-
27-
set_module(Future, "torch.jit")
44+
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
2845

2946
# For backwards compatibility
3047
_fork = fork
3148
_wait = wait
3249

33-
@contextlib.contextmanager
34-
def optimized_execution(should_optimize):
35-
"""
36-
A context manager that controls whether the JIT's executor will run
37-
optimizations before executing a function.
38-
"""
39-
stored_flag = torch._C._get_graph_executor_optimize()
40-
torch._C._set_graph_executor_optimize(should_optimize)
41-
try:
42-
yield
43-
finally:
44-
torch._C._set_graph_executor_optimize(stored_flag)
45-
46-
@contextlib.contextmanager
47-
def fuser(name):
48-
"""
49-
A context manager that facilitates switching between
50-
backend fusers.
51-
52-
Valid names:
53-
* ``fuser0`` - enables only legacy fuser
54-
* ``fuser1`` - enables only NNC
55-
* ``fuser2`` - enables only nvFuser
56-
"""
57-
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
58-
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
59-
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
60-
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
61-
if name == 'fuser0': # legacy fuser
62-
torch._C._jit_override_can_fuse_on_cpu(True)
63-
torch._C._jit_override_can_fuse_on_gpu(True)
64-
torch._C._jit_set_texpr_fuser_enabled(False)
65-
torch._C._jit_set_nvfuser_enabled(False)
66-
elif name == 'fuser1': # NNC
67-
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
68-
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
69-
torch._C._jit_override_can_fuse_on_cpu(False)
70-
torch._C._jit_override_can_fuse_on_gpu(False)
71-
torch._C._jit_set_texpr_fuser_enabled(True)
72-
torch._C._jit_set_nvfuser_enabled(False)
73-
elif name == 'fuser2': # nvFuser
74-
torch._C._jit_override_can_fuse_on_cpu(False)
75-
torch._C._jit_override_can_fuse_on_gpu(False)
76-
torch._C._jit_set_texpr_fuser_enabled(False)
77-
torch._C._jit_set_nvfuser_enabled(True)
78-
else:
79-
raise Exception("unrecognized fuser option")
80-
try:
81-
yield
82-
finally:
83-
if name == 'fuser1': # NNC
84-
torch._C._jit_set_profiling_executor(old_profiling_executor)
85-
torch._C._jit_set_profiling_mode(old_profiling_mode)
86-
# recover the previous values
87-
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
88-
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
89-
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
90-
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
9150

9251
def export_opnames(m):
9352
r"""
9453
Returns a list of operator names of a script module and its submodules
9554
"""
9655
return torch._C._export_opnames(m._c)
9756

98-
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
99-
return_inputs=False, _return_inputs_states=False):
100-
"""
101-
.. warning::
102-
This function is internal-only and should only be used by the ONNX
103-
exporter. If you are trying to get a graph through tracing, please go
104-
through the public API instead::
105-
106-
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
107-
trace_graph = trace.graph
108-
109-
Trace a function or model, returning a tuple consisting of the both the
110-
*trace* of an execution, as well as the original return value. If return_inputs,
111-
also returns the trace inputs as part of the tuple
112-
113-
Tracing is guaranteed not to change the semantics of the function/module
114-
that is traced.
115-
116-
Arguments:
117-
f (torch.nn.Module or function): the function or module
118-
to be traced.
119-
args (tuple or Tensor): the positional arguments to pass to the
120-
function/module to be traced. A non-tuple is assumed to
121-
be a single positional argument to be passed to the model.
122-
kwargs (dict): the keyword arguments to pass to the function/module
123-
to be traced.
124-
125-
Example (trace a cell):
126-
127-
.. testcode::
128-
129-
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
130-
"""
131-
if kwargs is None:
132-
kwargs = {}
133-
if not isinstance(args, tuple):
134-
args = (args,)
135-
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
136-
return outs
137-
138-
139-
def freeze(mod, preserved_attrs : Optional[List[str]] = None):
140-
r"""
141-
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
142-
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
143-
By default, `forward` will be preserved, as well as attributes & methods specified in
144-
`preserved_attrs`. Additionally, any attribute that is modified within a preserved
145-
method will be preserved.
146-
147-
Freezing currently only accepts ScriptModules that are in eval mode.
148-
149-
Arguments:
150-
mod (:class:`ScriptModule`): a module to be frozen
151-
152-
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
153-
Attributes modified in preserved methods will also be preserved.
154-
155-
Returns:
156-
Frozen :class:`ScriptModule`.
157-
158-
Example (Freezing a simple module with a Parameter):
159-
160-
.. testcode::
161-
import torch
162-
class MyModule(torch.nn.Module):
163-
def __init__(self, N, M):
164-
super(MyModule, self).__init__()
165-
self.weight = torch.nn.Parameter(torch.rand(N, M))
166-
self.linear = torch.nn.Linear(N, M)
167-
168-
def forward(self, input):
169-
output = self.weight.mm(input)
170-
output = self.linear(output)
171-
return output
172-
173-
scripted_module = torch.jit.script(MyModule(2, 3).eval())
174-
frozen_module = torch.jit.freeze(scripted_module)
175-
# parameters have been removed and inlined into the Graph as constants
176-
assert len(list(frozen_module.named_parameters())) == 0
177-
# See the compiled graph as Python code
178-
print(frozen_module.code)
179-
180-
Example (Freezing a module with preserved attributes)
181-
182-
.. testcode::
183-
import torch
184-
class MyModule2(torch.nn.Module):
185-
def __init__(self):
186-
super(MyModule2, self).__init__()
187-
self.modified_tensor = torch.tensor(10.)
188-
self.version = 1
189-
190-
def forward(self, input):
191-
self.modified_tensor += 1
192-
return input + self.modified_tensor
193-
194-
scripted_module = torch.jit.script(MyModule2().eval())
195-
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
196-
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
197-
assert frozen_module.version == 1
198-
frozen_module.version = 2
199-
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
200-
# it to retain model semantics
201-
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
202-
# now that we've run it once, the next result will be incremented by one
203-
assert frozen_module(torch.tensor(1)) == torch.tensor(13)
204-
205-
Note:
206-
If you're not sure why an attribute is not being inlined as a constant, you can run
207-
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
208-
attribute is being modified.
209-
"""
210-
if not isinstance(mod, ScriptModule):
211-
raise RuntimeError("Freezing expects a ScriptModule as input. "
212-
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
213-
214-
if mod.training:
215-
raise RuntimeError("Freezing is currently only implemented for modules in eval mode. "
216-
"Please call .eval() on your module before freezing.")
217-
218-
preserved_attrs = preserved_attrs if preserved_attrs is not None else []
219-
220-
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
221-
RecursiveScriptModule._finalize_scriptmodule(out)
222-
223-
return out
224-
225-
226-
class CompilationUnit(object):
227-
def __init__(self, lang=None, _frames_up=0):
228-
self._c = torch._C.CompilationUnit()
229-
if lang is not None:
230-
self.define(lang, _frames_up=_frames_up + 1)
231-
232-
def define(self, lang, rcb=None, _frames_up=0):
233-
if not rcb:
234-
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
235-
self._c.define(lang, rcb)
236-
237-
def __getattr__(self, attr):
238-
r = self._c.find_function(attr)
239-
if r is None:
240-
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
241-
return r
242-
243-
244-
def _try_get_dispatched_fn(fn):
245-
if not callable(fn):
246-
return None
247-
return _jit_internal.boolean_dispatched.get(fn)
248-
249-
250-
def _try_get_overloaded_fn(mod, field):
251-
return mod._overloads.get(field, None) if isinstance(mod, ScriptModule) else None
252-
253-
254-
@contextlib.contextmanager
255-
def _disable_emit_hooks():
256-
hooks = torch._C._jit_get_emit_hooks()
257-
torch._C._jit_set_emit_hooks(None, None)
258-
yield
259-
torch._C._jit_set_emit_hooks(hooks[0], hooks[1])
260-
261-
262-
def _disable_emit_hooks_decorator(_DecoratorContextManager): # noqa: F811
263-
def __enter__(self):
264-
self.hooks = torch._C._jit_get_emit_hooks()
265-
torch._C._jit_set_emit_hooks(None, None)
266-
267-
def __exit__(self, *args):
268-
torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1])
269-
270-
271-
def _script_if_tracing(fn):
272-
"""
273-
Compiles ``fn`` when it is first called during tracing. ``torch.jit.script``
274-
has a non-negligible start up time when it is first called due to
275-
lazy-initializations of many compiler builtins. Therefore you should not use
276-
it in library code. However, you may want to have parts of your library work
277-
in tracing even if they use control flow. In these cases, you should use
278-
``@torch.jit._script_if_tracing`` to substitute for
279-
``torch.jit.script``.
280-
"""
281-
@functools.wraps(fn)
282-
def wrapper(*args, **kwargs):
283-
if not is_tracing():
284-
# Not tracing, don't do anything
285-
return fn(*args, **kwargs)
286-
287-
compiled_fn = script(wrapper.__original_fn)
288-
return compiled_fn(*args, **kwargs)
289-
290-
wrapper.__original_fn = fn
291-
wrapper.__script_if_tracing_wrapper = True
292-
293-
return wrapper
294-
295-
def _unwrap_optional(x):
296-
assert x is not None, "Unwrapping null optional"
297-
return x
298-
299-
_register_builtin(_unwrap_optional, 'aten::_unwrap_optional')
300-
_register_builtin(_wait, 'aten::wait')
301-
_register_builtin(wait, 'aten::wait')
302-
_register_builtin(is_scripting, 'aten::is_scripting')
303-
30457

30558
# torch.jit.Error
30659
Error = torch._C.JITException
@@ -309,53 +62,11 @@ def _unwrap_optional(x):
30962
Error.__name__ = "Error"
31063
Error.__qualname__ = "Error"
31164

312-
def _get_named_tuple_properties(obj):
313-
assert issubclass(obj, tuple) and hasattr(obj, '_fields')
314-
fields = list(obj._fields)
315-
annotations = []
316-
has_annotations = hasattr(obj, '__annotations__')
317-
for field in fields:
318-
if has_annotations and field in obj.__annotations__:
319-
the_type = torch.jit.annotations.ann_to_type(obj.__annotations__[field], _jit_internal.fake_range())
320-
annotations.append(the_type)
321-
else:
322-
annotations.append(torch._C.TensorType.get())
323-
return type(obj).__name__, fields, annotations
324-
325-
def _create_named_tuple(t, unqual_name, field_names):
326-
TupleType = collections.namedtuple(unqual_name, field_names)
327-
return TupleType(*t)
328-
329-
class _disable_tracing(object):
330-
def __enter__(self):
331-
self.state = torch._C._get_tracing_state()
332-
torch._C._set_tracing_state(None)
333-
334-
def __exit__(self, *args):
335-
torch._C._set_tracing_state(self.state)
336-
self.state = None
337-
338-
33965
# for use in python if using annotate
34066
def annotate(the_type, the_value):
34167
# noop in python
34268
return the_value
34369

344-
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
345-
346-
347-
def _graph_for(self, *args, **kwargs):
348-
self(*args, **kwargs)
349-
return last_executed_optimized_graph()
350-
351-
torch._C.ScriptMethod.graph_for = _graph_for
352-
torch._C.ScriptFunction.graph_for = _graph_for
353-
ScriptFunction = torch._C.ScriptFunction
354-
ScriptFunction.__doc__ = """
355-
Functionally equivalent to a :class:`ScriptModule`, but represents a single
356-
function and does not have any attributes or Parameters.
357-
"""
358-
set_module(ScriptFunction, "torch.jit")
35970

36071
if not torch._C._jit_init():
36172
raise RuntimeError("JIT initialization failed")

‎torch/jit/_async.py

+9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99

1010
import torch
1111

12+
from torch.utils import set_module
13+
from torch.jit._builtins import _register_builtin
14+
from torch._jit_internal import Future
15+
16+
set_module(Future, "torch.jit")
17+
1218

1319
def fork(func, *args, **kwargs):
1420
"""
@@ -84,3 +90,6 @@ def wait(future):
8490
`T`: the return value of the the completed task
8591
"""
8692
return torch._C.wait(future)
93+
94+
95+
_register_builtin(wait, "aten::wait")

‎torch/jit/_freeze.py

+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Freezing
2+
3+
This is not intended to be imported directly; please use the exposed
4+
functionalities in `torch.jit`.
5+
"""
6+
7+
from typing import Optional, List
8+
9+
import torch
10+
from torch.jit._script import RecursiveScriptModule, ScriptModule
11+
12+
13+
def freeze(mod, preserved_attrs: Optional[List[str]] = None):
14+
r"""
15+
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
16+
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
17+
By default, `forward` will be preserved, as well as attributes & methods specified in
18+
`preserved_attrs`. Additionally, any attribute that is modified within a preserved
19+
method will be preserved.
20+
21+
Freezing currently only accepts ScriptModules that are in eval mode.
22+
23+
Arguments:
24+
mod (:class:`ScriptModule`): a module to be frozen
25+
26+
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
27+
Attributes modified in preserved methods will also be preserved.
28+
29+
Returns:
30+
Frozen :class:`ScriptModule`.
31+
32+
Example (Freezing a simple module with a Parameter):
33+
34+
.. testcode::
35+
import torch
36+
class MyModule(torch.nn.Module):
37+
def __init__(self, N, M):
38+
super(MyModule, self).__init__()
39+
self.weight = torch.nn.Parameter(torch.rand(N, M))
40+
self.linear = torch.nn.Linear(N, M)
41+
42+
def forward(self, input):
43+
output = self.weight.mm(input)
44+
output = self.linear(output)
45+
return output
46+
47+
scripted_module = torch.jit.script(MyModule(2, 3).eval())
48+
frozen_module = torch.jit.freeze(scripted_module)
49+
# parameters have been removed and inlined into the Graph as constants
50+
assert len(list(frozen_module.named_parameters())) == 0
51+
# See the compiled graph as Python code
52+
print(frozen_module.code)
53+
54+
Example (Freezing a module with preserved attributes)
55+
56+
.. testcode::
57+
import torch
58+
class MyModule2(torch.nn.Module):
59+
def __init__(self):
60+
super(MyModule2, self).__init__()
61+
self.modified_tensor = torch.tensor(10.)
62+
self.version = 1
63+
64+
def forward(self, input):
65+
self.modified_tensor += 1
66+
return input + self.modified_tensor
67+
68+
scripted_module = torch.jit.script(MyModule2().eval())
69+
frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
70+
# we've manually preserved `version`, so it still exists on the frozen module and can be modified
71+
assert frozen_module.version == 1
72+
frozen_module.version = 2
73+
# `modified_tensor` is detected as being mutated in the forward, so freezing preserves
74+
# it to retain model semantics
75+
assert frozen_module(torch.tensor(1)) == torch.tensor(12)
76+
# now that we've run it once, the next result will be incremented by one
77+
assert frozen_module(torch.tensor(1)) == torch.tensor(13)
78+
79+
Note:
80+
If you're not sure why an attribute is not being inlined as a constant, you can run
81+
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
82+
attribute is being modified.
83+
"""
84+
if not isinstance(mod, ScriptModule):
85+
raise RuntimeError(
86+
"Freezing expects a ScriptModule as input. "
87+
"Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'."
88+
)
89+
90+
if mod.training:
91+
raise RuntimeError(
92+
"Freezing is currently only implemented for modules in eval mode. "
93+
"Please call .eval() on your module before freezing."
94+
)
95+
96+
preserved_attrs = preserved_attrs if preserved_attrs is not None else []
97+
98+
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
99+
RecursiveScriptModule._finalize_scriptmodule(out)
100+
101+
return out

‎torch/jit/_fuser.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import contextlib
2+
3+
import torch
4+
5+
@contextlib.contextmanager
6+
def optimized_execution(should_optimize):
7+
"""
8+
A context manager that controls whether the JIT's executor will run
9+
optimizations before executing a function.
10+
"""
11+
stored_flag = torch._C._get_graph_executor_optimize()
12+
torch._C._set_graph_executor_optimize(should_optimize)
13+
try:
14+
yield
15+
finally:
16+
torch._C._set_graph_executor_optimize(stored_flag)
17+
18+
@contextlib.contextmanager
19+
def fuser(name):
20+
"""
21+
A context manager that facilitates switching between
22+
backend fusers.
23+
24+
Valid names:
25+
* ``fuser0`` - enables only legacy fuser
26+
* ``fuser1`` - enables only NNC
27+
* ``fuser2`` - enables only nvFuser
28+
"""
29+
old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
30+
old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
31+
old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
32+
old_nvfuser_state = torch._C._jit_nvfuser_enabled()
33+
if name == 'fuser0': # legacy fuser
34+
torch._C._jit_override_can_fuse_on_cpu(True)
35+
torch._C._jit_override_can_fuse_on_gpu(True)
36+
torch._C._jit_set_texpr_fuser_enabled(False)
37+
torch._C._jit_set_nvfuser_enabled(False)
38+
elif name == 'fuser1': # NNC
39+
old_profiling_executor = torch._C._jit_set_profiling_executor(True)
40+
old_profiling_mode = torch._C._jit_set_profiling_mode(True)
41+
torch._C._jit_override_can_fuse_on_cpu(False)
42+
torch._C._jit_override_can_fuse_on_gpu(False)
43+
torch._C._jit_set_texpr_fuser_enabled(True)
44+
torch._C._jit_set_nvfuser_enabled(False)
45+
elif name == 'fuser2': # nvFuser
46+
torch._C._jit_override_can_fuse_on_cpu(False)
47+
torch._C._jit_override_can_fuse_on_gpu(False)
48+
torch._C._jit_set_texpr_fuser_enabled(False)
49+
torch._C._jit_set_nvfuser_enabled(True)
50+
else:
51+
raise Exception("unrecognized fuser option")
52+
try:
53+
yield
54+
finally:
55+
if name == 'fuser1': # NNC
56+
torch._C._jit_set_profiling_executor(old_profiling_executor)
57+
torch._C._jit_set_profiling_mode(old_profiling_mode)
58+
# recover the previous values
59+
torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
60+
torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
61+
torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
62+
torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
63+
64+
65+
last_executed_optimized_graph = torch._C._last_executed_optimized_graph
66+
67+
68+
def _graph_for(self, *args, **kwargs):
69+
self(*args, **kwargs)
70+
return last_executed_optimized_graph()

‎torch/jit/_recursive.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def compile_unbound_method(concrete_type, fn):
609609
if _jit_internal.is_ignored_fn(fn):
610610
return None
611611
stub = make_stub(fn, fn.__name__)
612-
with torch.jit._disable_emit_hooks():
612+
with torch._jit_internal._disable_emit_hooks():
613613
# We don't want to call the hooks here since the graph that is calling
614614
# this function is not yet complete
615615
create_methods_from_stubs(concrete_type, (stub,))

‎torch/jit/_script.py

+42
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,32 @@
1515

1616
import torch
1717
import torch._jit_internal as _jit_internal
18+
from torch.utils import set_module
1819
from torch.jit._recursive import ScriptMethodStub, wrap_cpp_module
1920
from torch.nn import Module
2021
from torch.jit._state import _enabled
22+
from torch.jit._builtins import _register_builtin
2123
from torch._six import with_metaclass, get_function_from_type
2224
from torch.jit.frontend import get_jit_def, get_default_args, get_jit_class_def
2325
from torch._jit_internal import _qualified_name
26+
from torch.jit._fuser import _graph_for
2427
from torch.jit._state import (
2528
_try_get_jit_cached_function,
2629
_try_get_jit_cached_overloads,
2730
_set_jit_function_cache,
2831
_set_jit_overload_cache,
2932
)
3033

34+
torch._C.ScriptMethod.graph_for = _graph_for
35+
torch._C.ScriptFunction.graph_for = _graph_for
36+
ScriptFunction = torch._C.ScriptFunction
37+
ScriptFunction.__doc__ = """
38+
Functionally equivalent to a :class:`ScriptModule`, but represents a single
39+
function and does not have any attributes or Parameters.
40+
"""
41+
set_module(ScriptFunction, "torch.jit")
42+
43+
3144
if _enabled:
3245
Attribute = collections.namedtuple("Attribute", ["value", "type"])
3346
else:
@@ -1053,3 +1066,32 @@ def _recursive_compile_class(obj, loc):
10531066
error_stack = torch._C.CallStack(_qual_name, loc)
10541067
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
10551068
_compile_and_register_class(obj, rcb, _qual_name)
1069+
1070+
1071+
_register_builtin(is_scripting, "aten::is_scripting")
1072+
1073+
1074+
class CompilationUnit(object):
1075+
def __init__(self, lang=None, _frames_up=0):
1076+
self._c = torch._C.CompilationUnit()
1077+
if lang is not None:
1078+
self.define(lang, _frames_up=_frames_up + 1)
1079+
1080+
def define(self, lang, rcb=None, _frames_up=0):
1081+
if not rcb:
1082+
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
1083+
self._c.define(lang, rcb)
1084+
1085+
def __getattr__(self, attr):
1086+
r = self._c.find_function(attr)
1087+
if r is None:
1088+
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
1089+
return r
1090+
1091+
1092+
def _unwrap_optional(x):
1093+
assert x is not None, "Unwrapping null optional"
1094+
return x
1095+
1096+
1097+
_register_builtin(_unwrap_optional, "aten::_unwrap_optional")

‎torch/jit/_trace.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111

1212
import os
1313
import contextlib
14+
import functools
1415
import warnings
1516
import inspect
1617
import re
1718

1819
from torch.jit._state import _python_cu, _enabled
19-
from torch.jit._script import ScriptModule, _CachedForward
20+
from torch.jit._script import ScriptModule, _CachedForward, script
2021
from torch._jit_internal import _qualified_name
2122
from torch.autograd import function
2223
from torch import _jit_internal
@@ -1077,3 +1078,70 @@ def _reconstruct(self, cpp_module):
10771078
cpp_module: The C++ module that this TopLevelTracedModule will be rebuilt around.
10781079
"""
10791080
self.__dict__["_actual_script_module"]._reconstruct(cpp_module)
1081+
1082+
1083+
def _script_if_tracing(fn):
1084+
"""
1085+
Compiles ``fn`` when it is first called during tracing. ``torch.jit.script``
1086+
has a non-negligible start up time when it is first called due to
1087+
lazy-initializations of many compiler builtins. Therefore you should not use
1088+
it in library code. However, you may want to have parts of your library work
1089+
in tracing even if they use control flow. In these cases, you should use
1090+
``@torch.jit._script_if_tracing`` to substitute for
1091+
``torch.jit.script``.
1092+
"""
1093+
1094+
@functools.wraps(fn)
1095+
def wrapper(*args, **kwargs):
1096+
if not is_tracing():
1097+
# Not tracing, don't do anything
1098+
return fn(*args, **kwargs)
1099+
1100+
compiled_fn = script(wrapper.__original_fn)
1101+
return compiled_fn(*args, **kwargs)
1102+
1103+
wrapper.__original_fn = fn
1104+
wrapper.__script_if_tracing_wrapper = True
1105+
1106+
return wrapper
1107+
1108+
1109+
def _get_trace_graph(f, args=(), kwargs=None, strict=True, _force_outplace=False,
1110+
return_inputs=False, _return_inputs_states=False):
1111+
"""
1112+
.. warning::
1113+
This function is internal-only and should only be used by the ONNX
1114+
exporter. If you are trying to get a graph through tracing, please go
1115+
through the public API instead::
1116+
1117+
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
1118+
trace_graph = trace.graph
1119+
1120+
Trace a function or model, returning a tuple consisting of the both the
1121+
*trace* of an execution, as well as the original return value. If return_inputs,
1122+
also returns the trace inputs as part of the tuple
1123+
1124+
Tracing is guaranteed not to change the semantics of the function/module
1125+
that is traced.
1126+
1127+
Arguments:
1128+
f (torch.nn.Module or function): the function or module
1129+
to be traced.
1130+
args (tuple or Tensor): the positional arguments to pass to the
1131+
function/module to be traced. A non-tuple is assumed to
1132+
be a single positional argument to be passed to the model.
1133+
kwargs (dict): the keyword arguments to pass to the function/module
1134+
to be traced.
1135+
1136+
Example (trace a cell):
1137+
1138+
.. testcode::
1139+
1140+
trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
1141+
"""
1142+
if kwargs is None:
1143+
kwargs = {}
1144+
if not isinstance(args, tuple):
1145+
args = (args,)
1146+
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
1147+
return outs

‎torch/jit/supported_ops.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch.jit
2+
from torch.jit._builtins import _find_builtin
23
import inspect
34
import textwrap
45
# this file is for generating documentation using sphinx autodoc
@@ -92,7 +93,7 @@ def _get_nn_functional_ops():
9293
for mod in torch.jit._builtins._modules_containing_builtins:
9394
name = mod.__name__
9495
for elem in dir(mod):
95-
builtin = torch.jit._find_builtin(getattr(mod, elem))
96+
builtin = _find_builtin(getattr(mod, elem))
9697
if builtin is not None:
9798
schemas = torch._C._jit_get_schemas_for_operator(builtin)
9899
for schema in schemas:
@@ -133,7 +134,7 @@ def _get_torchscript_builtins():
133134
# Iterate over the specially added builtins
134135
for fn, _builtin_name in builtins:
135136
mod = inspect.getmodule(fn)
136-
builtin = torch.jit._find_builtin(fn)
137+
builtin = _find_builtin(fn)
137138
if builtin is not None:
138139
schemas = torch._C._jit_get_schemas_for_operator(builtin)
139140
for schema in schemas:
@@ -150,7 +151,7 @@ def _get_math_builtins():
150151
# Iterate over the specially added builtins
151152
for fn, _builtin_name in builtins:
152153
mod = inspect.getmodule(fn)
153-
builtin = torch.jit._find_builtin(fn)
154+
builtin = _find_builtin(fn)
154155
if builtin is not None:
155156
schemas = torch._C._jit_get_schemas_for_operator(builtin)
156157
for schema in schemas:

‎torch/testing/_internal/distributed/rpc/rpc_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ def validate_profiling_workload(self, dst, prof):
12221222
events = prof.function_events
12231223

12241224
rpc_mul_event = get_function_event(
1225-
events, torch.jit._find_builtin(torch.mul)
1225+
events, torch.jit._builtins._find_builtin(torch.mul)
12261226
)
12271227

12281228
remote_events = {

‎torch/testing/_internal/jit_metaprogramming_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name
358358

359359
f_args_variable = (self_variable,) + args_variable
360360
f_args_tensor = (self_tensor,) + args_tensor
361-
with torch.jit._disable_emit_hooks():
361+
with torch._jit_internal._disable_emit_hooks():
362362
script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable)
363363
return script_fn, inputs
364364

‎torch/testing/_internal/jit_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def extract_files(buffer):
159159
return code_files, debug_files
160160

161161
# disable the hook while we parse code, otherwise we will re-enter the hook
162-
with torch.jit._disable_emit_hooks():
162+
with torch._jit_internal._disable_emit_hooks():
163163
try:
164164
# short-circuit if this is an empty function or module
165165
if len(m.code) == 0:

0 commit comments

Comments
 (0)
Please sign in to comment.