Skip to content

Commit 79b8328

Browse files
vkuzofacebook-github-bot
authored andcommittedAug 8, 2020
optimize_for_mobile: bring packed params to root module (pytorch#42740)
Summary: Pull Request resolved: pytorch#42740 Adds a pass to hoist conv packed params to root module. The benefit is that if there is nothing else in the conv module, subsequent passes will delete it, which will reduce module size. For context, freezing does not handle this because conv packed params is a custom object. Test Plan: ``` PYTORCH_JIT_LOG_LEVEL=">hoist_conv_packed_params.cpp" python test/test_mobile_optimizer.py TestOptimizer.test_hoist_conv_packed_params ``` Imported from OSS Reviewed By: kimishpatel Differential Revision: D23005961 fbshipit-source-id: 31ab1f5c42a627cb74629566483cdc91f3770a94
1 parent d8801f5 commit 79b8328

9 files changed

+268
-6
lines changed
 

‎docs/source/mobile_optimizer.rst

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ By default, if optimization blacklist is None or empty, ``optimize_for_mobile``
1212
- **Insert and Fold prepacked ops** (blacklisting option `MobileOptimizerType::INSERT_FOLD_PREPACK_OPS`): This optimization pass rewrites the graph to replace 2D convolutions and linear ops with their prepacked counterparts. Prepacked ops are stateful ops in that, they require some state to be created, such as weight prepacking and use this state, i.e. prepacked weights, during op execution. XNNPACK is one such backend that provides prepacked ops, with kernels optimized for mobile platforms (such as ARM CPUs). Prepacking of weight enables efficient memory access and thus faster kernel execution. At the moment ``optimize_for_mobile`` pass rewrites the graph to replace ``Conv2D/Linear`` with 1) op that pre-packs weight for XNNPACK conv2d/linear ops and 2) op that takes pre-packed weight and activation as input and generates output activations. Since 1 needs to be done only once, we fold the weight pre-packing such that it is done only once at model load time. This pass of the ``optimize_for_mobile`` does 1 and 2 and then folds, i.e. removes, weight pre-packing ops.
1313
- **ReLU/Hardtanh fusion**: XNNPACK ops support fusion of clamping. That is clamping of output activation is done as part of the kernel, including for 2D convolution and linear op kernels. Thus clamping effectively comes for free. Thus any op that can be expressed as clamping op, such as ``ReLU`` or ``hardtanh``, can be fused with previous ``Conv2D`` or ``linear`` op in XNNPACK. This pass rewrites graph by finding ``ReLU/hardtanh`` ops that follow XNNPACK ``Conv2D/linear`` ops, written by the previous pass, and fuses them together.
1414
- **Dropout removal** (blacklisting option `MobileOptimizerType::REMOVE_DROPOUT`): This optimization pass removes ``dropout`` and ``dropout_`` nodes from this module when training is false.
15+
- **Conv packed params hoisting** (blacklisting option `MobileOptimizerType::HOIST_CONV_PACKED_PARAMS`): This optimization pass moves convolution packed params to the root module, so that the convolution structs can be deleted. This decreases model size without impacting numerics.
1516

1617
``optimize_for_mobile`` will also invoke freeze_module pass which only preserves ``forward`` method. If you have other method to that needed to be preserved, add them into the preserved method list and pass into the method.
1718

‎test/test_mobile_optimizer.py

+97-4
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,6 @@ def forward(self, x):
121121
optimization_blacklist_no_prepack = {MobileOptimizerType.INSERT_FOLD_PREPACK_OPS}
122122
bn_fold_scripted_module = optimize_for_mobile(bn_scripted_module, optimization_blacklist_no_prepack)
123123
self.assertEqual(len(torch.jit.export_opnames(bn_fold_scripted_module)), 1)
124-
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
125-
.run(str(get_forward_graph(bn_fold_scripted_module._c)))
126124
bn_input = torch.rand(1, 1, 6, 6)
127125
torch.testing.assert_allclose(bn_scripted_module(bn_input), bn_fold_scripted_module(bn_input), rtol=1e-2, atol=1e-3)
128126

@@ -201,8 +199,6 @@ def forward(self, x):
201199
model = torch.jit.script(model)
202200
# this line should not have ASAN failures
203201
model_optim = optimize_for_mobile(model)
204-
self.assertFalse(hasattr(model_optim.conv1, "bias"))
205-
self.assertFalse(hasattr(model_optim.child.conv2, "bias"))
206202

207203
def test_generate_mobile_module_lints(self):
208204
class MyTestModule(torch.nn.Module):
@@ -255,5 +251,102 @@ def get_lint_count_by_type(lint_type, module_lint_List):
255251
bi_module_lint_list = generate_mobile_module_lints(bi_module)
256252
self.assertEqual(len(bi_module_lint_list), 0)
257253

