Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 97a6594

Browse files
committedMay 23, 2025·
refactor: warmstart tutorial now uses FSDP2
1 parent f060089 commit 97a6594

File tree

3 files changed

+103
-130
lines changed

3 files changed

+103
-130
lines changed
 

‎tutorials/warmstart/configs/pre_training_config.yaml

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
settings:
1+
settings:
22
experiment_id: ${modalities_env:experiment_id}
33
config_file_path: ${modalities_env:config_file_path}
44
referencing_keys:
@@ -28,7 +28,7 @@ settings:
2828
training_target:
2929
num_target_tokens: 81920 # num_target_steps * world_size * local_train_micro_batch_size * sequence_length * gradient_accumulation_steps
3030
num_target_steps: 20 # we want to run for exactly 20 steps (although we will only get one checkpoint after 11 steps)
31-
training_progress:
31+
training_progress:
3232
global_num_seen_tokens: 0
3333
num_seen_steps: 0
3434
num_seen_samples: 0
@@ -95,7 +95,7 @@ checkpoint_saving:
9595
k: -1 # -1 to save all checkpoints
9696
checkpoint_saving_execution:
9797
component_key: checkpoint_saving_execution
98-
variant_key: fsdp1
98+
variant_key: dcp
9999
config:
100100
checkpoint_path: ${settings.paths.checkpoint_saving_path}
101101
global_rank: ${settings.cuda_env.global_rank}
@@ -108,12 +108,21 @@ loss_fn:
108108
target_key: ${settings.referencing_keys.target_key}
109109
prediction_key: ${settings.referencing_keys.prediction_key}
110110

111+
device_mesh:
112+
component_key: device_mesh
113+
variant_key: default
114+
config:
115+
device_type: cuda
116+
data_parallel_replicate_degree: 1
117+
data_parallel_shard_degree: ${settings.cuda_env.world_size} # i.e., fully sharded
118+
world_size: ${settings.cuda_env.world_size}
119+
111120
app_state:
112121
component_key: app_state
113122
variant_key: raw
114123
config:
115124
model:
116-
instance_key: wrapped_model
125+
instance_key: initialized_model
117126
pass_type: BY_REFERENCE
118127
optimizer:
119128
instance_key: optimizer
@@ -122,24 +131,12 @@ app_state:
122131
instance_key: lr_scheduler
123132
pass_type: BY_REFERENCE
124133

125-
wrapped_model:
126-
component_key: model
127-
variant_key: fsdp1_wrapped
128-
config:
129-
model:
130-
instance_key: model
131-
pass_type: BY_REFERENCE
132-
sync_module_states: true
133-
mixed_precision_settings: BF_16
134-
sharding_strategy: FULL_SHARD
135-
block_names: [GPT2Block]
136-
137-
model:
134+
initialized_model:
138135
component_key: model
139136
variant_key: model_initialized
140137
config:
141138
model:
142-
instance_key: model_raw
139+
instance_key: fsdp_model
143140
pass_type: BY_REFERENCE
144141
model_initializer:
145142
component_key: model_initialization
@@ -151,6 +148,21 @@ model:
151148
std: 0.02
152149
num_layers: ${model_raw.config.n_layer}
153150

151+
fsdp_model:
152+
component_key: model
153+
variant_key: fsdp2_wrapped
154+
config:
155+
model:
156+
instance_key: model_raw
157+
pass_type: BY_REFERENCE
158+
device_mesh:
159+
instance_key: device_mesh
160+
pass_type: BY_REFERENCE
161+
mixed_precision_settings:
162+
param_dtype: BF_16
163+
reduce_dtype: BF_16
164+
block_names: [GPT2Block]
165+
154166
model_raw:
155167
component_key: model
156168
variant_key: gpt2
@@ -171,12 +183,12 @@ model_raw:
171183
bias: false
172184
attention_config:
173185
qkv_transforms:
174-
- type_hint: RotaryTransform
175-
config:
176-
n_embd: ${model_raw.config.n_embd}
177-
n_head: ${model_raw.config.n_head_q}
178-
seq_length_dim: -2
179-
base_freq: 100000
186+
- type_hint: RotaryTransform
187+
config:
188+
n_embd: ${model_raw.config.n_embd}
189+
n_head: ${model_raw.config.n_head_q}
190+
seq_length_dim: -2
191+
base_freq: 100000
180192
attention_implementation: pytorch_flash
181193
activation_type: swiglu
182194
attention_norm_config:
@@ -223,15 +235,15 @@ optimizer:
223235
weight_decay: 1e-1
224236
weight_decay_groups_excluded: [embedding, layernorm]
225237
wrapped_model:
226-
instance_key: wrapped_model
238+
instance_key: initialized_model
227239
pass_type: BY_REFERENCE
228240

