1
1
import torch ._C
2
- import torch ._jit_internal as _jit_internal
3
2
4
- from torch .jit ._builtins import _find_builtin , _get_builtin_table , _register_builtin # noqa
5
- from torch ._jit_internal import Future
6
- from torch .nn import Module
7
3
from torch .utils import set_module
8
- from torch .autograd .grad_mode import _DecoratorContextManager
9
- from typing import Optional , List
10
-
11
- import collections
12
- import contextlib
13
- import functools
14
- import os
15
- import pathlib
16
4
17
5
# These are imported so users can access them from the `torch.jit` module
18
- from torch ._jit_internal import Final , _overload , _overload_method
19
- from torch ._jit_internal import ignore , export , unused
20
- from torch .jit ._script import script , Attribute , ScriptModule , is_scripting , script_method , \
21
- RecursiveScriptModule , ScriptWarning , interface
22
- from torch .jit ._trace import trace , trace_module , TracedModule , TracerWarning , TracingCheckError , \
23
- is_tracing , ONNXTracedModule , _unique_state_dict , _flatten , TopLevelTracedModule
6
+ from torch ._jit_internal import (
7
+ Final ,
8
+ Future ,
9
+ _overload ,
10
+ _overload_method ,
11
+ ignore ,
12
+ export ,
13
+ unused ,
14
+ )
15
+ from torch .jit ._script import (
16
+ script ,
17
+ Attribute ,
18
+ ScriptModule ,
19
+ is_scripting ,
20
+ script_method ,
21
+ RecursiveScriptModule ,
22
+ ScriptWarning ,
23
+ interface ,
24
+ CompilationUnit ,
25
+ ScriptFunction ,
26
+ _unwrap_optional ,
27
+ )
28
+ from torch .jit ._trace import (
29
+ trace ,
30
+ trace_module ,
31
+ TracedModule ,
32
+ TracerWarning ,
33
+ TracingCheckError ,
34
+ is_tracing ,
35
+ ONNXTracedModule ,
36
+ TopLevelTracedModule ,
37
+ _unique_state_dict ,
38
+ _flatten ,
39
+ _script_if_tracing ,
40
+ _get_trace_graph ,
41
+ )
24
42
from torch .jit ._async import fork , wait
25
43
from torch .jit ._serialization import save , load
26
-
27
- set_module (Future , "torch.jit" )
44
+ from torch .jit ._fuser import optimized_execution , fuser , last_executed_optimized_graph
28
45
29
46
# For backwards compatibility
30
47
_fork = fork
31
48
_wait = wait
32
49
33
- @contextlib .contextmanager
34
- def optimized_execution (should_optimize ):
35
- """
36
- A context manager that controls whether the JIT's executor will run
37
- optimizations before executing a function.
38
- """
39
- stored_flag = torch ._C ._get_graph_executor_optimize ()
40
- torch ._C ._set_graph_executor_optimize (should_optimize )
41
- try :
42
- yield
43
- finally :
44
- torch ._C ._set_graph_executor_optimize (stored_flag )
45
-
46
- @contextlib .contextmanager
47
- def fuser (name ):
48
- """
49
- A context manager that facilitates switching between
50
- backend fusers.
51
-
52
- Valid names:
53
- * ``fuser0`` - enables only legacy fuser
54
- * ``fuser1`` - enables only NNC
55
- * ``fuser2`` - enables only nvFuser
56
- """
57
- old_cpu_fuse = torch ._C ._jit_can_fuse_on_cpu ()
58
- old_gpu_fuse = torch ._C ._jit_can_fuse_on_gpu ()
59
- old_texpr_fuser_state = torch ._C ._jit_texpr_fuser_enabled ()
60
- old_nvfuser_state = torch ._C ._jit_nvfuser_enabled ()
61
- if name == 'fuser0' : # legacy fuser
62
- torch ._C ._jit_override_can_fuse_on_cpu (True )
63
- torch ._C ._jit_override_can_fuse_on_gpu (True )
64
- torch ._C ._jit_set_texpr_fuser_enabled (False )
65
- torch ._C ._jit_set_nvfuser_enabled (False )
66
- elif name == 'fuser1' : # NNC
67
- old_profiling_executor = torch ._C ._jit_set_profiling_executor (True )
68
- old_profiling_mode = torch ._C ._jit_set_profiling_mode (True )
69
- torch ._C ._jit_override_can_fuse_on_cpu (False )
70
- torch ._C ._jit_override_can_fuse_on_gpu (False )
71
- torch ._C ._jit_set_texpr_fuser_enabled (True )
72
- torch ._C ._jit_set_nvfuser_enabled (False )
73
- elif name == 'fuser2' : # nvFuser
74
- torch ._C ._jit_override_can_fuse_on_cpu (False )
75
- torch ._C ._jit_override_can_fuse_on_gpu (False )
76
- torch ._C ._jit_set_texpr_fuser_enabled (False )
77
- torch ._C ._jit_set_nvfuser_enabled (True )
78
- else :
79
- raise Exception ("unrecognized fuser option" )
80
- try :
81
- yield
82
- finally :
83
- if name == 'fuser1' : # NNC
84
- torch ._C ._jit_set_profiling_executor (old_profiling_executor )
85
- torch ._C ._jit_set_profiling_mode (old_profiling_mode )
86
- # recover the previous values
87
- torch ._C ._jit_override_can_fuse_on_cpu (old_cpu_fuse )
88
- torch ._C ._jit_override_can_fuse_on_gpu (old_gpu_fuse )
89
- torch ._C ._jit_set_texpr_fuser_enabled (old_texpr_fuser_state )
90
- torch ._C ._jit_set_nvfuser_enabled (old_nvfuser_state )
91
50
92
51
def export_opnames (m ):
93
52
r"""
94
53
Returns a list of operator names of a script module and its submodules
95
54
"""
96
55
return torch ._C ._export_opnames (m ._c )
97
56
98
- def _get_trace_graph (f , args = (), kwargs = None , strict = True , _force_outplace = False ,
99
- return_inputs = False , _return_inputs_states = False ):
100
- """
101
- .. warning::
102
- This function is internal-only and should only be used by the ONNX
103
- exporter. If you are trying to get a graph through tracing, please go
104
- through the public API instead::
105
-
106
- trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
107
- trace_graph = trace.graph
108
-
109
- Trace a function or model, returning a tuple consisting of the both the
110
- *trace* of an execution, as well as the original return value. If return_inputs,
111
- also returns the trace inputs as part of the tuple
112
-
113
- Tracing is guaranteed not to change the semantics of the function/module
114
- that is traced.
115
-
116
- Arguments:
117
- f (torch.nn.Module or function): the function or module
118
- to be traced.
119
- args (tuple or Tensor): the positional arguments to pass to the
120
- function/module to be traced. A non-tuple is assumed to
121
- be a single positional argument to be passed to the model.
122
- kwargs (dict): the keyword arguments to pass to the function/module
123
- to be traced.
124
-
125
- Example (trace a cell):
126
-
127
- .. testcode::
128
-
129
- trace = torch.jit.trace(nn.LSTMCell(), (input, hidden))
130
- """
131
- if kwargs is None :
132
- kwargs = {}
133
- if not isinstance (args , tuple ):
134
- args = (args ,)
135
- outs = ONNXTracedModule (f , strict , _force_outplace , return_inputs , _return_inputs_states )(* args , ** kwargs )
136
- return outs
137
-
138
-
139
- def freeze (mod , preserved_attrs : Optional [List [str ]] = None ):
140
- r"""
141
- Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
142
- module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
143
- By default, `forward` will be preserved, as well as attributes & methods specified in
144
- `preserved_attrs`. Additionally, any attribute that is modified within a preserved
145
- method will be preserved.
146
-
147
- Freezing currently only accepts ScriptModules that are in eval mode.
148
-
149
- Arguments:
150
- mod (:class:`ScriptModule`): a module to be frozen
151
-
152
- preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
153
- Attributes modified in preserved methods will also be preserved.
154
-
155
- Returns:
156
- Frozen :class:`ScriptModule`.
157
-
158
- Example (Freezing a simple module with a Parameter):
159
-
160
- .. testcode::
161
- import torch
162
- class MyModule(torch.nn.Module):
163
- def __init__(self, N, M):
164
- super(MyModule, self).__init__()
165
- self.weight = torch.nn.Parameter(torch.rand(N, M))
166
- self.linear = torch.nn.Linear(N, M)
167
-
168
- def forward(self, input):
169
- output = self.weight.mm(input)
170
- output = self.linear(output)
171
- return output
172
-
173
- scripted_module = torch.jit.script(MyModule(2, 3).eval())
174
- frozen_module = torch.jit.freeze(scripted_module)
175
- # parameters have been removed and inlined into the Graph as constants
176
- assert len(list(frozen_module.named_parameters())) == 0
177
- # See the compiled graph as Python code
178
- print(frozen_module.code)
179
-
180
- Example (Freezing a module with preserved attributes)
181
-
182
- .. testcode::
183
- import torch
184
- class MyModule2(torch.nn.Module):
185
- def __init__(self):
186
- super(MyModule2, self).__init__()
187
- self.modified_tensor = torch.tensor(10.)
188
- self.version = 1
189
-
190
- def forward(self, input):
191
- self.modified_tensor += 1
192
- return input + self.modified_tensor
193
-
194
- scripted_module = torch.jit.script(MyModule2().eval())
195
- frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"])
196
- # we've manually preserved `version`, so it still exists on the frozen module and can be modified
197
- assert frozen_module.version == 1
198
- frozen_module.version = 2
199
- # `modified_tensor` is detected as being mutated in the forward, so freezing preserves
200
- # it to retain model semantics
201
- assert frozen_module(torch.tensor(1)) == torch.tensor(12)
202
- # now that we've run it once, the next result will be incremented by one
203
- assert frozen_module(torch.tensor(1)) == torch.tensor(13)
204
-
205
- Note:
206
- If you're not sure why an attribute is not being inlined as a constant, you can run
207
- `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
208
- attribute is being modified.
209
- """
210
- if not isinstance (mod , ScriptModule ):
211
- raise RuntimeError ("Freezing expects a ScriptModule as input. "
212
- "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." )
213
-
214
- if mod .training :
215
- raise RuntimeError ("Freezing is currently only implemented for modules in eval mode. "
216
- "Please call .eval() on your module before freezing." )
217
-
218
- preserved_attrs = preserved_attrs if preserved_attrs is not None else []
219
-
220
- out = RecursiveScriptModule (torch ._C ._freeze_module (mod ._c , preserved_attrs ))
221
- RecursiveScriptModule ._finalize_scriptmodule (out )
222
-
223
- return out
224
-
225
-
226
- class CompilationUnit (object ):
227
- def __init__ (self , lang = None , _frames_up = 0 ):
228
- self ._c = torch ._C .CompilationUnit ()
229
- if lang is not None :
230
- self .define (lang , _frames_up = _frames_up + 1 )
231
-
232
- def define (self , lang , rcb = None , _frames_up = 0 ):
233
- if not rcb :
234
- rcb = _jit_internal .createResolutionCallbackFromFrame (_frames_up + 1 )
235
- self ._c .define (lang , rcb )
236
-
237
- def __getattr__ (self , attr ):
238
- r = self ._c .find_function (attr )
239
- if r is None :
240
- raise AttributeError ("'CompilationUnit' has no attribute '{}'" .format (attr ))
241
- return r
242
-
243
-
244
- def _try_get_dispatched_fn (fn ):
245
- if not callable (fn ):
246
- return None
247
- return _jit_internal .boolean_dispatched .get (fn )
248
-
249
-
250
- def _try_get_overloaded_fn (mod , field ):
251
- return mod ._overloads .get (field , None ) if isinstance (mod , ScriptModule ) else None
252
-
253
-
254
- @contextlib .contextmanager
255
- def _disable_emit_hooks ():
256
- hooks = torch ._C ._jit_get_emit_hooks ()
257
- torch ._C ._jit_set_emit_hooks (None , None )
258
- yield
259
- torch ._C ._jit_set_emit_hooks (hooks [0 ], hooks [1 ])
260
-
261
-
262
- def _disable_emit_hooks_decorator (_DecoratorContextManager ): # noqa: F811
263
- def __enter__ (self ):
264
- self .hooks = torch ._C ._jit_get_emit_hooks ()
265
- torch ._C ._jit_set_emit_hooks (None , None )
266
-
267
- def __exit__ (self , * args ):
268
- torch ._C ._jit_set_emit_hooks (self .hooks [0 ], self .hooks [1 ])
269
-
270
-
271
- def _script_if_tracing (fn ):
272
- """
273
- Compiles ``fn`` when it is first called during tracing. ``torch.jit.script``
274
- has a non-negligible start up time when it is first called due to
275
- lazy-initializations of many compiler builtins. Therefore you should not use
276
- it in library code. However, you may want to have parts of your library work
277
- in tracing even if they use control flow. In these cases, you should use
278
- ``@torch.jit._script_if_tracing`` to substitute for
279
- ``torch.jit.script``.
280
- """
281
- @functools .wraps (fn )
282
- def wrapper (* args , ** kwargs ):
283
- if not is_tracing ():
284
- # Not tracing, don't do anything
285
- return fn (* args , ** kwargs )
286
-
287
- compiled_fn = script (wrapper .__original_fn )
288
- return compiled_fn (* args , ** kwargs )
289
-
290
- wrapper .__original_fn = fn
291
- wrapper .__script_if_tracing_wrapper = True
292
-
293
- return wrapper
294
-
295
- def _unwrap_optional (x ):
296
- assert x is not None , "Unwrapping null optional"
297
- return x
298
-
299
- _register_builtin (_unwrap_optional , 'aten::_unwrap_optional' )
300
- _register_builtin (_wait , 'aten::wait' )
301
- _register_builtin (wait , 'aten::wait' )
302
- _register_builtin (is_scripting , 'aten::is_scripting' )
303
-
304
57
305
58
# torch.jit.Error
306
59
Error = torch ._C .JITException
@@ -309,53 +62,11 @@ def _unwrap_optional(x):
309
62
Error .__name__ = "Error"
310
63
Error .__qualname__ = "Error"
311
64
312
- def _get_named_tuple_properties (obj ):
313
- assert issubclass (obj , tuple ) and hasattr (obj , '_fields' )
314
- fields = list (obj ._fields )
315
- annotations = []
316
- has_annotations = hasattr (obj , '__annotations__' )
317
- for field in fields :
318
- if has_annotations and field in obj .__annotations__ :
319
- the_type = torch .jit .annotations .ann_to_type (obj .__annotations__ [field ], _jit_internal .fake_range ())
320
- annotations .append (the_type )
321
- else :
322
- annotations .append (torch ._C .TensorType .get ())
323
- return type (obj ).__name__ , fields , annotations
324
-
325
- def _create_named_tuple (t , unqual_name , field_names ):
326
- TupleType = collections .namedtuple (unqual_name , field_names )
327
- return TupleType (* t )
328
-
329
- class _disable_tracing (object ):
330
- def __enter__ (self ):
331
- self .state = torch ._C ._get_tracing_state ()
332
- torch ._C ._set_tracing_state (None )
333
-
334
- def __exit__ (self , * args ):
335
- torch ._C ._set_tracing_state (self .state )
336
- self .state = None
337
-
338
-
339
65
# for use in python if using annotate
340
66
def annotate (the_type , the_value ):
341
67
# noop in python
342
68
return the_value
343
69
344
- last_executed_optimized_graph = torch ._C ._last_executed_optimized_graph
345
-
346
-
347
- def _graph_for (self , * args , ** kwargs ):
348
- self (* args , ** kwargs )
349
- return last_executed_optimized_graph ()
350
-
351
- torch ._C .ScriptMethod .graph_for = _graph_for
352
- torch ._C .ScriptFunction .graph_for = _graph_for
353
- ScriptFunction = torch ._C .ScriptFunction
354
- ScriptFunction .__doc__ = """
355
- Functionally equivalent to a :class:`ScriptModule`, but represents a single
356
- function and does not have any attributes or Parameters.
357
- """
358
- set_module (ScriptFunction , "torch.jit" )
359
70
360
71
if not torch ._C ._jit_init ():
361
72
raise RuntimeError ("JIT initialization failed" )
0 commit comments