|
| 1 | +from unittest import main, skipIf |
| 2 | +from torch.testing._internal.common_utils import TestCase, IS_WINDOWS |
| 3 | +from tempfile import NamedTemporaryFile |
| 4 | +from torch.package import PackageExporter, PackageImporter |
| 5 | +from pathlib import Path |
| 6 | +from tempfile import TemporaryDirectory |
| 7 | +import torch |
| 8 | +from sys import version_info |
| 9 | + |
| 10 | +try: |
| 11 | + from torchvision.models import resnet18 |
| 12 | + HAS_TORCHVISION = True |
| 13 | +except ImportError: |
| 14 | + HAS_TORCHVISION = False |
| 15 | +skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision") |
| 16 | + |
| 17 | + |
| 18 | + |
| 19 | +packaging_directory = Path(__file__).parent |
| 20 | + |
| 21 | +class PackagingTest(TestCase): |
| 22 | + def __init__(self, *args, **kwargs): |
| 23 | + super().__init__(*args, **kwargs) |
| 24 | + self._temporary_files = [] |
| 25 | + |
| 26 | + def temp(self): |
| 27 | + t = NamedTemporaryFile() |
| 28 | + name = t.name |
| 29 | + if IS_WINDOWS: |
| 30 | + t.close() # can't read an open file in windows |
| 31 | + else: |
| 32 | + self._temporary_files.append(t) |
| 33 | + return name |
| 34 | + |
| 35 | + def tearDown(self): |
| 36 | + for t in self._temporary_files: |
| 37 | + t.close() |
| 38 | + self._temporary_files = [] |
| 39 | + |
| 40 | + def test_saving_source(self): |
| 41 | + filename = self.temp() |
| 42 | + with PackageExporter(filename, verbose=False) as he: |
| 43 | + he.save_source_file('foo', str(packaging_directory / 'module_a.py')) |
| 44 | + he.save_source_file('foodir', str(packaging_directory / 'package_a')) |
| 45 | + hi = PackageImporter(filename) |
| 46 | + foo = hi.import_module('foo') |
| 47 | + s = hi.import_module('foodir.subpackage') |
| 48 | + self.assertEqual(foo.result, 'module_a') |
| 49 | + self.assertEqual(s.result, 'package_a.subpackage') |
| 50 | + |
| 51 | + def test_saving_string(self): |
| 52 | + filename = self.temp() |
| 53 | + with PackageExporter(filename, verbose=False) as he: |
| 54 | + src = """\ |
| 55 | +import math |
| 56 | +the_math = math |
| 57 | +""" |
| 58 | + he.save_source_string('my_mod', src) |
| 59 | + hi = PackageImporter(filename) |
| 60 | + m = hi.import_module('math') |
| 61 | + import math |
| 62 | + self.assertIs(m, math) |
| 63 | + my_mod = hi.import_module('my_mod') |
| 64 | + self.assertIs(my_mod.math, math) |
| 65 | + |
| 66 | + def test_save_module(self): |
| 67 | + filename = self.temp() |
| 68 | + with PackageExporter(filename, verbose=False) as he: |
| 69 | + import module_a |
| 70 | + import package_a |
| 71 | + he.save_module(module_a.__name__) |
| 72 | + he.save_module(package_a.__name__) |
| 73 | + hi = PackageImporter(filename) |
| 74 | + module_a_i = hi.import_module('module_a') |
| 75 | + self.assertEqual(module_a_i.result, 'module_a') |
| 76 | + self.assertIsNot(module_a, module_a_i) |
| 77 | + package_a_i = hi.import_module('package_a') |
| 78 | + self.assertEqual(package_a_i.result, 'package_a') |
| 79 | + self.assertIsNot(package_a_i, package_a) |
| 80 | + |
| 81 | + def test_pickle(self): |
| 82 | + import package_a.subpackage |
| 83 | + obj = package_a.subpackage.PackageASubpackageObject() |
| 84 | + obj2 = package_a.PackageAObject(obj) |
| 85 | + |
| 86 | + filename = self.temp() |
| 87 | + with PackageExporter(filename, verbose=False) as he: |
| 88 | + he.save_pickle('obj', 'obj.pkl', obj2) |
| 89 | + hi = PackageImporter(filename) |
| 90 | + |
| 91 | + # check we got dependencies |
| 92 | + sp = hi.import_module('package_a.subpackage') |
| 93 | + # check we didn't get other stuff |
| 94 | + with self.assertRaises(ImportError): |
| 95 | + hi.import_module('module_a') |
| 96 | + |
| 97 | + obj_loaded = hi.load_pickle('obj', 'obj.pkl') |
| 98 | + self.assertIsNot(obj2, obj_loaded) |
| 99 | + self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject) |
| 100 | + self.assertIsNot(package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject) |
| 101 | + |
| 102 | + def test_resources(self): |
| 103 | + filename = self.temp() |
| 104 | + with PackageExporter(filename, verbose=False) as he: |
| 105 | + he.save_text('main', 'main', "my string") |
| 106 | + he.save_binary('main', 'main_binary', "my string".encode('utf-8')) |
| 107 | + src = """\ |
| 108 | +import resources |
| 109 | +t = resources.load_text('main', 'main') |
| 110 | +b = resources.load_binary('main', 'main_binary') |
| 111 | +""" |
| 112 | + he.save_source_string('main', src, is_package=True) |
| 113 | + hi = PackageImporter(filename) |
| 114 | + m = hi.import_module('main') |
| 115 | + self.assertEqual(m.t, "my string") |
| 116 | + self.assertEqual(m.b, "my string".encode('utf-8')) |
| 117 | + |
| 118 | + def test_extern(self): |
| 119 | + filename = self.temp() |
| 120 | + with PackageExporter(filename, verbose=False) as he: |
| 121 | + he.extern_modules(['package_a.subpackage', 'module_a']) |
| 122 | + he.save_module('package_a') |
| 123 | + hi = PackageImporter(filename) |
| 124 | + import package_a.subpackage |
| 125 | + import module_a |
| 126 | + |
| 127 | + module_a_im = hi.import_module('module_a') |
| 128 | + hi.import_module('package_a.subpackage') |
| 129 | + package_a_im = hi.import_module('package_a') |
| 130 | + |
| 131 | + self.assertIs(module_a, module_a_im) |
| 132 | + self.assertIsNot(package_a, package_a_im) |
| 133 | + self.assertIs(package_a.subpackage, package_a_im.subpackage) |
| 134 | + |
| 135 | + @skipIf(version_info.major < 3 or version_info.minor < 7, 'mock uses __getattr__ a 3.7 feature') |
| 136 | + def test_mock(self): |
| 137 | + filename = self.temp() |
| 138 | + with PackageExporter(filename, verbose=False) as he: |
| 139 | + he.mock_modules(['package_a.subpackage', 'module_a']) |
| 140 | + he.save_module('package_a') |
| 141 | + hi = PackageImporter(filename) |
| 142 | + import package_a.subpackage |
| 143 | + _ = package_a.subpackage |
| 144 | + import module_a |
| 145 | + _ = module_a |
| 146 | + |
| 147 | + m = hi.import_module('package_a.subpackage') |
| 148 | + r = m.result |
| 149 | + with self.assertRaisesRegex(NotImplementedError, 'was mocked out'): |
| 150 | + r() |
| 151 | + |
| 152 | + @skipIf(version_info.major < 3 or version_info.minor < 7, 'mock uses __getattr__ a 3.7 feature') |
| 153 | + def test_custom_requires(self): |
| 154 | + filename = self.temp() |
| 155 | + |
| 156 | + class Custom(PackageExporter): |
| 157 | + def require_module(self, name, dependencies): |
| 158 | + if name == 'module_a': |
| 159 | + self.mock_module('module_a') |
| 160 | + elif name == 'package_a': |
| 161 | + self.save_source_string('package_a', 'import module_a\nresult = 5\n') |
| 162 | + else: |
| 163 | + raise NotImplementedError('wat') |
| 164 | + |
| 165 | + with Custom(filename, verbose=False) as he: |
| 166 | + he.save_source_string('main', 'import package_a\n') |
| 167 | + |
| 168 | + hi = PackageImporter(filename) |
| 169 | + hi.import_module('module_a').should_be_mocked |
| 170 | + bar = hi.import_module('package_a') |
| 171 | + self.assertEqual(bar.result, 5) |
| 172 | + |
| 173 | + @skipIfNoTorchVision |
| 174 | + def test_resnet(self): |
| 175 | + resnet = resnet18() |
| 176 | + |
| 177 | + f1 = self.temp() |
| 178 | + |
| 179 | + # create a package that will save it along with its code |
| 180 | + with PackageExporter(f1, verbose=False) as e: |
| 181 | + # put the pickled resnet in the package, by default |
| 182 | + # this will also save all the code files references by |
| 183 | + # the objects in the pickle |
| 184 | + e.save_pickle('model', 'model.pkl', resnet) |
| 185 | + |
| 186 | + # we can now load the saved model |
| 187 | + i = PackageImporter(f1) |
| 188 | + r2 = i.load_pickle('model', 'model.pkl') |
| 189 | + |
| 190 | + # test that it works |
| 191 | + input = torch.rand(1, 3, 224, 224) |
| 192 | + ref = resnet(input) |
| 193 | + self.assertTrue(torch.allclose(r2(input), ref)) |
| 194 | + |
| 195 | + # functions exist also to get at the private modules in each package |
| 196 | + torchvision = i.import_module('torchvision') |
| 197 | + |
| 198 | + f2 = self.temp() |
| 199 | + # if we are doing transfer learning we might want to re-save |
| 200 | + # things that were loaded from a package |
| 201 | + with PackageExporter(f2, verbose=False) as e: |
| 202 | + # We need to tell the exporter about any modules that |
| 203 | + # came from imported packages so that it can resolve |
| 204 | + # class names like torchvision.models.resnet.ResNet |
| 205 | + # to their source code. |
| 206 | + |
| 207 | + e.importers.insert(0, i.import_module) |
| 208 | + |
| 209 | + # e.importers is a list of module importing functions |
| 210 | + # that by default contains importlib.import_module. |
| 211 | + # it is searched in order until the first success and |
| 212 | + # that module is taken to be what torchvision.models.resnet |
| 213 | + # should be in this code package. In the case of name collisions, |
| 214 | + # such as trying to save a ResNet from two different packages, |
| 215 | + # we take the first thing found in the path, so only ResNet objects from |
| 216 | + # one importer will work. This avoids a bunch of name mangling in |
| 217 | + # the source code. If you need to actually mix ResNet objects, |
| 218 | + # we suggest reconstructing the model objects using code from a single package |
| 219 | + # using functions like save_state_dict and load_state_dict to transfer state |
| 220 | + # to the correct code objects. |
| 221 | + e.save_pickle('model', 'model.pkl', r2) |
| 222 | + |
| 223 | + i2 = PackageImporter(f2) |
| 224 | + r3 = i2.load_pickle('model', 'model.pkl') |
| 225 | + self.assertTrue(torch.allclose(r3(input), ref)) |
| 226 | + |
| 227 | + # test we can load from a directory |
| 228 | + import zipfile |
| 229 | + zf = zipfile.ZipFile(f1, 'r') |
| 230 | + |
| 231 | + with TemporaryDirectory() as td: |
| 232 | + zf.extractall(path=td) |
| 233 | + iz = PackageImporter(str(Path(td) / Path(f1).name)) |
| 234 | + r4 = iz.load_pickle('model', 'model.pkl') |
| 235 | + self.assertTrue(torch.allclose(r4(input), ref)) |
| 236 | + |
| 237 | + @skipIfNoTorchVision |
| 238 | + def test_model_save(self): |
| 239 | + |
| 240 | + # This example shows how you might package a model |
| 241 | + # so that the creator of the model has flexibility about |
| 242 | + # how they want to save it but the 'server' can always |
| 243 | + # use the same API to load the package. |
| 244 | + |
| 245 | + # The convension is for each model to provide a |
| 246 | + # 'model' package with a 'load' function that actual |
| 247 | + # reads the model out of the archive. |
| 248 | + |
| 249 | + # How the load function is implemented is up to the |
| 250 | + # the packager. |
| 251 | + |
| 252 | + # get our normal torchvision resnet |
| 253 | + resnet = resnet18() |
| 254 | + |
| 255 | + |
| 256 | + f1 = self.temp() |
| 257 | + # Option 1: save by pickling the whole model |
| 258 | + # + single-line, similar to torch.jit.save |
| 259 | + # - more difficult to edit the code after the model is created |
| 260 | + with PackageExporter(f1, verbose=False) as e: |
| 261 | + e.save_pickle('model', 'pickled', resnet) |
| 262 | + # note that this source is the same for all models in this approach |
| 263 | + # so it can be made part of an API that just takes the model and |
| 264 | + # packages it with this source. |
| 265 | + src = """\ |
| 266 | +import resources # gives you access to the importer from within the package |
| 267 | +
|
| 268 | +# server knows to call model.load() to get the model, |
| 269 | +# maybe in the future it passes options as arguments by convension |
| 270 | +def load(): |
| 271 | + return resources.load_pickle('model', 'pickled') |
| 272 | + """ |
| 273 | + e.save_source_string('model', src, is_package=True) |
| 274 | + |
| 275 | + f2 = self.temp() |
| 276 | + # Option 2: save with state dict |
| 277 | + # - more code to write to save/load the model |
| 278 | + # + but this code can be edited later to adjust adapt the model later |
| 279 | + with PackageExporter(f2, verbose=False) as e: |
| 280 | + e.save_pickle('model', 'state_dict', resnet.state_dict()) |
| 281 | + src = """\ |
| 282 | +import resources # gives you access to the importer from within the package |
| 283 | +from torchvision.models.resnet import resnet18 |
| 284 | +def load(): |
| 285 | + # if you want, you can later edit how resnet is constructed here |
| 286 | + # to edit the model in the package, while still loading the original |
| 287 | + # state dict weights |
| 288 | + r = resnet18() |
| 289 | + state_dict = resources.load_pickle('model', 'state_dict') |
| 290 | + r.load_state_dict(state_dict) |
| 291 | + return r |
| 292 | + """ |
| 293 | + e.save_source_string('model', src, is_package=True) |
| 294 | + |
| 295 | + |
| 296 | + |
| 297 | + # regardless of how we chose to package, we can now use the model in a server in the same way |
| 298 | + input = torch.rand(1, 3, 224, 224) |
| 299 | + results = [] |
| 300 | + for m in [f1, f2]: |
| 301 | + importer = PackageImporter(m) |
| 302 | + the_model = importer.import_module('model').load() |
| 303 | + r = the_model(input) |
| 304 | + results.append(r) |
| 305 | + |
| 306 | + self.assertTrue(torch.allclose(*results)) |
| 307 | + |
| 308 | +if __name__ == '__main__': |
| 309 | + main() |
0 commit comments