229241
gradient_clipper:
230242
component_key: gradient_clipper
231-
variant_key: fsdp1
243+
variant_key: fsdp2
232244
config:
233245
wrapped_model:
234-
instance_key: wrapped_model
246+
instance_key: initialized_model
235247
pass_type: BY_REFERENCE
236248
norm_type: P2_NORM
237249
max_norm: 1.0

‎tutorials/warmstart/configs/warmstart_config.yaml

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ settings:
3333
component_key: number_conversion
3434
variant_key: global_num_seen_tokens_from_checkpoint_path
3535
config:
36-
checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path}
36+
checkpoint_path: ${settings.warmstart_checkpoint_paths.checkpoint_folder_path}
3737
num_seen_steps: # for the batch progress subscriber
3838
component_key: number_conversion
3939
variant_key: num_seen_steps_from_checkpoint_path
4040
config:
41-
checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path}
41+
checkpoint_path: ${settings.warmstart_checkpoint_paths.checkpoint_folder_path}
4242
num_seen_samples:
4343
component_key: number_conversion
4444
variant_key: num_samples_from_num_tokens
@@ -49,7 +49,7 @@ settings:
4949
component_key: number_conversion
5050
variant_key: last_step_from_checkpoint_path
5151
config:
52-
checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path}
52+
checkpoint_path: ${settings.warmstart_checkpoint_paths.checkpoint_folder_path}
5353
warmstart_checkpoint_paths: ${warmstart_env:checkpoint_paths}
5454

5555
collate_fn:
@@ -104,12 +104,9 @@ eval_dataloaders: []
104104

105105
checkpoint_loading:
106106
component_key: checkpoint_loading
107-
variant_key: fsdp1
107+
variant_key: dcp
108108
config:
109109
global_rank: ${settings.cuda_env.global_rank}
110-
block_names: [GPT2Block]
111-
mixed_precision_settings: BF_16
112-
sharding_strategy: FULL_SHARD
113110

114111
checkpoint_saving:
115112
component_key: checkpoint_saving
@@ -122,7 +119,7 @@ checkpoint_saving:
122119
k: -1 # -1 to save all checkpoints
123120
checkpoint_saving_execution:
124121
component_key: checkpoint_saving_execution
125-
variant_key: fsdp1
122+
variant_key: dcp
126123
config:
127124
checkpoint_path: ${settings.paths.checkpoint_saving_path}
128125
global_rank: ${settings.cuda_env.global_rank}
@@ -135,12 +132,30 @@ loss_fn:
135132
target_key: ${settings.referencing_keys.target_key}
136133
prediction_key: ${settings.referencing_keys.prediction_key}
137134

135+
device_mesh:
136+
component_key: device_mesh
137+
variant_key: default
138+
config:
139+
device_type: cuda
140+
data_parallel_replicate_degree: 1
141+
data_parallel_shard_degree: ${settings.cuda_env.world_size}
142+
world_size: ${settings.cuda_env.world_size}
143+
138144
app_state:
145+
component_key: app_state
146+
variant_key: dcp
147+
config:
148+
raw_app_state:
149+
instance_key: app_state_raw
150+
pass_type: BY_REFERENCE
151+
checkpoint_dir_path: ${settings.warmstart_checkpoint_paths.checkpoint_folder_path}
152+
153+
app_state_raw:
139154
component_key: app_state
140155
variant_key: raw
141156
config:
142-
model:
143-
instance_key: wrapped_model
157+
model:
158+
instance_key: initialized_model
144159
pass_type: BY_REFERENCE
145160
optimizer:
146161
instance_key: optimizer
@@ -149,24 +164,12 @@ app_state:
149164
instance_key: lr_scheduler
150165
pass_type: BY_REFERENCE
151166

