Skip to content

Commit cb75add

Browse files
zdevitofacebook-github-bot
authored andcommittedSep 23, 2020
torch.package - a way to package models and code (pytorch#45015)
Summary: Pull Request resolved: pytorch#45015 torch.package allows you to write packages of code, pickled python data, and arbitrary binary and text resources into a self-contained package. torch.package.PackageExporter writes the packages and torch.package.PackageImporter reads them. The importers can load this code in a hermetic way, such that code is loaded from the package rather than the normal python import system. This allows for the packaging of PyTorch model code and data so that it can be run on a server or used in the future for transfer learning. The code contained in packages is copied file-by-file from the original source when it is created, and the file format is a specially organized zip file. Future users of the package can unzip the package, and edit the code in order to perform custom modifications to it. The importer for packages ensures that code in the module can only be loaded from within the package, except for modules explicitly listed as external using :method:`extern_module`. The file `extern_modules` in the zip archive lists all the modules that a package externally depends on. This prevents "implicit" dependencies where the package runs locally because it is importing a locally-installed package, but then fails when the package is copied to another machine. Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D23824337 Pulled By: zdevito fbshipit-source-id: 1247c34ba9b656f9db68a83e31f2a0fbe3bea6bd
1 parent d4a634c commit cb75add

15 files changed

+1439
-3
lines changed
 

‎test/module_a.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
result = 'module_a'

‎test/namespace_b/subpackage.py

Whitespace-only changes.

‎test/package_a/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
result = 'package_a'
2+
3+
class PackageAObject:
4+
__slots__ = ['obj']
5+
6+
def __init__(self, obj):
7+
self.obj = obj

‎test/package_a/subpackage.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
result = 'package_a.subpackage'
2+
class PackageASubpackageObject:
3+
pass

‎test/run_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@
8989
'test_determination',
9090
'test_futures',
9191
'test_fx',
92-
'test_functional_autograd_benchmark'
92+
'test_functional_autograd_benchmark',
93+
'test_package',
9394
]
9495

