diff --git a/praxis/layers/checkpoint_policy.py b/praxis/layers/checkpoint_policy.py index 5f173fd0..904eb774 100644 --- a/praxis/layers/checkpoint_policy.py +++ b/praxis/layers/checkpoint_policy.py @@ -32,6 +32,7 @@ class AutodiffCheckpointType(str, enum.Enum): SAVE_CONTEXT = 'save_context' SAVE_CONTEXT_AND_OUT_PROJ = 'save_encoded_and_out_proj' SAVE_DOT_ONLY = 'save_dot_only' + OFFLOAD_DOT_WITH_NO_BATCH_DIM = 'offload_dot_with_no_batch_dims' SAVE_DOT_WITH_NO_BATCH_DIM = 'save_dot_with_no_batch_dims' SAVE_DOT_FOR_MLPERF_200B = 'save_dot_for_mlperf_200b' SAVE_ITERATION_INPUT = 'save_iteration_input' @@ -50,6 +51,10 @@ def custom_policy(checkpoint_policy: AutodiffCheckpointType): return jax.checkpoint_policies.everything_saveable if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_ONLY: return jax.checkpoint_policies.checkpoint_dots + if checkpoint_policy == AutodiffCheckpointType.OFFLOAD_DOT_WITH_NO_BATCH_DIM: + return jax.checkpoint_policies.offload_dot_with_no_batch_dims( + 'device', 'pinned_host' + ) if checkpoint_policy == AutodiffCheckpointType.SAVE_DOT_WITH_NO_BATCH_DIM: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims if checkpoint_policy == AutodiffCheckpointType.SAVE_QKV_OUT_PROJ: