Skip to content

Commit 92189b3

Browse files
scottxu0730facebook-github-bot
authored andcommittedSep 26, 2020
Add get_all_users_of function to GraphManipulation (pytorch#45216)
Summary: This PR adds get_all_users_of function. The function returns all the users of a specific node. A test unit is also added. Pull Request resolved: pytorch#45216 Reviewed By: ezyang Differential Revision: D23883572 Pulled By: scottxu0730 fbshipit-source-id: 3eb68a411c3c6db39ed2506c9cb7bb7337520ee4
1 parent 7763e1d commit 92189b3

File tree

3 files changed

+52
-0
lines changed

3 files changed

+52
-0
lines changed
 

‎test/test_fx.py

+21
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import copy
77
from pathlib import Path
88
from torch.fx import symbolic_trace, Proxy, Node, GraphModule, Tracer, Graph
9+
from torch.fx.experimental import GraphManipulation
910

1011
from torch.fx.proxy import TraceError
1112

@@ -619,6 +620,26 @@ def forward(self, x):
619620
with self.assertRaisesRegex(AssertionError, message):
620621
traced(torch.rand(4, 3))
621622

623+
def test_get_all_users_of(self):
624+
graph : torch.fx.Graph = torch.fx.Graph()
625+
a : torch.fx.Node = graph.create_node('placeholder', 'x')
626+
b : torch.fx.Node = graph.create_node('call_module', 'linear_mod', args=(a,))
627+
c : torch.fx.Node = graph.create_node('get_attr', 'y_attr')
628+
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
629+
graph.output(d)
630+
linear_mod : torch.nn.Module = torch.nn.Linear(3, 4)
631+
add_param : torch.Tensor = torch.rand(3, 4)
632+
gm : torch.fx.GraphModule = torch.fx.GraphModule(
633+
{'linear_mod': linear_mod, 'y_attr' : add_param}, graph)
634+
expected_uses: Dict[int, List[int]] = {
635+
0: [1],
636+
1: [3],
637+
2: [3],
638+
3: []
639+
}
640+
for i, node in enumerate(graph.nodes):
641+
user_indexes = GraphManipulation.get_all_users_of(gm, i)
642+
assert user_indexes == expected_uses[i]
622643

623644
if __name__ == '__main__':
624645
run_tests()
+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import List
2+
from torch.fx.graph_module import GraphModule
3+
from typing import Any
4+
from torch.fx.node import Node
5+
6+
"""find_use is used to find out if the node is another node's arg or kwargs."""
7+
def find_use(arg: Any, node: Node) -> bool:
8+
if isinstance(arg, (tuple, list)):
9+
return any(find_use(elem, node) for elem in arg)
10+
elif isinstance(arg, dict):
11+
return any(find_use(v, node) for k, v in arg.items())
12+
elif isinstance(arg, slice):
13+
return any([find_use(arg.start, node), find_use(arg.stop, node), find_use(arg.step, node)])
14+
elif isinstance(arg, Node):
15+
return arg is node
16+
else:
17+
return False
18+
19+
def get_all_users_of(fx_module: GraphModule, index: int) -> List[int]:
20+
"""Given the graph(fx_module) and an index, return a list of all node indexes that use this node"""
21+
graph = fx_module.graph
22+
current_node = graph.nodes[index]
23+
user_indexes: List[int] = []
24+
"""if the node A is in node B's args, then B is the user of A
25+
go through all the nodes, if the input node in any node's args,
26+
then that node is the input node's user
27+
"""
28+
for i, n in enumerate(graph.nodes):
29+
if find_use(n.args, current_node) or find_use(n.kwargs, current_node):
30+
user_indexes.append(i)
31+
return user_indexes

‎torch/fx/experimental/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)