9596
WINDOWS_BLOCKLIST = [

‎test/test_package.py

+309
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
from unittest import main, skipIf
2+
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS
3+
from tempfile import NamedTemporaryFile
4+
from torch.package import PackageExporter, PackageImporter
5+
from pathlib import Path
6+
from tempfile import TemporaryDirectory
7+
import torch
8+
from sys import version_info
9+
10+
try:
11+
from torchvision.models import resnet18
12+
HAS_TORCHVISION = True
13+
except ImportError:
14+
HAS_TORCHVISION = False
15+
skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
16+
17+
18+
19+
packaging_directory = Path(__file__).parent
20+
21+
class PackagingTest(TestCase):
22+
def __init__(self, *args, **kwargs):
23+
super().__init__(*args, **kwargs)
24+
self._temporary_files = []
25+
26+
def temp(self):
27+
t = NamedTemporaryFile()
28+
name = t.name
29+
if IS_WINDOWS:
30+
t.close() # can't read an open file in windows
31+
else:
32+
self._temporary_files.append(t)
33+
return name
34+
35+
def tearDown(self):
36+
for t in self._temporary_files:
37+
t.close()
38+
self._temporary_files = []
39+
40+
def test_saving_source(self):
41+
filename = self.temp()
42+
with PackageExporter(filename, verbose=False) as he:
43+
he.save_source_file('foo', str(packaging_directory / 'module_a.py'))
44+
he.save_source_file('foodir', str(packaging_directory / 'package_a'))
45+
hi = PackageImporter(filename)
46+
foo = hi.import_module('foo')
47+
s = hi.import_module('foodir.subpackage')
48+
self.assertEqual(foo.result, 'module_a')
49+
self.assertEqual(s.result, 'package_a.subpackage')
50+
51+
def test_saving_string(self):
52+
filename = self.temp()
53+
with PackageExporter(filename, verbose=False) as he:
54+
src = """\
55+
import math
56+
the_math = math
57+
"""
58+
he.save_source_string('my_mod', src)
59+
hi = PackageImporter(filename)
60+
m = hi.import_module('math')
61+
import math
62+
self.assertIs(m, math)
63+
my_mod = hi.import_module('my_mod')
64+
self.assertIs(my_mod.math, math)
65+
66+
def test_save_module(self):
67+
filename = self.temp()
68+
with PackageExporter(filename, verbose=False) as he:
69+
import module_a
70+
import package_a
71+
he.save_module(module_a.__name__)
72+
he.save_module(package_a.__name__)
73+
hi = PackageImporter(filename)
74+
module_a_i = hi.import_module('module_a')
75+
self.assertEqual(module_a_i.result, 'module_a')
76+
self.assertIsNot(module_a, module_a_i)
77+
package_a_i = hi.import_module('package_a')
78+
self.assertEqual(package_a_i.result, 'package_a')
79+
self.assertIsNot(package_a_i, package_a)
80+
81+
def test_pickle(self):
82+
import package_a.subpackage
83+
obj = package_a.subpackage.PackageASubpackageObject()
84+
obj2 = package_a.PackageAObject(obj)
85+
86+
filename = self.temp()
87+
with PackageExporter(filename, verbose=False) as he:
88+
he.save_pickle('obj', 'obj.pkl', obj2)
89+
hi = PackageImporter(filename)
90+
91+
# check we got dependencies
92+
sp = hi.import_module('package_a.subpackage')
93+
# check we didn't get other stuff
94+
with self.assertRaises(ImportError):
95+
hi.import_module('module_a')
96+
97+
obj_loaded = hi.load_pickle('obj', 'obj.pkl')
98+
self.assertIsNot(obj2, obj_loaded)
99+
self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject)
100+
self.assertIsNot(package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject)
101+
102+
def test_resources(self):
103+
filename = self.temp()
104+
with PackageExporter(filename, verbose=False) as he:
105+
he.save_text('main', 'main', "my string")
106+
he.save_binary('main', 'main_binary', "my string".encode('utf-8'))
107+
src = """\
108+
import resources
109+
t = resources.load_text('main', 'main')
110+
b = resources.load_binary('main', 'main_binary')
111+
"""
112+
he.save_source_string('main', src, is_package=True)
113+
hi = PackageImporter(filename)
114+
m = hi.import_module('main')
115+
self.assertEqual(m.t, "my string")
116+
self.assertEqual(m.b, "my string".encode('utf-8'))
117+
118+
def test_extern(self):
119+
filename = self.temp()
120+
with PackageExporter(filename, verbose=False) as he:
121+
he.extern_modules(['package_a.subpackage', 'module_a'])
122+
he.save_module('package_a')
123+
hi = PackageImporter(filename)
124+
import package_a.subpackage
125+
import module_a
126+
127+
module_a_im = hi.import_module('module_a')
128+
hi.import_module('package_a.subpackage')
129+
package_a_im = hi.import_module('package_a')
130+
131+
self.assertIs(module_a, module_a_im)
132+
self.assertIsNot(package_a, package_a_im)
133+
self.assertIs(package_a.subpackage, package_a_im.subpackage)
134+
135+
@skipIf(version_info.major < 3 or version_info.minor < 7, 'mock uses __getattr__ a 3.7 feature')
136+
def test_mock(self):
137+
filename = self.temp()
138+
with PackageExporter(filename, verbose=False) as he:
139+
he.mock_modules(['package_a.subpackage', 'module_a'])
140+
he.save_module('package_a')
141+
hi = PackageImporter(filename)
142+
import package_a.subpackage
143+
_ = package_a.subpackage
144+
import module_a
145+
_ = module_a
146+
147+
m = hi.import_module('package_a.subpackage')
148+
r = m.result
149+
with self.assertRaisesRegex(NotImplementedError, 'was mocked out'):
150+
r()
151+
152+
@skipIf(version_info.major < 3 or version_info.minor < 7, 'mock uses __getattr__ a 3.7 feature')
153+
def test_custom_requires(self):
154+
filename = self.temp()
155+
156+
class Custom(PackageExporter):
157+
def require_module(self, name, dependencies):
158+
if name == 'module_a':
159+
self.mock_module('module_a')
160+
elif name == 'package_a':
161+
self.save_source_string('package_a', 'import module_a\nresult = 5\n')
162+
else:
163+
raise NotImplementedError('wat')
164+
165+
with Custom(filename, verbose=False) as he:
166+
he.save_source_string('main', 'import package_a\n')
167+
168+
hi = PackageImporter(filename)
169+
hi.import_module('module_a').should_be_mocked
170+
bar = hi.import_module('package_a')
171+
self.assertEqual(bar.result, 5)
172+
173+
@skipIfNoTorchVision
174+
def test_resnet(self):
175+
resnet = resnet18()
176+
177+
f1 = self.temp()
178+
179+
# create a package that will save it along with its code
180+
with PackageExporter(f1, verbose=False) as e:
181+
# put the pickled resnet in the package, by default
182+
# this will also save all the code files references by
183+
# the objects in the pickle
184+
e.save_pickle('model', 'model.pkl', resnet)
185+
186+
# we can now load the saved model
187+
i = PackageImporter(f1)
188+
r2 = i.load_pickle('model', 'model.pkl')
189+
190+
# test that it works
191+
input = torch.rand(1, 3, 224, 224)
192+
ref = resnet(input)
193+
self.assertTrue(torch.allclose(r2(input), ref))
194+
195+
# functions exist also to get at the private modules in each package
196+
torchvision = i.import_module('torchvision')
197+
198+
f2 = self.temp()
199+
# if we are doing transfer learning we might want to re-save
200+
# things that were loaded from a package
201+
with PackageExporter(f2, verbose=False) as e:
202+
# We need to tell the exporter about any modules that
203+
# came from imported packages so that it can resolve
204+
# class names like torchvision.models.resnet.ResNet
205+
# to their source code.
206+
207+
e.importers.insert(0, i.import_module)
208+
209+
# e.importers is a list of module importing functions
210+
# that by default contains importlib.import_module.
211+
# it is searched in order until the first success and
212+
# that module is taken to be what torchvision.models.resnet
213+
# should be in this code package. In the case of name collisions,
214+
# such as trying to save a ResNet from two different packages,
215+
# we take the first thing found in the path, so only ResNet objects from
216+
# one importer will work. This avoids a bunch of name mangling in
217+
# the source code. If you need to actually mix ResNet objects,
218+
# we suggest reconstructing the model objects using code from a single package
219+
# using functions like save_state_dict and load_state_dict to transfer state
220+
# to the correct code objects.
221+
e.save_pickle('model', 'model.pkl', r2)
222+
223+
i2 = PackageImporter(f2)
224+
r3 = i2.load_pickle('model', 'model.pkl')
225+
self.assertTrue(torch.allclose(r3(input), ref))
226+
227+
# test we can load from a directory
228+
import zipfile
229+
zf = zipfile.ZipFile(f1, 'r')
230+
231+
with TemporaryDirectory() as td:
232+
zf.extractall(path=td)
233+
iz = PackageImporter(str(Path(td) / Path(f1).name))
234+
r4 = iz.load_pickle('model', 'model.pkl')
235+
self.assertTrue(torch.allclose(r4(input), ref))
236+
237+
@skipIfNoTorchVision
238+
def test_model_save(self):
239+
240+
# This example shows how you might package a model
241+
# so that the creator of the model has flexibility about
242+
# how they want to save it but the 'server' can always
243+
# use the same API to load the package.
244+
245+
# The convension is for each model to provide a
246+
# 'model' package with a 'load' function that actual
247+
# reads the model out of the archive.
248+
249+
# How the load function is implemented is up to the
250+
# the packager.
251+
252+
# get our normal torchvision resnet
253+
resnet = resnet18()
254+
255+
256+
f1 = self.temp()
257+
# Option 1: save by pickling the whole model
258+
# + single-line, similar to torch.jit.save
259+
# - more difficult to edit the code after the model is created
260+
with PackageExporter(f1, verbose=False) as e:
261+
e.save_pickle('model', 'pickled', resnet)
262+
# note that this source is the same for all models in this approach
263+
# so it can be made part of an API that just takes the model and
264+
# packages it with this source.
265+
src = """\
266+
import resources # gives you access to the importer from within the package
267+
268+
# server knows to call model.load() to get the model,
269+
# maybe in the future it passes options as arguments by convension
270+
def load():
271+
return resources.load_pickle('model', 'pickled')
272+
"""
273+
e.save_source_string('model', src, is_package=True)
274+
275+
f2 = self.temp()
276+
# Option 2: save with state dict
277+
# - more code to write to save/load the model
278+
# + but this code can be edited later to adjust adapt the model later
279+
with PackageExporter(f2, verbose=False) as e:
280+
e.save_pickle('model', 'state_dict', resnet.state_dict())
281+
src = """\
282+
import resources # gives you access to the importer from within the package
283+
from torchvision.models.resnet import resnet18
284+
def load():
285+
# if you want, you can later edit how resnet is constructed here
286+
# to edit the model in the package, while still loading the original
287+
# state dict weights
288+
r = resnet18()
289+
state_dict = resources.load_pickle('model', 'state_dict')
290+
r.load_state_dict(state_dict)
291+
return r
292+
"""
293+
e.save_source_string('model', src, is_package=True)
294+
295+
296+
297+
# regardless of how we chose to package, we can now use the model in a server in the same way
298+
input = torch.rand(1, 3, 224, 224)
299+
results = []
300+
for m in [f1, f2]:
301+
importer = PackageImporter(m)
302+
the_model = importer.import_module('model').load()
303+
r = the_model(input)
304+
results.append(r)
305+
306+
self.assertTrue(torch.allclose(*results))
307+
308+
if __name__ == '__main__':
309+
main()