254+
@unittest.skipUnless(torch.backends.xnnpack.enabled,
255+
" XNNPACK must be enabled for these tests."
256+
" Please build with USE_XNNPACK=1.")
257+
def test_hoist_conv_packed_params(self):
258+
259+
if 'qnnpack' not in torch.backends.quantized.supported_engines:
260+
return
261+
262+
class Standalone(nn.Module):
263+
def __init__(self):
264+
super(Standalone, self).__init__()
265+
self.quant = torch.quantization.QuantStub()
266+
self.conv1 = nn.Conv2d(1, 1, 1)
267+
self.conv2 = nn.Conv2d(1, 1, 1)
268+
self.relu = nn.ReLU()
269+
self.dequant = torch.quantization.DeQuantStub()
270+
271+
def forward(self, x):
272+
x = self.quant(x)
273+
x = self.conv1(x)
274+
x = self.conv2(x)
275+
x = self.relu(x)
276+
x = self.dequant(x)
277+
return x
278+
279+
def fuse_model(self):
280+
torch.quantization.fuse_modules(self, [['conv2', 'relu']], inplace=True)
281+
pass
282+
283+
class Child(nn.Module):
284+
def __init__(self):
285+
super(Child, self).__init__()
286+
self.conv1 = nn.Conv2d(1, 1, 1)
287+
288+
def forward(self, x):
289+
x = self.conv1(x)
290+
return x
291+
292+
class Parent(nn.Module):
293+
def __init__(self):
294+
super(Parent, self).__init__()
295+
self.quant = torch.quantization.QuantStub()
296+
self.conv1 = nn.Conv2d(1, 1, 1)
297+
self.child = Child()
298+
# TODO: test nn.Sequential after #42039 is fixed
299+
self.dequant = torch.quantization.DeQuantStub()
300+
301+
def forward(self, x):
302+
x = self.quant(x)
303+
x = self.conv1(x)
304+
x = self.child(x)
305+
x = self.dequant(x)
306+
return x
307+
308+
def fuse_model(self):
309+
pass
310+
311+
with override_quantized_engine('qnnpack'):
312+
def _quant_script_and_optimize(model):
313+
model.qconfig = torch.quantization.get_default_qconfig('qnnpack')
314+
model.fuse_model()
315+
torch.quantization.prepare(model, inplace=True)
316+
model(torch.randn(4, 1, 4, 4))
317+
torch.quantization.convert(model, inplace=True)
318+
model = torch.jit.script(model)
319+
model_optim = optimize_for_mobile(model)
320+
return model, model_optim
321+
322+
# basic case
323+
324+
m, m_optim = _quant_script_and_optimize(Standalone())
325+
FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
326+
.check_count("_jit_pass_hoist_conv_packed_params", 2, exactly=True) \
327+
.run(m_optim.graph)
328+
self.assertFalse(hasattr(m_optim, "conv1"))
329+
self.assertFalse(hasattr(m_optim, "conv2"))
330+
331+
data = torch.randn(4, 1, 4, 4)
332+
m_res = m(data)
333+
m_optim_res = m_optim(data)
334+
torch.testing.assert_allclose(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
335+
336+
# generic case
337+
338+
m, m_optim = _quant_script_and_optimize(Parent())
339+
FileCheck().check_not("Conv2d = prim::GetAttr[name=\"conv1\"]") \
340+
.check_count("_jit_pass_hoist_conv_packed_params", 2, exactly=True) \
341+
.run(m_optim.graph)
342+
self.assertFalse(hasattr(m_optim, "conv1"))
343+
self.assertFalse(hasattr(m_optim, "child"))
344+
345+
data = torch.randn(4, 1, 4, 4)
346+
m_res = m(data)
347+
m_optim_res = m_optim(data)
348+
torch.testing.assert_allclose(m_res, m_optim_res, rtol=1e-2, atol=1e-3)
349+
350+
258351
if __name__ == '__main__':
259352
unittest.main()

‎tools/build_variables.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ core_sources_full = [
173173
"torch/csrc/jit/passes/graph_fuser.cpp",
174174
"torch/csrc/jit/passes/graph_rewrite_helper.cpp",
175175
"torch/csrc/jit/passes/guard_elimination.cpp",
176+
"torch/csrc/jit/passes/hoist_conv_packed_params.cpp",
176177
"torch/csrc/jit/passes/inline_autodiff_subgraphs.cpp",
177178
"torch/csrc/jit/passes/inline_forked_closures.cpp",
178179
"torch/csrc/jit/passes/inliner.cpp",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#include <stack>
2+
3+
#include <torch/csrc/jit/api/module.h>
4+
#include <torch/csrc/jit/jit_log.h>
5+
#include <torch/csrc/jit/passes/constant_pooling.h>
6+
#include <torch/csrc/jit/passes/constant_propagation.h>
7+
#include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
8+
#include <torch/csrc/jit/passes/quantization/helper.h>
9+
10+
namespace torch {
11+
namespace jit {
12+
13+
// Hoists packed params from a conv module to the parent module.
14+
// The benefit is that after this hoisting, the conv module
15+
// no longer holds anything and can be deleted, reducing model
16+
// size.
17+
//
18+
// Before (easy case):
19+
//
20+
// %1 = prim::GetAttr[name="conv1"][%self]
21+
// %2 = prim::GetAttr[name="_packed_params][%1]
22+
//
23+
// After (easy case):
24+
//
25+
// %2 = prim::GetAttr[name="{prefix}.conv1._packed_params"][%self]
26+
//
27+
// Before (generic case):
28+
//
29+
// %1 = prim::GetAttr[name="name1"][%self]
30+
// %2 = prim::GetAttr[name="name2"][%1]
31+
// ...
32+
// %n = prim::GetAttr[name="_packed_params][%n-1]
33+
//
34+
// After (generic case):
35+
//
36+
// %n =
37+
// prim::GetAttr[name="{prefix}.name1{...}.name(n-1)._packed_params"][%self]
38+
//
39+
void hoistConvPackedParams(
40+
Module& rootModule,
41+
Node* getConvPackedParamsNode,
42+
const std::string& prefix,
43+
int& nameUniqueCounter) {
44+
auto method = rootModule.get_method("forward");
45+
auto graph = method.graph();
46+
Value* rootModuleAsValue = graph->inputs()[0];
47+
48+
// get a path from root module to conv module
49+
Value* convModuleAsValue = getConvPackedParamsNode->inputs()[0];
50+
std::vector<std::string> rootToConvPath =
51+
getModuleAccessPath(convModuleAsValue, rootModuleAsValue);
52+
53+
// get a module object representing the conv
54+
Module convModule = findChildModule(rootModule, rootToConvPath);
55+
56+
// get the packed params value
57+
c10::IValue packedParams = convModule.attr("_packed_params");
58+
59+
// create the new name
60+
61+
std::string suffix = "";
62+
for (const auto& attrName : rootToConvPath) {
63+
suffix += attrName + ".";
64+
}
65+
std::string newNameBase = prefix + "." + suffix + "_packed_params";
66+
nameUniqueCounter++;
67+
std::string newName = newNameBase + "." + c10::to_string(nameUniqueCounter);
68+
while (rootModule.hasattr(newName)) {
69+
nameUniqueCounter++;
70+
newName = newNameBase + "." + c10::to_string(nameUniqueCounter);
71+
}
72+
73+
// copy the packed params
74+
rootModule.register_attribute(newName, packedParams.type(), packedParams);
75+
76+
// change target module to rootModule
77+
getConvPackedParamsNode->replaceInput(0, rootModuleAsValue);
78+
79+
// change attribute name to new name
80+
getConvPackedParamsNode->s_(Symbol::attr("name"), newName);
81+
}
82+
83+
void HoistConvPackedParams(script::Module& m) {
84+
auto method = m.get_method("forward");
85+
auto graph = method.graph();
86+
87+
std::stack<Block*> blocks_to_visit;
88+
blocks_to_visit.push(graph->block());
89+
std::string attr_name_base = "_jit_pass_hoist_conv_packed_params";
90+
// counter to ensure new attribute names are unique
91+
int nameUniqueCounter = 0;
92+
93+
while (!blocks_to_visit.empty()) {
94+
Block* b = blocks_to_visit.top();
95+
blocks_to_visit.pop();
96+
97+
for (Node* n : b->nodes()) {
98+
// make sure this node is fetching {foo}.{_packed_params}
99+
bool isGetPackedParamsNode =
100+
n->kind() == prim::GetAttr && n->s(attr::name) == "_packed_params";
101+
if (isGetPackedParamsNode) {
102+
// make sure the foo in {foo}.{_packed_params} is a quantized conv
103+
c10::optional<std::string> moduleName = getModuleName(n->inputs()[0]);
104+
bool moduleNameIsQuantizedConv = moduleName.has_value() &&
105+
(moduleName.value() ==
106+
"__torch__.torch.nn.quantized.modules.conv.Conv1d" ||
107+
moduleName.value() ==
108+
"__torch__.torch.nn.quantized.modules.conv.Conv2d" ||
109+
moduleName.value() ==
110+
"__torch__.torch.nn.quantized.modules.conv.Conv3d" ||
111+
moduleName.value() ==
112+
"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU1d" ||
113+
moduleName.value() ==
114+
"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d" ||
115+
moduleName.value() ==
116+
"__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d");
117+
118+
if (moduleNameIsQuantizedConv) {
119+
GRAPH_UPDATE("Hoisting ", *n, " to root module.");
120+
hoistConvPackedParams(m, n, attr_name_base, nameUniqueCounter);
121+
}
122+
}
123+
124+
for (Block* subblock : n->blocks()) {
125+
blocks_to_visit.push(subblock);
126+
}
127+
128+
} // for
129+
130+
} // while
131+
}
132+
133+
} // namespace jit
134+
} // namespace torch
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#pragma once
2+
3+
#include <torch/csrc/jit/api/module.h>
4+
#include <torch/csrc/jit/ir/ir.h>
5+
6+
namespace torch {
7+
namespace jit {
8+
9+
void HoistConvPackedParams(script::Module& m);
10+
11+
} // namespace jit
12+
} // namespace torch

‎torch/csrc/jit/passes/quantization/helper.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -530,8 +530,10 @@ bool hitGraphInput(Value* value) {
530530
// Get the module access path for a Value representing a module instance
531531
// by tracing back the GetAttr nodes and recording all the attribute
532532
// names along the way.
533-
// For example, the module access path will be ['sub', 'basic_block', 'conv1']
534-
// for `self.sub.basic_block.conv1`
533+
// Assuming 'self.sub.basic_block.conv1',
534+
// Input1: Value instance of conv1
535+
// Input2: Value instance of self
536+
// Output: ['sub', 'basic_block', 'conv1']
535537
std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) {
536538
std::vector<std::string> path;
537539
// Iterator to traverse back the GetAttr calls
@@ -555,6 +557,10 @@ std::vector<std::string> getModuleAccessPath(Value* instance, Value* self) {
555557
return path;
556558
}
557559

560+
// Assuming self.foo.bar.conv1,
561+
// Input1: Module instance of self
562+
// Input2: ['foo', 'bar', 'conv1']
563+
// Output: Module instance of conv1
558564
Module findChildModule(
559565
const Module& module,
560566
const std::vector<std::string>& path) {

‎torch/csrc/jit/passes/xnnpack_rewrite.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <torch/csrc/jit/passes/fuse_linear.h>
1010
#include <torch/csrc/jit/passes/fuse_relu.h>
1111
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
12+
#include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
13+
#include <torch/csrc/jit/passes/inliner.h>
1214
#include <torch/csrc/jit/passes/prepack_folding.h>
1315
#include <torch/csrc/jit/passes/remove_dropout.h>
1416
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
@@ -294,6 +296,15 @@ script::Module optimizeForMobile(
294296
FoldPrePackingOps(cloned_module);
295297
}
296298

299+
if (!optimization_blocklist.count(
300+
MobileOptimizerType::HOIST_CONV_PACKED_PARAMS)) {
301+
// freeze again in case it was not done in previous optional passes
302+
cloned_module = freeze_module(cloned_module, preserved_methods);
303+
HoistConvPackedParams(cloned_module);
304+
// and freeze yet again to remove the empty QuantizedConv modules
305+
cloned_module = freeze_module(cloned_module, preserved_methods);
306+
}
307+
297308
// Run canonical optimizations post freezing
298309
// since freezing inlines the graph. Otherwise we
299310
// will have to explicitly call Inlining pass.

‎torch/csrc/jit/passes/xnnpack_rewrite.h

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ enum class MobileOptimizerType : int8_t {
1111
INSERT_FOLD_PREPACK_OPS,
1212
REMOVE_DROPOUT,
1313
FUSE_ADD_RELU,
14+
HOIST_CONV_PACKED_PARAMS,
1415
};
1516

1617
TORCH_API void insertPrePackedOps(std::shared_ptr<Graph>& graph);

‎torch/csrc/jit/python/init.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,9 @@ void initJITBindings(PyObject* module) {
725725
MobileOptimizerType::INSERT_FOLD_PREPACK_OPS)
726726
.value("REMOVE_DROPOUT", MobileOptimizerType::REMOVE_DROPOUT)
727727
.value("FUSE_ADD_RELU", MobileOptimizerType::FUSE_ADD_RELU)
728+
.value(
729+
"HOIST_CONV_PACKED_PARAMS",
730+
MobileOptimizerType::HOIST_CONV_PACKED_PARAMS)
728731
.export_values();
729732

730733
// This allows PyTorchStreamReader to read from a Python buffer. It requires

0 commit comments

Comments
 (0)
Please sign in to comment.