You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
# Plan
Stage 1 aims to ensure that it can run, and won't break from normal operations (e.g. checkpointing).
Checkpointing (i.e. state_dict and load_state_dict) are still work in progress. We also need to guarantee checkpointing for optimizer states.
Stage 2: save state_dict (mostly on fbgemm side)
* current hope is we can rely on flush to save state dict
Stage 3: load_state_dict (need more thoughts)
* solution should be similar to that of PS
Stage 4: optimizer states checkpointing (torchrec side, should be pretty standard)
* should be straightforward
* need fbgemm to support split_embedding_weights api
# Outstanding issues:
* init is not the same as before
* SSD TBE doesn't support mixed dim
# design doc
TODO:
# tests should cover
* state dict and load state dict (done)
* should copy dense parts and not break
* deterministics output (done)
* numerical equivalence to normal TBE (done)
* changing learning rate and warm up policy (done)
* work for different sharding types (done)
* work with mixed kernel (done)
* work with mixed sharding types
* multi-gpu training (todo)
# OSS
NOTE: SSD TBE won't work in an OSS environment, due to some rocksdb problem.
# ad hoc
* SSD kernel is guarded, user must specify it in constraints to use it
Differential Revision: D57452256
0 commit comments