‎torch/package/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .importer import PackageImporter
2+
from .exporter import PackageExporter
+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from pickle import _Pickler, _getattribute, whichmodule, _extension_registry, _compat_pickle # type: ignore
2+
from pickle import GLOBAL, STACK_GLOBAL, EXT1, EXT2, EXT4, PicklingError
3+
from struct import pack
4+
5+
class CustomImportPickler(_Pickler):
6+
def __init__(self, import_module, *args, **kwargs):
7+
self.import_module = import_module
8+
super().__init__(*args, **kwargs)
9+
10+
def save_global(self, obj, name=None):
11+
# unfortunately the pickler code is factored in a way that
12+
# forces us to copy/paste this function. The only change is marked
13+
# CHANGED below.
14+
write = self.write
15+
memo = self.memo
16+
17+
if name is None:
18+
name = getattr(obj, '__qualname__', None)
19+
if name is None:
20+
name = obj.__name__
21+
22+
module_name = whichmodule(obj, name)
23+
try:
24+
# CHANGED: self.import_module rather than
25+
# __import__
26+
module = self.import_module(module_name)
27+
obj2, parent = _getattribute(module, name)
28+
except (ImportError, KeyError, AttributeError):
29+
raise PicklingError(
30+
"Can't pickle %r: it's not found as %s.%s" %
31+
(obj, module_name, name)) from None
32+
else:
33+
if obj2 is not obj:
34+
raise PicklingError(
35+
"Can't pickle %r: it's not the same object as %s.%s" %
36+
(obj, module_name, name))
37+
38+
if self.proto >= 2:
39+
code = _extension_registry.get((module_name, name))
40+
if code:
41+
assert code > 0
42+
if code <= 0xff:
43+
write(EXT1 + pack("<B", code))
44+
elif code <= 0xffff:
45+
write(EXT2 + pack("<H", code))
46+
else:
47+
write(EXT4 + pack("<i", code))
48+
return
49+
lastname = name.rpartition('.')[2]
50+
if parent is module:
51+
name = lastname
52+
# Non-ASCII identifiers are supported only with protocols >= 3.
53+
if self.proto >= 4:
54+
self.save(module_name)
55+
self.save(name)
56+
write(STACK_GLOBAL)
57+
elif parent is not module:
58+
self.save_reduce(getattr, (parent, lastname))
59+
elif self.proto >= 3:
60+
write(GLOBAL + bytes(module_name, "utf-8") + b'\n' +
61+
bytes(name, "utf-8") + b'\n')
62+
else:
63+
if self.fix_imports:
64+
r_name_mapping = _compat_pickle.REVERSE_NAME_MAPPING
65+
r_import_mapping = _compat_pickle.REVERSE_IMPORT_MAPPING
66+
if (module_name, name) in r_name_mapping:
67+
module_name, name = r_name_mapping[(module_name, name)]
68+
elif module_name in r_import_mapping:
69+
module_name = r_import_mapping[module_name]
70+
try:
71+
write(GLOBAL + bytes(module_name, "ascii") + b'\n' +
72+
bytes(name, "ascii") + b'\n')
73+
except UnicodeEncodeError:
74+
raise PicklingError(
75+
"can't pickle global identifier '%s.%s' using "
76+
"pickle protocol %i" % (module, name, self.proto)) from None
77+
78+
self.memoize(obj)

