@@ -84,43 +84,51 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
84
84
%r = aten::conv2d(%a, %w, %b, %stride, %padding, %dilation, %groups)
85
85
return (%r) )" ;
86
86
87
- std::string conv_transpose2d_for_deprecated_conv = R"(
87
+ std::string conv1d_for_deprecated_conv = R"(
88
88
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
89
89
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
90
90
%deterministic:bool, %cudnn_enabled:bool):
91
- %r = aten::conv_transpose2d (%a, %w, %b, %stride, %padding, %output_padding , %groups, %dilation )
91
+ %r = aten::conv1d (%a, %w, %b, %stride, %padding, %dilation , %groups)
92
92
return (%r) )" ;
93
- std::string conv_transpose2d = R"(
93
+ std::string conv1d = R"(
94
94
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
95
95
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
96
96
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
97
- %r = aten::conv_transpose2d (%a, %w, %b, %stride, %padding, %output_padding , %groups, %dilation )
97
+ %r = aten::conv1d (%a, %w, %b, %stride, %padding, %dilation , %groups)
98
98
return (%r) )" ;
99
99
100
- std::string conv1d_for_deprecated_conv = R"(
100
+ std::string conv3d_for_deprecated_conv = R"(
101
101
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
102
102
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
103
103
%deterministic:bool, %cudnn_enabled:bool):
104
- %r = aten::conv1d (%a, %w, %b, %stride, %padding, %dilation, %groups)
104
+ %r = aten::conv3d (%a, %w, %b, %stride, %padding, %dilation, %groups)
105
105
return (%r) )" ;
106
- std::string conv1d = R"(
106
+ std::string conv3d = R"(
107
107
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
108
108
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
109
109
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
110
- %r = aten::conv1d (%a, %w, %b, %stride, %padding, %dilation, %groups)
110
+ %r = aten::conv3d (%a, %w, %b, %stride, %padding, %dilation, %groups)
111
111
return (%r) )" ;
112
112
113
- std::string conv3d_for_deprecated_conv = R"(
113
+ std::string conv_transpose1d = R"(
114
+ graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
115
+ %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
116
+ %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
117
+ %r = aten::conv_transpose1d(%a, %w, %b, %stride, %padding, %output_padding, %groups, %dilation)
118
+ return (%r) )" ;
119
+
120
+ std::string conv_transpose2d_for_deprecated_conv = R"(
114
121
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
115
122
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
116
123
%deterministic:bool, %cudnn_enabled:bool):
117
- %r = aten::conv3d (%a, %w, %b, %stride, %padding, %dilation , %groups)
124
+ %r = aten::conv_transpose2d (%a, %w, %b, %stride, %padding, %output_padding , %groups, %dilation )
118
125
return (%r) )" ;
119
- std::string conv3d = R"(
126
+
127
+ std::string conv_transpose2d = R"(
120
128
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
121
129
%transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
122
130
%deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
123
- %r = aten::conv3d (%a, %w, %b, %stride, %padding, %dilation , %groups)
131
+ %r = aten::conv_transpose2d (%a, %w, %b, %stride, %padding, %output_padding , %groups, %dilation )
124
132
return (%r) )" ;
125
133
126
134
// Filter the unsupported case
@@ -146,6 +154,29 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
146
154
}
147
155
return !calc_value_map[" transposed" ].toBool ();
148
156
};
157
+ auto filter_conv3d = [](const Match& match,
158
+ const std::unordered_map<std::string, Value*>& vmap) {
159
+ auto calc_value_map = getConvParams (match, vmap);
160
+ if (calc_value_map[" output_padding" ].toIntList ().size () != 3 ||
161
+ calc_value_map[" stride" ].toIntList ().size () != 3 ||
162
+ calc_value_map[" padding" ].toIntList ().size () != 3 ||
163
+ calc_value_map[" dilation" ].toIntList ().size () != 3 ) {
164
+ return false ;
165
+ }
166
+ return !calc_value_map[" transposed" ].toBool ();
167
+ };
168
+ auto filter_conv_transpose1d =
169
+ [](const Match& match,
170
+ const std::unordered_map<std::string, Value*>& vmap) {
171
+ auto calc_value_map = getConvParams (match, vmap);
172
+ if (calc_value_map[" output_padding" ].toIntList ().size () != 1 ||
173
+ calc_value_map[" stride" ].toIntList ().size () != 1 ||
174
+ calc_value_map[" padding" ].toIntList ().size () != 1 ||
175
+ calc_value_map[" dilation" ].toIntList ().size () != 1 ) {
176
+ return false ;
177
+ }
178
+ return calc_value_map[" transposed" ].toBool ();
179
+ };
149
180
auto filter_conv_transpose2d =
150
181
[](const Match& match,
151
182
const std::unordered_map<std::string, Value*>& vmap) {
@@ -158,39 +189,36 @@ void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
158
189
}
159
190
return calc_value_map[" transposed" ].toBool ();
160
191
};
161
- auto filter_conv3d = [](const Match& match,
162
- const std::unordered_map<std::string, Value*>& vmap) {
163
- auto calc_value_map = getConvParams (match, vmap);
164
- if (calc_value_map[" output_padding" ].toIntList ().size () != 3 ||
165
- calc_value_map[" stride" ].toIntList ().size () != 3 ||
166
- calc_value_map[" padding" ].toIntList ().size () != 3 ||
167
- calc_value_map[" dilation" ].toIntList ().size () != 3 ) {
168
- return false ;
169
- }
170
- return !calc_value_map[" transposed" ].toBool ();
171
- };
172
192
173
193
SubgraphRewriter rewriter_conv1d;
174
194
rewriter_conv1d.RegisterRewritePattern (convolution, conv1d);
175
195
rewriter_conv1d.RegisterRewritePattern (
176
196
convolution_deprecated, conv1d_for_deprecated_conv);
177
197
rewriter_conv1d.runOnGraph (graph, filter_conv1d);
198
+
178
199
SubgraphRewriter rewriter_conv2d;
179
200
rewriter_conv2d.RegisterRewritePattern (convolution, conv2d);
180
201
rewriter_conv2d.RegisterRewritePattern (
181
202
convolution_deprecated, conv2d_for_deprecated_conv);
182
203
rewriter_conv2d.runOnGraph (graph, filter_conv2d);
204
+
205
+ SubgraphRewriter rewriter_conv3d;
206
+ rewriter_conv3d.RegisterRewritePattern (convolution, conv3d);
207
+ rewriter_conv3d.RegisterRewritePattern (
208
+ convolution_deprecated, conv3d_for_deprecated_conv);
209
+ rewriter_conv3d.runOnGraph (graph, filter_conv3d);
210
+
211
+ SubgraphRewriter rewriter_conv_transpose1d;
212
+ rewriter_conv_transpose1d.RegisterRewritePattern (
213
+ convolution, conv_transpose1d);
214
+ rewriter_conv_transpose1d.runOnGraph (graph, filter_conv_transpose1d);
215
+
183
216
SubgraphRewriter rewriter_conv_transpose2d;
184
217
rewriter_conv_transpose2d.RegisterRewritePattern (
185
218
convolution, conv_transpose2d);
186
219
rewriter_conv_transpose2d.RegisterRewritePattern (
187
220
convolution_deprecated, conv_transpose2d_for_deprecated_conv);
188
221
rewriter_conv_transpose2d.runOnGraph (graph, filter_conv_transpose2d);
189
- SubgraphRewriter rewriter_conv3d;
190
- rewriter_conv3d.RegisterRewritePattern (convolution, conv3d);
191
- rewriter_conv3d.RegisterRewritePattern (
192
- convolution_deprecated, conv3d_for_deprecated_conv);
193
- rewriter_conv3d.runOnGraph (graph, filter_conv3d);
194
222
}
195
223
196
224
bool isClampFusable (
0 commit comments