|
| 1 | +import glob |
| 2 | +import json |
| 3 | +import re |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | + |
| 7 | +def _get_checkpoint_file_name_without_eid(checkpoint_file_name: str) -> str: |
| 8 | + # Remove the experiment id from the checkpoint file name |
| 9 | + return re.sub(r"^eid_\d{4}-\d{2}-\d{2}__\d{2}-\d{2}-\d{2}_[a-f0-9]+-", "", checkpoint_file_name) |
| 10 | + |
| 11 | + |
| 12 | +def test_checkpoint_files_exist(checkpoint_folder_path: list[Path], expected_checkpoint_names: list[str]): |
| 13 | + # Check if all the checkpoint files exist and have the correct names |
| 14 | + checkpoint_paths = glob.glob(str(checkpoint_folder_path / "**/*.bin"), recursive=True) |
| 15 | + |
| 16 | + assert len(checkpoint_paths) == 6, "ERROR! Expected 6 checkpoint files." |
| 17 | + |
| 18 | + for checkpoint_path in checkpoint_paths: |
| 19 | + checkpoint_file_name = Path(checkpoint_path).name |
| 20 | + cleaned_checkpoint_file_name = _get_checkpoint_file_name_without_eid(checkpoint_file_name) |
| 21 | + |
| 22 | + assert ( |
| 23 | + cleaned_checkpoint_file_name in expected_checkpoint_names |
| 24 | + ), f"ERROR! {checkpoint_file_name} is not a valid checkpoint file name." |
| 25 | + |
| 26 | + |
| 27 | +def check_last_checkpoint_info_correctness(checkpoint_folder_path: Path, expected_last_checkpoint_names: list[str]): |
| 28 | + # Check if the last checkpoint info files reference the correct checkpoint files |
| 29 | + |
| 30 | + checkpoint_info_paths = glob.glob(str(checkpoint_folder_path / "**/*.json"), recursive=True) |
| 31 | + |
| 32 | + assert len(checkpoint_info_paths) == 2, "ERROR! Expected 2 checkpoint info files." |
| 33 | + |
| 34 | + assert len(set(checkpoint_info_paths)) == len( |
| 35 | + checkpoint_info_paths |
| 36 | + ), "ERROR! Duplicate checkpoint info files found." |
| 37 | + |
| 38 | + for checkpoint_info_path in checkpoint_info_paths: |
| 39 | + with open(checkpoint_info_path, "r") as f: |
| 40 | + checkpoint_info = json.load(f) |
| 41 | + model_checkpoint_path = Path(checkpoint_info["model_checkpoint_path"]) |
| 42 | + optimizer_checkpoint_path = Path(checkpoint_info["optimizer_checkpoint_path"]) |
| 43 | + assert model_checkpoint_path.exists(), f"ERROR! {model_checkpoint_path} does not exist." |
| 44 | + assert optimizer_checkpoint_path.exists(), f"ERROR! {optimizer_checkpoint_path} does not exist." |
| 45 | + |
| 46 | + cleaned_model_checkpoint_file_name = _get_checkpoint_file_name_without_eid(model_checkpoint_path.name) |
| 47 | + cleaned_optimizer_checkpoint_file_name = _get_checkpoint_file_name_without_eid(optimizer_checkpoint_path.name) |
| 48 | + |
| 49 | + assert cleaned_model_checkpoint_file_name in expected_last_checkpoint_names |
| 50 | + assert cleaned_optimizer_checkpoint_file_name in expected_last_checkpoint_names |
| 51 | + |
| 52 | + |
| 53 | +if __name__ == "__main__": |
| 54 | + checkpoint_folder_path = Path("../data/checkpoints") |
| 55 | + |
| 56 | + expected_checkpoint_names = [ |
| 57 | + # pretrain checkpoint |
| 58 | + "model-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin", |
| 59 | + "optimizer-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin", |
| 60 | + # warmstart checkpoints |
| 61 | + "model-seen_steps_15-seen_tokens_61440-target_steps_20-target_tokens_81920.bin", |
| 62 | + "optimizer-seen_steps_15-seen_tokens_61440-target_steps_20-target_tokens_81920.bin", |
| 63 | + "model-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin", |
| 64 | + "optimizer-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin", |
| 65 | + ] |
| 66 | + |
| 67 | + expected_last_checkpoint_names = [ |
| 68 | + # pretrain checkpoint |
| 69 | + "model-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin", |
| 70 | + "optimizer-seen_steps_11-seen_tokens_45056-target_steps_20-target_tokens_81920.bin", |
| 71 | + # warmstart checkpoints |
| 72 | + "model-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin", |
| 73 | + "optimizer-seen_steps_20-seen_tokens_81920-target_steps_20-target_tokens_81920.bin", |
| 74 | + ] |
| 75 | + |
| 76 | + test_checkpoint_files_exist(checkpoint_folder_path, expected_checkpoint_names) |
| 77 | + check_last_checkpoint_info_correctness(checkpoint_folder_path, expected_last_checkpoint_names) |
0 commit comments