Skip to content

Commit 95a97e5

Browse files
neginraooffacebook-github-bot
authored andcommittedSep 28, 2020
[ONNX] Improve scripting inplace indexing ops (pytorch#44351)
Summary: Fix a couple of issues with scripting inplace indexing in prepare_inplace_ops_for_onnx pass. 1- Tracing index copy (such as cases lik x[1:3] = data) already applies broadcasting on rhs if needed. The broadcasting node (aten::expand) is missing in scripting cases. 2- Inplace indexing with ellipsis (aten::copy_) is replaced with aten::index_put and then handled with slice+select in this pass. Support for negative indices for this op added. Shape inference is also enabled for scripting tests using new JIT API. A few more tests are enabled for scripting. Pull Request resolved: pytorch#44351 Reviewed By: ezyang Differential Revision: D23880267 Pulled By: bzinodev fbshipit-source-id: 78b33444633eb7ae0fbabc7415e3b16001f5207f
1 parent 13f76f2 commit 95a97e5

8 files changed

+304
-154
lines changed
 

‎aten/src/ATen/core/interned_strings.h

+1
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ namespace c10 {
272272
_(prim, grad) \
273273
_(aten, zero_) \
274274
_(aten, fill_) \
275+
_(aten, masked_fill_) \
275276
FORALL_ATEN_BASE_SYMBOLS(_) \
276277
_(onnx, Add) \
277278
_(onnx, Concat) \

‎test/onnx/test_models.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ class TestModels(TestCase):
4949
opset_version = _export_onnx_opset_version
5050

5151
def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7):
52-
self.is_script_test_enabled = True
5352
with torch.onnx.select_model_mode_for_export(model, None):
5453
graph = torch.onnx.utils._trace(model, inputs, OperatorExportTypes.ONNX)
5554
torch._C._jit_pass_lint(graph)
@@ -94,14 +93,12 @@ def test_srresnet(self):
9493
self.exportTest(toC(SRResNet(rescale_factor=4, n_filters=64, n_blocks=8)), toC(x))
9594

9695
@skipIfNoLapack
97-
@disableScriptTest()
9896
def test_super_resolution(self):
9997
x = Variable(
10098
torch.randn(BATCH_SIZE, 1, 224, 224).fill_(1.0)
10199
)
102100
self.exportTest(toC(SuperResolutionNet(upscale_factor=3)), toC(x), atol=1e-6)
103101

104-
@disableScriptTest()
105102
def test_alexnet(self):
106103
x = Variable(
107104
torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
@@ -137,13 +134,12 @@ def test_vgg19_bn(self):
137134
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
138135
self.exportTest(toC(vgg19_bn()), toC(x))
139136

140-
@disableScriptTest()
141137
def test_resnet(self):
142138
# ResNet50 model
143139
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
144140
self.exportTest(toC(resnet50()), toC(x), atol=1e-6)
145141

146-
@disableScriptTest()
142+
@disableScriptTest() # None type in outputs
147143
def test_inception(self):
148144
x = Variable(
149145
torch.randn(BATCH_SIZE, 3, 299, 299) + 1.)
@@ -208,22 +204,20 @@ def test_qat_resnet(self):
208204

209205
self.exportTest(toC(qat_resnet50), toC(x))
210206

211-
@disableScriptTest()
207+
@disableScriptTest() # None type in outputs
212208
def test_googlenet(self):
213209
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
214210
self.exportTest(toC(googlenet()), toC(x), rtol=1e-3, atol=1e-5)
215211

216-
@disableScriptTest()
217212
def test_mnasnet(self):
218213
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
219214
self.exportTest(toC(mnasnet1_0()), toC(x), rtol=1e-3, atol=1e-5)
220215

221-
@disableScriptTest()
222216
def test_mobilenet(self):
223217
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
224218
self.exportTest(toC(mobilenet_v2()), toC(x), rtol=1e-3, atol=1e-5)
225219

226-
@disableScriptTest()
220+
@disableScriptTest() # prim_data
227221
def test_shufflenet(self):
228222
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
229223
self.exportTest(toC(shufflenet_v2_x1_0()), toC(x), rtol=1e-3, atol=1e-5)
@@ -238,20 +232,18 @@ def test_deeplab(self):
238232
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
239233
self.exportTest(toC(deeplabv3_resnet101()), toC(x), rtol=1e-3, atol=1e-5)
240234

241-
@disableScriptTest()
242235
def test_r3d_18_video(self):
243236
x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
244237
self.exportTest(toC(r3d_18()), toC(x), rtol=1e-3, atol=1e-5)
245238

246-
@disableScriptTest()
247239
def test_mc3_18_video(self):
248240
x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
249241
self.exportTest(toC(mc3_18()), toC(x), rtol=1e-3, atol=1e-5)
250242

251-
@disableScriptTest()
252243
def test_r2plus1d_18_video(self):
253244
x = Variable(torch.randn(1, 3, 4, 112, 112).fill_(1.0))
254245
self.exportTest(toC(r2plus1d_18()), toC(x), rtol=1e-3, atol=1e-5)
255246

247+
256248
if __name__ == '__main__':
257249
run_tests()

‎test/onnx/test_models_onnxruntime.py

+21-3
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,31 @@ def exportTest(self, model, inputs, rtol=1e-2, atol=1e-7, opset_versions=None):
1515
input=inputs, rtol=rtol, atol=atol)
1616

1717
if self.is_script_test_enabled and opset_version > 11:
18+
TestModels.use_new_jit_passes = True
19+
TestModels.onnx_shape_inference = True
20+
1821
outputs = model(inputs)
1922
script_model = torch.jit.script(model)
2023
run_model_test(self, script_model, False, example_outputs=outputs,
21-
input=inputs, rtol=rtol, atol=atol, use_new_jit_passes=True)
24+
input=inputs, rtol=rtol, atol=atol)
25+
26+
27+
TestModels = type(str("TestModels"),
28+
(unittest.TestCase,),
29+
dict(TestModels.__dict__,
30+
is_script_test_enabled=False,
31+
exportTest=exportTest))
32+
33+
34+
# model tests for scripting with new JIT APIs and shape inference
35+
TestModels_new_jit_API = type(str("TestModels_new_jit_API"),
36+
(unittest.TestCase,),
37+
dict(TestModels.__dict__,
38+
exportTest=exportTest,
39+
is_script_test_enabled=True,
40+
use_new_jit_passes=True,
41+
onnx_shape_inference=True))
2242