152-
wrapped_model:
153-
component_key: model
154-
variant_key: fsdp1_checkpointed
155-
config:
156-
model:
157-
instance_key: model
158-
pass_type: BY_REFERENCE
159-
checkpoint_loading:
160-
instance_key: checkpoint_loading
161-
pass_type: BY_REFERENCE
162-
checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path}
163-
164-
model:
167+
initialized_model:
165168
component_key: model
166169
variant_key: model_initialized
167170
config:
168171
model:
169-
instance_key: model_raw
172+
instance_key: fsdp_model
170173
pass_type: BY_REFERENCE
171174
model_initializer:
172175
component_key: model_initialization
@@ -178,6 +181,21 @@ model:
178181
std: 0.02
179182
num_layers: ${model_raw.config.n_layer}
180183

184+
fsdp_model:
185+
component_key: model
186+
variant_key: fsdp2_wrapped
187+
config:
188+
model:
189+
instance_key: model_raw
190+
pass_type: BY_REFERENCE
191+
device_mesh:
192+
instance_key: device_mesh
193+
pass_type: BY_REFERENCE
194+
mixed_precision_settings:
195+
param_dtype: BF_16
196+
reduce_dtype: BF_16
197+
block_names: [GPT2Block]
198+
181199
model_raw:
182200
component_key: model
183201
variant_key: gpt2
@@ -198,12 +216,12 @@ model_raw:
198216
bias: false
199217
attention_config:
200218
qkv_transforms:
201-
- type_hint: RotaryTransform
202-
config:
203-
n_embd: ${model_raw.config.n_embd}
204-
n_head: ${model_raw.config.n_head_q}
205-
seq_length_dim: -2
206-
base_freq: 100000
219+
- type_hint: RotaryTransform
220+
config:
221+
n_embd: ${model_raw.config.n_embd}
222+
n_head: ${model_raw.config.n_head_q}
223+
seq_length_dim: -2
224+
base_freq: 100000
207225
attention_implementation: pytorch_flash
208226
activation_type: swiglu
209227
attention_norm_config:
@@ -238,24 +256,9 @@ lr_scheduler:
238256
total_steps: ${settings.training_target.num_target_steps}
239257
pct_start: 0.01
240258
anneal_strategy: cos
241-
last_epoch: ${settings.training_progress.last_step}
259+
# last_epoch: ${settings.training_progress.last_step}
242260

243261
optimizer:
244-
component_key: optimizer
245-
variant_key: fsdp1_checkpointed
246-
config:
247-
optimizer:
248-
instance_key: optimizer_original
249-
pass_type: BY_REFERENCE
250-
wrapped_model:
251-
instance_key: wrapped_model
252-
pass_type: BY_REFERENCE
253-
checkpoint_loading:
254-
instance_key: checkpoint_loading
255-
pass_type: BY_REFERENCE
256-
checkpoint_path: ${settings.warmstart_checkpoint_paths.optimizer_checkpoint_path}
257-
258-
optimizer_original:
259262
component_key: optimizer
260263
variant_key: adam_w
261264
config:
@@ -265,15 +268,15 @@ optimizer_original:
265268
weight_decay: 1e-1
266269
weight_decay_groups_excluded: [embedding, layernorm]
267270
wrapped_model:
268-
instance_key: wrapped_model
271+
instance_key: initialized_model
269272
pass_type: BY_REFERENCE
270273

271274
gradient_clipper:
272275
component_key: gradient_clipper
273-
variant_key: fsdp1
276+
variant_key: fsdp2
274277
config:
275278
wrapped_model:
276-
instance_key: wrapped_model
279+
instance_key: initialized_model
277280
pass_type: BY_REFERENCE
278281
norm_type: P2_NORM
279282
max_norm: 1.0
Lines changed: 12 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import glob
2-
import json
2+
import os
33
import re
44
from pathlib import Path
55

@@ -11,67 +11,25 @@ def _get_checkpoint_file_name_without_eid(checkpoint_file_name: str) -> str:
1111

1212
def test_checkpoint_files_exist(checkpoint_folder_path: list[Path], expected_checkpoint_names: list[str]):
1313
# Check if all the checkpoint files exist and have the correct names
14-
checkpoint_paths = glob.glob(str(checkpoint_folder_path / "**/*.bin"), recursive=True)
14+
checkpoint_paths = glob.glob(str(checkpoint_folder_path / "**/*"), recursive=True)
1515

