Skip to content

Commit 993628c

Browse files
Krovatkinfacebook-github-bot
authored andcommittedSep 28, 2020
Build shape expressions and remove outputs that are only used by aten::sizes (pytorch#45080)
Summary: Currently, TE materializes all intermediate results even if they are only used for computing their shapes. This diff ports the approach the OF (Old Fuser) took to deal with this issue. Namely, given the structure of a fusion group we infer all the sizes outside a fusion group based on fusion group's inputs. A simple example would be: ``` def test_fuse(a, b): c = a + b d = c + b return d ``` Here we don't need to cache `c` as computing a gradient for `b` in `d = c + b` doesn't need it. We do need to compute sizes for all arguments here in case broadcasts happen. Without this optimization, TE would need to materialize `c` so we can get its size ``` [DUMP profiling_graph_executor_impl.cpp:499] Optimized Graph: [DUMP profiling_graph_executor_impl.cpp:499] graph(%a.1 : Tensor, [DUMP profiling_graph_executor_impl.cpp:499] %b.1 : Tensor): [DUMP profiling_graph_executor_impl.cpp:499] %11 : Tensor = prim::DifferentiableGraph_0(%b.1, %a.1) [DUMP profiling_graph_executor_impl.cpp:499] return (%11) [DUMP profiling_graph_executor_impl.cpp:499] with prim::DifferentiableGraph_0 = graph(%11 : Tensor, [DUMP profiling_graph_executor_impl.cpp:499] %13 : Tensor): [DUMP profiling_graph_executor_impl.cpp:499] %59 : int[] = aten::size(%13) # <string>:3:44 [DUMP profiling_graph_executor_impl.cpp:499] %62 : int[] = aten::size(%11) # <string>:3:93 [DUMP profiling_graph_executor_impl.cpp:499] %83 : Double(1:1, requires_grad=0, device=cuda:0), %84 : Double(1:1, requires_grad=0, device=cuda:0), %85 : bool = prim::TypeCheck(%11, %13) [DUMP profiling_graph_executor_impl.cpp:499] %86 : Tensor, %87 : Tensor = prim::If(%85) [DUMP profiling_graph_executor_impl.cpp:499] block0(): [DUMP profiling_graph_executor_impl.cpp:499] %d.4 : Double(1:1, requires_grad=0, device=cuda:0), %c.4 : Double(1:1, requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%83, %84) [DUMP profiling_graph_executor_impl.cpp:499] -> (%d.4, %c.4) [DUMP profiling_graph_executor_impl.cpp:499] block1(): [DUMP profiling_graph_executor_impl.cpp:499] %94 : Function = prim::Constant[name="fallback_function", fallback=1]() [DUMP profiling_graph_executor_impl.cpp:499] %95 : (Tensor, Tensor) = prim::CallFunction(%94, %11, %13) [DUMP profiling_graph_executor_impl.cpp:499] %96 : Tensor, %97 : Tensor = prim::TupleUnpack(%95) [DUMP profiling_graph_executor_impl.cpp:499] -> (%96, %97) [DUMP profiling_graph_executor_impl.cpp:499] %60 : int[] = aten::size(%87) # <string>:3:55 [DUMP profiling_graph_executor_impl.cpp:499] %61 : int[]? = aten::_size_if_not_equal(%59, %60) # <string>:3:19 [DUMP profiling_graph_executor_impl.cpp:499] %64 : int[]? = aten::_size_if_not_equal(%62, %60) # <string>:3:68 [DUMP profiling_graph_executor_impl.cpp:499] %67 : int[] = aten::size(%86) # <string>:3:55 [DUMP profiling_graph_executor_impl.cpp:499] %68 : int[]? = aten::_size_if_not_equal(%60, %67) # <string>:3:19 [DUMP profiling_graph_executor_impl.cpp:499] %71 : int[]? = aten::_size_if_not_equal(%62, %67) # <string>:3:68 [DUMP profiling_graph_executor_impl.cpp:499] return (%86, %61, %64, %68, %71) [DUMP profiling_graph_executor_impl.cpp:499] with prim::TensorExprGroup_0 = graph(%1 : Double(1:1, requires_grad=0, device=cuda:0), [DUMP profiling_graph_executor_impl.cpp:499] %4 : Double(1:1, requires_grad=0, device=cuda:0)): [DUMP profiling_graph_executor_impl.cpp:499] %5 : int = prim::Constant[value=1]() [DUMP profiling_graph_executor_impl.cpp:499] %c.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%4, %1, %5) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2872:16 [DUMP profiling_graph_executor_impl.cpp:499] %2 : int = prim::Constant[value=1]() [DUMP profiling_graph_executor_impl.cpp:499] %d.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%c.3, %1, %2) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2873:16 [DUMP profiling_graph_executor_impl.cpp:499] return (%d.3, %c.3) ``` With this optimization we use `prim::BroadcastSizes` to compute the size of `c`. No need to materialize it. ``` [DUMP profiling_graph_executor_impl.cpp:499] Optimized Graph: [DUMP profiling_graph_executor_impl.cpp:499] graph(%a.1 : Tensor, [DUMP profiling_graph_executor_impl.cpp:499] %b.1 : Tensor): [DUMP profiling_graph_executor_impl.cpp:499] %11 : Tensor = prim::DifferentiableGraph_0(%b.1, %a.1) [DUMP profiling_graph_executor_impl.cpp:499] return (%11) [DUMP profiling_graph_executor_impl.cpp:499] with prim::DifferentiableGraph_0 = graph(%11 : Tensor, [DUMP profiling_graph_executor_impl.cpp:499] %13 : Tensor): [DUMP profiling_graph_executor_impl.cpp:499] %59 : int[] = aten::size(%13) # <string>:3:44 [DUMP profiling_graph_executor_impl.cpp:499] %62 : int[] = aten::size(%11) # <string>:3:93 [DUMP profiling_graph_executor_impl.cpp:499] %88 : Double(1:1, requires_grad=0, device=cuda:0), %89 : Double(1:1, requires_grad=0, device=cuda:0), %90 : bool = prim::TypeCheck(%11, %13) [DUMP profiling_graph_executor_impl.cpp:499] %91 : Tensor = prim::If(%90) [DUMP profiling_graph_executor_impl.cpp:499] block0(): [DUMP profiling_graph_executor_impl.cpp:499] %d.4 : Double(1:1, requires_grad=0, device=cuda:0) = prim::TensorExprGroup_0(%88, %89) [DUMP profiling_graph_executor_impl.cpp:499] -> (%d.4) [DUMP profiling_graph_executor_impl.cpp:499] block1(): [DUMP profiling_graph_executor_impl.cpp:499] %97 : Function = prim::Constant[name="fallback_function", fallback=1]() [DUMP profiling_graph_executor_impl.cpp:499] %98 : (Tensor) = prim::CallFunction(%97, %11, %13) [DUMP profiling_graph_executor_impl.cpp:499] %99 : Tensor = prim::TupleUnpack(%98) [DUMP profiling_graph_executor_impl.cpp:499] -> (%99) [DUMP profiling_graph_executor_impl.cpp:499] %85 : int[] = aten::size(%91) [DUMP profiling_graph_executor_impl.cpp:499] %86 : int[] = prim::BroadcastSizes(%59, %62) [DUMP profiling_graph_executor_impl.cpp:499] %61 : int[]? = aten::_size_if_not_equal(%59, %86) # <string>:3:19 [DUMP profiling_graph_executor_impl.cpp:499] %64 : int[]? = aten::_size_if_not_equal(%62, %86) # <string>:3:68 [DUMP profiling_graph_executor_impl.cpp:499] %68 : int[]? = aten::_size_if_not_equal(%86, %85) # <string>:3:19 [DUMP profiling_graph_executor_impl.cpp:499] %71 : int[]? = aten::_size_if_not_equal(%62, %85) # <string>:3:68 [DUMP profiling_graph_executor_impl.cpp:499] return (%91, %61, %64, %68, %71) [DUMP profiling_graph_executor_impl.cpp:499] with prim::TensorExprGroup_0 = graph(%1 : Double(1:1, requires_grad=0, device=cuda:0), [DUMP profiling_graph_executor_impl.cpp:499] %4 : Double(1:1, requires_grad=0, device=cuda:0)): [DUMP profiling_graph_executor_impl.cpp:499] %5 : int = prim::Constant[value=1]() [DUMP profiling_graph_executor_impl.cpp:499] %c.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%4, %1, %5) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2872:16 [DUMP profiling_graph_executor_impl.cpp:499] %2 : int = prim::Constant[value=1]() [DUMP profiling_graph_executor_impl.cpp:499] %d.3 : Double(1:1, requires_grad=0, device=cuda:0) = aten::add(%c.3, %1, %2) # /scratch/villedepommes/pytorches/bench/test/test_jit.py:2873:16 [DUMP profiling_graph_executor_impl.cpp:499] return (%d.3) ``` Pull Request resolved: pytorch#45080 Reviewed By: bertmaher Differential Revision: D23856410 Pulled By: Krovatkin fbshipit-source-id: 2956286eb03a4894a5baa151c35e6092466322b1
1 parent e5242aa commit 993628c

File tree

4 files changed

+171
-20
lines changed

4 files changed

+171
-20
lines changed
 

‎test/test_jit_fuser_te.py

+20
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,26 @@ def foo(hx, cx):
683683
# XXX: TE fuser can handle concats in a fusion group.
684684
# FileCheck().check("FusedConcat").check_next("return").run(str(graph))
685685

686+
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
687+
def test_remove_output_used_only_in_size(self):
688+
def test_fuse(a, b):
689+
c = a + b
690+
d = c + b
691+
return d
692+
693+
scripted_f = torch.jit.script(test_fuse)
694+
x = torch.ones(1, requires_grad=True, device='cuda')
695+
y = torch.ones(1, requires_grad=True, device='cuda')
696+
warmup_forward(scripted_f, x, y)
697+
g = torch.jit.last_executed_optimized_graph()
698+
diff_nodes = [n for n in g.nodes() if n.kind() == 'prim::DifferentiableGraph']
699+
self.assertEqual(len(diff_nodes), 1)
700+
g = diff_nodes[0].g('Subgraph')
701+
if_nodes = [n for n in g.nodes() if n.kind() == 'prim::If']
702+
self.assertEqual(len(if_nodes), 1)
703+
# the if node and the fusion group inside it should only have one output
704+
self.assertEqual(len(list(if_nodes[0].outputs())), 1)
705+
686706
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
687707
def test_concat_invariant_cuda(self):
688708
# Invariant: the output of prim::FusedConcat may

‎torch/csrc/jit/passes/graph_fuser.cpp

+1-17
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
88
#include <torch/csrc/jit/passes/constant_pooling.h>
99
#include <torch/csrc/jit/passes/dead_code_elimination.h>
10+
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
1011
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
1112
#include <torch/csrc/jit/runtime/autodiff.h>
1213
#include <torch/csrc/jit/runtime/custom_operator.h>
@@ -120,16 +121,6 @@ bool isSimpleMap(Node* node) {
120121
return true;
121122
}
122123

123-
Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db) {
124-
AT_ASSERT(!sizes.empty());
125-
Graph* graph = sizes[0]->owningGraph();
126-
Node* broadcast_n =
127-
graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
128-
broadcast_n->output()->setType(ListType::ofInts());
129-
db->createValue(broadcast_n->output());
130-
return broadcast_n->output();
131-
}
132-
133124
struct GraphFuser {
134125
using FusionCallback = std::function<bool(GraphFuser*, Node*)>;
135126

@@ -926,13 +917,6 @@ struct GraphFuser {
926917
}
927918
}
928919

