Skip to content

Commit ccfbfe5

Browse files
jerryzh168facebook-github-bot
authored andcommittedSep 23, 2020
[quant][graphmode][fx] Custom module support (pytorch#44766)
Summary: Pull Request resolved: pytorch#44766 There might be modules that are not symbolically traceable, e.g. LSTM (since it has input dependent control flows), to support quantization in these cases, user will provide the corresponding observed and quantized version of the custom module, the observed custom module with observers already inserted in the module and the quantized version will have the corresponding ops quantized. And use ``` from torch.quantization import register_observed_custom_module_mapping from torch.quantization import register_quantized_custom_module_mapping register_observed_custom_module_mapping(CustomModule, ObservedCustomModule) register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule) ``` to register the custom module mappings, we'll also need to define a custom delegate class for symbolic trace in order to prevent the custom module from being traced: ```python class CustomDelegate(DefaultDelegate): def is_leaf_module(self, m): return (m.__module__.startswith('torch.nn') and not isinstance(m, torch.nn.Sequential)) or \ isinstance(m, CustomModule) m = symbolic_trace(original_m, delegate_class=CustomDelegate) ``` Test Plan: Imported from OSS Reviewed By: z-a-f Differential Revision: D23723455 fbshipit-source-id: 50d666e29b94cbcbea5fb6bcc73b00cff87eb77a
1 parent 7f4a27b commit ccfbfe5

File tree

6 files changed

+274
-2
lines changed

6 files changed

+274
-2
lines changed
 

‎test/quantization/test_quantize_fx.py

+136
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
quantize_static_fx,
2121
quantize_dynamic_fx,
2222
prepare_qat_fx,
23+
register_observed_custom_module_mapping,
24+
register_quantized_custom_module_mapping,
2325
)
2426