16-
assert len(checkpoint_paths) == 6, "ERROR! Expected 6 checkpoint files."
16+
assert len(checkpoint_paths) == 17, "ERROR! Expected 6 checkpoint files."
1717

18-
for checkpoint_path in checkpoint_paths:
19-
checkpoint_file_name = Path(checkpoint_path).name
20-
cleaned_checkpoint_file_name = _get_checkpoint_file_name_without_eid(checkpoint_file_name)
21-
22-
assert (
23-
cleaned_checkpoint_file_name in expected_checkpoint_names
24-
), f"ERROR! {checkpoint_file_name} is not a valid checkpoint file name."
25-
26-
27-
def check_last_checkpoint_info_correctness(checkpoint_folder_path: Path, expected_last_checkpoint_names: list[str]):
28-
# Check if the last checkpoint info files reference the correct checkpoint files
29-
30-
checkpoint_info_paths = glob.glob(str(checkpoint_folder_path / "**/*.json"), recursive=True)
31-
32-
assert len(checkpoint_info_paths) == 2, "ERROR! Expected 2 checkpoint info files."
33-
34-
assert len(set(checkpoint_info_paths)) == len(
35-
checkpoint_info_paths
36-
), "ERROR! Duplicate checkpoint info files found."
37-
38-
for checkpoint_info_path in checkpoint_info_paths:
39-
with open(checkpoint_info_path, "r") as f:
40-
checkpoint_info = json.load(f)
41-
model_checkpoint_path = Path(checkpoint_info["model_checkpoint_path"])
42-
optimizer_checkpoint_path = Path(checkpoint_info["optimizer_checkpoint_path"])
43-
assert model_checkpoint_path.exists(), f"ERROR! {model_checkpoint_path} does not exist."
44-
assert optimizer_checkpoint_path.exists(), f"ERROR! {optimizer_checkpoint_path} does not exist."
45-
46-
cleaned_model_checkpoint_file_name = _get_checkpoint_file_name_without_eid(model_checkpoint_path.name)
47-
cleaned_optimizer_checkpoint_file_name = _get_checkpoint_file_name_without_eid(optimizer_checkpoint_path.name)
48-
49-
assert cleaned_model_checkpoint_file_name in expected_last_checkpoint_names
50-
assert cleaned_optimizer_checkpoint_file_name in expected_last_checkpoint_names
18+
assert len([p for p in checkpoint_paths if p.endswith(".distcp")]), "ERROR! Expected 6 checkpoint files."
5119

5220

5321
if __name__ == "__main__":
54-
checkpoint_folder_path = Path("../data/checkpoints")
22+
current_file_path = Path(__file__).resolve()
23+
os.chdir(current_file_path.parent)
5524

56-
expected_checkpoint_names = [
57-
# pretrain checkpoint
58-
"model-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin",
59-
"optimizer-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin",
60-
# warmstart checkpoints
61-
"model-seen_steps_15-seen_tokens_61440-target_steps_20-target_tokens_81920.bin",
62-
"optimizer-seen_steps_15-seen_tokens_61440-target_steps_20-target_tokens_81920.bin",
63-
"model-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin",
64-
"optimizer-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin",
65-
]
25+
checkpoint_folder_path = Path("../data/checkpoints")
6626

67-
expected_last_checkpoint_names = [
27+
expected_checkpoint_folder_names = [
6828
# pretrain checkpoint
69-
"model-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin",
70-
"optimizer-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin",
29+
"seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920",
7130
# warmstart checkpoints
72-
"model-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin",
73-
"optimizer-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin",
31+
"seen_steps_15-seen_tokens_61440-target_steps_20-target_tokens_81920",
32+
"seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920",
7433
]
7534

76-
test_checkpoint_files_exist(checkpoint_folder_path, expected_checkpoint_names)
77-
check_last_checkpoint_info_correctness(checkpoint_folder_path, expected_last_checkpoint_names)
35+
test_checkpoint_files_exist(checkpoint_folder_path, expected_checkpoint_folder_names)

0 commit comments

Comments
 (0)
Please sign in to comment.