@@ -524,6 +524,77 @@ def forward(self, x):
524
524
self .assertEqual (output_s , output_f )
525
525
526
526
527
+ def test_freeze_module_with_preserve_sub_module (self ):
528
+ class SubModule (nn .Module ):
529
+ def __init__ (self ):
530
+ super (SubModule , self ).__init__ ()
531
+ self .a = torch .tensor ([1.1 ])
532
+ self .b = 2.2
533
+
534
+ def forward (self , x ):
535
+ return self .a
536
+
537
+ class TestModule (nn .Module ):
538
+ def __init__ (self ):
539
+ super (TestModule , self ).__init__ ()
540
+ self .sub1 = SubModule () # aliasing
541
+ self .sub2 = SubModule ()
542
+
543
+ def forward (self , x ):
544
+ return self .sub2 (x ) + self .sub1 (x )
545
+ m = TestModule ()
546
+ ms = torch .jit .script (m )
547
+ ms .eval ()
548
+ mf = torch ._C ._freeze_module (ms ._c , ["sub1" ])
549
+
550
+ # Test that 'sub1' is preserved entirely and 'sub2' is completely folded
551
+ self .assertTrue (mf .hasattr ('sub1' ))
552
+ self .assertTrue (mf .sub1 .hasattr ('a' ))
553
+ self .assertTrue (mf .sub1 .hasattr ('b' ))
554
+ self .assertFalse (mf .hasattr ('sub2' ))
555
+ input = torch .randn (2 , 2 )
556
+ output_s = ms .forward (input )
557
+ output_f = mf .forward (input )
558
+ self .assertEqual (output_s , output_f )
559
+
560
+ def test_freeze_module_with_preserve_sub_module_and_mutation (self ):
561
+ class SubModule (nn .Module ):
562
+ def __init__ (self ):
563
+ super (SubModule , self ).__init__ ()
564
+ self .a = torch .tensor ([1.1 ])
565
+ self .b = 2.2
566
+
567
+ def forward (self , x ):
568
+ self .a [0 ] = 3.3
569
+ return self .a
570
+
571
+ class TestModule (nn .Module ):
572
+ def __init__ (self ):
573
+ super (TestModule , self ).__init__ ()
574
+ self .sub1 = SubModule () # aliasing
575
+ self .sub2 = SubModule ()
576
+
577
+ def forward (self , x ):
578
+ return self .sub2 (x ) + self .sub1 (x )
579
+ m = TestModule ()
580
+ ms = torch .jit .script (m )
581
+ ms .eval ()
582
+ mf = torch ._C ._freeze_module (ms ._c , ["sub1" ])
583
+
584
+ # Test that be both sub1 and sub1 are preserved and 'b' is preserved
585
+ # even if it is not used. To fulfill user request to preserve 'sub1'
586
+ self .assertTrue (mf .hasattr ('sub1' ))
587
+ self .assertTrue (mf .sub1 .hasattr ('a' ))
588
+ self .assertTrue (mf .sub1 .hasattr ('b' ))
589
+ self .assertTrue (mf .hasattr ('sub2' ))
590
+ self .assertTrue (mf .sub2 .hasattr ('a' ))
591
+ self .assertTrue (mf .sub2 .hasattr ('b' ))
592
+ input = torch .randn (2 , 2 )
593
+ output_s = ms .forward (input )
594
+ output_f = mf .forward (input )
595
+ self .assertEqual (output_s , output_f )
596
+
597
+
527
598
def test_freeze_module_with_helperfunction (self ):
528
599
class SubModule (nn .Module ):
529
600
def __init__ (self ):
0 commit comments