‎torch/package/_importlib.py

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import _warnings
2+
import os.path
3+
# note: implementations
4+
# copied from cpython's import code
5+
6+
7+
# _zip_searchorder defines how we search for a module in the Zip
8+
# archive: we first search for a package __init__, then for
9+
# non-package .pyc, and .py entries. The .pyc entries
10+
# are swapped by initzipimport() if we run in optimized mode. Also,
11+
# '/' is replaced by path_sep there.
12+
13+
_zip_searchorder = (
14+
('/__init__.py', True),
15+
('.py', False),
16+
)
17+
18+
# Replace any occurrences of '\r\n?' in the input string with '\n'.
19+
# This converts DOS and Mac line endings to Unix line endings.
20+
def _normalize_line_endings(source):
21+
source = source.replace(b'\r\n', b'\n')
22+
source = source.replace(b'\r', b'\n')
23+
return source
24+
25+
def _resolve_name(name, package, level):
26+
"""Resolve a relative module name to an absolute one."""
27+
bits = package.rsplit('.', level - 1)
28+
if len(bits) < level:
29+
raise ValueError('attempted relative import beyond top-level package')
30+
base = bits[0]
31+
return '{}.{}'.format(base, name) if name else base
32+
33+
def _sanity_check(name, package, level):
34+
"""Verify arguments are "sane"."""
35+
if not isinstance(name, str):
36+
raise TypeError('module name must be str, not {}'.format(type(name)))
37+
if level < 0:
38+
raise ValueError('level must be >= 0')
39+
if level > 0:
40+
if not isinstance(package, str):
41+
raise TypeError('__package__ not set to a string')
42+
elif not package:
43+
raise ImportError('attempted relative import with no known parent '
44+
'package')
45+
if not name and level == 0:
46+
raise ValueError('Empty module name')
47+
48+
def _calc___package__(globals):
49+
"""Calculate what __package__ should be.
50+
51+
__package__ is not guaranteed to be defined or could be set to None
52+
to represent that its proper value is unknown.
53+
54+
"""
55+
package = globals.get('__package__')
56+
spec = globals.get('__spec__')
57+
if package is not None:
58+
if spec is not None and package != spec.parent:
59+
_warnings.warn("__package__ != __spec__.parent "
60+
f"({package!r} != {spec.parent!r})",
61+
ImportWarning, stacklevel=3)
62+
return package
63+
elif spec is not None:
64+
return spec.parent
65+
else:
66+
_warnings.warn("can't resolve package from __spec__ or __package__, "
67+
"falling back on __name__ and __path__",
68+
ImportWarning, stacklevel=3)
69+
package = globals['__name__']
70+
if '__path__' not in globals:
71+
package = package.rpartition('.')[0]
72+
return package
73+
74+
def _normalize_path(path):
75+
"""Normalize a path by ensuring it is a string.
76+
77+
If the resulting string contains path separators, an exception is raised.
78+
"""
79+
parent, file_name = os.path.split(path)
80+
if parent:
81+
raise ValueError('{!r} must be only a file name'.format(path))
82+
else:
83+
return file_name

