Skip to content

Commit 7f4a27b

Browse files
James Reedfacebook-github-bot
James Reed
authored andcommittedSep 23, 2020
[resubmit][FX] s/get_param/get_attr/ (pytorch#45147)
Summary: Pull Request resolved: pytorch#45147 ghstack-source-id: 112605923 Test Plan: Imported from OSS Reviewed By: eellison Differential Revision: D23845096 fbshipit-source-id: 9ca209aa84cbaddd6e89c52b541e43b11197e2d5
1 parent 35cdb01 commit 7f4a27b

File tree

8 files changed

+21
-21
lines changed

8 files changed

+21
-21
lines changed
 

‎test/fx/quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def load_arg(a):
222222
for node in self.graph.nodes:
223223
if node.op == 'placeholder':
224224
result = next(args_iter)
225-
elif node.op == 'get_param':
225+
elif node.op == 'get_attr':
226226
result = self.state_dict[node.target]
227227
elif node.op == 'call_function':
228228
result = node.target(*load_arg(node.args), **load_arg(node.kwargs))

‎test/test_fx.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def __init__(self, interpreter):
338338
placeholder_nodes.append(graph.create_node('placeholder', name))
339339

340340
# Get the interpreter object
341-
interpreter_node = graph.create_node('get_param', 'interpreter')
341+
interpreter_node = graph.create_node('get_attr', 'interpreter')
342342

343343
# Add a node to call the interpreter instance
344344
output_node = graph.create_node(
@@ -570,7 +570,7 @@ def test_graph_fns(self):
570570
g = Graph()
571571
a = g.placeholder('a')
572572
b = g.call_module('linear', (a,))
573-
c = g.get_param('bias')
573+
c = g.get_attr('bias')
574574
d = g.call_method('add', (b, c))
575575
e = g.call_function(torch.sin, (d,))
576576
g.output(e)
@@ -587,7 +587,7 @@ def test_construct_root_dict(self):
587587
graph : torch.fx.Graph = torch.fx.Graph()
588588
a : torch.fx.Node = graph.create_node('placeholder', 'x')
589589
b : torch.fx.Node = graph.create_node('call_module', 'foo.bar.baz', args=(a,))
590-
c : torch.fx.Node = graph.create_node('get_param', 'zip.zap.zam')
590+
c : torch.fx.Node = graph.create_node('get_attr', 'zip.zap.zam')
591591
d : torch.fx.Node = graph.create_node('call_function', operator.add, args=(b, c))
592592
graph.output(d)
593593

‎torch/fx/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def forward(self, x):
3636
opcode name target args kwargs
3737
------------- ------------- ------------------------------------------------------- ------------------ -----------
3838
placeholder x x () {}
39-
get_param linear_weight linear.weight () {}
39+
get_attr linear_weight linear.weight () {}
4040
call_function add_1 <built-in function add> (x, linear_weight) {}
4141
call_module linear_1 linear (add_1,) {}
4242
call_method relu_2 relu [linear_1] {}
@@ -48,7 +48,7 @@ def forward(self, x):
4848
4949
- `placeholder` represents a function input. The `name` attribute specifies the name this value will take on.
5050
`target` is similarly the name of the argument. `args` and `kwargs` are don't-care
51-
- `get_param` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the
51+
- `get_attr` retrieves a parameter from the module hierarchy. `name` is similarly the name the result of the
5252
fetch is assigned to. `target` is the fully-qualified name of the parameter's position in the module hierarchy.
5353
`args` and `kwargs` are don't-care
5454
- `call_function` applies a free function to some values. `name` is similarly the name of the value to assign

‎torch/fx/graph.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def create_node(self, op: str, target: Target,
9292
args: Optional[Tuple[Argument, ...]] = None,
9393
kwargs: Optional[Dict[str, Argument]] = None,
9494
name: Optional[str] = None) -> Node:
95-
assert op in ('call_function', 'call_method', 'get_param', 'call_module', 'placeholder')
95+
assert op in ('call_function', 'call_method', 'get_attr', 'call_module', 'placeholder')
9696
args = () if args is None else args
9797
kwargs = {} if kwargs is None else kwargs
9898
self._mark_uses(args)
@@ -105,8 +105,8 @@ def create_node(self, op: str, target: Target,
105105
def placeholder(self, name: str) -> Node:
106106
return self.create_node('placeholder', name)
107107

108-
def get_param(self, name: str) -> Node:
109-
return self.create_node('get_param', name)
108+
def get_attr(self, name: str) -> Node:
109+
return self.create_node('get_attr', name)
110110

111111
def call_module(self,
112112
module_name: str,
@@ -208,7 +208,7 @@ def python_code(self, root_module: str) -> Tuple[str, str, List[str]]:
208208
assert isinstance(node.target, str)
209209
body.append(f'{node.name} = {_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})\n')
210210
continue
211-
elif node.op == 'get_param':
211+
elif node.op == 'get_attr':
212212
assert isinstance(node.target, str)
213213
body.append(f'{node.name} = {_format_target(root_module, node.target)}\n')
214214
continue
@@ -242,7 +242,7 @@ def format_node(n : Node) -> Optional[str]:
242242
assert isinstance(n.target, str)
243243
placeholder_names.append(n.target)
244244
return None
245-
elif n.op == 'get_param':
245+
elif n.op == 'get_attr':
246246
return f'%{n.name} : [uses={n.uses}] = self.{n.target}'
247247
else:
248248
return f'%{n.name} : [uses={n.uses}] = {n.op}[target={n.target}](' \

‎torch/fx/graph_module.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,13 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
135135
if hasattr(root, 'training'):
136136
self.training = root.training
137137
for node in graph.nodes:
138-
if node.op in ['get_param', 'call_module']:
138+
if node.op in ['get_attr', 'call_module']:
139139
assert isinstance(node.target, str)
140140
_copy_attr(root, self, node.target)
141141
elif isinstance(root, dict):
142142
targets_to_copy = []
143143
for node in graph.nodes:
144-
if node.op in ['get_param', 'call_module']:
144+
if node.op in ['get_attr', 'call_module']:
145145
assert isinstance(node.target, str)
146146
if node.target not in root:
147147
raise RuntimeError('Node ' + str(node) + ' referenced target ' + node.target +

‎torch/fx/symbolic_trace.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,15 @@ def create_arg(self, a: Any) -> Argument:
5555
if isinstance(a, torch.nn.Parameter):
5656
for n, p in self.root.named_parameters():
5757
if a is p:
58-
return self.create_node('get_param', n, (), {})
58+
return self.create_node('get_attr', n, (), {})
5959
raise NameError('parameter is not a member of this module')
6060
# Tensors do not have a reliable string repr() from which they can be
6161
# constructed (and we probably don't want to rely on that, either), so
6262
# for any constant Tensor values we encounter, first search for if they
6363
# are an attribute of some module in the module hierarchy. If so, emit
64-
# a get_param to retrieve that tensor. Otherwise, we'll store away the
64+
# a get_attr to retrieve that tensor. Otherwise, we'll store away the
6565
# tensor value into a special attribute on the Module s.t. we can
66-
# retrieve it with a get_param.
66+
# retrieve it with a get_attr.
6767
if isinstance(a, torch.Tensor):
6868
# TODO: slow
6969
def search_for_tensor(m : torch.nn.Module) -> Optional[List[str]]:
@@ -96,7 +96,7 @@ def search_for_tensor(m : torch.nn.Module) -> Optional[List[str]]:
9696
i += 1
9797
setattr(self.root, qualname, a)
9898

99-
return self.create_node('get_param', qualname, (), {})
99+
return self.create_node('get_attr', qualname, (), {})
100100
return super().create_arg(a)
101101

102102
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name : str) -> bool:

‎torch/quantization/fx/quantize.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def get_qconfig(module):
177177

178178
self.qconfig_map = dict()
179179
for node in input_graph.nodes:
180-
if node.op == 'get_param':
180+
if node.op == 'get_attr':
181181
parent, _ = _parent_name(node.target)
182182
self.qconfig_map[node.name] = get_qconfig(self.modules[parent])
183183
elif node.op == 'call_function':
@@ -557,7 +557,7 @@ def load_arg(a):
557557
setattr(quantized_root, packed_weight_name, packed_weight)
558558
# replace prepack node with a getattr node
559559
env[node.name] = folded_graph.create_node(
560-
'get_param', packed_weight_name, (), {})
560+
'get_attr', packed_weight_name, (), {})
561561
elif prepack_node is not None:
562562
# remove the foled node
563563
continue

‎torch/quantization/fx/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def graph_pretty_str(g, shorten=True) -> str:
1717
built_in_meth_re = re.compile('<built-in method (.*) of type.*>')
1818
op_dict = {
1919
'placeholder': 'plchdr',
20-
'get_param': 'gt_prm',
20+
'get_attr': 'gt_prm',
2121
'call_function': 'cl_fun',
2222
'call_module': 'cl_mod',
2323
'call_method': 'cl_meth',
@@ -136,5 +136,5 @@ def get_next_qparams_idx(module, qparams):
136136
for key, value in qparams.items():
137137
setattr(root_module, key + str(idx), value)
138138
qparam_full_path = key + str(idx)
139-
inputs.append(graph.create_node('get_param', qparam_full_path))
139+
inputs.append(graph.create_node('get_attr', qparam_full_path))
140140
return graph.create_node('call_function', quantize_op, tuple(inputs), {})

0 commit comments

Comments
 (0)
Please sign in to comment.