Skip to content

Commit 4d80c8c

Browse files
bzinodevfacebook-github-bot
authored andcommittedSep 23, 2020
Fix inlining interface call in fork subgraph (pytorch#43790)
Summary: Pull Request resolved: pytorch#43790 Interface calls were not handled properly when they are used in fork subgraph. This PR fixes this issue. Test Plan: Imported from OSS Reviewed By: eellison Differential Revision: D23402039 Pulled By: bzinodev fbshipit-source-id: 41adc5ee7d942250e732e243ab30e356d78d9bf7
1 parent da4033d commit 4d80c8c

File tree

3 files changed

+121
-20
lines changed

3 files changed

+121
-20
lines changed
 

‎test/jit/test_freezing.py

+48-4
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,8 @@ def forward(self, x):
237237

238238
def test_freeze_module_with_fork2(self):
239239
@torch.jit.script
240-
def foo(x, y):
241-
return x * y
240+
def foo(x):
241+
return x * 2
242242

243243
class TestModule(nn.Module):
244244
def __init__(self):
@@ -247,8 +247,8 @@ def __init__(self):
247247
self.b = torch.ones(20, 20)
248248

249249
def forward(self, x):
250-
fut = torch.jit._fork(foo, self.a, self.b)
251-
y_hat = foo(self.a, self.b)
250+
fut = torch.jit._fork(foo, self.a)
251+
y_hat = foo(self.b)
252252
y = torch.jit._wait(fut)
253253
return y_hat + y
254254

@@ -272,6 +272,50 @@ def forward(self, x):
272272
# conservatively assumes there is a mutation because attributes are
273273
# passed to fork subgraph. both 'a' and 'b' are preserved.
274274
self.assertTrue(mf.hasattr('a'))
275+
self.assertFalse(mf.hasattr('b'))
276+
output_f = mf.forward(input)
277+
self.assertEqual(output_s, output_f)
278+
279+
def test_freeze_module_with_fork_calling_module_method(self):
280+
@torch.jit.script
281+
def foo(x, y):
282+
return x * y
283+
284+
class TestModule(nn.Module):
285+
def __init__(self):
286+
super(TestModule, self).__init__()
287+
self.a = torch.ones(20, 20)
288+
self.b = torch.ones(20, 20)
289+
290+
@torch.jit.export
291+
def foo(self, x):
292+
return x * self.a
293+
294+
@torch.jit.export
295+
def bar(self, x):
296+
return x * self.b
297+
298+
def forward(self, x):
299+
fut = torch.jit._fork(self.foo, self.b)
300+
y_hat = self.bar(self.a)
301+
y = torch.jit._wait(fut)
302+
return y_hat + y
303+
304+
m = torch.jit.script(TestModule())
305+
m.eval()
306+
input = torch.randn(2, 2)
307+
output_s = m.forward(input)
308+
mf = torch._C._freeze_module(m._c)
309+
# Check if frozen module looks as below:
310+
# module m {
311+
# attributes {
312+
# self.b = ..
313+
# }
314+
# ...
315+
# TODO: Although there are no mutation, the alias analysis
316+
# conservatively assumes there is a mutation because attributes are
317+
# passed to fork subgraph. 'b' is preserved.
318+
self.assertFalse(mf.hasattr('a'))
275319
self.assertTrue(mf.hasattr('b'))
276320
output_f = mf.forward(input)
277321
self.assertEqual(output_s, output_f)

‎test/jit/test_module_interface.py

+52
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,58 @@ def forward(self, x):
595595
with self.assertRaisesRegex(RuntimeError, "failed to freeze interface attribute 'proxy_mod'"):
596596
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
597597

598+
def test_freeze_module_with_interface_and_fork(self):
599+
class SubModule(torch.nn.Module):
600+
def __init__(self):
601+
super(SubModule, self).__init__()
602+
self.b = torch.tensor([1.5])
603+
604+
def forward(self, x):
605+
self.b[0] += 3.2
606+
return self.b
607+
608+
class OrigMod(torch.nn.Module):
609+
def __init__(self):
610+
super(OrigMod, self).__init__()
611+
self.a = torch.tensor([0.5])
612+
613+
def forward(self, x):
614+
return self.a
615+
616+
@torch.jit.interface
617+
class ModInterface(torch.nn.Module):
618+
def forward(self, x):
619+
# type: (Tensor) -> Tensor
620+
pass
621+
622+
class TestModule(torch.nn.Module):
623+
proxy_mod : ModInterface
624+
625+
def __init__(self):
626+
super(TestModule, self).__init__()
627+
self.proxy_mod = OrigMod()
628+
self.sub = SubModule()
629+
630+
def forward(self, x):
631+
y = self.proxy_mod(x);
632+
z= self.sub(x)
633+
return y + z
634+
635+
class MainModule(torch.nn.Module):
636+
def __init__(self):
637+
super(MainModule, self).__init__()
638+
self.test= TestModule();
639+
640+
def forward(self, x):
641+
fut = torch.jit._fork(self.test.forward, x)
642+
y = self.test(x)
643+
z = torch.jit._wait(fut)
644+
return y + z
645+
646+
m = torch.jit.script(MainModule())
647+
m.eval()
648+
mf = torch._C._freeze_module(m._c, freezeInterfaces = True)
649+
598650
def test_module_apis_interface(self):
599651
@torch.jit.interface
600652
class ModuleInterface(nn.Module):

‎torch/csrc/jit/passes/freeze_module.cpp

+21-16
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,7 @@ class AttributePropagator {
9797
auto graph = function->graph();
9898
optimizeSubGraphs(graph, applyInline);
9999
if (freezeInterfaces_) {
100-
optimizeSubGraphs(
101-
graph,
102-
std::bind(
103-
&AttributePropagator::inlineInterfaceCalls,
104-
*this,
105-
std::placeholders::_1));
100+
inlineInterfaceCalls(graph);
106101
}
107102
// Record Attributes that are explicitly set in the module.
108103
// They cannot be folded.
@@ -379,6 +374,14 @@ class AttributePropagator {
379374
inlineInterfaceCall(n, attr);
380375
// Reset the GetAttr to concrete module type.
381376
n->output()->setType(attr.type());
377+
} else if (n->kind() == prim::fork) {
378+
applyToForkSubgraph(
379+
n,
380+
graph,
381+
std::bind(
382+
&AttributePropagator::inlineInterfaceCalls,
383+
*this,
384+
std::placeholders::_1));
382385
}
383386
}
384387
}
@@ -476,18 +479,20 @@ class AttributePropagator {
476479
auto node = n->inputs()[0]->node();
477480
// Check if first parameter of fork is a module. This module is used
478481
// as the base module (similar to 'self' in forward) to resolve GetAttrs.
479-
if (node->kind() != prim::GetAttr) {
480-
return;
481-
}
482-
auto name = node->s(attr::name);
483-
auto input = node->inputs()[0];
484-
if (!findConstantAttr(input, name, attrModule, graph)) {
485-
// Module needs to be preserved.
486-
return;
482+
// Otherwise freezing is applied using module_
483+
if (node->kind() == prim::GetAttr &&
484+
node->output()->type()->cast<ClassType>()) {
485+
auto name = node->s(attr::name);
486+
auto input = node->inputs()[0];
487+
if (!findConstantAttr(input, name, attrModule, graph)) {
488+
// Module needs to be preserved.
489+
return;
490+
}
491+
attrModule = attrModule.attr(name).toModule();
492+
std::swap(module_, attrModule);
487493
}
488-
attrModule = attrModule.attr(name).toModule();
494+
489495
auto subgraph = n->g(attr::Subgraph);
490-
std::swap(module_, attrModule);
491496
func(subgraph);
492497
module_ = attrModule;
493498
}

0 commit comments

Comments
 (0)
Please sign in to comment.