2343

2444
if __name__ == '__main__':
25-
TestModels.is_script_test_enabled = True
26-
TestModels.exportTest = exportTest
2745
unittest.main()

‎test/onnx/test_pytorch_onnx_onnxruntime.py

+225-121
Large diffs are not rendered by default.

‎torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp

+48-16
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,27 @@ std::unordered_map<int64_t, ConvertedIndex> MergeSliceAndSelectToIndices(
105105
// Loop over fetched slice and select nodes and convert them to index tensors.
106106
// keep track of which dimension the current slice/select node is applying to.
107107
int64_t cur_dim = 0;
108-
// select does not keep dims,
109-
// this creates offset for latter slice and select nodes.
110108
int64_t dim_offset = 0;
111109
const auto orig_tensor_indices = index_put_node->input(1)->node()->inputs();
112110
for (auto it = slice_and_select_nodes.rbegin();
113111
it != slice_and_select_nodes.rend();
114112
++it) {
115113
auto node = *it;
116-
auto dim = node->get(attr::dim)->toInt() + dim_offset;
114+
// select does not keep dims,
115+
// this creates offset for latter slice and select nodes.
116+
auto dim = node->get(attr::dim)->toInt();
117+
if (dim < 0) {
118+
auto input_type = node->input(0)->type()->expect<TensorType>();
119+
if (input_type->dim().has_value()) {
120+
auto rank = input_type->dim().value();
121+
dim = dim + rank;
122+
} else {
123+
std::cerr
124+
<< "Error: ONNX Remove Inplace Ops - Cannot export ellipsis indexing for input "
125+
<< "of unknown rank.";
126+
}
127+
}
128+
dim = dim + dim_offset;
117129

118130
while (cur_dim < dim) {
119131
// Handle skipped dims, these are created from ..., or tensor indices
@@ -340,14 +352,23 @@ void PrepareCopyForONNX(Block* block) {
340352
// Remove aten::copy_, and replace it with index_put.
341353
// 1. create an empty listConstruct node as indices input for index_put.
342354
// 2. create index_put node.
355+
356+
// Tracing aten::copy_ broadcasts the rhs values.
357+
// 3. Apply broadcasting for scripting.
343358
WithInsertPoint guard(node);
344359
auto graph = node->owningGraph();
345360
auto dummy_list =
346361
graph->insertNode(graph->createList(OptionalType::ofTensor(), {}))
347362
->output();
363+
364+
auto expanded_value =
365+
graph->insert(aten::expand_as, {node->input(1), node->input(0)});
366+
expanded_value->node()->setSourceRange(node->sourceRange());
367+
expanded_value->copyMetadata(node->input(1));
368+
348369
auto index_put = graph->insert(
349370
aten::index_put,
350-
{node->input(0), dummy_list, node->input(1), node->input(2)});
371+
{node->input(0), dummy_list, expanded_value, node->input(2)});
351372
index_put->node()->setSourceRange(node->sourceRange());
352373
index_put->copyMetadata(node->output());
353374
node->output()->replaceAllUsesWith(index_put);
@@ -452,18 +473,29 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) {
452473
<< "Warning: ONNX Preprocess - Removing mutation on block inputs. "
453474
<< "This changes graph semantics." << std::endl;
454475

455-
auto newNode = node->owningGraph()->create(aten::clone, 1);
456-
newNode->output()->copyMetadata(input);
457-
newNode->addInput(input);
458-
459-
auto* noneNode = node->owningGraph()->create(prim::Constant);
460-
noneNode->output()->setType(NoneType::get());
461-
newNode->addInput(noneNode->output());
462-
463-
newNode->insertBefore(node);
464-
noneNode->insertBefore(newNode);
465-
node->replaceInput(index, newNode->output());
466-
input->replaceAllUsesAfterNodeWith(node, newNode->output());
476+
if (input->type()->kind() == TypeKind::ListType) {
477+
// Create an aten::list to clone the list in graph inputs
478+
auto newNode = node->owningGraph()->create(aten::list, 1);
479+
newNode->output()->copyMetadata(input);
480+
newNode->addInput(input);
481+
newNode->insertBefore(node);
482+
node->replaceInput(index, newNode->output());
483+
input->replaceAllUsesAfterNodeWith(node, newNode->output());
484+
} else {
485+
// Create an aten::clone to clone the tensor in graph inputs
486+
auto newNode = node->owningGraph()->create(aten::clone, 1);
487+
newNode->output()->copyMetadata(input);
488+
newNode->addInput(input);
489+
490+
auto* noneNode = node->owningGraph()->create(prim::Constant);
491+
noneNode->output()->setType(NoneType::get());
492+
newNode->addInput(noneNode->output());
493+
494+
newNode->insertBefore(node);
495+
noneNode->insertBefore(newNode);
496+
node->replaceInput(index, newNode->output());
497+
input->replaceAllUsesAfterNodeWith(node, newNode->output());
498+
}
467499
}
468500
}
469501
}

