Skip to content

Commit 13f76f2

Browse files
bzinodevfacebook-github-bot
authored andcommittedSep 28, 2020
Fix preserve submodule attribute in freezing (pytorch#45143)
Summary: Pull Request resolved: pytorch#45143 This PR prevents freezing cleaning up a submodule when user requests to preserve a submodule. Test Plan: Imported from OSS Reviewed By: eellison Differential Revision: D23844969 Pulled By: bzinodev fbshipit-source-id: 80e6db3fc12460d62e634ea0336ae2a3551c2151
1 parent c3bf402 commit 13f76f2

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed
 

‎test/jit/test_freezing.py

+71
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,77 @@ def forward(self, x):
524524
self.assertEqual(output_s, output_f)
525525

526526

527+
def test_freeze_module_with_preserve_sub_module(self):
528+
class SubModule(nn.Module):
529+
def __init__(self):
530+
super(SubModule, self).__init__()
531+
self.a = torch.tensor([1.1])
532+
self.b = 2.2
533+
534+
def forward(self, x):
535+
return self.a
536+
537+
class TestModule(nn.Module):
538+
def __init__(self):
539+
super(TestModule, self).__init__()
540+
self.sub1 = SubModule() # aliasing
541+
self.sub2 = SubModule()
542+
543+
def forward(self, x):
544+
return self.sub2(x) + self.sub1(x)
545+
m = TestModule()
546+
ms = torch.jit.script(m)
547+
ms.eval()
548+
mf = torch._C._freeze_module(ms._c, ["sub1"])
549+
550+
# Test that 'sub1' is preserved entirely and 'sub2' is completely folded
551+
self.assertTrue(mf.hasattr('sub1'))
552+
self.assertTrue(mf.sub1.hasattr('a'))
553+
self.assertTrue(mf.sub1.hasattr('b'))
554+
self.assertFalse(mf.hasattr('sub2'))
555+
input = torch.randn(2, 2)
556+
output_s = ms.forward(input)
557+
output_f = mf.forward(input)
558+
self.assertEqual(output_s, output_f)
559+
560+
def test_freeze_module_with_preserve_sub_module_and_mutation(self):
561+
class SubModule(nn.Module):
562+
def __init__(self):
563+
super(SubModule, self).__init__()
564+
self.a = torch.tensor([1.1])
565+
self.b = 2.2
566+
567+
def forward(self, x):
568+
self.a[0] = 3.3
569+
return self.a
570+
571+
class TestModule(nn.Module):
572+
def __init__(self):
573+
super(TestModule, self).__init__()
574+
self.sub1 = SubModule() # aliasing
575+
self.sub2 = SubModule()
576+
577+
def forward(self, x):
578+
return self.sub2(x) + self.sub1(x)
579+
m = TestModule()
580+
ms = torch.jit.script(m)
581+
ms.eval()
582+
mf = torch._C._freeze_module(ms._c, ["sub1"])
583+
584+
# Test that be both sub1 and sub1 are preserved and 'b' is preserved
585+
# even if it is not used. To fulfill user request to preserve 'sub1'
586+
self.assertTrue(mf.hasattr('sub1'))
587+
self.assertTrue(mf.sub1.hasattr('a'))
588+
self.assertTrue(mf.sub1.hasattr('b'))
589+
self.assertTrue(mf.hasattr('sub2'))
590+
self.assertTrue(mf.sub2.hasattr('a'))
591+
self.assertTrue(mf.sub2.hasattr('b'))
592+
input = torch.randn(2, 2)
593+
output_s = ms.forward(input)
594+
output_f = mf.forward(input)
595+
self.assertEqual(output_s, output_f)
596+
597+
527598
def test_freeze_module_with_helperfunction(self):
528599
class SubModule(nn.Module):
529600
def __init__(self):

‎torch/csrc/jit/passes/freeze_module.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,14 @@ class AttributePropagator {
4242
// explicitly.
4343
auto checkName = [this](std::string& name) {
4444
if (module_.hasattr(name)) {
45-
insertMutableAttr(name, module_.attr(name), module_._ivalue());
45+
auto attr = module_.attr(name);
46+
47+
// Freezing client wants to presever this submodule. When cleaning
48+
// the frozen module, make sure it will be preserved entirely.
49+
if (attr.isModule()) {
50+
preservedSubModule_.insert(attr.toModule()._ivalue());
51+
}
52+
insertMutableAttr(name, attr, module_._ivalue());
4653
return true;
4754
}
4855

@@ -503,7 +510,7 @@ class AttributePropagator {
503510
return true;
504511
}
505512
}
506-
return false;
513+
return preservedSubModule_.count(subModule._ivalue());
507514
}
508515

509516
void removeExtraWaitCalls(Block* b) {
@@ -683,6 +690,9 @@ class AttributePropagator {
683690
// Contains user specified methods to be preserved in frozen module.
684691
std::unordered_set<Function*> preservedMethods_;
685692

693+
// Contains user specified sub module to be preserve in frozen module.
694+
std::unordered_set<ModulePtr> preservedSubModule_;
695+
686696
// Track all used attributes ivalues that can be aliased.
687697
IValue::HashAliasedIValues usedAttrs_;
688698

0 commit comments

Comments
 (0)
Please sign in to comment.