Skip to content

Commit 03a00e8

Browse files
committed
Merge pull request #3090 from longjon/summarize-tool
A Python script for at-a-glance net summary
2 parents 9c9f94e + 84eb44e commit 03a00e8

File tree

1 file changed

+140
-0
lines changed

1 file changed

+140
-0
lines changed

tools/extra/summarize.py

+140
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#!/usr/bin/env python
2+
3+
"""Net summarization tool.
4+
5+
This tool summarizes the structure of a net in a concise but comprehensive
6+
tabular listing, taking a prototxt file as input.
7+
8+
Use this tool to check at a glance that the computation you've specified is the
9+
computation you expect.
10+
"""
11+
12+
from caffe.proto import caffe_pb2
13+
from google import protobuf
14+
import re
15+
import argparse
16+
17+
# ANSI codes for coloring blobs (used cyclically)
18+
COLORS = ['92', '93', '94', '95', '97', '96', '42', '43;30', '100',
19+
'444', '103;30', '107;30']
20+
DISCONNECTED_COLOR = '41'
21+
22+
def read_net(filename):
23+
net = caffe_pb2.NetParameter()
24+
with open(filename) as f:
25+
protobuf.text_format.Parse(f.read(), net)
26+
return net
27+
28+
def format_param(param):
29+
out = []
30+
if len(param.name) > 0:
31+
out.append(param.name)
32+
if param.lr_mult != 1:
33+
out.append('x{}'.format(param.lr_mult))
34+
if param.decay_mult != 1:
35+
out.append('Dx{}'.format(param.decay_mult))
36+
return ' '.join(out)
37+
38+
def printed_len(s):
39+
return len(re.sub(r'\033\[[\d;]+m', '', s))
40+
41+
def print_table(table, max_width):
42+
"""Print a simple nicely-aligned table.
43+
44+
table must be a list of (equal-length) lists. Columns are space-separated,
45+
and as narrow as possible, but no wider than max_width. Text may overflow
46+
columns; note that unlike string.format, this will not affect subsequent
47+
columns, if possible."""
48+
49+
max_widths = [max_width] * len(table[0])
50+
column_widths = [max(printed_len(row[j]) + 1 for row in table)
51+
for j in range(len(table[0]))]
52+
column_widths = [min(w, max_w) for w, max_w in zip(column_widths, max_widths)]
53+
54+
for row in table:
55+
row_str = ''
56+
right_col = 0
57+
for cell, width in zip(row, column_widths):
58+
right_col += width
59+
row_str += cell + ' '
60+
row_str += ' ' * max(right_col - printed_len(row_str), 0)
61+
print row_str
62+
63+
def summarize_net(net):
64+
disconnected_tops = set()
65+
for lr in net.layer:
66+
disconnected_tops |= set(lr.top)
67+
disconnected_tops -= set(lr.bottom)
68+
69+
table = []
70+
colors = {}
71+
for lr in net.layer:
72+
tops = []
73+
for ind, top in enumerate(lr.top):
74+
color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)])
75+
if top in disconnected_tops:
76+
top = '\033[1;4m' + top
77+
if len(lr.loss_weight) > 0:
78+
top = '{} * {}'.format(lr.loss_weight[ind], top)
79+
tops.append('\033[{}m{}\033[0m'.format(color, top))
80+
top_str = ', '.join(tops)
81+
82+
bottoms = []
83+
for bottom in lr.bottom:
84+
color = colors.get(bottom, DISCONNECTED_COLOR)
85+
bottoms.append('\033[{}m{}\033[0m'.format(color, bottom))
86+
bottom_str = ', '.join(bottoms)
87+
88+
if lr.type == 'Python':
89+
type_str = lr.python_param.module + '.' + lr.python_param.layer
90+
else:
91+
type_str = lr.type
92+
93+
# Summarize conv/pool parameters.
94+
# TODO support rectangular/ND parameters
95+
conv_param = lr.convolution_param
96+
if (lr.type in ['Convolution', 'Deconvolution']
97+
and len(conv_param.kernel_size) == 1):
98+
arg_str = str(conv_param.kernel_size[0])
99+
if len(conv_param.stride) > 0 and conv_param.stride[0] != 1:
100+
arg_str += '/' + str(conv_param.stride[0])
101+
if len(conv_param.pad) > 0 and conv_param.pad[0] != 0:
102+
arg_str += '+' + str(conv_param.pad[0])
103+
arg_str += ' ' + str(conv_param.num_output)
104+
if conv_param.group != 1:
105+
arg_str += '/' + str(conv_param.group)
106+
elif lr.type == 'Pooling':
107+
arg_str = str(lr.pooling_param.kernel_size)
108+
if lr.pooling_param.stride != 1:
109+
arg_str += '/' + str(lr.pooling_param.stride)
110+
if lr.pooling_param.pad != 0:
111+
arg_str += '+' + str(lr.pooling_param.pad)
112+
else:
113+
arg_str = ''
114+
115+
if len(lr.param) > 0:
116+
param_strs = map(format_param, lr.param)
117+
if max(map(len, param_strs)) > 0:
118+
param_str = '({})'.format(', '.join(param_strs))
119+
else:
120+
param_str = ''
121+
else:
122+
param_str = ''
123+
124+
table.append([lr.name, type_str, param_str, bottom_str, '->', top_str,
125+
arg_str])
126+
return table
127+
128+
def main():
129+
parser = argparse.ArgumentParser(description="Print a concise summary of net computation.")
130+
parser.add_argument('filename', help='net prototxt file to summarize')
131+
parser.add_argument('-w', '--max-width', help='maximum field width',
132+
type=int, default=30)
133+
args = parser.parse_args()
134+
135+
net = read_net(args.filename)
136+
table = summarize_net(net)
137+
print_table(table, max_width=args.max_width)
138+
139+
if __name__ == '__main__':
140+
main()

0 commit comments

Comments
 (0)