2527
from torch.quantization import (
@@ -482,6 +484,140 @@ def forward(self, x):
482484
# Verify that loaded state dict produces same results.
483485
self.assertEqual(quant(x), quant_2(x))
484486

487+
@skipIfNoFBGEMM
488+
def test_custom_module_class(self):
489+
class CustomModule(torch.nn.Module):
490+
def __init__(self):
491+
super().__init__()
492+
self.conv = torch.nn.Conv2d(1, 1, 1)
493+
494+
def forward(self, x):
495+
return self.conv(x)
496+
497+
class ObservedCustomModule(torch.nn.Module):
498+
def __init__(self, conv):
499+
super().__init__()
500+
self.conv = conv
501+
502+
def forward(self, x):
503+
return self.conv(x)
504+
505+
@classmethod
506+
def from_float(cls, float_module):
507+
assert hasattr(float_module, 'qconfig')
508+
observed = cls(float_module.conv)
509+
observed.qconfig = float_module.qconfig
510+
return observed
511+
512+
class QuantizedCustomModule(torch.nn.Module):
513+
def __init__(self, conv):
514+
super().__init__()
515+
self.conv = conv
516+
517+
def forward(self, x):
518+
return self.conv(x)
519+
520+
@classmethod
521+
def from_observed(cls, observed_module):
522+
assert hasattr(observed_module, 'qconfig')
523+
assert hasattr(observed_module, 'activation_post_process')
524+
observed_module.conv.activation_post_process = \
525+
observed_module.activation_post_process
526+
quantized = cls(nnq.Conv2d.from_float(observed_module.conv))
527+
return quantized
528+
529+
class DynamicallyQuantizedCustomModule(torch.nn.Module):
530+
def __init__(self, conv):
531+
super().__init__()
532+
self.conv = conv
533+
534+
def forward(self, x):
535+
return self.conv(x)
536+
537+
@classmethod
538+
def from_observed(cls, observed_module):
539+
assert hasattr(observed_module, 'qconfig')
540+
assert hasattr(observed_module, 'activation_post_process')
541+
quantized = cls(nnqd.Conv2d.from_float(observed_module.conv))
542+
return quantized
543+
544+
class M(torch.nn.Module):
545+
def __init__(self):
546+
super().__init__()
547+
self.conv = torch.nn.Conv2d(1, 1, 1)
548+
self.custom = CustomModule()
549+
550+
def forward(self, x):
551+
x = self.conv(x)
552+
x = self.custom(x)
553+
return x
554+
555+
class RefM(torch.nn.Module):
556+
def __init__(self):
557+
super().__init__()
558+
self.conv1 = torch.nn.Conv2d(1, 1, 1)
559+
self.conv2 = torch.nn.Conv2d(1, 1, 1)
560+
561+
def forward(self, x):
562+
x = self.conv1(x)
563+
x = self.conv2(x)
564+
return x
565+
566+
data = torch.randn(1, 1, 1, 1)
567+
# instantiate M and RefM and align the parameters
568+
original_m = M()
569+
original_ref_m = RefM()
570+
original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
571+
original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
572+
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach())
573+
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach())
574+
575+
from torch.fx.symbolic_trace import Tracer
576+
577+
# define a custom tracer to not trace through the custom module
578+
579+
class CustomTracer(Tracer):
580+
def is_leaf_module(self, m, module_qualified_name):
581+
return (m.__module__.startswith('torch.nn') and
582+
not isinstance(m, torch.nn.Sequential)) or \
583+
isinstance(m, CustomModule)
584+
585+
# TODO: add other quant types after mixed mode support
586+
for quant_type in [QuantType.STATIC]:
587+
# register observed and quantized custom module classes
588+
register_observed_custom_module_mapping(CustomModule, ObservedCustomModule)
589+
register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule)
590+
591+
m = CustomTracer().trace(original_m).eval()
592+
qconfig_dict = {'': default_qconfig}
593+
# check prepared model
594+
m = prepare_static_fx(m, qconfig_dict)
595+
# calibration
596+
m(data)
597+
# all activation observers are inserted in the top level module
598+
count_check = {
599+
ns.call_module(torch.quantization.MinMaxObserver): 3
600+
}
601+
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
602+
603+
# check converted/quantized model
604+
m = convert_static_fx(m)
605+
count_check = {
606+
ns.call_function(torch.quantize_per_tensor) : 1,
607+
ns.call_module(nnq.Conv2d) : 1,
608+
ns.call_method('dequantize') : 1,
609+
}
610+
self.checkGraphModuleNodes(m, expected_node_occurrence=count_check)
611+
res = m(data)
612+
613+
# quantize the reference model
614+
ref_m = symbolic_trace(original_ref_m).eval()
615+
ref_m = prepare_fx(ref_m, qconfig_dict)
616+
ref_m(data)
617+
ref_m = convert_fx(ref_m)
618+
ref_res = ref_m(data)
619+
self.assertEqual(res, ref_res)
620+
485621
class TestQuantizeFxOps(QuantizationTestCase):
486622
"""Unit tests for individual ops
487623
"""

‎torch/nn/quantized/modules/conv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def __setstate__(self, state):
146146

