Skip to content

Commit a054d05

Browse files
dreissfacebook-github-bot
authored andcommittedApr 3, 2020
Add torch.utils.show_pickle for showing pickle contents in saved models (pytorch#35168)
Summary: Pull Request resolved: pytorch#35168 Sometimes when a saved model isn't working, it's nice to be able to look at the contents of the pickle files. Unfortunately, pickletools output isn't particularly readable, and unpickling is often either not possible or runs so much post-processing code that it's not possible to tell exactly what is present in the pickled data. This script uses a custom Unpickler to unpickle (almost) any data into stub objects that have no dependency on torch or any other runtime types and suppress (almost) any postprocessing code. As a convenience, the wrapper can search through zip files, supporting command lines like `python -m torch.utils.show_pickle /path/to/model.pt1@*/data.pkl` When the module is invoked as main, we also install a hack in pprint to allow semi-resonable formatting of our stub objects. Test Plan: Ran it on a data.pkl, constants.pkl, and a debug pkl Differential Revision: D20842550 Pulled By: dreiss fbshipit-source-id: ef662d8915fc5795039054d1f8fef2e1c51cf40a
1 parent ec3b355 commit a054d05

File tree

3 files changed

+150
-0
lines changed

3 files changed

+150
-0
lines changed
 

‎test/run_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
'quantization/test_quantize_script',
6060
'test_sparse',
6161
'test_serialization',
62+
'test_show_pickle',
6263
'test_torch',
6364
'test_type_info',
6465
'test_type_hints',

‎test/test_show_pickle.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import unittest
2+
import io
3+
import tempfile
4+
import torch
5+
import torch.utils.show_pickle
6+
7+
from torch.testing._internal.common_utils import IS_WINDOWS
8+
9+
class TestShowPickle(unittest.TestCase):
10+
11+
@unittest.skipIf(IS_WINDOWS, "Can't re-open temp file on Windows")
12+
def test_scripted_model(self):
13+
class MyCoolModule(torch.nn.Module):
14+
def __init__(self, weight):
15+
super().__init__()
16+
self.weight = weight
17+
18+
def forward(self, x):
19+
return x * self.weight
20+
21+
m = torch.jit.script(MyCoolModule(torch.tensor([2.0])))
22+
23+
with tempfile.NamedTemporaryFile() as tmp:
24+
torch.jit.save(m, tmp)
25+
tmp.flush()
26+
buf = io.StringIO()
27+
torch.utils.show_pickle.main(["", tmp.name + "@*/data.pkl"], output_stream=buf)
28+
output = buf.getvalue()
29+
self.assertRegex(output, "MyCoolModule")
30+
self.assertRegex(output, "weight")
31+
32+
33+
if __name__ == '__main__':
34+
unittest.main()

‎torch/utils/show_pickle.py

+115
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#!/usr/bin/env python3
2+
import sys
3+
import pickle
4+
import pprint
5+
import zipfile
6+
import fnmatch
7+
8+
9+
class FakeObject(object):
10+
def __init__(self, module, name, args):
11+
self.module = module
12+
self.name = name
13+
self.args = args
14+
# NOTE: We don't distinguish between state never set and state set to None.
15+
self.state = None
16+
17+
def __repr__(self):
18+
state_str = "" if self.state is None else f"(state={self.state!r})"
19+
return f"{self.module}.{self.name}{self.args!r}{state_str}"
20+
21+
def __setstate__(self, state):
22+
self.state = state
23+
24+
@staticmethod
25+
def pp_format(printer, obj, stream, indent, allowance, context, level):
26+
if not obj.args and obj.state is None:
27+
stream.write(repr(obj))
28+
return
29+
if obj.state is None:
30+
stream.write(f"{obj.module}.{obj.name}")
31+
printer._format(obj.args, stream, indent + 1, allowance + 1, context, level)
32+
return
33+
if not obj.args:
34+
stream.write(f"{obj.module}.{obj.name}()(state=\n")
35+
indent += printer._indent_per_level
36+
stream.write(" " * indent)
37+
printer._format(obj.state, stream, indent, allowance + 1, context, level + 1)
38+
stream.write(")")
39+
return
40+
raise Exception("Need to implement")
41+
42+
43+
class FakeClass(object):
44+
def __init__(self, module, name):
45+
self.module = module
46+
self.name = name
47+
self.__new__ = self.fake_new
48+
49+
def __repr__(self):
50+
return f"{self.module}.{self.name}"
51+
52+
def __call__(self, *args):
53+
return FakeObject(self.module, self.name, args)
54+
55+
def fake_new(self, *args):
56+
return FakeObject(self.module, self.name, args[1:])
57+
58+
59+
class DumpUnpickler(pickle._Unpickler):
60+
def find_class(self, module, name):
61+
return FakeClass(module, name)
62+
63+
def persistent_load(self, pid):
64+
return FakeObject("pers", "obj", (pid,))
65+
66+
@classmethod
67+
def dump(cls, in_stream, out_stream):
68+
value = cls(in_stream).load()
69+
pprint.pprint(value, stream=out_stream)
70+
71+
72+
def main(argv, output_stream=None):
73+
if len(argv) != 2:
74+
# Don't spam stderr if not using stdout.
75+
if output_stream is not None:
76+
raise Exception("Pass argv of length 2.")
77+
sys.stderr.write("usage: show_pickle PICKLE_FILE\n")
78+
sys.stderr.write(" PICKLE_FILE can be any of:\n")
79+
sys.stderr.write(" path to a pickle file\n")
80+
sys.stderr.write(" file.zip@member.pkl\n")
81+
sys.stderr.write(" file.zip@*/pattern.*\n")
82+
sys.stderr.write(" (shell glob pattern for members)\n")
83+
sys.stderr.write(" (only first match will be shown)\n")
84+
return 2
85+
86+
fname = argv[1]
87+
if "@" not in fname:
88+
with open(fname, "rb") as handle:
89+
DumpUnpickler.dump(handle, output_stream)
90+
else:
91+
zfname, mname = fname.split("@", 1)
92+
with zipfile.ZipFile(zfname) as zf:
93+
if "*" not in mname:
94+
with zf.open(mname) as handle:
95+
DumpUnpickler.dump(handle, output_stream)
96+
else:
97+
found = False
98+
for info in zf.infolist():
99+
if fnmatch.fnmatch(info.filename, mname):
100+
with zf.open(info) as handle:
101+
DumpUnpickler.dump(handle, output_stream)
102+
found = True
103+
break
104+
if not found:
105+
raise Exception(f"Could not find member matching {mname} in {zfname}")
106+
107+
108+
if __name__ == "__main__":
109+
# This hack works on every version of Python I've tested.
110+
# I've tested on the following versions:
111+
# 3.7.4
112+
if True:
113+
pprint.PrettyPrinter._dispatch[FakeObject.__repr__] = FakeObject.pp_format
114+
115+
sys.exit(main(sys.argv))

0 commit comments

Comments
 (0)
Please sign in to comment.