‎torch/package/_mock.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
_magic_methods = ['__subclasscheck__', '__hex__', '__rmul__',
3+
'__float__', '__idiv__', '__setattr__', '__div__', '__invert__',
4+
'__nonzero__', '__rshift__',
5+
'__eq__', '__pos__', '__round__',
6+
'__rand__', '__or__', '__complex__', '__divmod__',
7+
'__len__', '__reversed__', '__copy__', '__reduce__',
8+
'__deepcopy__', '__rdivmod__', '__rrshift__', '__ifloordiv__',
9+
'__hash__', '__iand__', '__xor__', '__isub__', '__oct__',
10+
'__ceil__', '__imod__', '__add__', '__truediv__',
11+
'__unicode__', '__le__', '__delitem__', '__sizeof__', '__sub__',
12+
'__ne__', '__pow__', '__bytes__', '__mul__',
13+
'__itruediv__', '__bool__', '__iter__', '__abs__',
14+
'__gt__', '__iadd__', '__enter__',
15+
'__floordiv__', '__call__', '__neg__',
16+
'__and__', '__ixor__', '__getitem__', '__exit__', '__cmp__',
17+
'__getstate__', '__index__', '__contains__', '__floor__', '__lt__', '__getattr__',
18+
'__mod__', '__trunc__', '__delattr__', '__instancecheck__', '__setitem__', '__ipow__',
19+
'__ilshift__', '__long__', '__irshift__', '__imul__',
20+
'__lshift__', '__dir__', '__ge__', '__int__', '__ior__']
21+
22+
23+
class MockedObject:
24+
_name: str
25+
26+
def __init__(self, name):
27+
self.__dict__['_name'] = name
28+
29+
def __repr__(self):
30+
return f"MockedObject({self._name})"
31+
32+
33+
def install_method(method_name):
34+
def _not_implemented(self, *args, **kwargs):
35+
raise NotImplementedError(f"Object '{self._name}' was mocked out during packaging but it is being used in {method_name}")
36+
setattr(MockedObject, method_name, _not_implemented)
37+
38+
for method_name in _magic_methods:
39+
install_method(method_name)

