Skip to content

Commit 958c208

Browse files
Zafarfacebook-github-bot
Zafar
authored andcommittedSep 26, 2020
[quant] conv_transpose graph patterns (pytorch#45078)
Summary: Pull Request resolved: pytorch#45078 Test Plan: Imported from OSS Reviewed By: vkuzo Differential Revision: D23821580 Pulled By: z-a-f fbshipit-source-id: 813a4ef1bbc429720765d61791fe754b6678a334
1 parent 606b1a9 commit 958c208

File tree

7 files changed

+221
-35
lines changed

7 files changed

+221
-35
lines changed
 

‎test/quantization/test_quantize_jit.py

+31
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
SkipQuantModel,
5252
NestedModel,
5353
ConvModel,
54+
ConvTransposeModel,
5455
default_per_channel_qconfig,
5556
test_only_eval_fn,
5657
ConvBnModel,
@@ -61,6 +62,7 @@
6162
AnnotatedSkipQuantModel,
6263
AnnotatedNestedModel,
6364
AnnotatedConvModel,
65+
AnnotatedConvTransposeModel,
6466
AnnotatedConvBnModel,
6567
)
6668

@@ -3171,6 +3173,35 @@ def test_conv(self):
31713173
inplace=False)
31723174
self.assertEqual(model_quantized(self.img_data_2d[0][0]), result_eager)
31733175

