Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c3f43b8

Browse files
committedFeb 19, 2025·
feat: added warmstart example to end2end test
1 parent 1725aae commit c3f43b8

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed
 

‎tests/tests.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,17 @@ def main(cpu: bool = False, single_gpu: bool = False, multi_gpu: bool = False, d
122122

123123
check_existence_and_clear_getting_started_example_output(run_getting_started_example_directory, date_of_run)
124124

125+
# warmstart example
126+
print("\n=== RUN WARMSTART EXAMPLE ===")
127+
run_warmstart_example_directory = _ROOT_DIR / "tutorials/warmstart/scripts"
128+
run_warmstart_example_script = _ROOT_DIR / "tutorials/warmstart/scripts/pre_train_and_warmstart.sh"
129+
assert isfile(run_warmstart_example_script), f"ERROR! {run_warmstart_example_script} does not exist."
130+
command_warmstart_example = (
131+
f"cd {run_warmstart_example_directory}; sh pre_train_and_warmstart.sh {devices[0]} {devices[1]}"
132+
)
133+
print(command_warmstart_example)
134+
subprocess.run(command_warmstart_example, shell=True, capture_output=False, text=True)
135+
125136
print("\n=== DONE ===")
126137

127138

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import glob
2+
import json
3+
import re
4+
from pathlib import Path
5+
6+
7+
def _get_checkpoint_file_name_without_eid(checkpoint_file_name: str) -> str:
8+
# Remove the experiment id from the checkpoint file name
9+
return re.sub(r"^eid_\d{4}-\d{2}-\d{2}__\d{2}-\d{2}-\d{2}_[a-f0-9]+-", "", checkpoint_file_name)
10+
11+
12+
def test_checkpoint_files_exist(checkpoint_folder_path: list[Path], expected_checkpoint_names: list[str]):
13+
# Check if all the checkpoint files exist and have the correct names
14+
checkpoint_paths = glob.glob(str(checkpoint_folder_path / "**/*.bin"), recursive=True)
15+
16+
assert len(checkpoint_paths) == 6, "ERROR! Expected 6 checkpoint files."
17+
18+
for checkpoint_path in checkpoint_paths:
19+
checkpoint_file_name = Path(checkpoint_path).name
20+
cleaned_checkpoint_file_name = _get_checkpoint_file_name_without_eid(checkpoint_file_name)
21+
22+
assert (
23+
cleaned_checkpoint_file_name in expected_checkpoint_names
24+
), f"ERROR! {checkpoint_file_name} is not a valid checkpoint file name."
25+
26+
27+
def check_last_checkpoint_info_correctness(checkpoint_folder_path: Path, expected_last_checkpoint_names: list[str]):
28+
# Check if the last checkpoint info files reference the correct checkpoint files
29+
30+
checkpoint_info_paths = glob.glob(str(checkpoint_folder_path / "**/*.json"), recursive=True)
31+
32+
assert len(checkpoint_info_paths) == 2, "ERROR! Expected 2 checkpoint info files."
33+
34+
assert len(set(checkpoint_info_paths)) == len(
35+
checkpoint_info_paths
36+
), "ERROR! Duplicate checkpoint info files found."
37+
38+
for checkpoint_info_path in checkpoint_info_paths:
39+
with open(checkpoint_info_path, "r") as f:
40+
checkpoint_info = json.load(f)
41+
model_checkpoint_path = Path(checkpoint_info["model_checkpoint_path"])
42+
optimizer_checkpoint_path = Path(checkpoint_info["optimizer_checkpoint_path"])
43+
assert model_checkpoint_path.exists(), f"ERROR! {model_checkpoint_path} does not exist."
44+
assert optimizer_checkpoint_path.exists(), f"ERROR! {optimizer_checkpoint_path} does not exist."
45+
46+
cleaned_model_checkpoint_file_name = _get_checkpoint_file_name_without_eid(model_checkpoint_path.name)
47+
cleaned_optimizer_checkpoint_file_name = _get_checkpoint_file_name_without_eid(optimizer_checkpoint_path.name)
48+
49+
assert cleaned_model_checkpoint_file_name in expected_last_checkpoint_names
50+
assert cleaned_optimizer_checkpoint_file_name in expected_last_checkpoint_names
51+
52+
53+
if __name__ == "__main__":
54+
checkpoint_folder_path = Path("../data/checkpoints")
55+
56+
expected_checkpoint_names = [
57+
# pretrain checkpoint
58+
"model-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin",
59+
"optimizer-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin",
60+
# warmstart checkpoints
61+
"model-seen_steps_15-seen_tokens_61440-target_steps_20-target_tokens_81920.bin",
62+
"optimizer-seen_steps_15-seen_tokens_61440-target_steps_20-target_tokens_81920.bin",
63+
"model-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin",
64+
"optimizer-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin",
65+
]
66+
67+
expected_last_checkpoint_names = [
68+
# pretrain checkpoint
69+
"model-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin",
70+
"optimizer-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin",
71+
# warmstart checkpoints
72+
"model-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin",
73+
"optimizer-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin",
74+
]
75+
76+
test_checkpoint_files_exist(checkpoint_folder_path, expected_checkpoint_names)
77+
check_last_checkpoint_info_correctness(checkpoint_folder_path, expected_last_checkpoint_names)

0 commit comments

Comments
 (0)
Please sign in to comment.