147147
@classmethod
148148
def get_qconv(cls, mod, activation_post_process, weight_post_process=None):
149-
r"""Creates a qconv object and returns it.
149+
r"""Creates a qconv object and returns it.
150150
"""
151151
if weight_post_process is None:
152152
weight_post_process = mod.qconfig.weight()

‎torch/quantization/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .quantize_fx import *
1010
from .quantization_mappings import *
1111
from .fuser_method_mappings import *
12+
from .custom_module_class_mappings import *
1213

1314
def default_eval_fn(model, calib_data):
1415
r"""
@@ -40,6 +41,11 @@ def default_eval_fn(model, calib_data):
4041
'get_compare_output_module_list',
4142
'register_quantized_operator_mapping', 'get_quantized_operator',
4243
'register_fuser_method', 'get_fuser_method',
44+
'register_observed_custom_module_mapping',
45+
'get_observed_custom_module_class',
46+
'register_quantized_custom_mdoule_mapping',
47+
'get_quantized_custom_module_class',
48+
'is_custom_module_class',
4349
# Sub functions for `prepare` and `swap_module`
4450
'propagate_qconfig_', 'add_quant_dequant', 'add_observer_', 'swap_module',
4551
'default_eval_fn', 'get_observer_dict',
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS = dict()
2+
3+
def register_observed_custom_module_mapping(float_custom_module_class, observed_custom_module_class):
4+
""" Register a mapping from `float_custom_module_class` to
5+
`observed_custom_module_class`
6+
`observed_custom_module_class` will have a `from_float` classmethod,
7+
which will return an observed custom module instance given
8+
a float custom module instance.
9+
This will be used in prepare step of post training static quantization or
10+
quantization aware training
11+
"""
12+
assert hasattr(observed_custom_module_class, 'from_float'), 'from_float must be' + \
13+
' defined in observed custom module class'
14+
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \
15+
observed_custom_module_class
16+
17+
def get_observed_custom_module_class(float_custom_module_class):
18+
""" Get the corresponding observed module class for a given
19+
float custom module.
20+
"""
21+
observed_custom_module_class = \
22+
OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None)
23+
assert observed_custom_module_class is not None, \
24+
'Float Custom module class {}'.format(float_custom_module_class) + \
25+
' does not have a corresponding observed module class'
26+
return observed_custom_module_class
27+
28+
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS = dict()
29+
30+
def register_quantized_custom_module_mapping(float_custom_module_class, quantized_custom_module_class):
31+
""" Register a mapping from `float_custom_module_class` to `quantized_custom_module_class`
32+
A quantized custom module class should accept quantized input and
33+
return quantized output. (we can relax this condition in the
34+
future if there is a need)
35+
`quantized_custom_module_class` will have a `from_observed` classmethod,
36+
which will return an quantized custom module instance given
37+
a observed custom module instance.
38+
This will be used in prepare step of post training static quantization or
39+
quantization aware training
40+
"""
41+
assert hasattr(quantized_custom_module_class, 'from_observed'), 'from_observed' + \
42+
' must be defined in quantized custom module class'
43+
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS[float_custom_module_class] = \
44+
quantized_custom_module_class
45+
46+
def get_quantized_custom_module_class(float_custom_module_class):
47+
""" Get the corresponding quantized module class for a given
48+
float custom module.
49+
"""
50+
quantized_custom_module_class = \
51+
QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS.get(float_custom_module_class, None)
52+
assert quantized_custom_module_class is not None, \
53+
'Float Custom module class {}'.format(float_custom_module_class) + \
54+
' does not have a corresponding quantized module class'
55+
return quantized_custom_module_class
56+
57+
def is_custom_module_class(module_class):
58+
""" Check if a given module class is a custom module class
59+
"""
60+
return module_class in OBSERVED_CUSTOM_MODULE_CLASS_MAPPINGS and \
61+
module_class in QUANTIZED_CUSTOM_MODULE_CLASS_MAPPINGS
62+
63+
def mark_observed_custom_module(module, custom_module_class):
64+
""" Mark a module as observed custom module, so that
65+
it can be identified during convert step
66+
"""
67+
module._is_observed_custom_module = True
68+
module._FLOAT_MODULE = custom_module_class
69+
70+
def is_observed_custom_module(module):
71+
""" Check if a module is marked as observed custom module
72+
or not
73+
"""
74+
return hasattr(module, '_is_observed_custom_module') and \
75+
module._is_observed_custom_module

‎torch/quantization/fx/quantization_patterns.py

+25
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
get_static_quant_module_class,
77
get_quantized_operator,
88
)
9+
from ..custom_module_class_mappings import (
10+
get_quantized_custom_module_class,
11+
)
912
from .pattern_utils import (
1013
register_quant_pattern,
1114
register_dynamic_quant_pattern,
@@ -507,6 +510,28 @@ def convert(self, quantizer, node):
507510
quantizer.quantized_graph,
508511
node, quantizer.activation_post_process_map[node.name])
509512

