@@ -29,6 +29,23 @@ bool isSupportedForBlock(Node* node) {
29
29
}
30
30
}
31
31
32
+ bool usedOnlyInSize (Value* v) {
33
+ const auto & uses = v->uses ();
34
+ return std::all_of (uses.begin (), uses.end (), [](const Use& u) {
35
+ return u.user ->matches (" aten::size(Tensor self) -> int[]" );
36
+ });
37
+ }
38
+
39
+ Value* broadcastSizes (at::ArrayRef<Value*> sizes, AliasDb* db) {
40
+ AT_ASSERT (!sizes.empty ());
41
+ Graph* graph = sizes[0 ]->owningGraph ();
42
+ Node* broadcast_n =
43
+ graph->insertNode (graph->create (prim::BroadcastSizes, sizes));
44
+ broadcast_n->output ()->setType (ListType::ofInts ());
45
+ db->createValue (broadcast_n->output ());
46
+ return broadcast_n->output ();
47
+ }
48
+
32
49
namespace tensorexpr {
33
50
bool isSupported (Node* node) {
34
51
// For Block codegen we allow limited ops.
@@ -287,6 +304,132 @@ class TensorExprFuser {
287
304
min_group_size_ (min_group_size),
288
305
disable_shape_checks_(disable_shape_checks) {}
289
306
307
+ // Builds up expressions that compute shapes of all intermediates (and
308
+ // outputs) of the fusion group, based on the sizes of inputs. You should run
309
+ // DCE to remove those that you end up not using.
310
+ std::unordered_map<Value*, Value*> buildShapeExpressions (Node* fusion_group) {
311
+ GRAPH_DUMP (" buildShapeExpressions for " , fusion_group->g (attr::Subgraph));
312
+ WithInsertPoint insert_guard{fusion_group->next ()};
313
+ std::unordered_map<Value*, Value*> shape_of;
314
+
315
+ Graph* graph = fusion_group->owningGraph ();
316
+ auto subgraph = fusion_group->g (attr::Subgraph);
317
+
318
+ auto inputs = fusion_group->inputs ();
319
+ auto sinputs = subgraph->inputs ();
320
+ AT_ASSERT (inputs.size () == sinputs.size ());
321
+ for (size_t i = 0 ; i < inputs.size (); ++i) {
322
+ if (inputs[i]->type ()->isSubtypeOf (TensorType::get ())) {
323
+ Value* soutput = graph->insert (aten::size, {inputs[i]});
324
+ aliasDb_->createValue (soutput);
325
+ GRAPH_DEBUG (
326
+ " Adding a mapping for %" ,
327
+ sinputs[i]->debugName (),
328
+ " " ,
329
+ getHeader (soutput->node ()));
330
+ shape_of[sinputs[i]] = soutput;
331
+ }
332
+ }
333
+
334
+ // When we have a guarantee that an output won't be removed, because it's
335
+ // used in expressions that don't involve size checks, we can use its size
336
+ // instead of computing a long chain of broadcasts, starting from the
337
+ // beginning of the kernel.
338
+ auto outputs = fusion_group->outputs ();
339
+ auto soutputs = subgraph->outputs ();
340
+ AT_ASSERT (outputs.size () == soutputs.size ());
341
+ for (size_t i = 0 ; i < outputs.size (); ++i) {
342
+ if (usedOnlyInSize (outputs[i]))
343
+ continue ;
344
+ Value* soutput = graph->insert (aten::size, {outputs[i]});
345
+ aliasDb_->createValue (soutput);
346
+ shape_of[soutputs[i]] = soutput;
347
+ }
348
+
349
+ for (Node* n : subgraph->nodes ()) {
350
+ // XXX: Use of shape_of.emplace is crucial to the output shape
351
+ // optimization!
352
+ if (n->kind () == aten::cat) {
353
+ // This is a bit more involved, because we have to account for the case
354
+ // when inputs have different shapes, but fortunately those tensors are
355
+ // always outputs, and so we can simply avoid replacing their queries,
356
+ // because it won't help us.
357
+ continue ;
358
+ }
359
+ if (n->kind () == prim::Constant) {
360
+ continue ;
361
+ }
362
+ if (n->kind () == prim::ConstantChunk) {
363
+ Node* sizes_node = graph->insertNode (
364
+ graph->create (prim::ChunkSizes, shape_of.at (n->input ()), 2 ));
365
+ sizes_node->i_ (attr::dim, n->i (attr::dim));
366
+ sizes_node->i_ (attr::chunks, n->i (attr::chunks));
367
+ for (Value* output : sizes_node->outputs ()) {
368
+ aliasDb_->createValue (output);
369
+ }
370
+ Value* regular_size = sizes_node->outputs ().at (0 );
371
+ Value* last_size = sizes_node->outputs ().at (1 );
372
+ regular_size->setType (ListType::ofInts ());
373
+ last_size->setType (ListType::ofInts ());
374
+ auto outputs = n->outputs ();
375
+ for (Value* o : outputs.slice (0 , outputs.size () - 1 )) {
376
+ shape_of.emplace (o, regular_size);
377
+ }
378
+ shape_of.emplace (outputs.at (outputs.size () - 1 ), last_size);
379
+ continue ;
380
+ }
381
+ auto tensor_inputs = filter (n->inputs (), [](Value* v) {
382
+ return v->type ()->isSubtypeOf (TensorType::get ());
383
+ });
384
+ GRAPH_DEBUG (" Building sizes for " , getHeader (n));
385
+ bool all_inputs_have_sizes = true ;
386
+ auto shapes = fmap (tensor_inputs, [&](Value* v) {
387
+ GRAPH_DEBUG (" Getting aten::size for %" , v->debugName ());
388
+ all_inputs_have_sizes &= shape_of.count (v);
389
+ return shape_of.count (v) != 0 ? shape_of.at (v) : nullptr ;
390
+ });
391
+
392
+ if (!all_inputs_have_sizes) {
393
+ GRAPH_DEBUG (
394
+ " Not all tensor arguments have sizes available to compute the broadcasted size" ,
395
+ getHeader (n));
396
+ continue ;
397
+ }
398
+ shape_of.emplace (
399
+ n->output (),
400
+ shapes.size () == 1 ? shapes[0 ]
401
+ : broadcastSizes (shapes, aliasDb_.get ()));
402
+ }
403
+ return shape_of;
404
+ }
405
+
406
+ void removeOutputsUsedOnlyInSize (Node* fusion_group) {
407
+ if (fusion_group->kind () != prim::TensorExprGroup)
408
+ return ;
409
+ auto subgraph = fusion_group->g (attr::Subgraph);
410
+
411
+ auto shape_of = buildShapeExpressions (fusion_group);
412
+ auto outputs = fusion_group->outputs ().vec ();
413
+ auto soutputs = subgraph->outputs ().vec ();
414
+ // XXX: Iterating in this order is not only good for performance reasons!
415
+ // It is also crucial for correctness (i has to reflect the current true
416
+ // index of outputs[i])!
417
+ for (int64_t i = static_cast <int64_t >(outputs.size ()) - 1 ; i >= 0 ; --i) {
418
+ auto output = outputs[i];
419
+ auto soutput = soutputs[i];
420
+ if (usedOnlyInSize (output) && shape_of.count (soutput) > 0 ) {
421
+ auto uses = output->uses ();
422
+ for (Use u : uses) {
423
+ AT_ASSERT (u.user ->matches (" aten::size(Tensor self) -> int[]" ));
424
+ u.user ->output ()->replaceAllUsesWith (shape_of.at (soutput));
425
+ u.user ->destroy ();
426
+ }
427
+ fusion_group->eraseOutput (i);
428
+ subgraph->eraseOutput (i);
429
+ }
430
+ }
431
+ }
432
+
290
433
void run () {
291
434
aliasDb_ = torch::make_unique<AliasDb>(graph_);
292
435
RemoveRedundantProfiles (graph_);
@@ -298,7 +441,7 @@ class TensorExprFuser {
298
441
// fusion is done.
299
442
inlineSmallFusionGroups (graph_->block ());
300
443
GRAPH_DUMP (" After inlining small fusion groups: " , graph_);
301
- guardFusionGroups (graph_->block ());
444
+ guardFusionGroupsAndRemoveOutputs (graph_->block ());
302
445
GRAPH_DUMP (" After guarding fusion groups: " , graph_);
303
446
removeTensorTypeSpecializations (graph_->block ());
304
447
GRAPH_DUMP (" After removing tensor type specializations: " , graph_);
@@ -772,17 +915,18 @@ class TensorExprFuser {
772
915
}
773
916
}
774
917
775
- void guardFusionGroups (Block* block) {
918
+ void guardFusionGroupsAndRemoveOutputs (Block* block) {
776
919
std::vector<Node*> fusion_groups;
777
920
for (Node* n : block->nodes ()) {
778
921
for (Block* b : n->blocks ()) {
779
- guardFusionGroups (b);
922
+ guardFusionGroupsAndRemoveOutputs (b);
780
923
}
781
924
if (n->kind () == prim::TensorExprGroup) {
782
925
fusion_groups.push_back (n);
783
926
}
784
927
}
785
928
for (Node* fusion_group : fusion_groups) {
929
+ removeOutputsUsedOnlyInSize (fusion_group);
786
930
guardFusionGroup (fusion_group);
787
931
}
788
932
}
0 commit comments