@@ -33,12 +33,12 @@ settings:
33
33
component_key : number_conversion
34
34
variant_key : global_num_seen_tokens_from_checkpoint_path
35
35
config :
36
- checkpoint_path : ${settings.warmstart_checkpoint_paths.model_checkpoint_path }
36
+ checkpoint_path : ${settings.warmstart_checkpoint_paths.checkpoint_folder_path }
37
37
num_seen_steps : # for the batch progress subscriber
38
38
component_key : number_conversion
39
39
variant_key : num_seen_steps_from_checkpoint_path
40
40
config :
41
- checkpoint_path : ${settings.warmstart_checkpoint_paths.model_checkpoint_path }
41
+ checkpoint_path : ${settings.warmstart_checkpoint_paths.checkpoint_folder_path }
42
42
num_seen_samples :
43
43
component_key : number_conversion
44
44
variant_key : num_samples_from_num_tokens
@@ -49,7 +49,7 @@ settings:
49
49
component_key : number_conversion
50
50
variant_key : last_step_from_checkpoint_path
51
51
config :
52
- checkpoint_path : ${settings.warmstart_checkpoint_paths.model_checkpoint_path }
52
+ checkpoint_path : ${settings.warmstart_checkpoint_paths.checkpoint_folder_path }
53
53
warmstart_checkpoint_paths : ${warmstart_env:checkpoint_paths}
54
54
55
55
collate_fn :
@@ -104,12 +104,9 @@ eval_dataloaders: []
104
104
105
105
checkpoint_loading :
106
106
component_key : checkpoint_loading
107
- variant_key : fsdp1
107
+ variant_key : dcp
108
108
config :
109
109
global_rank : ${settings.cuda_env.global_rank}
110
- block_names : [GPT2Block]
111
- mixed_precision_settings : BF_16
112
- sharding_strategy : FULL_SHARD
113
110
114
111
checkpoint_saving :
115
112
component_key : checkpoint_saving
@@ -122,7 +119,7 @@ checkpoint_saving:
122
119
k : -1 # -1 to save all checkpoints
123
120
checkpoint_saving_execution :
124
121
component_key : checkpoint_saving_execution
125
- variant_key : fsdp1
122
+ variant_key : dcp
126
123
config :
127
124
checkpoint_path : ${settings.paths.checkpoint_saving_path}
128
125
global_rank : ${settings.cuda_env.global_rank}
@@ -135,12 +132,30 @@ loss_fn:
135
132
target_key : ${settings.referencing_keys.target_key}
136
133
prediction_key : ${settings.referencing_keys.prediction_key}
137
134
135
+ device_mesh :
136
+ component_key : device_mesh
137
+ variant_key : default
138
+ config :
139
+ device_type : cuda
140
+ data_parallel_replicate_degree : 1
141
+ data_parallel_shard_degree : ${settings.cuda_env.world_size}
142
+ world_size : ${settings.cuda_env.world_size}
143
+
138
144
app_state :
145
+ component_key : app_state
146
+ variant_key : dcp
147
+ config :
148
+ raw_app_state :
149
+ instance_key : app_state_raw
150
+ pass_type : BY_REFERENCE
151
+ checkpoint_dir_path : ${settings.warmstart_checkpoint_paths.checkpoint_folder_path}
152
+
153
+ app_state_raw :
139
154
component_key : app_state
140
155
variant_key : raw
141
156
config :
142
- model :
143
- instance_key : wrapped_model
157
+ model :
158
+ instance_key : initialized_model
144
159
pass_type : BY_REFERENCE
145
160
optimizer :
146
161
instance_key : optimizer
@@ -149,24 +164,12 @@ app_state:
149
164
instance_key : lr_scheduler
150
165
pass_type : BY_REFERENCE
151
166
152
- wrapped_model :
153
- component_key : model
154
- variant_key : fsdp1_checkpointed
155
- config :
156
- model :
157
- instance_key : model
158
- pass_type : BY_REFERENCE
159
- checkpoint_loading :
160
- instance_key : checkpoint_loading
161
- pass_type : BY_REFERENCE
162
- checkpoint_path : ${settings.warmstart_checkpoint_paths.model_checkpoint_path}
163
-
164
- model :
167
+ initialized_model :
165
168
component_key : model
166
169
variant_key : model_initialized
167
170
config :
168
171
model :
169
- instance_key : model_raw
172
+ instance_key : fsdp_model
170
173
pass_type : BY_REFERENCE
171
174
model_initializer :
172
175
component_key : model_initialization
@@ -178,6 +181,21 @@ model:
178
181
std : 0.02
179
182
num_layers : ${model_raw.config.n_layer}
180
183
184
+ fsdp_model :
185
+ component_key : model
186
+ variant_key : fsdp2_wrapped
187
+ config :
188
+ model :
189
+ instance_key : model_raw
190
+ pass_type : BY_REFERENCE
191
+ device_mesh :
192
+ instance_key : device_mesh
193
+ pass_type : BY_REFERENCE
194
+ mixed_precision_settings :
195
+ param_dtype : BF_16
196
+ reduce_dtype : BF_16
197
+ block_names : [GPT2Block]
198
+
181
199
model_raw :
182
200
component_key : model
183
201
variant_key : gpt2
@@ -198,12 +216,12 @@ model_raw:
198
216
bias : false
199
217
attention_config :
200
218
qkv_transforms :
201
- - type_hint : RotaryTransform
202
- config :
203
- n_embd : ${model_raw.config.n_embd}
204
- n_head : ${model_raw.config.n_head_q}
205
- seq_length_dim : -2
206
- base_freq : 100000
219
+ - type_hint : RotaryTransform
220
+ config :
221
+ n_embd : ${model_raw.config.n_embd}
222
+ n_head : ${model_raw.config.n_head_q}
223
+ seq_length_dim : -2
224
+ base_freq : 100000
207
225
attention_implementation : pytorch_flash
208
226
activation_type : swiglu
209
227
attention_norm_config :
@@ -238,24 +256,9 @@ lr_scheduler:
238
256
total_steps : ${settings.training_target.num_target_steps}
239
257
pct_start : 0.01
240
258
anneal_strategy : cos
241
- last_epoch : ${settings.training_progress.last_step}
259
+ # last_epoch: ${settings.training_progress.last_step}
242
260
243
261
optimizer :
244
- component_key : optimizer
245
- variant_key : fsdp1_checkpointed
246
- config :
247
- optimizer :
248
- instance_key : optimizer_original
249
- pass_type : BY_REFERENCE
250
- wrapped_model :
251
- instance_key : wrapped_model
252
- pass_type : BY_REFERENCE
253
- checkpoint_loading :
254
- instance_key : checkpoint_loading
255
- pass_type : BY_REFERENCE
256
- checkpoint_path : ${settings.warmstart_checkpoint_paths.optimizer_checkpoint_path}
257
-
258
- optimizer_original :
259
262
component_key : optimizer
260
263
variant_key : adam_w
261
264
config :
@@ -265,15 +268,15 @@ optimizer_original:
265
268
weight_decay : 1e-1
266
269
weight_decay_groups_excluded : [embedding, layernorm]
267
270
wrapped_model :
268
- instance_key : wrapped_model
271
+ instance_key : initialized_model
269
272
pass_type : BY_REFERENCE
270
273
271
274
gradient_clipper :
272
275
component_key : gradient_clipper
273
- variant_key : fsdp1
276
+ variant_key : fsdp2
274
277
config :
275
278
wrapped_model :
276
- instance_key : wrapped_model
279
+ instance_key : initialized_model
277
280
pass_type : BY_REFERENCE
278
281
norm_type : P2_NORM
279
282
max_norm : 1.0
0 commit comments