513+
class CustomModuleQuantizeHandler(QuantizeHandler):
514+
def convert(self, quantizer, node, load_arg, debug=False):
515+
""" Convert a float custom module to quantized custom module
516+
"""
517+
assert node.op == 'call_module'
518+
observed_custom_module = quantizer.modules[node.target]
519+
if node.name in quantizer.activation_post_process_map:
520+
observed_custom_module.activation_post_process = \
521+
quantizer.activation_post_process_map[node.name]
522+
quantized_custom_module_class = \
523+
get_quantized_custom_module_class(observed_custom_module._FLOAT_MODULE)
524+
quantized_custom_module = \
525+
quantized_custom_module_class.from_observed(observed_custom_module)
526+
parent_name, name = _parent_name(node.target)
527+
setattr(quantizer.modules[parent_name], name, quantized_custom_module)
528+
# hardcoded the qunatized input to be None (take whatever is in the environemnt),
529+
# we can extend this
530+
# if there is a need, e.g. get the indexes of quantized inputs from some
531+
# module attribute like module._QUANTIZED_INPUT_INDEXES
532+
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
533+
534+
510535
# 2. Post Training Dynamic Quantizatoin Patterns
511536
@register_dynamic_quant_pattern(torch.nn.Linear)
512537
@register_dynamic_quant_pattern(torch.nn.functional.linear)

‎torch/quantization/fx/quantize.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
from ..quantization_mappings import (
1919
get_qat_module_mappings,
2020
)
21+
from ..custom_module_class_mappings import (
22+
is_custom_module_class,
23+
get_observed_custom_module_class,
24+
mark_observed_custom_module,
25+
is_observed_custom_module,
26+
)
2127

2228
from ..quantize import _remove_qconfig
2329

@@ -193,7 +199,6 @@ def _prepare(self, model, qconfig_dict, inplace, is_dynamic_quant):
193199
if not inplace:
194200
model = copy.deepcopy(model)
195201
self.is_dynamic_quant = is_dynamic_quant
196-
# TODO: allow user specified patterns
197202
if self.is_dynamic_quant:
198203
self.patterns = get_dynamic_quant_patterns()
199204
else:
@@ -235,6 +240,8 @@ def load_arg(a):
235240
env[node.name] = observed_graph.node_copy(node, load_arg)
236241
elif root_node is node:
237242
env[node.name] = observed_graph.node_copy(node, load_arg)
243+
if qconfig is None:
244+
continue
238245

239246
def insert_observer(node, observer, device):
240247
get_new_observer_name = get_new_attr_name_with_prefix(prefix)
@@ -246,10 +253,22 @@ def insert_observer(node, observer, device):
246253
if device:
247254
getattr(model, observer_name).to(device)
248255

256+
if isinstance(obj, CustomModuleQuantizeHandler):
257+
custom_module = self.modules[node.target]
258+
observed_custom_module_class = \
259+
get_observed_custom_module_class(type(custom_module))
260+
observed_custom_module = \
261+
observed_custom_module_class.from_float(custom_module)
262+
mark_observed_custom_module(observed_custom_module, type(custom_module))
263+
parent_name, name = _parent_name(node.target)
264+
setattr(self.modules[parent_name], name, observed_custom_module)
265+
249266
# don't need to insert observer for output in dynamic quantization
250267
if self.is_dynamic_quant:
251268
continue
252269

270+
# inserting observers for output of observed module, or mark the output
271+
# as observed
253272
if isinstance(obj, CopyNode):
254273
assert node.op in [
255274
'call_module',
@@ -355,6 +374,7 @@ def _convert(self, model, inplace=False, debug=False, is_dynamic_quant=False):
355374
self.modules = dict(model.named_modules())
356375

357376
matches = self._find_matches(model.graph, self.modules, self.patterns)
377+
358378
quants = self._find_quants(model.graph, matches)
359379
self.quantized_graph = Graph()
360380
env = {}
@@ -619,6 +639,16 @@ def record_match(pattern, node, matched):
619639
all_matched.add(n.name)
620640
# break after finding the first match
621641
break
642+
643+
# add custom module instances to the match result
644+
for node in graph.nodes:
645+
if node.op == 'call_module' and \
646+
(is_custom_module_class(type(self.modules[node.target])) or
647+
is_observed_custom_module(self.modules[node.target])):
648+
custom_module_qconfig = self.qconfig_map[node.name]
649+
match_map[node.name] = (
650+
node, [node], CustomModuleQuantizeHandler(self, node), custom_module_qconfig)
651+
622652
return match_map
623653

624654
def _find_quants(self, graph, matches):

0 commit comments

Comments
 (0)
Please sign in to comment.