Skip to content

Commit 9c82b57

Browse files
James Reedfacebook-github-bot
James Reed
authored andcommittedJul 6, 2020
Fix delegating to jit.load from torch.load (pytorch#40937)
Summary: Pull Request resolved: pytorch#40937 Test Plan: Imported from OSS Differential Revision: D22363816 Pulled By: jamesr66a fbshipit-source-id: 50fc318869407fe8b215368026eaceb129b68a46
1 parent 73c5a78 commit 9c82b57

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed
 

‎test/test_serialization.py

+7
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,13 @@ def test(name_or_buffer):
713713

714714
test(io.BytesIO())
715715

716+
@unittest.skipIf(IS_WINDOWS, "NamedTemporaryFile on windows")
717+
def test_serialization_zipfile_actually_jit(self):
718+
with tempfile.NamedTemporaryFile() as f:
719+
torch.jit.save(torch.jit.script(torch.nn.Linear(3, 4)), f)
720+
f.seek(0)
721+
torch.load(f)
722+
716723
# Ensure large zip64 serialization works properly
717724
def test_serialization_2gb_file(self):
718725
big_model = torch.nn.Conv2d(20000, 3200, kernel_size=3)

‎torch/serialization.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -580,12 +580,17 @@ def load(f, map_location=None, pickle_module=pickle, **pickle_load_args):
580580

581581
with _open_file_like(f, 'rb') as opened_file:
582582
if _is_zipfile(opened_file):
583+
# The zipfile reader is going to advance the current file position.
584+
# If we want to actually tail call to torch.jit.load, we need to
585+
# reset back to the original position.
586+
orig_position = opened_file.tell()
583587
with _open_zipfile_reader(opened_file) as opened_zipfile:
584588
if _is_torchscript_zip(opened_zipfile):
585589
warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
586590
" dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
587591
" silence this warning)", UserWarning)
588-
return torch.jit.load(f)
592+
opened_file.seek(orig_position)
593+
return torch.jit.load(opened_file)
589594
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
590595
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
591596

0 commit comments

Comments
 (0)
Please sign in to comment.