‎torch/package/_mock_zipreader.py

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
from glob import glob
3+
import os.path
4+
from typing import List, Any
5+
6+
_storages : List[Any] = [
7+
torch.DoubleStorage,
8+
torch.FloatStorage,
9+
torch.LongStorage,
10+
torch.IntStorage,
11+
torch.ShortStorage,
12+
torch.CharStorage,
13+
torch.ByteStorage,
14+
torch.BoolStorage,
15+
]
16+
_dtype_to_storage = {
17+
data_type(0).dtype: data_type for data_type in _storages
18+
}
19+
20+
# because get_storage_from_record returns a tensor!?
21+
class _HasStorage(object):
22+
def __init__(self, storage):
23+
self._storage = storage
24+
25+
def storage(self):
26+
return self._storage
27+
28+
29+
class MockZipReader(object):
30+
def __init__(self, directory):
31+
self.directory = directory
32+
33+
def get_record(self, name):
34+
filename = f'{self.directory}/{name}'
35+
with open(filename, 'rb') as f:
36+
return f.read()
37+
38+
def get_storage_from_record(self, name, numel, dtype):
39+
storage = _dtype_to_storage[dtype]
40+
filename = f'{self.directory}/{name}'
41+
return _HasStorage(storage.from_file(filename=filename, size=numel))
42+
43+
def get_all_records(self, ):
44+
files = []
45+
for filename in glob(f'{self.directory}/**', recursive=True):
46+
if not os.path.isdir(filename):
47+
files.append(filename[len(self.directory) + 1:])
48+
return files

‎torch/package/exporter.py

+435
Large diffs are not rendered by default.
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from typing import List, Optional, Tuple
2+
import ast
3+
from ._importlib import _resolve_name
4+
5+
class _ExtractModuleReferences(ast.NodeVisitor):
6+
"""
7+
Extract the list of global variables a block of code will read and write
8+
"""
9+
10+
@classmethod
11+
def run(cls, src: str, package: str) -> List[Tuple[str, Optional[str]]]:
12+
visitor = cls(package)
13+
tree = ast.parse(src)
14+
visitor.visit(tree)
15+
return list(visitor.references.keys())
16+
17+
def __init__(self, package):
18+
super().__init__()
19+
self.package = package
20+
self.references = {}
21+
22+
def _absmodule(self, module_name: str, level: int) -> str:
23+
if level > 0:
24+
return _resolve_name(module_name, self.package, level)
25+
return module_name
26+
27+
def visit_Import(self, node):
28+
for alias in node.names:
29+
self.references[(alias.name, None)] = True
30+
31+
def visit_ImportFrom(self, node):
32+
name = self._absmodule(node.module, 0 if node.level is None else node.level)
33+
for alias in node.names:
34+
# from my_package import foo
35+
# foo may be a module, so we have to add it to the list of
36+
# potential references, if import of it fails, we will ignore it
37+
if alias.name != '*':
38+
self.references[(name, alias.name)] = True
39+
else:
40+
self.references[(name, None)] = True
41+
42+
find_files_source_depends_on = _ExtractModuleReferences.run

‎torch/package/importer.py

+388
Large diffs are not rendered by default.

‎torch/serialization.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,7 @@ def restore_location(storage, location):
821821
return restore_location
822822

823823

824-
def _load(zip_file, map_location, pickle_module, **pickle_load_args):
824+
def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickle_load_args):
825825
restore_location = _get_restore_location(map_location)
826826

827827
loaded_storages = {}
@@ -847,7 +847,7 @@ def persistent_load(saved_id):
847847
return storage
848848

849849
# Load the data (which may in turn use `persistent_load` to load tensors)
850-
data_file = io.BytesIO(zip_file.get_record('data.pkl'))
850+
data_file = io.BytesIO(zip_file.get_record(pickle_file))
851851
unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
852852
unpickler.persistent_load = persistent_load
853853
result = unpickler.load()

0 commit comments

Comments
 (0)
Please sign in to comment.