929-
bool usedOnlyInSize(Value* v) {
930-
const auto& uses = v->uses();
931-
return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
932-
return u.user->matches("aten::size(Tensor self) -> int[]");
933-
});
934-
}
935-
936920
// Builds up expressions that compute shapes of all intermediates (and
937921
// outputs) of the fusion group, based on the sizes of inputs. You should run
938922
// DCE to remove those that you end up not using.

‎torch/csrc/jit/passes/tensorexpr_fuser.cpp

+147-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,23 @@ bool isSupportedForBlock(Node* node) {
2929
}
3030
}
3131

32+
bool usedOnlyInSize(Value* v) {
33+
const auto& uses = v->uses();
34+
return std::all_of(uses.begin(), uses.end(), [](const Use& u) {
35+
return u.user->matches("aten::size(Tensor self) -> int[]");
36+
});
37+
}
38+
39+
Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db) {
40+
AT_ASSERT(!sizes.empty());
41+
Graph* graph = sizes[0]->owningGraph();
42+
Node* broadcast_n =
43+
graph->insertNode(graph->create(prim::BroadcastSizes, sizes));
44+
broadcast_n->output()->setType(ListType::ofInts());
45+
db->createValue(broadcast_n->output());
46+
return broadcast_n->output();
47+
}
48+
3249
namespace tensorexpr {
3350
bool isSupported(Node* node) {
3451
// For Block codegen we allow limited ops.
@@ -287,6 +304,132 @@ class TensorExprFuser {
287304
min_group_size_(min_group_size),
288305
disable_shape_checks_(disable_shape_checks) {}
289306

307+
// Builds up expressions that compute shapes of all intermediates (and
308+
// outputs) of the fusion group, based on the sizes of inputs. You should run
309+
// DCE to remove those that you end up not using.
310+
std::unordered_map<Value*, Value*> buildShapeExpressions(Node* fusion_group) {
311+
GRAPH_DUMP("buildShapeExpressions for ", fusion_group->g(attr::Subgraph));
312+
WithInsertPoint insert_guard{fusion_group->next()};
313+
std::unordered_map<Value*, Value*> shape_of;
314+
315+
Graph* graph = fusion_group->owningGraph();
316+
auto subgraph = fusion_group->g(attr::Subgraph);
317+
318+
auto inputs = fusion_group->inputs();
319+
auto sinputs = subgraph->inputs();
320+
AT_ASSERT(inputs.size() == sinputs.size());
321+
for (size_t i = 0; i < inputs.size(); ++i) {
322+
if (inputs[i]->type()->isSubtypeOf(TensorType::get())) {
323+
Value* soutput = graph->insert(aten::size, {inputs[i]});
324+
aliasDb_->createValue(soutput);
325+
GRAPH_DEBUG(
326+
"Adding a mapping for %",
327+
sinputs[i]->debugName(),
328+
" ",
329+
getHeader(soutput->node()));
330+
shape_of[sinputs[i]] = soutput;
331+
}
332+
}
333+
334+
// When we have a guarantee that an output won't be removed, because it's
335+
// used in expressions that don't involve size checks, we can use its size
336+
// instead of computing a long chain of broadcasts, starting from the
337+
// beginning of the kernel.
338+
auto outputs = fusion_group->outputs();
339+
auto soutputs = subgraph->outputs();
340+
AT_ASSERT(outputs.size() == soutputs.size());
341+
for (size_t i = 0; i < outputs.size(); ++i) {
342+
if (usedOnlyInSize(outputs[i]))
343+
continue;
344+
Value* soutput = graph->insert(aten::size, {outputs[i]});
345+
aliasDb_->createValue(soutput);
346+
shape_of[soutputs[i]] = soutput;
347+
}
348+
349+
for (Node* n : subgraph->nodes()) {
350+
// XXX: Use of shape_of.emplace is crucial to the output shape
351+
// optimization!
352+
if (n->kind() == aten::cat) {
353+
// This is a bit more involved, because we have to account for the case
354+
// when inputs have different shapes, but fortunately those tensors are
355+
// always outputs, and so we can simply avoid replacing their queries,
356+
// because it won't help us.
357+
continue;
358+
}
359+
if (n->kind() == prim::Constant) {
360+
continue;
361+
}
362+
if (n->kind() == prim::ConstantChunk) {
363+
Node* sizes_node = graph->insertNode(
364+
graph->create(prim::ChunkSizes, shape_of.at(n->input()), 2));
365+
sizes_node->i_(attr::dim, n->i(attr::dim));
366+
sizes_node->i_(attr::chunks, n->i(attr::chunks));
367+
for (Value* output : sizes_node->outputs()) {
368+
aliasDb_->createValue(output);
369+
}
370+
Value* regular_size = sizes_node->outputs().at(0);
371+
Value* last_size = sizes_node->outputs().at(1);
372+
regular_size->setType(ListType::ofInts());
373+
last_size->setType(ListType::ofInts());
374+
auto outputs = n->outputs();
375+
for (Value* o : outputs.slice(0, outputs.size() - 1)) {
376+
shape_of.emplace(o, regular_size);
377+
}
378+
shape_of.emplace(outputs.at(outputs.size() - 1), last_size);
379+
continue;
380+
}
381+
auto tensor_inputs = filter(n->inputs(), [](Value* v) {
382+
return v->type()->isSubtypeOf(TensorType::get());
383+
});
384+
GRAPH_DEBUG("Building sizes for ", getHeader(n));
385+
bool all_inputs_have_sizes = true;
386+
auto shapes = fmap(tensor_inputs, [&](Value* v) {
387+
GRAPH_DEBUG("Getting aten::size for %", v->debugName());
388+
all_inputs_have_sizes &= shape_of.count(v);
389+
return shape_of.count(v) != 0 ? shape_of.at(v) : nullptr;
390+
});
391+
392+
if (!all_inputs_have_sizes) {
393+
GRAPH_DEBUG(
394+
"Not all tensor arguments have sizes available to compute the broadcasted size",
395+
getHeader(n));
396+
continue;
397+
}
398+
shape_of.emplace(
399+
n->output(),
400+
shapes.size() == 1 ? shapes[0]
401+
: broadcastSizes(shapes, aliasDb_.get()));
402+
}
403+
return shape_of;
404+
}
405+
406+
void removeOutputsUsedOnlyInSize(Node* fusion_group) {
407+
if (fusion_group->kind() != prim::TensorExprGroup)
408+
return;
409+
auto subgraph = fusion_group->g(attr::Subgraph);
410+
411+
auto shape_of = buildShapeExpressions(fusion_group);
412+
auto outputs = fusion_group->outputs().vec();
413+
auto soutputs = subgraph->outputs().vec();
414+
// XXX: Iterating in this order is not only good for performance reasons!
415+
// It is also crucial for correctness (i has to reflect the current true
416+
// index of outputs[i])!
417+
for (int64_t i = static_cast<int64_t>(outputs.size()) - 1; i >= 0; --i) {
418+
auto output = outputs[i];
419+
auto soutput = soutputs[i];
420+
if (usedOnlyInSize(output) && shape_of.count(soutput) > 0) {
421+
auto uses = output->uses();
422+
for (Use u : uses) {
423+
AT_ASSERT(u.user->matches("aten::size(Tensor self) -> int[]"));
424+
u.user->output()->replaceAllUsesWith(shape_of.at(soutput));
425+
u.user->destroy();
426+
}
427+
fusion_group->eraseOutput(i);
428+
subgraph->eraseOutput(i);
429+
}
430+
}
431+
}
432+
290433
void run() {
291434
aliasDb_ = torch::make_unique<AliasDb>(graph_);
292435
RemoveRedundantProfiles(graph_);
@@ -298,7 +441,7 @@ class TensorExprFuser {
298441
// fusion is done.
299442
inlineSmallFusionGroups(graph_->block());
300443
GRAPH_DUMP("After inlining small fusion groups: ", graph_);
301-
guardFusionGroups(graph_->block());
444+
guardFusionGroupsAndRemoveOutputs(graph_->block());
302445
GRAPH_DUMP("After guarding fusion groups: ", graph_);
303446
removeTensorTypeSpecializations(graph_->block());
304447
GRAPH_DUMP("After removing tensor type specializations: ", graph_);
@@ -772,17 +915,18 @@ class TensorExprFuser {
772915
}
773916
}
774917

775-
void guardFusionGroups(Block* block) {
918+
void guardFusionGroupsAndRemoveOutputs(Block* block) {
776919
std::vector<Node*> fusion_groups;
777920
for (Node* n : block->nodes()) {
778921
for (Block* b : n->blocks()) {
779-
guardFusionGroups(b);
922+
guardFusionGroupsAndRemoveOutputs(b);
780923
}
781924
if (n->kind() == prim::TensorExprGroup) {
782925
fusion_groups.push_back(n);
783926
}
784927
}
785928
for (Node* fusion_group : fusion_groups) {
929+
removeOutputsUsedOnlyInSize(fusion_group);
786930
guardFusionGroup(fusion_group);
787931
}
788932
}

‎torch/csrc/jit/passes/tensorexpr_fuser.h

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ TORCH_API void RemoveProfileNodesAndSpecializeTypes(
2929
std::shared_ptr<Graph>& graph);
3030
TORCH_API void RemoveTensorTypeSpecializations(std::shared_ptr<Graph>& graph);
3131

32+
TORCH_API bool usedOnlyInSize(Value* v);
33+
TORCH_API Value* broadcastSizes(at::ArrayRef<Value*> sizes, AliasDb* db);
34+
3235
namespace tensorexpr {
3336
TORCH_API bool isSupported(Node* node);
3437
}

0 commit comments

Comments
 (0)
Please sign in to comment.