Skip to content

Commit 0d71282

Browse files
henrylhtsangfacebook-github-bot
authored andcommitted
Add test for optimizer load state dict
Summary: This test checks two things: * for optimizer with states, they need to be load_state_dict in order for numerical equivalence. * optimizer state dict and load state dict work as expected Differential Revision: D57921839
1 parent b22a19d commit 0d71282

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

torchrec/distributed/test_utils/test_model_parallel_base.py

+74
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,80 @@ def test_load_state_dict(
594594
self._eval_models(m1, m2, batch)
595595
self._compare_models(m1, m2)
596596

597+
# pyre-ignore[56]
598+
@given(
599+
sharder_type=st.sampled_from(
600+
[
601+
SharderType.EMBEDDING_BAG_COLLECTION.value,
602+
]
603+
),
604+
sharding_type=st.sampled_from(
605+
[
606+
ShardingType.TABLE_WISE.value,
607+
ShardingType.COLUMN_WISE.value,
608+
ShardingType.ROW_WISE.value,
609+
ShardingType.TABLE_ROW_WISE.value,
610+
ShardingType.TABLE_COLUMN_WISE.value,
611+
]
612+
),
613+
kernel_type=st.sampled_from(
614+
[
615+
EmbeddingComputeKernel.FUSED.value,
616+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
617+
EmbeddingComputeKernel.FUSED_UVM.value,
618+
]
619+
),
620+
)
621+
@settings(verbosity=Verbosity.verbose, max_examples=4, deadline=None)
622+
def test_optimizer_load_state_dict(
623+
self,
624+
sharder_type: str,
625+
sharding_type: str,
626+
kernel_type: str,
627+
) -> None:
628+
if (
629+
self.device == torch.device("cpu")
630+
and kernel_type != EmbeddingComputeKernel.FUSED.value
631+
):
632+
self.skipTest("CPU does not support uvm.")
633+
634+
sharders = [
635+
cast(
636+
ModuleSharder[nn.Module],
637+
create_test_sharder(
638+
sharder_type,
639+
sharding_type,
640+
kernel_type,
641+
fused_params={
642+
"optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
643+
},
644+
),
645+
),
646+
]
647+
models, batch = self._generate_dmps_and_batch(sharders)
648+
m1, m2 = models
649+
650+
# train m1 a bit, to make sure the optimizer state is not zero
651+
self._train_models(m1, m1, batch)
652+
# sync the state dict
653+
m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict()))
654+
# train both models, so they should diverage
655+
self._train_models(m1, m2, batch)
656+
# expect eval models to fail, since one model starts with non-zero optimizer state
657+
with self.assertRaises(AssertionError):
658+
self._eval_models(m1, m2, batch)
659+
660+
# sync state dict again
661+
m2.load_state_dict(cast("OrderedDict[str, torch.Tensor]", m1.state_dict()))
662+
# load state dict for optimizer as well
663+
opt1 = m1.fused_optimizer
664+
opt2 = m2.fused_optimizer
665+
opt1.load_state_dict(opt2.state_dict())
666+
667+
self._train_models(m1, m2, batch)
668+
self._eval_models(m1, m2, batch)
669+
self._compare_models(m1, m2)
670+
597671
# pyre-ignore[56]
598672
@given(
599673
sharder_type=st.sampled_from(

0 commit comments

Comments
 (0)