3176+
@override_qengines
3177+
def test_conv_transpose(self):
3178+
r"""Compare the result of quantizing conv_transpose layer in
3179+
eager mode and graph mode
3180+
"""
3181+
if not qengine_is_qnnpack():
3182+
return # Currently only qnnpack is supported
3183+
# eager mode
3184+
annotated_conv_model = AnnotatedConvTransposeModel(
3185+
torch.backends.quantized.engine).eval()
3186+
conv_model = ConvTransposeModel().eval()
3187+
# copy the weight from eager mode so that we can
3188+
# compare the result of the two quantized models later
3189+
conv_model.conv.weight = torch.nn.Parameter(annotated_conv_model.conv.weight.detach())
3190+
model_eager = quantize(annotated_conv_model, test_only_eval_fn, self.img_data_2d)
3191+
qconfig_dict = {'': get_default_qconfig(torch.backends.quantized.engine)}
3192+
model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0])
3193+
model_script = torch.jit.script(conv_model)
3194+
result_eager = model_eager(self.img_data_2d[0][0])
3195+
for model_under_test in [model_traced, model_script]:
3196+
model_quantized = quantize_jit(
3197+
model_under_test,
3198+
qconfig_dict,
3199+
test_only_eval_fn,
3200+
[self.img_data_2d],
3201+
inplace=False)
3202+
self.assertEqual(model_quantized(self.img_data_2d[0][0]),
3203+
result_eager)
3204+
31743205
@override_qengines
31753206
def test_conv_bn(self):
31763207
r"""Compare the result of quantizing conv + bn layer in

‎torch/csrc/jit/passes/graph_rewrite_helper.cpp

+56-28
Original file line numberDiff line numberDiff line change
@@ -84,43 +84,51 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
8484
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
8585
return (%r) )";
8686

87-
std::string conv_transpose2d_for_deprecated_conv = R"(
87+
std::string conv1d_for_deprecated_conv = R"(
8888
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
8989
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
9090
%deterministic:bool, %cudnn_enabled:bool):
91-
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
91+
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
9292
return (%r) )";
93-
std::string conv_transpose2d = R"(
93+
std::string conv1d = R"(
9494
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
9595
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
9696
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
97-
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
97+
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
9898
return (%r) )";
9999

100-
std::string conv1d_for_deprecated_conv = R"(
100+
std::string conv3d_for_deprecated_conv = R"(
101101
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
102102
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
103103
%deterministic:bool, %cudnn_enabled:bool):
104-
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
104+
%r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
105105
return (%r) )";
106-
std::string conv1d = R"(
106+
std::string conv3d = R"(
107107
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
108108
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
109109
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
110-
%r = aten::conv1d(%a, %w, %b, %stride, %padding, %dilation, %groups)
110+
%r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
111111
return (%r) )";
112112

113-
std::string conv3d_for_deprecated_conv = R"(
113+
std::string conv_transpose1d = R"(
114+
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
115+
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
116+
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
117+
%r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
118+
return (%r) )";
119+
120+
std::string conv_transpose2d_for_deprecated_conv = R"(
114121
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
115122
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
116123
%deterministic:bool, %cudnn_enabled:bool):
117-
%r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
124+
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
118125
return (%r) )";
119-
std::string conv3d = R"(
126+
127+
std::string conv_transpose2d = R"(
120128
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
121129
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
122130
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
123-
%r = aten::conv3d(%a, %w, %b, %stride, %padding, %dilation, %groups)
131+
%r = aten::conv_transpose2d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
124132
return (%r) )";
125133

126134
// Filter the unsupported case
@@ -146,6 +154,29 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
146154
}
147155
return !calc_value_map["transposed"].toBool();
148156
};
157+
auto filter_conv3d = [](const Match& match,
158+
const std::unordered_map<std::string, Value*>& vmap) {
159+
auto calc_value_map = getConvParams(match, vmap);
160+
if (calc_value_map["output_padding"].toIntList().size() != 3 ||
161+
calc_value_map["stride"].toIntList().size() != 3 ||
162+
calc_value_map["padding"].toIntList().size() != 3 ||
163+
calc_value_map["dilation"].toIntList().size() != 3) {
164+
return false;
165+
}
166+
return !calc_value_map["transposed"].toBool();
167+
};
168+
auto filter_conv_transpose1d =
169+
[](const Match& match,
170+
const std::unordered_map<std::string, Value*>& vmap) {
171+
auto calc_value_map = getConvParams(match, vmap);
172+
if (calc_value_map["output_padding"].toIntList().size() != 1 ||
173+
calc_value_map["stride"].toIntList().size() != 1 ||
174+
calc_value_map["padding"].toIntList().size() != 1 ||
175+
calc_value_map["dilation"].toIntList().size() != 1) {
176+
return false;
177+
}
178+
return calc_value_map["transposed"].toBool();
179+
};
149180
auto filter_conv_transpose2d =
150181
[](const Match& match,
151182
const std::unordered_map<std::string, Value*>& vmap) {
@@ -158,39 +189,36 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
158189
}
159190
return calc_value_map["transposed"].toBool();
160191
};
161-
auto filter_conv3d = [](const Match& match,
162-
const std::unordered_map<std::string, Value*>& vmap) {
163-
auto calc_value_map = getConvParams(match, vmap);
164-
if (calc_value_map["output_padding"].toIntList().size() != 3 ||
165-
calc_value_map["stride"].toIntList().size() != 3 ||
166-
calc_value_map["padding"].toIntList().size() != 3 ||
167-
calc_value_map["dilation"].toIntList().size() != 3) {
168-
return false;
169-
}
170-
return !calc_value_map["transposed"].toBool();
171-
};
172192

173193
SubgraphRewriter rewriter_conv1d;
174194
rewriter_conv1d.RegisterRewritePattern(convolution, conv1d);
175195
rewriter_conv1d.RegisterRewritePattern(
176196
convolution_deprecated, conv1d_for_deprecated_conv);
177197
rewriter_conv1d.runOnGraph(graph, filter_conv1d);
198+
178199
SubgraphRewriter rewriter_conv2d;
179200
rewriter_conv2d.RegisterRewritePattern(convolution, conv2d);
180201
rewriter_conv2d.RegisterRewritePattern(
181202
convolution_deprecated, conv2d_for_deprecated_conv);
182203
rewriter_conv2d.runOnGraph(graph, filter_conv2d);
204+
205+
SubgraphRewriter rewriter_conv3d;
206+
rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
207+
rewriter_conv3d.RegisterRewritePattern(
208+
convolution_deprecated, conv3d_for_deprecated_conv);
209+
rewriter_conv3d.runOnGraph(graph, filter_conv3d);
210+
211+
SubgraphRewriter rewriter_conv_transpose1d;
212+
rewriter_conv_transpose1d.RegisterRewritePattern(
213+
convolution, conv_transpose1d);
214+
rewriter_conv_transpose1d.runOnGraph(graph, filter_conv_transpose1d);
215+
183216
SubgraphRewriter rewriter_conv_transpose2d;
184217
rewriter_conv_transpose2d.RegisterRewritePattern(
185218
convolution, conv_transpose2d);
186219
rewriter_conv_transpose2d.RegisterRewritePattern(
187220
convolution_deprecated, conv_transpose2d_for_deprecated_conv);
188221
rewriter_conv_transpose2d.runOnGraph(graph, filter_conv_transpose2d);
189-
SubgraphRewriter rewriter_conv3d;
190-
rewriter_conv3d.RegisterRewritePattern(convolution, conv3d);
191-
rewriter_conv3d.RegisterRewritePattern(
192-
convolution_deprecated, conv3d_for_deprecated_conv);
193-
rewriter_conv3d.runOnGraph(graph, filter_conv3d);
194222
}
195223

196224
bool isClampFusable(

‎torch/csrc/jit/passes/quantization/finalize.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,14 @@ void InsertPrepackUnpack(Module& module) {
6565
void FoldQuantizedPrepackingOps(Module& module) {
6666
auto filter_fn = [](const Node* n) -> bool {
6767
return (
68-
(n->kind() == Symbol::fromQualString("quantized::linear_prepack")) ||
68+
n->kind() == Symbol::fromQualString("quantized::linear_prepack") ||
6969
n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") ||
7070
n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") ||
71-
n->kind() == Symbol::fromQualString("quantized::conv3d_prepack"));
71+
n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") ||
72+
n->kind() ==
73+
Symbol::fromQualString("quantized::conv_transpose1d_prepack") ||
74+
n->kind() ==
75+
Symbol::fromQualString("quantized::conv_transpose2d_prepack"));
7276
};
7377
PrePackingOpsFolder(module, filter_fn, "quantized");
7478
}

‎torch/csrc/jit/passes/quantization/helper.cpp

+24-2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ std::vector<std::string> _static_quantizable_aten_funcs = {
3232
"conv1d",
3333
"conv2d",
3434
"conv3d",
35+
"conv_transpose1d",
36+
"conv_transpose2d",
3537
"linear",
3638
"hardswish",
3739
"hardswish_",
@@ -273,6 +275,8 @@ bool isWeight(Value* v) {
273275
AtenFuncArgs({{"conv1d", 1},
274276
{"conv2d", 1},
275277
{"conv3d", 1},
278+
{"conv_transpose1d", 1},
279+
{"conv_transpose2d", 1},
276280
{"linear", 1},
277281
{"embedding_bag", 0}}),
278282
// embedding_bag - prim::CallFunction(%func, %input.1, %weight,
@@ -285,8 +289,12 @@ bool isWeight(Value* v) {
285289
bool isBiasOfConvOrLinear(Value* v) {
286290
bool result = matchArgPattern(
287291
v,
288-
AtenFuncArgs(
289-
{{"conv1d", 2}, {"conv2d", 2}, {"conv3d", 2}, {"linear", 2}}),
292+
AtenFuncArgs({{"conv1d", 2},
293+
{"conv2d", 2},
294+
{"conv3d", 2},
295+
{"conv_transpose1d", 2},
296+
{"conv_transpose2d", 2},
297+
{"linear", 2}}),
290298
CallFuncArgs({{"linear", 3}}));
291299
return result;
292300
}
@@ -728,6 +736,20 @@ bool is_conv3d_module(
728736
match, vmap, "conv", "__torch__.torch.nn.modules.conv.Conv3d");
729737
}
730738

739+
bool is_conv_transpose1d_module(
740+
const Match& match,
741+
const std::unordered_map<std::string, Value*>& vmap) {
742+
return is_module(
743+
match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose1d");
744+
}
745+
746+
bool is_conv_transpose2d_module(
747+
const Match& match,
748+
const std::unordered_map<std::string, Value*>& vmap) {
749+
return is_module(
750+
match, vmap, "conv", "__torch__.torch.nn.modules.conv.ConvTranspose2d");
751+
}
752+
731753
bool is_batchnorm2d_module(
732754
const Match& match,
733755
const std::unordered_map<std::string, Value*>& vmap) {

‎torch/csrc/jit/passes/quantization/helper.h

+8
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,14 @@ bool is_conv3d_module(
194194
const Match& match,
195195
const std::unordered_map<std::string, Value*>& vmap);
196196

197+
bool is_conv_transpose1d_module(
198+
const Match& match,
199+
const std::unordered_map<std::string, Value*>& vmap);
200+
201+
bool is_conv_transpose2d_module(
202+
const Match& match,
203+
const std::unordered_map<std::string, Value*>& vmap);
204+
197205
bool is_batchnorm2d_module(
198206
const Match& match,
199207
const std::unordered_map<std::string, Value*>& vmap);

‎torch/csrc/jit/passes/quantization/quantization_patterns.h

+73-3
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,38 @@ graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %pad
407407
%r_quant = quantized::conv3d_relu(%a_quant, %packed_params, %r_scale, %r_zero_point)
408408
return (%r_quant) )";
409409

410+
// aten::conv_transpose1d
411+
std::string conv_transpose1d = R"(
412+
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
413+
%a_dequant = aten::dequantize(%a_quant)
414+
%w_quant : Tensor, %b : Tensor? = quantized::conv_transpose1d_unpack(%packed_params)
415+
%w_dequant = aten::dequantize(%w_quant)
416+
%r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
417+
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
418+
return (%r_quant) )";
419+
420+
// quantized::conv_transpose1d
421+
std::string quantized_conv_transpose1d = R"(
422+
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
423+
%r_quant = quantized::conv_transpose1d(%a_quant, %packed_params, %r_scale, %r_zero_point)
424+
return (%r_quant) )";
425+
426+
// aten::conv_transpose2d
427+
std::string conv_transpose2d = R"(
428+
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
429+
%a_dequant = aten::dequantize(%a_quant)
430+
%w_quant : Tensor, %b : Tensor? = quantized::conv_transpose2d_unpack(%packed_params)
431+
%w_dequant = aten::dequantize(%w_quant)
432+
%r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
433+
%r_quant = aten::quantize_per_tensor(%r, %r_scale, %r_zero_point, %r_dtype)
434+
return (%r_quant) )";
435+
436+
// quantized::conv_transpose1d
437+
std::string quantized_conv_transpose2d = R"(
438+
graph(%a_quant, %packed_params, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %output_padding, %groups, %dilation):
439+
%r_quant = quantized::conv_transpose2d(%a_quant, %packed_params, %r_scale, %r_zero_point)
440+
return (%r_quant) )";
441+
410442
std::string add_relu = R"(
411443
graph(%a_quant, %b_quant, %alpha, %scale, %zero_point, %dtype):
412444
%a_dequant = aten::dequantize(%a_quant)
@@ -907,6 +939,12 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
907939
{"quantized::conv3d", conv3d, quantized_conv3d},
908940
{"quantized::conv3d_relu", conv3d_relu, quantized_conv3d_relu},
909941
{"quantized::conv3d_relu", conv3d_inplace_relu, quantized_conv3d_relu},
942+
{"quantized::conv_transpose1d",
943+
conv_transpose1d,
944+
quantized_conv_transpose1d},
945+
{"quantized::conv_transpose2d",
946+
conv_transpose2d,
947+
quantized_conv_transpose2d},
910948
{"quantized::linear", linear, quantized_linear},
911949
{"quantized::linear_relu", linear_relu, quantized_linear_relu},
912950
{"quantized::linear_relu", linear_inplace_relu, quantized_linear_relu},
@@ -1128,12 +1166,44 @@ graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
11281166
%r = aten::conv3d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %dilation, %groups)
11291167
return (%r) )";
11301168

1169+
std::string conv_transpose1d_with_quant = R"(
1170+
graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1171+
%w_dequant = aten::dequantize(%w_quant)
1172+
%r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
1173+
return (%r) )";
1174+
1175+
std::string conv_transpose1d_with_quant_prepack = R"(
1176+
graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1177+
%packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose1d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups)
1178+
%w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose1d_unpack(%packed_params)
1179+
%w_dequant = aten::dequantize(%w_quant_unpacked)
1180+
%r = aten::conv_transpose1d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation)
1181+
return (%r) )";
1182+
1183+
std::string conv_transpose2d_with_quant = R"(
1184+
graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1185+
%w_dequant = aten::dequantize(%w_quant)
1186+
%r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b, %stride, %padding, %output_padding, %groups, %dilation)
1187+
return (%r) )";
1188+
1189+
std::string conv_transpose2d_with_quant_prepack = R"(
1190+
graph(%a_dequant, %w_quant, %b, %stride, %padding, %output_padding, %groups, %dilation):
1191+
%packed_params : __torch__.torch.classes.quantized.Conv2dPackedParamsBase = quantized::conv_transpose2d_prepack(%w_quant, %b, %stride, %padding, %output_padding, %dilation, %groups)
1192+
%w_quant_unpacked : Tensor, %b_unpacked : Tensor? = quantized::conv_transpose2d_unpack(%packed_params)
1193+
%w_dequant = aten::dequantize(%w_quant_unpacked)
1194+
%r = aten::conv_transpose2d(%a_dequant, %w_dequant, %b_unpacked, %stride, %padding, %output_padding, %groups, %dilation)
1195+
return (%r) )";
1196+
11311197
return {
11321198
{"conv1d_prepack_unpack", conv1d_with_quant, conv1d_with_quant_prepack},
11331199
{"conv2d_prepack_unpack", conv2d_with_quant, conv2d_with_quant_prepack},
1134-
{"conv3d_prepack_unpack", conv3d_with_quant, conv3d_with_quant_prepack}
1135-
1136-
};
1200+
{"conv3d_prepack_unpack", conv3d_with_quant, conv3d_with_quant_prepack},
1201+
{"conv_transpose1d_prepack_unpack",
1202+
conv_transpose1d_with_quant,
1203+
conv_transpose1d_with_quant_prepack},
1204+
{"conv_transpose2d_prepack_unpack",
1205+
conv_transpose2d_with_quant,
1206+
conv_transpose2d_with_quant_prepack}};
11371207
}
11381208

11391209
} // namespace jit

‎torch/testing/_internal/common_quantization.py

+23
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,15 @@ def forward(self, x):
810810
x = self.conv(x)
811811
return x
812812

813+
class ConvTransposeModel(torch.nn.Module):
814+
def __init__(self):
815+
super().__init__()
816+
self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
817+
818+
def forward(self, x):
819+
x = self.conv(x)
820+
return x
821+
813822
class AnnotatedConvModel(torch.nn.Module):
814823
def __init__(self, qengine):
815824
super().__init__()
@@ -824,6 +833,20 @@ def forward(self, x):
824833
x = self.dequant(x)
825834
return x
826835

836+
class AnnotatedConvTransposeModel(torch.nn.Module):
837+
def __init__(self, qengine):
838+
super().__init__()
839+
self.qconfig = torch.quantization.get_default_qconfig(qengine)
840+
self.conv = torch.nn.ConvTranspose2d(3, 5, 3, bias=False).to(dtype=torch.float)
841+
self.quant = QuantStub()
842+
self.dequant = DeQuantStub()
843+
844+
def forward(self, x):
845+
x = self.quant(x)
846+
x = self.conv(x)
847+
x = self.dequant(x)
848+
return x
849+
827850
class ConvBnModel(torch.nn.Module):
828851
def __init__(self):
829852
super().__init__()

0 commit comments

Comments
 (0)
Please sign in to comment.