‎torch/csrc/jit/passes/remove_inplace_ops.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = {
88
{aten::sub_, aten::sub},
99
{aten::div_, aten::div},
1010
{aten::mul_, aten::mul},
11+
{aten::masked_fill_, aten::masked_fill},
1112
{aten::zero_, aten::zeros_like},
1213
{aten::fill_, aten::full_like}};
1314

‎torch/onnx/symbolic_opset11.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def masked_scatter(g, self, mask, source):
272272

273273

274274
def _len(g, self):
275-
if self.type().isSubtypeOf(torch._C.ListType.ofTensors()):
275+
if self.type().isSubtypeOf(torch._C.ListType.ofTensors()) or self.node().kind() == "onnx::SplitToSequence":
276276
return g.op("SequenceLength", self)
277277
return g.op("Size", self)
278278

‎torch/onnx/symbolic_opset9.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,11 @@ def stack(g, tensor_list, dim):
186186
unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in sym_help._unpack_list(tensor_list)]
187187
return g.op("Concat", *unsqueezed, axis_i=dim)
188188

189+
189190
def _list(g, self):
190191
return self
191192

193+
192194
def mm(g, self, other):
193195
# Create a dummy C tensor. Only needed for API purposes, the value is
194196
# since beta = 0
@@ -1558,7 +1560,7 @@ def tensor(g, data, dtype=None, device=None, requires_grad=False):
15581560
return g.op("Concat", *input_list, axis_i=0)
15591561
else:
15601562
if dtype is None:
1561-
dtype = sym_help._maybe_get_const(data, 't').type().scalarType()
1563+
dtype = data.type().scalarType()
15621564
dtype = sym_help.scalar_type_to_onnx.index(sym_help.cast_pytorch_to_onnx[dtype])
15631565
return g.op("Cast", data, to_i=sym_help.scalar_type_to_onnx[dtype])
15641566

0 commit comments

Comments
 (0)
Please sign in to comment.