@@ -594,6 +594,80 @@ def test_load_state_dict(
594
594
self ._eval_models (m1 , m2 , batch )
595
595
self ._compare_models (m1 , m2 )
596
596
597
+ # pyre-ignore[56]
598
+ @given (
599
+ sharder_type = st .sampled_from (
600
+ [
601
+ SharderType .EMBEDDING_BAG_COLLECTION .value ,
602
+ ]
603
+ ),
604
+ sharding_type = st .sampled_from (
605
+ [
606
+ ShardingType .TABLE_WISE .value ,
607
+ ShardingType .COLUMN_WISE .value ,
608
+ ShardingType .ROW_WISE .value ,
609
+ ShardingType .TABLE_ROW_WISE .value ,
610
+ ShardingType .TABLE_COLUMN_WISE .value ,
611
+ ]
612
+ ),
613
+ kernel_type = st .sampled_from (
614
+ [
615
+ EmbeddingComputeKernel .FUSED .value ,
616
+ EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
617
+ EmbeddingComputeKernel .FUSED_UVM .value ,
618
+ ]
619
+ ),
620
+ )
621
+ @settings (verbosity = Verbosity .verbose , max_examples = 4 , deadline = None )
622
+ def test_optimizer_load_state_dict (
623
+ self ,
624
+ sharder_type : str ,
625
+ sharding_type : str ,
626
+ kernel_type : str ,
627
+ ) -> None :
628
+ if (
629
+ self .device == torch .device ("cpu" )
630
+ and kernel_type != EmbeddingComputeKernel .FUSED .value
631
+ ):
632
+ self .skipTest ("CPU does not support uvm." )
633
+
634
+ sharders = [
635
+ cast (
636
+ ModuleSharder [nn .Module ],
637
+ create_test_sharder (
638
+ sharder_type ,
639
+ sharding_type ,
640
+ kernel_type ,
641
+ fused_params = {
642
+ "optimizer" : EmbOptimType .EXACT_ROWWISE_ADAGRAD ,
643
+ },
644
+ ),
645
+ ),
646
+ ]
647
+ models , batch = self ._generate_dmps_and_batch (sharders )
648
+ m1 , m2 = models
649
+
650
+ # train m1 a bit, to make sure the optimizer state is not zero
651
+ self ._train_models (m1 , m1 , batch )
652
+ # sync the state dict
653
+ m2 .load_state_dict (cast ("OrderedDict[str, torch.Tensor]" , m1 .state_dict ()))
654
+ # train both models, so they should diverage
655
+ self ._train_models (m1 , m2 , batch )
656
+ # expect eval models to fail, since one model starts with non-zero optimizer state
657
+ with self .assertRaises (AssertionError ):
658
+ self ._eval_models (m1 , m2 , batch )
659
+
660
+ # sync state dict again
661
+ m2 .load_state_dict (cast ("OrderedDict[str, torch.Tensor]" , m1 .state_dict ()))
662
+ # load state dict for optimizer as well
663
+ opt1 = m1 .fused_optimizer
664
+ opt2 = m2 .fused_optimizer
665
+ opt1 .load_state_dict (opt2 .state_dict ())
666
+
667
+ self ._train_models (m1 , m2 , batch )
668
+ self ._eval_models (m1 , m2 , batch )
669
+ self ._compare_models (m1 , m2 )
670
+
597
671
# pyre-ignore[56]
598
672
@given (
599
673
sharder_type = st .sampled_from (
0 commit comments