From 90ad8755020c1e3d911882d33d52413b65089948 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 23 May 2024 16:38:28 +0000 Subject: [PATCH 01/16] add profiling tools --- profile.sh | 46 +++++++++++++++++-------- profiling_tools.py | 58 ++++++++++++++++++++++++++++++++ train.py | 83 +++++++++++++++++++++++++++++++++++++++------- 3 files changed, 161 insertions(+), 26 deletions(-) create mode 100644 profiling_tools.py diff --git a/profile.sh b/profile.sh index 5c22cf2..09c379b 100755 --- a/profile.sh +++ b/profile.sh @@ -10,19 +10,37 @@ # For additional information, see: # https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html +# python train.py \ +# --model_name meta-llama/Meta-Llama-3-8B \ +# --train_type hqq_dora \ +# --n_bits 4 \ +# --precision bf16 \ +# --dataset orca_math \ +# --dataset_samples 8 \ +# --batch_size 2 \ +# --context_length 512 \ +# --gradient_accumulation_steps 2 \ +# --use_gradient_checkpointing False \ +# --use_cpu_offload False \ +# --use_activation_cpu_offload False \ +# --save_model False \ +# --profiling_output /tmp/profile +# export CUDA_VISIBLE_DEVICES=3,4 python train.py \ - --model_name meta-llama/Meta-Llama-3-8B \ - --train_type hqq_dora \ - --n_bits 4 \ - --precision bf16 \ - --dataset orca_math \ - --dataset_samples 8 \ - --batch_size 2 \ - --context_length 512 \ - --gradient_accumulation_steps 2 \ - --use_gradient_checkpointing False \ - --use_cpu_offload False \ - --use_activation_cpu_offload False \ - --save_model False \ - --profiling_output /tmp/profile +--world_size -1 \ +--model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ +--gradient_accumulation_steps 2 \ +--batch_size 1 \ +--context_length 256 \ +--num_epochs 1 \ +--sharding_strategy full_shard \ +--precision bf16 \ +--train_type qlora \ +--use_gradient_checkpointing false \ +--use_cpu_offload false \ +--log_to stdout \ +--dataset dummy \ +--profile false \ +--with_stack true \ +--verbose true diff --git a/profiling_tools.py b/profiling_tools.py new file mode 100644 index 0000000..c4575c6 --- /dev/null +++ b/profiling_tools.py @@ -0,0 +1,58 @@ +import torch +import os +from datetime import datetime +from functools import partial +def trace_handler( + prof: torch.profiler.profile, + rank: int, + export_trace=True, + export_memory_timeline=False, + with_stack: bool = True, + group_by_stack: int = 0, + group_by_input_shapes: bool = False, + prefix="", + out_dir="./profiles", + time_fmt_str: str = "%m_%d_%H", + metric="self_cuda_time_total", + row_limit=25, + verbose=False, +): + # Prefix for file names. + timestamp = datetime.now().strftime(time_fmt_str) + file_prefix = os.path.join(out_dir, f"{prefix}-{timestamp}") + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + # Construct the trace file. + if export_trace: + prof.export_chrome_trace(f"{file_prefix}-chrome-trace.json.gz") + + # Construct the memory timeline file. + if export_memory_timeline: + prof.export_memory_timeline( + f"{file_prefix}-memory-timeline.html" + ) + + if with_stack: + prof.export_stacks(f"{file_prefix}-stacks.txt", metric=metric) + + key_avgs = prof.key_averages( + group_by_input_shape=group_by_input_shapes, group_by_stack_n=group_by_stack + ).table(sort_by=metric, row_limit=row_limit) + with open(f"{file_prefix}-key_averages.txt", "w") as f: + print( + key_avgs, file=f + ) + if rank == 0: + print(f"Saving profiling results to {out_dir}") + if verbose: + print(key_avgs) + +class FakeContext: + def __enter__(self): + return self + def __exit__(self, *args, **kwargs): + pass + + def step(self): + pass \ No newline at end of file diff --git a/train.py b/train.py index 6d2f68b..403c6df 100644 --- a/train.py +++ b/train.py @@ -86,6 +86,7 @@ sys.path.append("./scripts") from lora import LORA from dora import BNBDORA, HQQDORA, DORALayer, MagnitudeLayer +from profiling_tools import trace_handler, FakeContext class Logger: def __init__(self, args, log_to="stdout", project_name="fsdp_qlora", entity=None, group=None, name=None, rank=0): @@ -481,14 +482,41 @@ def mlp_policy_fn(module): # Main function, run on each process def fsdp_main(local_rank:int, world_size:int, args:Dict): - profiler_context = profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - with_stack=True, - use_cuda=True, - record_shapes=False, - experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) # https://github.com/pytorch/pytorch/issues/100253 - ) if args["profiling_output"] else nullcontext() - + #Profiler args + if args["profile"]: + #Profiler args + with_stack = False if args["export_trace"] else args["with_stack"] #See https://github.com/pytorch/pytorch/issues/121219 + with_shapes = args["with_shapes"] + profile_memory = args["export_memory_timeline"] + export_trace = args["export_trace"] + export_memory_timeline = args["export_memory_timeline"] + model_name = args["model_name"].split("/")[-1] + train_type = args["train_type"] + prefix = f"{model_name}_{train_type}" + output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}-{local_rank}" + schedule = torch.profiler.schedule(wait=args["wait_steps"], warmup=args["warmup_steps"], active=args["active_steps"], repeat=1) + callback = functools.partial(trace_handler, + export_trace=export_trace, + export_memory_timeline=export_memory_timeline, + with_stack=with_stack, + group_by_stack=5 if with_stack else 0, + prefix=prefix, + out_dir=output_dir, + rank=local_rank, + verbose=args["verbose"]) + #Instantiate profiler + profiler_context = profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + with_stack=with_stack, + profile_memory=profile_memory, + use_cuda=True, + record_shapes=with_shapes, + schedule=schedule, + on_trace_ready=callback, + experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) if with_stack else None, # https://github.com/pytorch/pytorch/issues/100253 + ) + else: + profiler_context = FakeContext() with profiler_context as prof: # Setup and initialize the process group os.environ['MASTER_ADDR'] = args["master_addr"] @@ -826,6 +854,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs): ddp_loss = torch.zeros(2).to(local_rank) for batch_idx, batch in enumerate(dataloader): + accumulate_grads = (batch_idx+1) % gradient_accumulation_steps == 0 # Prevent gradient syncing until update step if using no_sync option. @@ -925,6 +954,17 @@ def load_and_quantize_parallel(name_param, model, **kwargs): if args["log_to"] == 'wandb': logger.log({"loss": log_loss, "lr": log_lr}, rank) ddp_loss = torch.zeros(2).to(local_rank) + + if rank == 0: + print(f"Batch idx {batch_idx}") + + + prof.step() + + if batch_idx > args["max_steps"]: + if rank == 0: + print("Max steps reached, stopping training", rank) + break # Print + log peak memory usage for the whole fourth step of training if epoch == 0 and (rank == 0 or args['verbose']): @@ -999,9 +1039,9 @@ def load_and_quantize_parallel(name_param, model, **kwargs): # Clean up dist.destroy_process_group() - if args["profiling_output"]: - prof.export_stacks(path = f"{args['profiling_output']}_{local_rank}.txt", - metric = "self_cuda_time_total") + # if args["profiling_output"]: + # prof.export_stacks(path = f"{args['profiling_output']}_{local_rank}.txt", + # metric = "self_cuda_time_total") def validate_args(args): if args["n_bits"] != 4 and args["train_type"] not in ["hqq_lora", "hqq_dora", "hqq_llama_pro"]: @@ -1051,12 +1091,22 @@ def fsdp_qlora( group: str = None, # For wandb logging entity: str = None, # For wandb logging n_bits: int = 4, # passed to hqq - profiling_output: str = None, # Output file for profiling + profile: bool = False, # Enable PyTorch profiler + profiling_output: str = None, # Output directory for torch.profiler artifacts + with_stack: bool = False, # Output stacks for profiling + with_shapes: bool = False, # Output shapes for profiling + export_trace: bool = True, # Output trace for profiling + export_memory_timeline: bool = False, # Output memory timelinefor profiling + wait_steps: int = 1, # Wait steps when running profiler + warmup_steps: int = 0, # Warmup steps when running profiler + active_steps: int = 5, # Active steps when running profiler + max_steps: int = 10, # For debugging ): """ Train a model with FSDP and QLoRA/QDoRA. Args: + world_size: Number of GPUs to use. -1 = all available GPUs. train_type: "full", "lora", "qlora", or "custom_qlora" llama_pro_path: Path to the quantized llama pro model @@ -1186,6 +1236,15 @@ def main( group: str = None, # For wandb logging entity: str = None, # For wandb logging n_bits: int = 4, # passed to hqq + profile: bool_arg = False, # Whether to profile with torch.profiler profiling_output: str = "", # Output file prefix for profiling + with_stack: bool_arg = False, # Output stacks for profiling + with_shapes: bool_arg = False, # Output shapes for profiling + export_trace: bool_arg = True, # Output trace for profiling + export_memory_timeline: bool_arg = False, # Output memory timelinefor profiling + wait_steps: int = 1, # Wait steps when running profiler + warmup_steps: int = 1, # Warmup steps when running profiler + active_steps: int = 5, # Active steps when running profiler + max_steps: int = 10, # Max number of training steps (in units of batches), only for debugging when epochs is set to 1 ): fsdp_qlora(**locals()) From e0ce93f3967d2a010ec0bab1ce698e3e5d19bada Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 23 May 2024 19:03:34 +0000 Subject: [PATCH 02/16] change default settings --- profile.sh | 7 ++++--- train.py | 12 ++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/profile.sh b/profile.sh index 09c379b..1e46132 100755 --- a/profile.sh +++ b/profile.sh @@ -28,7 +28,7 @@ # export CUDA_VISIBLE_DEVICES=3,4 python train.py \ --world_size -1 \ ---model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ +--model_name "meta-llama/Llama-2-7b-hf" \ --gradient_accumulation_steps 2 \ --batch_size 1 \ --context_length 256 \ @@ -40,7 +40,8 @@ python train.py \ --use_cpu_offload false \ --log_to stdout \ --dataset dummy \ ---profile false \ ---with_stack true \ +--profile true \ +--export_trace true \ +--export_memory_timeline true \ --verbose true diff --git a/train.py b/train.py index 403c6df..9723784 100644 --- a/train.py +++ b/train.py @@ -485,11 +485,11 @@ def fsdp_main(local_rank:int, world_size:int, args:Dict): #Profiler args if args["profile"]: #Profiler args - with_stack = False if args["export_trace"] else args["with_stack"] #See https://github.com/pytorch/pytorch/issues/121219 - with_shapes = args["with_shapes"] profile_memory = args["export_memory_timeline"] export_trace = args["export_trace"] export_memory_timeline = args["export_memory_timeline"] + with_stack = args["with_stack"] or args["export_memory_timeline"]#False if args["export_trace"] else (args["with_stack"] or args["export_memory_timeline"]) #See https://github.com/pytorch/pytorch/issues/121219 + with_shapes = args["with_shapes"] or export_memory_timeline model_name = args["model_name"].split("/")[-1] train_type = args["train_type"] prefix = f"{model_name}_{train_type}" @@ -1237,14 +1237,14 @@ def main( entity: str = None, # For wandb logging n_bits: int = 4, # passed to hqq profile: bool_arg = False, # Whether to profile with torch.profiler - profiling_output: str = "", # Output file prefix for profiling - with_stack: bool_arg = False, # Output stacks for profiling - with_shapes: bool_arg = False, # Output shapes for profiling + profiling_output: str = "profiles", # Output file prefix for profiling + with_stack: bool_arg = False, # Output stacks for profiling. Note that setting export_memory_timeline will automatically export traces since `with_stack` must be true to profile memory. + with_shapes: bool_arg = False, # Output shapes for profiling. Note that setting export_memory_timeline will automatically export traces since `with_shapes` must be true to profile memory. export_trace: bool_arg = True, # Output trace for profiling export_memory_timeline: bool_arg = False, # Output memory timelinefor profiling wait_steps: int = 1, # Wait steps when running profiler warmup_steps: int = 1, # Warmup steps when running profiler - active_steps: int = 5, # Active steps when running profiler + active_steps: int = 2, # Active steps when running profiler max_steps: int = 10, # Max number of training steps (in units of batches), only for debugging when epochs is set to 1 ): fsdp_qlora(**locals()) From ca343091ba06e3305de0de8e22c1089434cb1860 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 23 May 2024 20:29:11 +0000 Subject: [PATCH 03/16] update profile.sh --- profile.sh | 72 ++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 27 deletions(-) diff --git a/profile.sh b/profile.sh index 1e46132..aac6437 100755 --- a/profile.sh +++ b/profile.sh @@ -1,31 +1,49 @@ -# This generates uses export_stacks to generate profiling output -# /tmp/profile_0.txt, /tmp/profile_1.txt, etc. (1 file per process.) -# -# Output files are generated using export_stacks(), note there are some -# outstanding issues to be aware of: -# https://github.com/pytorch/pytorch/issues/100253 -# -# Profiling output files can be used with speedscope or other tools. -# -# For additional information, see: -# https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html +# Running below will result in a directory `Llama-2-7b_qlora-{local_rank}` with the following artifacts: +# - Llama-2-7b_qlora-chrome-trace.json.gz - interactive trace that can be viewed using `chrome::tracing` or `perfetto` +# - Llama-2-7b_qlora-key_averages.txt - sorted table of events, e.g.: +# | Name | Self CPU % | Self CPU | CPU total % | CPU total | CPU time avg | Self CUDA | Self CUDA % | CUDA total | CUDA time avg | CPU Mem | Self CPU Mem | CUDA Mem | Self CUDA Mem | # of Calls | Source Location | +# |---------------------------------------|------------|------------|-------------|------------|--------------|-------------|-------------|-------------|---------------|---------|--------------|----------|---------------|------------|----------------------------------------------------------------------------------| +# | ProfilerStep* | 0.00% | 0.000us | 0.00% | 0.000us | 0.000us | 4.816s | 44.60% | 4.816s | 963.233ms | 0 b | 0 b | 0 b | 0 b | 5 | | +# | | | | | | | | | | | | | | | | train.py(962): fsdp_main | +# | | | | | | | | | | | | | | | | torch/multiprocessing/spawn.py(75): _wrap | +# | | | | | | | | | | | | | | | | multiprocessing/process.py(108): run | +# | | | | | | | | | | | | | | | | multiprocessing/process.py(314): _bootstrap | +# | FullyShardedDataParallel.forward | 0.00% | 0.000us | 0.00% | 0.000us | 0.000us | 2.208s | 20.45% | 2.208s | 441.555ms | 0 b | 0 b | 0 b | 0 b | 5 | | +# | | | | | | | | | | | | | | | | torch/nn/functional.py(2154): embedding | +# | | | | | | | | | | | | | | | | torch/nn/modules/sparse.py(162): forward | +# | | | | | | | | | | | | | | | | torch/nn/modules/module.py(1534): _call_impl | +# | | | | | | | | | | | | | | | | nn.Module: Embedding_0 | +# | aten::mm | 0.44% | 31.314ms | 0.69% | 48.739ms | 43.517us | 332.421ms | 3.08% | 337.208ms | 301.079us | 0 b | 0 b | 3.26 Gb | 3.26 Gb | 1120 | | +# | | | | | | | | | | | | | | | | bitsandbytes/autograd/_functions.py(492): forward | +# | | | | | | | | | | | | | | | | | +# | | | | | | | | | | | | | | | | torch/autograd/function.py(582): apply | +# | | | | | | | | | | | | | | | | bitsandbytes/autograd/_functions.py(559): matmul_4bit | +# | MatMul4Bit | 2.81% | 198.511ms | 4.93% | 347.437ms | 310.212us | 284.169ms | 2.63% | 630.417ms | 562.872us | 0 b | 0 b | 3.26 Gb | -62.31 Gb | 1120 | | +# | | | | | | | | | | | | | | | | torch/autograd/function.py(582): apply | +# | | | | | | | | | | | | | | | | bitsandbytes/autograd/_functions.py(559): matmul_4bit | +# | | | | | | | | | | | | | | | | bitsandbytes/nn/modules.py(442): forward | +# | | | | | | | | | | | | | | | | torch/nn/modules/module.py(1534): _call_impl | + +# - Llama-2-7b_qlora-memory-timeline.html - Stacked time series plot of memory use broken down by `Parameter`, `Gradients`, `Activations`, etc. +# - Llama-2-7b_qlora-stacks.txt - Stack trace. See [docs](https://pytorch.org/docs/stable/profiler.html#torch.profiler._KinetoProfile.export_stacks). + +# Detailed `CLI` options: +# - `profile` - whether to profile +# - `profiling_outputs` - output directory for `torch.profiler` artifacts +# - `export_trace` - enables exporting of interactive trace that can be viewed and analyzed using `chrome::tracing` +# - `export_memory_timeline` - exports an HTML memory timeline which shows memory use by category (`parameters`, `activations`, `gradients`, etc.) +# - `with_stack` - exports stack trace +# - `with_shapes` - adds shapes of operators to the trace +# - `{wait, warmup, active}_steps` - controls how many profiling steps are recorded: +# - `wait_steps` - number of steps for the profiler to wait before starting to profile +# - `warmup_steps` - number of steps for profiler to profile without recording +# - `active_steps` - number of steps to record +# See [docs](https://pytorch.org/docs/stable/profiler.html#torch.profiler.schedule) for further details. + +# The default schedule for the profiler is set such that only 2 steps of the each epoch are recorded (not counting `wait` and `warmup` steps which are not recorded). + +# Note that `with_stack` and `with_shapes` are overridden by `export_memory_timeline` since the memory profile requires these options to be `True`. -# python train.py \ -# --model_name meta-llama/Meta-Llama-3-8B \ -# --train_type hqq_dora \ -# --n_bits 4 \ -# --precision bf16 \ -# --dataset orca_math \ -# --dataset_samples 8 \ -# --batch_size 2 \ -# --context_length 512 \ -# --gradient_accumulation_steps 2 \ -# --use_gradient_checkpointing False \ -# --use_cpu_offload False \ -# --use_activation_cpu_offload False \ -# --save_model False \ -# --profiling_output /tmp/profile -# export CUDA_VISIBLE_DEVICES=3,4 python train.py \ --world_size -1 \ --model_name "meta-llama/Llama-2-7b-hf" \ From 175cf8fbc302b9297a15e5d08e351c0c5ddb4ecb Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 23 May 2024 20:44:20 +0000 Subject: [PATCH 04/16] add usage notes to profile.sh --- profile.sh | 6 ++++-- train.py | 16 ++++++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/profile.sh b/profile.sh index aac6437..d10fd79 100755 --- a/profile.sh +++ b/profile.sh @@ -44,6 +44,9 @@ # Note that `with_stack` and `with_shapes` are overridden by `export_memory_timeline` since the memory profile requires these options to be `True`. +#**IMPORTANT** There are issues with recording stack traces and exporting traces simultaneously (see this [issue](https://github.com/pytorch/pytorch/issues/113564)) depending on `python` version. The only combination I was able to get both to work at the same time was with `python=3.11.9` and `torch=2.3.0`. +#Tested on `python=3.11.9 and torch=2.3.0`` + python train.py \ --world_size -1 \ --model_name "meta-llama/Llama-2-7b-hf" \ @@ -61,5 +64,4 @@ python train.py \ --profile true \ --export_trace true \ --export_memory_timeline true \ ---verbose true - +--max_steps 10 diff --git a/train.py b/train.py index 9723784..f4d1a67 100644 --- a/train.py +++ b/train.py @@ -494,7 +494,7 @@ def fsdp_main(local_rank:int, world_size:int, args:Dict): train_type = args["train_type"] prefix = f"{model_name}_{train_type}" output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}-{local_rank}" - schedule = torch.profiler.schedule(wait=args["wait_steps"], warmup=args["warmup_steps"], active=args["active_steps"], repeat=1) + schedule = torch.profiler.schedule(wait=args["wait_steps"], warmup=args["warmup_steps"], active=args["active_steps"], repeat=args["repeat"]) callback = functools.partial(trace_handler, export_trace=export_trace, export_memory_timeline=export_memory_timeline, @@ -515,6 +515,9 @@ def fsdp_main(local_rank:int, world_size:int, args:Dict): on_trace_ready=callback, experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) if with_stack else None, # https://github.com/pytorch/pytorch/issues/100253 ) + if args["max_steps"] > 0: + #Ensure enough steps to accommodate profiler schedule + args["max_steps"] = max(args["max_steps"], args["wait_steps"] + args["warmup_steps"] + args["active_steps"]) else: profiler_context = FakeContext() with profiler_context as prof: @@ -961,9 +964,9 @@ def load_and_quantize_parallel(name_param, model, **kwargs): prof.step() - if batch_idx > args["max_steps"]: + if args["max_steps"] > 0 and batch_idx > args["max_steps"]: if rank == 0: - print("Max steps reached, stopping training", rank) + print("Max steps reached, skipping rest of epoch") break # Print + log peak memory usage for the whole fourth step of training @@ -1100,8 +1103,8 @@ def fsdp_qlora( wait_steps: int = 1, # Wait steps when running profiler warmup_steps: int = 0, # Warmup steps when running profiler active_steps: int = 5, # Active steps when running profiler - max_steps: int = 10, # For debugging - ): + repeat: int = 1, #Number of profiler cycles (wait + warmup + active) + max_steps: int = -1, # Max number of training steps (in units of batches) per epoch. -1 means no max_steps, otherwise training loop breaks after `max_steps` each epoch. ): """ Train a model with FSDP and QLoRA/QDoRA. @@ -1245,6 +1248,7 @@ def main( wait_steps: int = 1, # Wait steps when running profiler warmup_steps: int = 1, # Warmup steps when running profiler active_steps: int = 2, # Active steps when running profiler - max_steps: int = 10, # Max number of training steps (in units of batches), only for debugging when epochs is set to 1 + repeat: int = 1, #Number of profiler cycles (wait + warmup + active) + max_steps: int = -1, # Max number of training steps (in units of batches) per epoch. -1 means no max_steps, otherwise training loop breaks after `max_steps` each epoch. ): fsdp_qlora(**locals()) From 6e438fd66cf9e62f5bf50d73842f48e44033ab2a Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 23 May 2024 21:04:32 +0000 Subject: [PATCH 05/16] update profile.sh --- profile.sh | 1 - train.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/profile.sh b/profile.sh index d10fd79..14a917d 100755 --- a/profile.sh +++ b/profile.sh @@ -48,7 +48,6 @@ #Tested on `python=3.11.9 and torch=2.3.0`` python train.py \ ---world_size -1 \ --model_name "meta-llama/Llama-2-7b-hf" \ --gradient_accumulation_steps 2 \ --batch_size 1 \ diff --git a/train.py b/train.py index f4d1a67..a3005eb 100644 --- a/train.py +++ b/train.py @@ -1104,7 +1104,8 @@ def fsdp_qlora( warmup_steps: int = 0, # Warmup steps when running profiler active_steps: int = 5, # Active steps when running profiler repeat: int = 1, #Number of profiler cycles (wait + warmup + active) - max_steps: int = -1, # Max number of training steps (in units of batches) per epoch. -1 means no max_steps, otherwise training loop breaks after `max_steps` each epoch. ): + max_steps: int = -1, # Max number of training steps (in units of batches) per epoch. -1 means no max_steps, otherwise training loop breaks after `max_steps` each epoch. +): """ Train a model with FSDP and QLoRA/QDoRA. From 512b3a3a890701701c9777bc47979eedfed3bfbb Mon Sep 17 00:00:00 2001 From: jeromeku Date: Thu, 23 May 2024 21:08:48 +0000 Subject: [PATCH 06/16] clean up comments --- profiling_tools.py | 4 ++++ train.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/profiling_tools.py b/profiling_tools.py index c4575c6..2e09161 100644 --- a/profiling_tools.py +++ b/profiling_tools.py @@ -49,6 +49,10 @@ def trace_handler( print(key_avgs) class FakeContext: + """ + Fake context when not using profiler with profiling script. + + """ def __enter__(self): return self def __exit__(self, *args, **kwargs): diff --git a/train.py b/train.py index a3005eb..7ae5fb1 100644 --- a/train.py +++ b/train.py @@ -488,7 +488,7 @@ def fsdp_main(local_rank:int, world_size:int, args:Dict): profile_memory = args["export_memory_timeline"] export_trace = args["export_trace"] export_memory_timeline = args["export_memory_timeline"] - with_stack = args["with_stack"] or args["export_memory_timeline"]#False if args["export_trace"] else (args["with_stack"] or args["export_memory_timeline"]) #See https://github.com/pytorch/pytorch/issues/121219 + with_stack = args["with_stack"] or args["export_memory_timeline"] with_shapes = args["with_shapes"] or export_memory_timeline model_name = args["model_name"].split("/")[-1] train_type = args["train_type"] From 9b29cfd83f79908dd43aae84275f0d52ad36789c Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 29 May 2024 20:58:57 +0000 Subject: [PATCH 07/16] refactor profiling tools --- PROFILING.md | 61 ++++ profiling_utils.py | 125 +++++++ train.py | 786 +++++++++++++++++++++------------------------ 3 files changed, 560 insertions(+), 412 deletions(-) create mode 100644 PROFILING.md create mode 100644 profiling_utils.py diff --git a/PROFILING.md b/PROFILING.md new file mode 100644 index 0000000..59682b5 --- /dev/null +++ b/PROFILING.md @@ -0,0 +1,61 @@ +## Profiling + +## Usage + +**IMPORTANT** +There are issues with recording stack traces and exporting traces simultaneously (see this [issue](https://github.com/pytorch/pytorch/issues/113564)) depending on `python` version. The only combination I was able to get both to work at the same time was with `python=3.11.9` and `torch=2.3.0`. + +Running the following: + +``` +python train.py \ +--model_name "meta-llama/Llama-2-7b-hf" \ +--train_type qlora \ +--profile true \ +--export_trace true \ +--export_memory_timeline true \ +--max_steps 10 +``` + +will result in a directory `{model_name}_{train_type}-{local_rank}` with the following artifacts: + +- `{model_name}-{train_type}-chrome-trace.json.gz` - interactive trace that can be viewed using `chrome::tracing` or `perfetto` +- `{model_name}-{train_type}-key_averages.txt` - sorted table of events, e.g.: + +``` + +``` + +- `{model_name}-{train_type}-memory-timeline.html` - Stacked time series plot of memory use broken down by `Parameter`, `Gradients`, `Activations`, etc. +- `{model_name}-{train_type}-stacks.txt` - Stack trace. See [docs](https://pytorch.org/docs/stable/profiler.html#torch.profiler._KinetoProfile.export_stacks). + +Detailed `CLI` options: + +- `profile` - whether to profile +- `profiling_outputs` - output directory for `torch.profiler` artifacts +- `export_trace` - enables exporting of interactive trace that can be viewed and analyzed using `chrome::tracing` +- `export_memory_timeline` - exports an HTML memory timeline which shows memory use by category (`parameters`, `activations`, `gradients`, etc.) +- `with_stack` - exports stack trace +- `with_shapes` - adds shapes of operators to the trace +- `{wait, warmup, active}_steps, repeat, profiling_frequency` - controls the profiling schedule: + + - `wait_steps` - number of steps for the profiler to wait before starting to profile. Overridden if `repeat=0` (see note below). + - `warmup_steps` - number of steps for profiler to profile without recording + - `active_steps` - number of steps to record + - `repeat` - number of times to repeat the above cycle of `wait, warmup, active` if `repeat > 0` else cycles forever + - `profiling_frequency` - profiling frequency in steps. Only used if `repeat = 0`, in which case `wait_steps = profiling_frequency - (warmup_steps + active_steps)` such that the effective cycle length = `profiling_frequency`. E.g., if `profiling_frequency=10`, `warmup_steps=2`, `active_steps=1`, then the profiler will wait 8 steps, warmup for 2, record for 1, then repeat. + + **Note**: Simplest to think of 2 ways of scheduling the profiler: + + 1. Set `repeat` to the number of total number of desired profiling cycles. For example if `wait=1`, `warmup=1`, `active=1`, and `repeat=1`, then the profiler will wait for 1 step, warmup for 1, and record for 1 then stop. + 2. Set `repeat` to `0` and `profiling_frequency` to the cycle length. E.g., with `repeat=0`, `profiling_frequency=10`, `warmup=2`, `active=1`, then `wait` will be automatically set to `profiling_frequency - (warmup + active) = 7`. The profiler will then continuously execute the following cycle: wait for 7 steps, warmup for 2, record for 1. + + See [docs](https://pytorch.org/docs/stable/profiler.html#torch.profiler.schedule) for further details. + +- `max_steps` - maximum number of batches per epoch. E.g., with `num_epochs=1`, stops training after `max_steps` of batches. Note that this is automatically adjusted to accommodate the profiler schedule; for example, if `max_steps < wait_steps + warmup_steps + active_steps`, it will automatically be set to `wait_steps + warmup_steps + active_steps` such that the profiler can run for at least 1 cycle. + +#### Additional Notes + +The default schedule for the profiler is set such that to continuously execute a 10-step cycle: wait for 7, warmup for 2, record for 1. + +`with_stack` and `with_shapes` are overridden by `export_memory_timeline` since the memory profile requires these options to be `True`. diff --git a/profiling_utils.py b/profiling_utils.py new file mode 100644 index 0000000..cabf1b7 --- /dev/null +++ b/profiling_utils.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import time +import logging +import torch +import torch.distributed +from functools import partial +WARMUP = 3 + +logger = logging.getLogger() + +#adapted from https://github.com/pytorch/torchtitan + +def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_cuda_time_total", with_stack=True, group_by_stack=0, group_by_input_shapes=False, row_limit=25): + curr_trace_dir_name = str(prof.step_num) + curr_trace_dir = os.path.join(output_dir, curr_trace_dir_name) + if not os.path.exists(curr_trace_dir): + os.makedirs(curr_trace_dir, exist_ok=True) + + #Export chrome / tensorboard trace + logger.info(f"Dumping traces at step {prof.step_num}") + begin = time.monotonic() + prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") + logger.info( + f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds" + ) + + #Construct the memory timeline file. + if export_memory_timeline: + prof.export_memory_timeline( + f"{curr_trace_dir}/rank{rank}_memory-timeline.html" + ) + + #Dump stack traces + if with_stack: + prof.export_stacks(f"{curr_trace_dir}/rank{rank}_stacks.txt", metric=metric) + + #Export event averages + key_avgs = prof.key_averages( + group_by_input_shape=group_by_input_shapes, group_by_stack_n=group_by_stack + ).table(sort_by=metric, row_limit=row_limit) + with open(f"{curr_trace_dir}/rank{rank}_key_averages.txt", "w") as f: + print( + key_avgs, file=f + ) + if rank == 0: + print(f"Saving profiling results to {curr_trace_dir}") + + #TODO: Is this necessary? + torch.distributed.barrier() + +@contextlib.contextmanager +def profiling_context(args, rank, *, global_step: int = 0): + enable_profiling = args.profile + + if enable_profiling: + + logger.info(f"Profiling enabled. Traces will be saved at {output_dir}") + + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + warmup = args["warmup_steps"] + active = args["active_steps"] + repeat = args["repeat"] + + if repeat == 0: + steps_per_cycle = args["profiling_frequency"] + wait = steps_per_cycle - (active + warmup) + else: + wait = args["wait_steps"] + steps_per_cycle = wait + warmup + active + assert ( + wait >= 0 + ), "profile_freq must be greater than or equal to warmup + active" + logger.info(f"Profiler schedule - steps per cycle: {steps_per_cycle} wait: {wait} warmup: {warmup} active: {active} repeat: {repeat if repeat !=0 else 'inf'}") + + profile_memory = args["export_memory_timeline"] + export_memory_timeline = args["export_memory_timeline"] + with_stack = args["with_stack"] or args["export_memory_timeline"] + with_shapes = args["with_shapes"] or export_memory_timeline + model_name = args["model_name"].split("/")[-1] + train_type = args["train_type"] + output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}" + callback = partial(trace_handler, rank=rank, + export_memory_timeline=export_memory_timeline, + output_dir=output_dir, + with_stack=with_stack, + group_by_input_shape=with_shapes, + group_by_stack=5 if with_stack else 0) + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=with_stack, + profile_memory=profile_memory, + with_shapes=with_shapes, + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat), + on_trace_ready=callback, + experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) if with_stack else None, + ) as torch_profiler: + yield torch_profiler + else: + class FakeProfiler: + """ + Fake profiler object when profiling is not enabled. + + """ + def __enter__(self): + return self + def __exit__(self, *args, **kwargs): + pass + + def step(self): + pass + + yield FakeProfiler() diff --git a/train.py b/train.py index 7ae5fb1..40dd51b 100644 --- a/train.py +++ b/train.py @@ -86,7 +86,7 @@ sys.path.append("./scripts") from lora import LORA from dora import BNBDORA, HQQDORA, DORALayer, MagnitudeLayer -from profiling_tools import trace_handler, FakeContext +from profiling_utils import profiling_context class Logger: def __init__(self, args, log_to="stdout", project_name="fsdp_qlora", entity=None, group=None, name=None, rank=0): @@ -482,382 +482,346 @@ def mlp_policy_fn(module): # Main function, run on each process def fsdp_main(local_rank:int, world_size:int, args:Dict): - #Profiler args - if args["profile"]: - #Profiler args - profile_memory = args["export_memory_timeline"] - export_trace = args["export_trace"] - export_memory_timeline = args["export_memory_timeline"] - with_stack = args["with_stack"] or args["export_memory_timeline"] - with_shapes = args["with_shapes"] or export_memory_timeline - model_name = args["model_name"].split("/")[-1] - train_type = args["train_type"] - prefix = f"{model_name}_{train_type}" - output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}-{local_rank}" - schedule = torch.profiler.schedule(wait=args["wait_steps"], warmup=args["warmup_steps"], active=args["active_steps"], repeat=args["repeat"]) - callback = functools.partial(trace_handler, - export_trace=export_trace, - export_memory_timeline=export_memory_timeline, - with_stack=with_stack, - group_by_stack=5 if with_stack else 0, - prefix=prefix, - out_dir=output_dir, - rank=local_rank, - verbose=args["verbose"]) - #Instantiate profiler - profiler_context = profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - with_stack=with_stack, - profile_memory=profile_memory, - use_cuda=True, - record_shapes=with_shapes, - schedule=schedule, - on_trace_ready=callback, - experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) if with_stack else None, # https://github.com/pytorch/pytorch/issues/100253 - ) - if args["max_steps"] > 0: - #Ensure enough steps to accommodate profiler schedule - args["max_steps"] = max(args["max_steps"], args["wait_steps"] + args["warmup_steps"] + args["active_steps"]) + + # Setup and initialize the process group + os.environ['MASTER_ADDR'] = args["master_addr"] + os.environ['MASTER_PORT'] = args["master_port"] + if 'SLURM_PROCID' in os.environ: + # assumes same number of GPUs per node. + rank = int(os.environ['SLURM_PROCID']) * torch.cuda.device_count() + local_rank else: - profiler_context = FakeContext() - with profiler_context as prof: - # Setup and initialize the process group - os.environ['MASTER_ADDR'] = args["master_addr"] - os.environ['MASTER_PORT'] = args["master_port"] - if 'SLURM_PROCID' in os.environ: - # assumes same number of GPUs per node. - rank = int(os.environ['SLURM_PROCID']) * torch.cuda.device_count() + local_rank - else: - rank = local_rank - - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(local_rank) - if args["use_cpu_offload"]: - torch.set_num_threads(os.cpu_count()//(min(world_size, torch.cuda.device_count()))) - - # Start logging - logger = Logger(args, log_to=args["log_to"], project_name=args["project_name"], - entity=args["entity"], group=args["group"], name=args["name"], rank=rank) - - # Timing stuff - init_start_event = torch.cuda.Event(enable_timing=True) - init_end_event = torch.cuda.Event(enable_timing=True) - - # model precision, qlora compute precison, and FSDP mixed precision policy. - # The Linear4Bit quant_storage dtype should always match the FSDP param_dtype. The compute_dtype should match the AMP compute dtype. - # MixedPrecision(param_dtype=fp32, reduce_dtype=fp32, buffer_dtype=fp32) uses `torch.amp.autocast` to control precision. - # limited qlora testing shows that fp16 only works with autocast while bf16 trains with both pure and autocast modes. - # TODO: test how often this holds for mp_fp16 - mp_policy = None - load_param_skip_names = [] - if args["precision"] == "bf16": - torch_dtype, compute_dtype = torch.bfloat16, torch.bfloat16 - elif args["precision"] == "fp32": - torch_dtype, compute_dtype = torch.float32, torch.float16 - elif args["precision"] == "fp16_autocast": - compute_dtype, torch_dtype = torch.float16, torch.float32 - mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32) - elif args["precision"] == "bf16_autocast": - compute_dtype, torch_dtype = torch.bfloat16, torch.float32 - mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32) - elif args["precision"] == "bf16_buffers_autocast": - compute_dtype, torch_dtype = torch.bfloat16, torch.bfloat16 - mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32) - load_param_skip_names = ['inv_freq'] - else: - raise ValueError("Invalid precision") + rank = local_rank + + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(local_rank) + if args["use_cpu_offload"]: + torch.set_num_threads(os.cpu_count()//(min(world_size, torch.cuda.device_count()))) + + # Start logging + logger = Logger(args, log_to=args["log_to"], project_name=args["project_name"], + entity=args["entity"], group=args["group"], name=args["name"], rank=rank) + + # Timing stuff + init_start_event = torch.cuda.Event(enable_timing=True) + init_end_event = torch.cuda.Event(enable_timing=True) + + # model precision, qlora compute precison, and FSDP mixed precision policy. + # The Linear4Bit quant_storage dtype should always match the FSDP param_dtype. The compute_dtype should match the AMP compute dtype. + # MixedPrecision(param_dtype=fp32, reduce_dtype=fp32, buffer_dtype=fp32) uses `torch.amp.autocast` to control precision. + # limited qlora testing shows that fp16 only works with autocast while bf16 trains with both pure and autocast modes. + # TODO: test how often this holds for mp_fp16 + mp_policy = None + load_param_skip_names = [] + if args["precision"] == "bf16": + torch_dtype, compute_dtype = torch.bfloat16, torch.bfloat16 + elif args["precision"] == "fp32": + torch_dtype, compute_dtype = torch.float32, torch.float16 + elif args["precision"] == "fp16_autocast": + compute_dtype, torch_dtype = torch.float16, torch.float32 + mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + elif args["precision"] == "bf16_autocast": + compute_dtype, torch_dtype = torch.bfloat16, torch.float32 + mp_policy = MixedPrecision(param_dtype=torch.float32, reduce_dtype=torch.float32, buffer_dtype=torch.float32) + elif args["precision"] == "bf16_buffers_autocast": + compute_dtype, torch_dtype = torch.bfloat16, torch.bfloat16 + mp_policy = MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.float32) + load_param_skip_names = ['inv_freq'] + else: + raise ValueError("Invalid precision") - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(args["model_name"]) - tokenizer.pad_token_id = tokenizer.eos_token_id # TODO check if it exists first + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(args["model_name"]) + tokenizer.pad_token_id = tokenizer.eos_token_id # TODO check if it exists first - # Set up dataloader - dataloader = get_dataloader(tokenizer, args) + # Set up dataloader + dataloader = get_dataloader(tokenizer, args) - # Create model - cfg = None - attn_impl = "sdpa" # torch 2.2 sdpa uses flash attn 2 - if rank == 0 or args['verbose']: - print("Creating model", rank) - if args["train_type"] in ["full", "lora", "custom_lora"]: - if (args["low_memory"] and rank == 0) or (not args["low_memory"]): - model = AutoModelForCausalLM.from_pretrained( - args["model_name"], - use_cache=False, - torch_dtype=torch_dtype, - _attn_implementation=attn_impl - ) - dtype = torch_dtype if args["precision"] == "bf16" else None - model.to(dtype=dtype, device="cpu" if args["low_memory"] else rank) - else: - cfg = AutoConfig.from_pretrained(args["model_name"]) - cfg.use_cache = False - cfg._attn_implementation = attn_impl - with init_empty_weights(): - model = AutoModelForCausalLM.from_config(cfg, torch_dtype=torch_dtype) - if args["precision"] == "bf16": - model.to(torch_dtype) - elif args["train_type"] in ["qlora", "custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora", "bnb_llama_pro", "hqq_llama_pro"]: # Our custom loading + # Create model + cfg = None + attn_impl = "sdpa" # torch 2.2 sdpa uses flash attn 2 + if rank == 0 or args['verbose']: + print("Creating model", rank) + if args["train_type"] in ["full", "lora", "custom_lora"]: + if (args["low_memory"] and rank == 0) or (not args["low_memory"]): + model = AutoModelForCausalLM.from_pretrained( + args["model_name"], + use_cache=False, + torch_dtype=torch_dtype, + _attn_implementation=attn_impl + ) + dtype = torch_dtype if args["precision"] == "bf16" else None + model.to(dtype=dtype, device="cpu" if args["low_memory"] else rank) + else: cfg = AutoConfig.from_pretrained(args["model_name"]) cfg.use_cache = False cfg._attn_implementation = attn_impl - skip_modules = ["lm_head"] - - if args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]: - llama_pro_path = Path(args["llama_pro_path"]) - num_original_layers, num_expanded_layers = llama_pro_path.name.split("blk_exp-")[1].split("-") - num_original_layers, num_expanded_layers = int(num_original_layers), int(num_expanded_layers) - total_new_layers = num_expanded_layers - num_original_layers - split = int(num_original_layers / (num_expanded_layers - num_original_layers)) - new_layer_ids = [split+(split+1)*n for n in range(total_new_layers)] - new_layer_names = [f"layers.{i}" for i in new_layer_ids] - skip_modules += [str(lid) for lid in new_layer_ids] - cfg.num_hidden_layers = num_expanded_layers - - # load model on meta device without calling init and replace nn.Linear with Linear4bit with init_empty_weights(): - model = AutoModelForCausalLM.from_config(cfg) - if args["train_type"] in ["hqq_lora", "hqq_dora", "hqq_llama_pro"]: - # TODO: Tune BaseQuantizeConfig. - quant_config = BaseQuantizeConfig(nbits=int(args["n_bits"]), group_size=64, quant_zero=True, - quant_scale=True, offload_meta=True, view_as_float=True) - model.model = replace_linear(model.model, HQQLinear, quant_config, device=rank, - compute_dtype=compute_dtype, del_orig=True, initialize=False, skip_modules=skip_modules) - HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP) - else: - model.model = replace_linear(model.model, Linear4bit, compute_dtype=compute_dtype, - quant_type='nf4', quant_storage=torch_dtype, skip_modules=skip_modules) - model.is_loaded_in_4bit = True - - # Grab the safetensors files that hold the weights - if args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]: - files = glob(str(llama_pro_path/"*.safetensors")) + model = AutoModelForCausalLM.from_config(cfg, torch_dtype=torch_dtype) + if args["precision"] == "bf16": + model.to(torch_dtype) + elif args["train_type"] in ["qlora", "custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora", "bnb_llama_pro", "hqq_llama_pro"]: # Our custom loading + cfg = AutoConfig.from_pretrained(args["model_name"]) + cfg.use_cache = False + cfg._attn_implementation = attn_impl + skip_modules = ["lm_head"] + + if args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]: + llama_pro_path = Path(args["llama_pro_path"]) + num_original_layers, num_expanded_layers = llama_pro_path.name.split("blk_exp-")[1].split("-") + num_original_layers, num_expanded_layers = int(num_original_layers), int(num_expanded_layers) + total_new_layers = num_expanded_layers - num_original_layers + split = int(num_original_layers / (num_expanded_layers - num_original_layers)) + new_layer_ids = [split+(split+1)*n for n in range(total_new_layers)] + new_layer_names = [f"layers.{i}" for i in new_layer_ids] + skip_modules += [str(lid) for lid in new_layer_ids] + cfg.num_hidden_layers = num_expanded_layers + + # load model on meta device without calling init and replace nn.Linear with Linear4bit + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(cfg) + if args["train_type"] in ["hqq_lora", "hqq_dora", "hqq_llama_pro"]: + # TODO: Tune BaseQuantizeConfig. + quant_config = BaseQuantizeConfig(nbits=int(args["n_bits"]), group_size=64, quant_zero=True, + quant_scale=True, offload_meta=True, view_as_float=True) + model.model = replace_linear(model.model, HQQLinear, quant_config, device=rank, + compute_dtype=compute_dtype, del_orig=True, initialize=False, skip_modules=skip_modules) + HQQLinear.set_backend(HQQBackend.ATEN_BACKPROP) else: + model.model = replace_linear(model.model, Linear4bit, compute_dtype=compute_dtype, + quant_type='nf4', quant_storage=torch_dtype, skip_modules=skip_modules) + model.is_loaded_in_4bit = True + + # Grab the safetensors files that hold the weights + if args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]: + files = glob(str(llama_pro_path/"*.safetensors")) + else: + try: + idx = hub.cached_file(args["model_name"], SAFE_WEIGHTS_INDEX_NAME) + files, _ = hub.get_checkpoint_shard_files(args["model_name"], idx) + except OSError: try: - idx = hub.cached_file(args["model_name"], SAFE_WEIGHTS_INDEX_NAME) - files, _ = hub.get_checkpoint_shard_files(args["model_name"], idx) - except OSError: - try: - # This means the model doesn't have a model.safetensors.index.json because it is not sharded - files = [] - files.append(hub.cached_file(args["model_name"], SAFE_WEIGHTS_NAME)) - except OSError as e: - # This means the model probably doesn't have a safetensors file - raise e - - # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly - # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage - def load_and_quantize_parallel(name_param, model, **kwargs): - name, param = name_param - load_and_quantize(model, name, param, **kwargs) - - quant_method = "hqq" if args["train_type"] in ["hqq_lora", "hqq_dora", "hqq_llama_pro"] else "bnb" - param_count = sum((p.numel() for n,p in model.named_parameters())) - if rank == 0 or args['verbose']: - print("Loading model", rank) - if rank == 0 and args['verbose']: - print(f"Total model params: {param_count}") - - n_workers = n_loading_workers(quant_method, param_count) if args["loading_workers"]==-1 else args["loading_workers"] - if rank == 0 and args['verbose']: - print(f"Using n_workers: {n_workers} for loading") - - start = time.time() - for filename in tqdm(files, desc="Loading & Quantizing Model Shards", disable=rank!=0, position=0): - weights = safetensors.torch.load_file(filename) - parallel(load_and_quantize_parallel, iter(weights.items()), n_workers=n_workers, threadpool=True, - model=model, dtype=torch_dtype, device=local_rank, skip_names=load_param_skip_names, - to_cpu=(args["low_memory"] and rank==0), to_meta=(args["low_memory"] and rank!=0), - verbose=args["verbose"], quant_method=quant_method, is_dora=(args["train_type"] in ["hqq_dora", "bnb_dora"])) - - if rank == 0 and args["verbose"]: - print(f"Loaded model weights in {time.time()-start:.3f} seconds") - # cleanup any extra memory usage from parallel loading - torch.cuda.empty_cache() + # This means the model doesn't have a model.safetensors.index.json because it is not sharded + files = [] + files.append(hub.cached_file(args["model_name"], SAFE_WEIGHTS_NAME)) + except OSError as e: + # This means the model probably doesn't have a safetensors file + raise e + + # Load in the weights, using our custom load_and_quantize method which quantizes Params4bit on the fly + # and then places each layer on CPU or meta if using low_memory to minimize GPU memory usage + def load_and_quantize_parallel(name_param, model, **kwargs): + name, param = name_param + load_and_quantize(model, name, param, **kwargs) + + quant_method = "hqq" if args["train_type"] in ["hqq_lora", "hqq_dora", "hqq_llama_pro"] else "bnb" + param_count = sum((p.numel() for n,p in model.named_parameters())) if rank == 0 or args['verbose']: - print(f"Rank {rank}: Model created: {torch.cuda.memory_reserved(local_rank)/2**30:.3f} GiB") - - - # PEFT setup (LoRA and QLoRA) - if args["train_type"] in ["lora", "qlora"]: - from peft import get_peft_model, LoraConfig, TaskType - - peft_config = LoraConfig( - task_type=TaskType.CAUSAL_LM, inference_mode=False, - r=args["lora_rank"], - lora_alpha=args["lora_alpha"], - lora_dropout=args["lora_dropout"], - target_modules=args["lora_target_modules"], - ) - # PEFT will move quant_state to meta device, so this method prevents that - # from happening by replacing quant_state.to with a dummy function - if rank!=0 and args["low_memory"]: - setup_quantized_meta_for_peft(model) - - model = get_peft_model(model, peft_config) - - if rank==0: - model.print_trainable_parameters() - elif args['low_memory']: - # And then setup_quantized_peft_meta_for_training sets quant_state.to back to normal - setup_quantized_peft_meta_for_training(model) - elif args["train_type"] in ["custom_qlora", "custom_lora", "hqq_lora", "hqq_dora", "bnb_dora"]: - if args["train_type"] == "hqq_dora": - print("Using HQQDORA", rank) - lora_cls = HQQDORA - elif args["train_type"] == "bnb_dora": - print("Using BNB DORA", rank) - lora_cls = BNBDORA - else: - print("Using LORA", rank) - lora_cls = LORA - # Create LORA layers. - for name, _ in model.named_modules(): - module_key, _, value_key = name.rpartition('.') - if value_key in args['lora_target_modules']: - m = model.get_submodule(name) - qlora_layer = lora_cls(m, args["lora_rank"], args["lora_alpha"], args["lora_dropout"]) - parent_module = model.get_submodule(module_key) - setattr(parent_module, value_key, qlora_layer) - for n,p in model.named_parameters(): - if any([lora_name in n for lora_name in ['lora_AB', 'lora_A', 'lora_B', 'magnitude']]): - p.requires_grad = True - if args['verbose']: - print("Trainable LORA layer", n) - else: - p.requires_grad = False - if rank == 0 or args['verbose']: - print(f"Rank {rank}: LoRA layers added: {torch.cuda.memory_reserved(local_rank)/2**30:.3f} GiB") - - elif args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]: - for n,p in model.named_parameters(): - if any([layer_name in n for layer_name in new_layer_names]): - p.requires_grad = True - if args['verbose']: - print("Trainable Llama-Pro layer", n) - else: - p.requires_grad = False - - if args["log_to"] == 'wandb': - logger.log({"memory/allocated_after_model_created": torch.cuda.memory_allocated(local_rank)}, rank) - logger.log({"memory/reserved_after_model_creation": torch.cuda.memory_reserved(local_rank)}, rank) - - - # Wrap model with llama-recipies or custom LoRA policy - my_auto_wrap_policy = get_wrapping_policy(custom_policy=args["train_type"] in ["custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora"], - vanilla_policy=args["train_type"] in ["full", "bnb_llama_pro", "hqq_llama_pro"]) + print("Loading model", rank) + if rank == 0 and args['verbose']: + print(f"Total model params: {param_count}") - if rank == 0 or args['verbose']: - print("Wrapping model w/ FSDP", rank) - - if args["sharding_strategy"] == "full_shard": - sharding_strategy = ShardingStrategy.FULL_SHARD - elif args["sharding_strategy"] == "shard_grad_op": - sharding_strategy = ShardingStrategy.SHARD_GRAD_OP - elif args["sharding_strategy"] == "ddp": - sharding_strategy = ShardingStrategy.NO_SHARD - elif args["sharding_strategy"] == "hybrid_full_shard": - sharding_strategy = ShardingStrategy.HYBRID_SHARD - elif args["sharding_strategy"] == "hybrid_shard_grad_op": - sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 - else: - raise ValueError("Invalid FSDP sharding strategy") - - model = FSDP( - model, - sharding_strategy=sharding_strategy, - auto_wrap_policy=my_auto_wrap_policy, - # backward_prefetch=None, #BackwardPrefetch.BACKWARD_PRE - use_orig_params=False, - cpu_offload=CPUOffload(offload_params=True) if args["use_cpu_offload"] else None, - limit_all_gathers=True, # See https://github.com/pytorch/pytorch/issues/91165 - device_id=torch.cuda.current_device(), - sync_module_states=args["low_memory"], - param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) - if (rank!=0 and args["low_memory"]) else None, # TODO note about meta device and why we need this - mixed_precision=mp_policy, + n_workers = n_loading_workers(quant_method, param_count) if args["loading_workers"]==-1 else args["loading_workers"] + if rank == 0 and args['verbose']: + print(f"Using n_workers: {n_workers} for loading") + + start = time.time() + for filename in tqdm(files, desc="Loading & Quantizing Model Shards", disable=rank!=0, position=0): + weights = safetensors.torch.load_file(filename) + parallel(load_and_quantize_parallel, iter(weights.items()), n_workers=n_workers, threadpool=True, + model=model, dtype=torch_dtype, device=local_rank, skip_names=load_param_skip_names, + to_cpu=(args["low_memory"] and rank==0), to_meta=(args["low_memory"] and rank!=0), + verbose=args["verbose"], quant_method=quant_method, is_dora=(args["train_type"] in ["hqq_dora", "bnb_dora"])) + + if rank == 0 and args["verbose"]: + print(f"Loaded model weights in {time.time()-start:.3f} seconds") + # cleanup any extra memory usage from parallel loading + torch.cuda.empty_cache() + if rank == 0 or args['verbose']: + print(f"Rank {rank}: Model created: {torch.cuda.memory_reserved(local_rank)/2**30:.3f} GiB") + + + # PEFT setup (LoRA and QLoRA) + if args["train_type"] in ["lora", "qlora"]: + from peft import get_peft_model, LoraConfig, TaskType + + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, inference_mode=False, + r=args["lora_rank"], + lora_alpha=args["lora_alpha"], + lora_dropout=args["lora_dropout"], + target_modules=args["lora_target_modules"], ) + # PEFT will move quant_state to meta device, so this method prevents that + # from happening by replacing quant_state.to with a dummy function + if rank!=0 and args["low_memory"]: + setup_quantized_meta_for_peft(model) + + model = get_peft_model(model, peft_config) + + if rank==0: + model.print_trainable_parameters() + elif args['low_memory']: + # And then setup_quantized_peft_meta_for_training sets quant_state.to back to normal + setup_quantized_peft_meta_for_training(model) + elif args["train_type"] in ["custom_qlora", "custom_lora", "hqq_lora", "hqq_dora", "bnb_dora"]: + if args["train_type"] == "hqq_dora": + print("Using HQQDORA", rank) + lora_cls = HQQDORA + elif args["train_type"] == "bnb_dora": + print("Using BNB DORA", rank) + lora_cls = BNBDORA + else: + print("Using LORA", rank) + lora_cls = LORA + # Create LORA layers. + for name, _ in model.named_modules(): + module_key, _, value_key = name.rpartition('.') + if value_key in args['lora_target_modules']: + m = model.get_submodule(name) + qlora_layer = lora_cls(m, args["lora_rank"], args["lora_alpha"], args["lora_dropout"]) + parent_module = model.get_submodule(module_key) + setattr(parent_module, value_key, qlora_layer) + for n,p in model.named_parameters(): + if any([lora_name in n for lora_name in ['lora_AB', 'lora_A', 'lora_B', 'magnitude']]): + p.requires_grad = True + if args['verbose']: + print("Trainable LORA layer", n) + else: + p.requires_grad = False if rank == 0 or args['verbose']: - print(f"Rank {rank}: Wrapped model: {torch.cuda.memory_reserved(local_rank)/2**30:.3f} GiB") - if args["log_to"] == 'wandb': - logger.log({"memory/allocated_after_model_wrap": torch.cuda.memory_allocated(local_rank)}, rank) - logger.log({"memory/reserved_after_model_wrap": torch.cuda.memory_reserved(local_rank)}, rank) + print(f"Rank {rank}: LoRA layers added: {torch.cuda.memory_reserved(local_rank)/2**30:.3f} GiB") + + elif args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]: + for n,p in model.named_parameters(): + if any([layer_name in n for layer_name in new_layer_names]): + p.requires_grad = True + if args['verbose']: + print("Trainable Llama-Pro layer", n) + else: + p.requires_grad = False + + if args["log_to"] == 'wandb': + logger.log({"memory/allocated_after_model_created": torch.cuda.memory_allocated(local_rank)}, rank) + logger.log({"memory/reserved_after_model_creation": torch.cuda.memory_reserved(local_rank)}, rank) + + + # Wrap model with llama-recipies or custom LoRA policy + my_auto_wrap_policy = get_wrapping_policy(custom_policy=args["train_type"] in ["custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora"], + vanilla_policy=args["train_type"] in ["full", "bnb_llama_pro", "hqq_llama_pro"]) + + if rank == 0 or args['verbose']: + print("Wrapping model w/ FSDP", rank) + + if args["sharding_strategy"] == "full_shard": + sharding_strategy = ShardingStrategy.FULL_SHARD + elif args["sharding_strategy"] == "shard_grad_op": + sharding_strategy = ShardingStrategy.SHARD_GRAD_OP + elif args["sharding_strategy"] == "ddp": + sharding_strategy = ShardingStrategy.NO_SHARD + elif args["sharding_strategy"] == "hybrid_full_shard": + sharding_strategy = ShardingStrategy.HYBRID_SHARD + elif args["sharding_strategy"] == "hybrid_shard_grad_op": + sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 + else: + raise ValueError("Invalid FSDP sharding strategy") + + model = FSDP( + model, + sharding_strategy=sharding_strategy, + auto_wrap_policy=my_auto_wrap_policy, + # backward_prefetch=None, #BackwardPrefetch.BACKWARD_PRE + use_orig_params=False, + cpu_offload=CPUOffload(offload_params=True) if args["use_cpu_offload"] else None, + limit_all_gathers=True, # See https://github.com/pytorch/pytorch/issues/91165 + device_id=torch.cuda.current_device(), + sync_module_states=args["low_memory"], + param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) + if (rank!=0 and args["low_memory"]) else None, # TODO note about meta device and why we need this + mixed_precision=mp_policy, + ) + if rank == 0 or args['verbose']: + print(f"Rank {rank}: Wrapped model: {torch.cuda.memory_reserved(local_rank)/2**30:.3f} GiB") + if args["log_to"] == 'wandb': + logger.log({"memory/allocated_after_model_wrap": torch.cuda.memory_allocated(local_rank)}, rank) + logger.log({"memory/reserved_after_model_wrap": torch.cuda.memory_reserved(local_rank)}, rank) - # Synchronize at the start - dist.barrier() + # Synchronize at the start + dist.barrier() - # Apply activation checkpointing - if args["use_gradient_checkpointing"]: - if args['reentrant_checkpointing']: - model.enable_input_require_grads() - non_reentrant_wrapper = functools.partial( - checkpoint_wrapper, - checkpoint_impl=CheckpointImpl.REENTRANT if args['reentrant_checkpointing'] else CheckpointImpl.NO_REENTRANT, + # Apply activation checkpointing + if args["use_gradient_checkpointing"]: + if args['reentrant_checkpointing']: + model.enable_input_require_grads() + non_reentrant_wrapper = functools.partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.REENTRANT if args['reentrant_checkpointing'] else CheckpointImpl.NO_REENTRANT, - ) + ) - check_fn = lambda submodule: isinstance(submodule, (LlamaDecoderLayer, MistralDecoderLayer)) - if rank == 0 or args['verbose']: - print("Applying activation checkpointing", rank) - apply_activation_checkpointing( - model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn - ) + check_fn = lambda submodule: isinstance(submodule, (LlamaDecoderLayer, MistralDecoderLayer)) + if rank == 0 or args['verbose']: + print("Applying activation checkpointing", rank) + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + ) - if args["use_activation_cpu_offload"]: - if rank == 0 or args['verbose']: - print("Applying activation offloading", rank) - model = offload_wrapper(model) + if args["use_activation_cpu_offload"]: + if rank == 0 or args['verbose']: + print("Applying activation offloading", rank) + model = offload_wrapper(model) - if rank == 0 and args['verbose']: - print("Config:") - print(cfg) - print("Model:") - print(model) - print("Starting training") + if rank == 0 and args['verbose']: + print("Config:") + print(cfg) + print("Model:") + print(model) + print("Starting training") - # Create the optimizer - optimizer = get_optimizer(model, args) + # Create the optimizer + optimizer = get_optimizer(model, args) - # LR scheduler. - gradient_accumulation_steps = max(1, args['gradient_accumulation_steps']) - lr_scheduler, num_training_steps = get_lr_scheduler(optimizer, dataloader, gradient_accumulation_steps, args) + # LR scheduler. + gradient_accumulation_steps = max(1, args['gradient_accumulation_steps']) + lr_scheduler, num_training_steps = get_lr_scheduler(optimizer, dataloader, gradient_accumulation_steps, args) - # Sanity check: see what parameters the optimizer has and which require grad: - if rank == 0 and args['verbose']: - print("Optimizer params:") - for group in optimizer.param_groups: - for param in group['params']: - print(f"Shape: {param.shape}, Requires Grad: {param.requires_grad}, Dtype: {param.dtype}") + # Sanity check: see what parameters the optimizer has and which require grad: + if rank == 0 and args['verbose']: + print("Optimizer params:") + for group in optimizer.param_groups: + for param in group['params']: + print(f"Shape: {param.shape}, Requires Grad: {param.requires_grad}, Dtype: {param.dtype}") - # Autocast for mixed precision with fp16/bf16 compute types with fp32 params - if args["precision"] in ["fp16_autocast", "bf16_autocast", "bf16_buffers_autocast"]: - autocast = torch.cuda.amp.autocast(enabled=True, dtype=compute_dtype) - else: - autocast = nullcontext() - scaler = ShardedGradScaler() if args["precision"] == "fp16_autocast" else None - scale_grads = scaler is not None + # Autocast for mixed precision with fp16/bf16 compute types with fp32 params + if args["precision"] in ["fp16_autocast", "bf16_autocast", "bf16_buffers_autocast"]: + autocast = torch.cuda.amp.autocast(enabled=True, dtype=compute_dtype) + else: + autocast = nullcontext() + scaler = ShardedGradScaler() if args["precision"] == "fp16_autocast" else None + scale_grads = scaler is not None - if rank == 0: - print("Total Training Steps:", num_training_steps) - memory_stats = [] - progress_bar = tqdm(range(num_training_steps), disable=rank != 0) - init_start_event.record() - log_loss, log_lr = 0.0, -1 - # Reset peak memory to track that - torch.cuda.reset_peak_memory_stats(local_rank) + if rank == 0: + print("Total Training Steps:", num_training_steps) + memory_stats = [] + progress_bar = tqdm(range(num_training_steps), disable=rank != 0) + init_start_event.record() + log_loss, log_lr = 0.0, -1 + # Reset peak memory to track that + torch.cuda.reset_peak_memory_stats(local_rank) + with profiling_context(args, rank=rank) as prof: for epoch in range(args['num_epochs']): update_progress_bar(progress_bar, epoch, log_loss, log_lr, rank) model.train() ddp_loss = torch.zeros(2).to(local_rank) for batch_idx, batch in enumerate(dataloader): - + prof.step_num = f"epoch{epoch}-batch{batch_idx}" + accumulate_grads = (batch_idx+1) % gradient_accumulation_steps == 0 # Prevent gradient syncing until update step if using no_sync option. @@ -961,7 +925,6 @@ def load_and_quantize_parallel(name_param, model, **kwargs): if rank == 0: print(f"Batch idx {batch_idx}") - prof.step() if args["max_steps"] > 0 and batch_idx > args["max_steps"]: @@ -979,72 +942,69 @@ def load_and_quantize_parallel(name_param, model, **kwargs): logger.log({"memory/allocated_peak": peak_allocated_memory}, rank) logger.log({"memory/reserved_peak": peak_reserved_memory}, rank) - # Synchronize at the end and record time - init_end_event.record() - dist.barrier() - torch.cuda.synchronize() + # Synchronize at the end and record time + init_end_event.record() + dist.barrier() + torch.cuda.synchronize() - if rank == 0: - print("Finished training", rank) + if rank == 0: + print("Finished training", rank) - # Print time, model, & memory stats - time_taken = init_start_event.elapsed_time(init_end_event) / 1000 - dist.barrier() - torch.cuda.synchronize() + # Print time, model, & memory stats + time_taken = init_start_event.elapsed_time(init_end_event) / 1000 + dist.barrier() + torch.cuda.synchronize() + if rank == 0: + print(f"CUDA event elapsed time: {time_taken} sec") + logger.log({"time_taken": time_taken}, rank) + for line in memory_stats: + print(line) + + # End logging + logger.finish(rank=rank) + + # Save model - ref: https://github.com/pytorch/pytorch/issues/98823 + # HQQLinear custom state_dict() method causes issues when saving. + # Model is saved fine when `state_dict()` method is removed. + # Non param/buffer types are not saved with FSDP. + # It might be better to just save the trained lora layers. + # summon_full_params on lora layers and save. + if args["save_model"]: if rank == 0: - print(f"CUDA event elapsed time: {time_taken} sec") - logger.log({"time_taken": time_taken}, rank) - for line in memory_stats: - print(line) - - # End logging - logger.finish(rank=rank) - - # Save model - ref: https://github.com/pytorch/pytorch/issues/98823 - # HQQLinear custom state_dict() method causes issues when saving. - # Model is saved fine when `state_dict()` method is removed. - # Non param/buffer types are not saved with FSDP. - # It might be better to just save the trained lora layers. - # summon_full_params on lora layers and save. - if args["save_model"]: - if rank == 0: - os.makedirs(args["output_dir"], exist_ok=True) - dist.barrier() - save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) - if args["train_type"] in ["custom_lora", "custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora", "bnb_llama_pro", "hqq_llama_pro"]: - cpu_state_dict = {} - if args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]: - trainable_fsdp_modules =[(n,m) for n,m in model.named_modules() if n.endswith(tuple(new_layer_names))] - else: - trainable_fsdp_modules = [(n,m) for n,m in model.named_modules() if n.endswith(('lora_AB', 'dora_layer', 'magnitude_layer'))] - for prefix, module in trainable_fsdp_modules: - prefix = (prefix.replace("_fsdp_wrapped_module.", "") - .replace("_checkpoint_wrapped_module.", "") - .replace("_offload_wrapped_module.", "")) - if args['verbose']: print(f"Saving {prefix}") - with FSDP.state_dict_type(module, StateDictType.FULL_STATE_DICT, save_policy): - cpu_state_dict = {**cpu_state_dict, **{f"{prefix}.{k}":v for k,v in module.state_dict().items()}} - dist.barrier() - torch.cuda.synchronize() + os.makedirs(args["output_dir"], exist_ok=True) + dist.barrier() + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + if args["train_type"] in ["custom_lora", "custom_qlora", "hqq_lora", "hqq_dora", "bnb_dora", "bnb_llama_pro", "hqq_llama_pro"]: + cpu_state_dict = {} + if args["train_type"] in ["bnb_llama_pro", "hqq_llama_pro"]: + trainable_fsdp_modules =[(n,m) for n,m in model.named_modules() if n.endswith(tuple(new_layer_names))] + else: + trainable_fsdp_modules = [(n,m) for n,m in model.named_modules() if n.endswith(('lora_AB', 'dora_layer', 'magnitude_layer'))] + for prefix, module in trainable_fsdp_modules: + prefix = (prefix.replace("_fsdp_wrapped_module.", "") + .replace("_checkpoint_wrapped_module.", "") + .replace("_offload_wrapped_module.", "")) + if args['verbose']: print(f"Saving {prefix}") + with FSDP.state_dict_type(module, StateDictType.FULL_STATE_DICT, save_policy): + cpu_state_dict = {**cpu_state_dict, **{f"{prefix}.{k}":v for k,v in module.state_dict().items()}} + dist.barrier() + torch.cuda.synchronize() + if rank==0: + print("Saving trained LoRA weights.") + save_file(cpu_state_dict, os.path.join(args["output_dir"], "model_state_dict.safetensors")) + print("Done", rank) + else: + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): + cpu_state_dict = model.state_dict() if rank==0: - print("Saving trained LoRA weights.") + print("Saving full model weights.") save_file(cpu_state_dict, os.path.join(args["output_dir"], "model_state_dict.safetensors")) print("Done", rank) - else: - with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): - cpu_state_dict = model.state_dict() - if rank==0: - print("Saving full model weights.") - save_file(cpu_state_dict, os.path.join(args["output_dir"], "model_state_dict.safetensors")) - print("Done", rank) - dist.barrier() # Stop other processes ending while model saving - probably not needed? + dist.barrier() # Stop other processes ending while model saving - probably not needed? - # Clean up - dist.destroy_process_group() - # if args["profiling_output"]: - # prof.export_stacks(path = f"{args['profiling_output']}_{local_rank}.txt", - # metric = "self_cuda_time_total") + # Clean up + dist.destroy_process_group() def validate_args(args): if args["n_bits"] != 4 and args["train_type"] not in ["hqq_lora", "hqq_dora", "hqq_llama_pro"]: @@ -1094,16 +1054,17 @@ def fsdp_qlora( group: str = None, # For wandb logging entity: str = None, # For wandb logging n_bits: int = 4, # passed to hqq - profile: bool = False, # Enable PyTorch profiler - profiling_output: str = None, # Output directory for torch.profiler artifacts - with_stack: bool = False, # Output stacks for profiling - with_shapes: bool = False, # Output shapes for profiling - export_trace: bool = True, # Output trace for profiling - export_memory_timeline: bool = False, # Output memory timelinefor profiling - wait_steps: int = 1, # Wait steps when running profiler - warmup_steps: int = 0, # Warmup steps when running profiler - active_steps: int = 5, # Active steps when running profiler - repeat: int = 1, #Number of profiler cycles (wait + warmup + active) + profile: bool_arg = False, # Whether to profile with torch.profiler + profiling_output: str = "profiles", # Output file prefix for profiling + with_stack: bool_arg = False, # Output stacks for profiling. Note that setting export_memory_timeline will automatically export traces since `with_stack` must be true to profile memory. + with_shapes: bool_arg = False, # Output shapes for profiling. Can impact performance. Note that setting export_memory_timeline will automatically export traces since `with_shapes` must be true to profile memory. + export_trace: bool_arg = True, # Output trace for profiling + export_memory_timeline: bool_arg = False, # Output memory timelinefor profiling + wait_steps: int = 1, # Wait steps when running profiler. Only used if repeat != 0. + warmup_steps: int = 1, # Warmup steps when running profiler + active_steps: int = 2, # Active steps when running profiler + repeat: int = 0, #Number of profiler cycles (wait + warmup + active) if > 0, else repeats forever + profiling_frequency: int = 10, # Profiling frequency in steps. Only used if repeat == 0, in which case wait_steps will be set to profiling_frequency - (warmup_steps + active_steps) such that the effective cycle length == profiling_frequency max_steps: int = -1, # Max number of training steps (in units of batches) per epoch. -1 means no max_steps, otherwise training loop breaks after `max_steps` each epoch. ): """ @@ -1243,13 +1204,14 @@ def main( profile: bool_arg = False, # Whether to profile with torch.profiler profiling_output: str = "profiles", # Output file prefix for profiling with_stack: bool_arg = False, # Output stacks for profiling. Note that setting export_memory_timeline will automatically export traces since `with_stack` must be true to profile memory. - with_shapes: bool_arg = False, # Output shapes for profiling. Note that setting export_memory_timeline will automatically export traces since `with_shapes` must be true to profile memory. + with_shapes: bool_arg = False, # Output shapes for profiling. Can impact performance. Note that setting export_memory_timeline will automatically export traces since `with_shapes` must be true to profile memory. export_trace: bool_arg = True, # Output trace for profiling export_memory_timeline: bool_arg = False, # Output memory timelinefor profiling - wait_steps: int = 1, # Wait steps when running profiler + wait_steps: int = 1, # Wait steps when running profiler. Only used if repeat != 0. warmup_steps: int = 1, # Warmup steps when running profiler active_steps: int = 2, # Active steps when running profiler - repeat: int = 1, #Number of profiler cycles (wait + warmup + active) + repeat: int = 0, #Number of profiler cycles (wait + warmup + active) if > 0, else repeats forever + profiling_frequency: int = 10, # Profiling frequency in steps. Only used if repeat == 0, in which case wait_steps will be set to profiling_frequency - (warmup_steps + active_steps) such that the effective cycle length == profiling_frequency max_steps: int = -1, # Max number of training steps (in units of batches) per epoch. -1 means no max_steps, otherwise training loop breaks after `max_steps` each epoch. ): fsdp_qlora(**locals()) From 373114d759af449c339a17757453885124aec597 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 3 Jun 2024 19:55:39 +0000 Subject: [PATCH 08/16] fix step_num --- profile.sh | 10 +++++++--- profiling_utils.py | 20 ++++++++++---------- train.py | 2 +- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/profile.sh b/profile.sh index 14a917d..de79f22 100755 --- a/profile.sh +++ b/profile.sh @@ -47,8 +47,10 @@ #**IMPORTANT** There are issues with recording stack traces and exporting traces simultaneously (see this [issue](https://github.com/pytorch/pytorch/issues/113564)) depending on `python` version. The only combination I was able to get both to work at the same time was with `python=3.11.9` and `torch=2.3.0`. #Tested on `python=3.11.9 and torch=2.3.0`` +#"meta-llama/Llama-2-7b-hf" + python train.py \ ---model_name "meta-llama/Llama-2-7b-hf" \ +--model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ --gradient_accumulation_steps 2 \ --batch_size 1 \ --context_length 256 \ @@ -62,5 +64,7 @@ python train.py \ --dataset dummy \ --profile true \ --export_trace true \ ---export_memory_timeline true \ ---max_steps 10 +--export_memory_timeline false \ +--with_stack true \ +--max_steps 10 \ +--repeat 1 diff --git a/profiling_utils.py b/profiling_utils.py index cabf1b7..aaaf94d 100644 --- a/profiling_utils.py +++ b/profiling_utils.py @@ -17,8 +17,8 @@ #adapted from https://github.com/pytorch/torchtitan -def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_cuda_time_total", with_stack=True, group_by_stack=0, group_by_input_shapes=False, row_limit=25): - curr_trace_dir_name = str(prof.step_num) +def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_cuda_time_total", with_stack=True, group_by_stack=0, group_by_input_shape=False, row_limit=25): + curr_trace_dir_name = "iteration_" + str(prof.step_num) curr_trace_dir = os.path.join(output_dir, curr_trace_dir_name) if not os.path.exists(curr_trace_dir): os.makedirs(curr_trace_dir, exist_ok=True) @@ -43,7 +43,7 @@ def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_c #Export event averages key_avgs = prof.key_averages( - group_by_input_shape=group_by_input_shapes, group_by_stack_n=group_by_stack + group_by_input_shape=group_by_input_shape, group_by_stack_n=group_by_stack ).table(sort_by=metric, row_limit=row_limit) with open(f"{curr_trace_dir}/rank{rank}_key_averages.txt", "w") as f: print( @@ -56,10 +56,13 @@ def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_c torch.distributed.barrier() @contextlib.contextmanager -def profiling_context(args, rank, *, global_step: int = 0): - enable_profiling = args.profile - +def profiling_context(args, rank): + enable_profiling = args["profile"] + if enable_profiling: + output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}" + model_name = args["model_name"].split("/")[-1] + train_type = args["train_type"] logger.info(f"Profiling enabled. Traces will be saved at {output_dir}") @@ -85,9 +88,6 @@ def profiling_context(args, rank, *, global_step: int = 0): export_memory_timeline = args["export_memory_timeline"] with_stack = args["with_stack"] or args["export_memory_timeline"] with_shapes = args["with_shapes"] or export_memory_timeline - model_name = args["model_name"].split("/")[-1] - train_type = args["train_type"] - output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}" callback = partial(trace_handler, rank=rank, export_memory_timeline=export_memory_timeline, output_dir=output_dir, @@ -102,7 +102,7 @@ def profiling_context(args, rank, *, global_step: int = 0): ], with_stack=with_stack, profile_memory=profile_memory, - with_shapes=with_shapes, + record_shapes=with_shapes, schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat), on_trace_ready=callback, experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) if with_stack else None, diff --git a/train.py b/train.py index 40dd51b..bec349f 100644 --- a/train.py +++ b/train.py @@ -820,7 +820,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs): ddp_loss = torch.zeros(2).to(local_rank) for batch_idx, batch in enumerate(dataloader): - prof.step_num = f"epoch{epoch}-batch{batch_idx}" + #prof.step_num = f"epoch{epoch}-batch{batch_idx}" accumulate_grads = (batch_idx+1) % gradient_accumulation_steps == 0 From 4f9066c03bb8693bef4b9b151c1d2edf1e13d13c Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 3 Jun 2024 20:39:04 +0000 Subject: [PATCH 09/16] fix chrome exporting --- profile.sh | 5 +++-- profiling_utils.py | 18 +++++++++++++----- train.py | 2 +- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/profile.sh b/profile.sh index de79f22..a7b8450 100755 --- a/profile.sh +++ b/profile.sh @@ -66,5 +66,6 @@ python train.py \ --export_trace true \ --export_memory_timeline false \ --with_stack true \ ---max_steps 10 \ ---repeat 1 +--max_steps 50 \ +--repeat 2 \ +--profiling_output llama-test diff --git a/profiling_utils.py b/profiling_utils.py index aaaf94d..9be5544 100644 --- a/profiling_utils.py +++ b/profiling_utils.py @@ -11,6 +11,8 @@ import torch import torch.distributed from functools import partial +import shutil +from torch.profiler import tensorboard_trace_handler WARMUP = 3 logger = logging.getLogger() @@ -26,7 +28,13 @@ def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_c #Export chrome / tensorboard trace logger.info(f"Dumping traces at step {prof.step_num}") begin = time.monotonic() - prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") + + #Use tensorboard trace handler rather than directly exporting chrome traces since + #tensorboard doesn't seem to be able to parse traces when with prof.export_chrome_trace + exporter = tensorboard_trace_handler(curr_trace_dir, worker_name=f"rank{rank}", use_gzip=True) + exporter(prof) + #prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") + logger.info( f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds" ) @@ -60,14 +68,14 @@ def profiling_context(args, rank): enable_profiling = args["profile"] if enable_profiling: - output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}" model_name = args["model_name"].split("/")[-1] train_type = args["train_type"] - - logger.info(f"Profiling enabled. Traces will be saved at {output_dir}") - + output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}" + if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) + + logger.info(f"Profiling enabled. Traces will be saved at {output_dir}") warmup = args["warmup_steps"] active = args["active_steps"] diff --git a/train.py b/train.py index bec349f..57712c8 100644 --- a/train.py +++ b/train.py @@ -820,7 +820,6 @@ def load_and_quantize_parallel(name_param, model, **kwargs): ddp_loss = torch.zeros(2).to(local_rank) for batch_idx, batch in enumerate(dataloader): - #prof.step_num = f"epoch{epoch}-batch{batch_idx}" accumulate_grads = (batch_idx+1) % gradient_accumulation_steps == 0 @@ -1056,6 +1055,7 @@ def fsdp_qlora( n_bits: int = 4, # passed to hqq profile: bool_arg = False, # Whether to profile with torch.profiler profiling_output: str = "profiles", # Output file prefix for profiling + overwrite_profiling_output: bool = True, # Overwrite output directory with_stack: bool_arg = False, # Output stacks for profiling. Note that setting export_memory_timeline will automatically export traces since `with_stack` must be true to profile memory. with_shapes: bool_arg = False, # Output shapes for profiling. Can impact performance. Note that setting export_memory_timeline will automatically export traces since `with_shapes` must be true to profile memory. export_trace: bool_arg = True, # Output trace for profiling From 5ca555194493c8703f898a249192bb7a29312319 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 3 Jun 2024 20:39:54 +0000 Subject: [PATCH 10/16] remove deprecated utils --- profiling_utils.py | 133 --------------------------------------------- 1 file changed, 133 deletions(-) delete mode 100644 profiling_utils.py diff --git a/profiling_utils.py b/profiling_utils.py deleted file mode 100644 index 9be5544..0000000 --- a/profiling_utils.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import contextlib -import os -import time -import logging -import torch -import torch.distributed -from functools import partial -import shutil -from torch.profiler import tensorboard_trace_handler -WARMUP = 3 - -logger = logging.getLogger() - -#adapted from https://github.com/pytorch/torchtitan - -def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_cuda_time_total", with_stack=True, group_by_stack=0, group_by_input_shape=False, row_limit=25): - curr_trace_dir_name = "iteration_" + str(prof.step_num) - curr_trace_dir = os.path.join(output_dir, curr_trace_dir_name) - if not os.path.exists(curr_trace_dir): - os.makedirs(curr_trace_dir, exist_ok=True) - - #Export chrome / tensorboard trace - logger.info(f"Dumping traces at step {prof.step_num}") - begin = time.monotonic() - - #Use tensorboard trace handler rather than directly exporting chrome traces since - #tensorboard doesn't seem to be able to parse traces when with prof.export_chrome_trace - exporter = tensorboard_trace_handler(curr_trace_dir, worker_name=f"rank{rank}", use_gzip=True) - exporter(prof) - #prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") - - logger.info( - f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds" - ) - - #Construct the memory timeline file. - if export_memory_timeline: - prof.export_memory_timeline( - f"{curr_trace_dir}/rank{rank}_memory-timeline.html" - ) - - #Dump stack traces - if with_stack: - prof.export_stacks(f"{curr_trace_dir}/rank{rank}_stacks.txt", metric=metric) - - #Export event averages - key_avgs = prof.key_averages( - group_by_input_shape=group_by_input_shape, group_by_stack_n=group_by_stack - ).table(sort_by=metric, row_limit=row_limit) - with open(f"{curr_trace_dir}/rank{rank}_key_averages.txt", "w") as f: - print( - key_avgs, file=f - ) - if rank == 0: - print(f"Saving profiling results to {curr_trace_dir}") - - #TODO: Is this necessary? - torch.distributed.barrier() - -@contextlib.contextmanager -def profiling_context(args, rank): - enable_profiling = args["profile"] - - if enable_profiling: - model_name = args["model_name"].split("/")[-1] - train_type = args["train_type"] - output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}" - - if not os.path.exists(output_dir): - os.makedirs(output_dir, exist_ok=True) - - logger.info(f"Profiling enabled. Traces will be saved at {output_dir}") - - warmup = args["warmup_steps"] - active = args["active_steps"] - repeat = args["repeat"] - - if repeat == 0: - steps_per_cycle = args["profiling_frequency"] - wait = steps_per_cycle - (active + warmup) - else: - wait = args["wait_steps"] - steps_per_cycle = wait + warmup + active - assert ( - wait >= 0 - ), "profile_freq must be greater than or equal to warmup + active" - logger.info(f"Profiler schedule - steps per cycle: {steps_per_cycle} wait: {wait} warmup: {warmup} active: {active} repeat: {repeat if repeat !=0 else 'inf'}") - - profile_memory = args["export_memory_timeline"] - export_memory_timeline = args["export_memory_timeline"] - with_stack = args["with_stack"] or args["export_memory_timeline"] - with_shapes = args["with_shapes"] or export_memory_timeline - callback = partial(trace_handler, rank=rank, - export_memory_timeline=export_memory_timeline, - output_dir=output_dir, - with_stack=with_stack, - group_by_input_shape=with_shapes, - group_by_stack=5 if with_stack else 0) - - with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - with_stack=with_stack, - profile_memory=profile_memory, - record_shapes=with_shapes, - schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat), - on_trace_ready=callback, - experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) if with_stack else None, - ) as torch_profiler: - yield torch_profiler - else: - class FakeProfiler: - """ - Fake profiler object when profiling is not enabled. - - """ - def __enter__(self): - return self - def __exit__(self, *args, **kwargs): - pass - - def step(self): - pass - - yield FakeProfiler() From b2d2475af2b1e6fa3429fdcf630bcf9241f0fee7 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 3 Jun 2024 20:49:51 +0000 Subject: [PATCH 11/16] remove deprecated profiling tools --- profile.sh | 7 ++++-- profiling_tools.py | 62 ---------------------------------------------- 2 files changed, 5 insertions(+), 64 deletions(-) delete mode 100644 profiling_tools.py diff --git a/profile.sh b/profile.sh index a7b8450..f948f73 100755 --- a/profile.sh +++ b/profile.sh @@ -66,6 +66,9 @@ python train.py \ --export_trace true \ --export_memory_timeline false \ --with_stack true \ ---max_steps 50 \ ---repeat 2 \ +--max_steps 20 \ +--repeat 0 \ +--warmup_steps 4 \ +--active_steps 1 \ +--profiling_frequency 5 \ --profiling_output llama-test diff --git a/profiling_tools.py b/profiling_tools.py deleted file mode 100644 index 2e09161..0000000 --- a/profiling_tools.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch -import os -from datetime import datetime -from functools import partial -def trace_handler( - prof: torch.profiler.profile, - rank: int, - export_trace=True, - export_memory_timeline=False, - with_stack: bool = True, - group_by_stack: int = 0, - group_by_input_shapes: bool = False, - prefix="", - out_dir="./profiles", - time_fmt_str: str = "%m_%d_%H", - metric="self_cuda_time_total", - row_limit=25, - verbose=False, -): - # Prefix for file names. - timestamp = datetime.now().strftime(time_fmt_str) - file_prefix = os.path.join(out_dir, f"{prefix}-{timestamp}") - if not os.path.exists(out_dir): - os.makedirs(out_dir) - - # Construct the trace file. - if export_trace: - prof.export_chrome_trace(f"{file_prefix}-chrome-trace.json.gz") - - # Construct the memory timeline file. - if export_memory_timeline: - prof.export_memory_timeline( - f"{file_prefix}-memory-timeline.html" - ) - - if with_stack: - prof.export_stacks(f"{file_prefix}-stacks.txt", metric=metric) - - key_avgs = prof.key_averages( - group_by_input_shape=group_by_input_shapes, group_by_stack_n=group_by_stack - ).table(sort_by=metric, row_limit=row_limit) - with open(f"{file_prefix}-key_averages.txt", "w") as f: - print( - key_avgs, file=f - ) - if rank == 0: - print(f"Saving profiling results to {out_dir}") - if verbose: - print(key_avgs) - -class FakeContext: - """ - Fake context when not using profiler with profiling script. - - """ - def __enter__(self): - return self - def __exit__(self, *args, **kwargs): - pass - - def step(self): - pass \ No newline at end of file From 474dbae9a08ad3554ac57e3854029a18aa38044d Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 3 Jun 2024 20:52:42 +0000 Subject: [PATCH 12/16] add back profiling utils --- profiling_utils.py | 133 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 profiling_utils.py diff --git a/profiling_utils.py b/profiling_utils.py new file mode 100644 index 0000000..1f97af5 --- /dev/null +++ b/profiling_utils.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import time +import logging +import torch +import torch.distributed +from functools import partial +import shutil +from torch.profiler import tensorboard_trace_handler +WARMUP = 3 + +logger = logging.getLogger() + +#adapted from https://github.com/pytorch/torchtitan + +def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_cuda_time_total", with_stack=True, group_by_stack=0, group_by_input_shape=False, row_limit=25): + curr_trace_dir_name = "iteration_" + str(prof.step_num) + curr_trace_dir = os.path.join(output_dir, curr_trace_dir_name) + if not os.path.exists(curr_trace_dir): + os.makedirs(curr_trace_dir, exist_ok=True) + + #Export chrome / tensorboard trace + logger.info(f"Dumping traces at step {prof.step_num}") + begin = time.monotonic() + + #Use tensorboard trace handler rather than directly exporting chrome traces since + #tensorboard doesn't seem to be able to parse traces when with prof.export_chrome_trace + exporter = tensorboard_trace_handler(curr_trace_dir, worker_name=f"rank{rank}", use_gzip=True) + exporter(prof) + #prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json") + + logger.info( + f"Finished dumping traces in {time.monotonic() - begin:.2f} seconds" + ) + + #Construct the memory timeline file. + if export_memory_timeline: + prof.export_memory_timeline( + f"{curr_trace_dir}/rank{rank}_memory-timeline.html" + ) + + #Dump stack traces + if with_stack: + prof.export_stacks(f"{curr_trace_dir}/rank{rank}_stacks.txt", metric=metric) + + #Export event averages + key_avgs = prof.key_averages( + group_by_input_shape=group_by_input_shape, group_by_stack_n=group_by_stack + ).table(sort_by=metric, row_limit=row_limit) + with open(f"{curr_trace_dir}/rank{rank}_key_averages.txt", "w") as f: + print( + key_avgs, file=f + ) + if rank == 0: + print(f"Saving profiling results to {curr_trace_dir}") + + #TODO: Is this necessary? + torch.distributed.barrier() + +@contextlib.contextmanager +def profiling_context(args, rank): + enable_profiling = args["profile"] + + if enable_profiling: + model_name = args["model_name"].split("/")[-1] + train_type = args["train_type"] + output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}" + + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + logger.info(f"Profiling enabled. Traces will be saved at {output_dir}") + + warmup = args["warmup_steps"] + active = args["active_steps"] + repeat = args["repeat"] + + if repeat == 0: + steps_per_cycle = args["profiling_frequency"] + wait = steps_per_cycle - (active + warmup) + else: + wait = args["wait_steps"] + steps_per_cycle = wait + warmup + active + assert ( + wait >= 0 + ), "profile_freq must be greater than or equal to warmup + active" + logger.info(f"Profiler schedule - steps per cycle: {steps_per_cycle} wait: {wait} warmup: {warmup} active: {active} repeat: {repeat if repeat !=0 else 'inf'}") + + profile_memory = args["export_memory_timeline"] + export_memory_timeline = args["export_memory_timeline"] + with_stack = args["with_stack"] or args["export_memory_timeline"] + with_shapes = args["with_shapes"] or export_memory_timeline + callback = partial(trace_handler, rank=rank, + export_memory_timeline=export_memory_timeline, + output_dir=output_dir, + with_stack=with_stack, + group_by_input_shape=with_shapes, + group_by_stack=5 if with_stack else 0) + + with torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=with_stack, + profile_memory=profile_memory, + record_shapes=with_shapes, + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=repeat), + on_trace_ready=callback, + experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) if with_stack else None, + ) as torch_profiler: + yield torch_profiler + else: + class FakeProfiler: + """ + Fake profiler object when profiling is not enabled. + + """ + def __enter__(self): + return self + def __exit__(self, *args, **kwargs): + pass + + def step(self): + pass + + yield FakeProfiler() \ No newline at end of file From 5981ae483af643afcf1cd922bd9b84c73e204e54 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Mon, 3 Jun 2024 22:04:37 +0000 Subject: [PATCH 13/16] update profiling docs --- PROFILING.md | 66 +++++++++++++++++++++++++++++++++++++++++++++ profile.sh | 67 +++++++++++++++++++++++++++++++--------------- profiling_utils.py | 23 ++++++++++++---- train.py | 2 +- 4 files changed, 130 insertions(+), 28 deletions(-) diff --git a/PROFILING.md b/PROFILING.md index 59682b5..31946dd 100644 --- a/PROFILING.md +++ b/PROFILING.md @@ -59,3 +59,69 @@ Detailed `CLI` options: The default schedule for the profiler is set such that to continuously execute a 10-step cycle: wait for 7, warmup for 2, record for 1. `with_stack` and `with_shapes` are overridden by `export_memory_timeline` since the memory profile requires these options to be `True`. + +#### Examples + +- Record every 5th step, exporting a `chrome` / `tensorboard` trace for each cycle: + + ``` + python train.py \ + --model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ + --gradient_accumulation_steps 2 \ + --batch_size 1 \ + --context_length 256 \ + --num_epochs 1 \ + --sharding_strategy full_shard \ + --precision bf16 \ + --train_type qlora \ + --use_gradient_checkpointing false \ + --use_cpu_offload false \ + --log_to stdout \ + --dataset dummy \ + --profile true \ + --export_trace true \ + --export_memory_timeline false \ + --with_stack true \ + --num_epochs 1 \ + --max_steps 20 \ + --repeat 0 \ + --warmup_steps 4 \ + --active_steps 1 \ + --profiling_frequency 5 \ + --profiling_output llama-test + ``` + + The output will be a 4 trace output folders, at iteration 5, 10, ..., each containing a trace with a single training step at that iteration. + + Also in the folder will be exported stacks (which can be visualized using flamegraphs or other stack viewers) and `key_averages`, which is a summary table of operations ordered by `cuda` time. + + Note that we set `max_steps=20` so that the training loop will exit after 20 batches. If `max_steps=-1` (the default setting), the profiler will repeat the cycle during the entire training run. + +- Record 5 steps (after 1 warmup step) then stop profiling: + ``` + python train.py \ + --model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ + --gradient_accumulation_steps 2 \ + --batch_size 1 \ + --context_length 256 \ + --num_epochs 1 \ + --sharding_strategy full_shard \ + --precision bf16 \ + --train_type qlora \ + --use_gradient_checkpointing false \ + --use_cpu_offload false \ + --log_to stdout \ + --dataset dummy \ + --profile true \ + --export_trace true \ + --export_memory_timeline true \ + --with_stack true \ + --num_epochs 1 \ + --max_steps 20 \ + --repeat 1 \ + --warmup_steps 1 \ + --active_steps 5 \ + --profiling_output llama-test2 + ``` + The output will be a single trace at `iteration_6` which contains 5 training steps. + In addition to the `stacks` and `key_averages` artifacts, there will be a `memory_timeline` `html`, which shows a breakdown of memory usage by `parameter`, `gradients`, `activations`, etc. diff --git a/profile.sh b/profile.sh index f948f73..b2b1990 100755 --- a/profile.sh +++ b/profile.sh @@ -49,26 +49,49 @@ #"meta-llama/Llama-2-7b-hf" +# python train.py \ +# --model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ +# --gradient_accumulation_steps 2 \ +# --batch_size 1 \ +# --context_length 256 \ +# --num_epochs 1 \ +# --sharding_strategy full_shard \ +# --precision bf16 \ +# --train_type qlora \ +# --use_gradient_checkpointing false \ +# --use_cpu_offload false \ +# --log_to stdout \ +# --dataset dummy \ +# --profile true \ +# --export_trace true \ +# --export_memory_timeline false \ +# --with_stack true \ +# --max_steps 20 \ +# --repeat 0 \ +# --warmup_steps 4 \ +# --active_steps 1 \ +# --profiling_frequency 5 \ +# --profiling_output llama-test python train.py \ ---model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ ---gradient_accumulation_steps 2 \ ---batch_size 1 \ ---context_length 256 \ ---num_epochs 1 \ ---sharding_strategy full_shard \ ---precision bf16 \ ---train_type qlora \ ---use_gradient_checkpointing false \ ---use_cpu_offload false \ ---log_to stdout \ ---dataset dummy \ ---profile true \ ---export_trace true \ ---export_memory_timeline false \ ---with_stack true \ ---max_steps 20 \ ---repeat 0 \ ---warmup_steps 4 \ ---active_steps 1 \ ---profiling_frequency 5 \ ---profiling_output llama-test + --model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ + --gradient_accumulation_steps 2 \ + --batch_size 1 \ + --context_length 256 \ + --num_epochs 1 \ + --sharding_strategy full_shard \ + --precision bf16 \ + --train_type qlora \ + --use_gradient_checkpointing false \ + --use_cpu_offload false \ + --log_to stdout \ + --dataset dummy \ + --profile true \ + --export_trace true \ + --export_memory_timeline true \ + --with_stack true \ + --num_epochs 1 \ + --max_steps 20 \ + --repeat 1 \ + --warmup_steps 1 \ + --active_steps 4 \ + --profiling_output llama-test2 \ No newline at end of file diff --git a/profiling_utils.py b/profiling_utils.py index 1f97af5..a2e9bb8 100644 --- a/profiling_utils.py +++ b/profiling_utils.py @@ -19,7 +19,7 @@ #adapted from https://github.com/pytorch/torchtitan -def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_cuda_time_total", with_stack=True, group_by_stack=0, group_by_input_shape=False, row_limit=25): +def trace_handler(prof: torch.profiler.profiler.profile, rank, export_memory_timeline, output_dir, metric="self_cuda_time_total", with_stack=True, group_by_stack=0, group_by_input_shape=False, row_limit=25): curr_trace_dir_name = "iteration_" + str(prof.step_num) curr_trace_dir = os.path.join(output_dir, curr_trace_dir_name) if not os.path.exists(curr_trace_dir): @@ -41,10 +41,23 @@ def trace_handler(prof, rank, export_memory_timeline, output_dir, metric="self_c #Construct the memory timeline file. if export_memory_timeline: - prof.export_memory_timeline( - f"{curr_trace_dir}/rank{rank}_memory-timeline.html" - ) - + try: + prof.export_memory_timeline( + f"{curr_trace_dir}/rank{rank}_memory-timeline.html" + ) + except: + logger.info("Failed to export memory timeline to html, retrying as gzipped json.") + try: + prof.export_memory_timeline( + f"{curr_trace_dir}/rank{rank}_memory-timeline.json.gz" + ) + except: + + logger.info("Failed to export memory timeline to gzipped json. Saving profiler timeline object instead.") + from torch.profiler._memory_profiler import MemoryProfileTimeline + memory_profile = MemoryProfileTimeline(prof._memory_profile()) + torch.save(memory_profile, f"{curr_trace_dir}/rank{rank}_memory-timeline.pt") + #Dump stack traces if with_stack: prof.export_stacks(f"{curr_trace_dir}/rank{rank}_stacks.txt", metric=metric) diff --git a/train.py b/train.py index 57712c8..12d6451 100644 --- a/train.py +++ b/train.py @@ -1207,7 +1207,7 @@ def main( with_shapes: bool_arg = False, # Output shapes for profiling. Can impact performance. Note that setting export_memory_timeline will automatically export traces since `with_shapes` must be true to profile memory. export_trace: bool_arg = True, # Output trace for profiling export_memory_timeline: bool_arg = False, # Output memory timelinefor profiling - wait_steps: int = 1, # Wait steps when running profiler. Only used if repeat != 0. + wait_steps: int = 0, # Wait steps when running profiler. Only used if repeat != 0. warmup_steps: int = 1, # Warmup steps when running profiler active_steps: int = 2, # Active steps when running profiler repeat: int = 0, #Number of profiler cycles (wait + warmup + active) if > 0, else repeats forever From 26d65954e7cb401811f1f750f2d5a05c37f1ec92 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 4 Jun 2024 19:10:46 +0000 Subject: [PATCH 14/16] update docs --- PROFILING.md | 44 ++++++++++++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/PROFILING.md b/PROFILING.md index 31946dd..2bcb315 100644 --- a/PROFILING.md +++ b/PROFILING.md @@ -1,9 +1,18 @@ ## Profiling -## Usage +Documentation for how to profile your training runs. + +**Tips** + +- Only record what is necessary as profiling can significantly slow down training process. +- Set a `torch.profile.schedule` when running the profiler (description below), as trace artifacts are exported at the end of each profiling cycle and can be very large (on the order of hundreds of MBs each). **IMPORTANT** -There are issues with recording stack traces and exporting traces simultaneously (see this [issue](https://github.com/pytorch/pytorch/issues/113564)) depending on `python` version. The only combination I was able to get both to work at the same time was with `python=3.11.9` and `torch=2.3.0`. +There are issues with recording stack traces and exporting traces simultaneously (see this [issue](https://github.com/pytorch/pytorch/issues/113564)) depending on `python` version. + +Tested with `python=3.11.9` and `torch=2.3.0`. + +## Quickstart Running the following: @@ -19,17 +28,28 @@ python train.py \ will result in a directory `{model_name}_{train_type}-{local_rank}` with the following artifacts: -- `{model_name}-{train_type}-chrome-trace.json.gz` - interactive trace that can be viewed using `chrome::tracing` or `perfetto` +- `{model_name}-{train_type}-chrome-trace.json.gz` - interactive trace that can be viewed using `chrome::tracing`, `perfetto`, or `tensorboard` - `{model_name}-{train_type}-key_averages.txt` - sorted table of events, e.g.: -``` - -``` +| Name | Self CPU % | Self CPU | CPU total % | CPU total | CPU time avg | Self CUDA | Self CUDA % | CUDA total | CUDA time avg | # of Calls | Source Location | +| --------------------------------------------------------------------------------- | ---------- | -------- | ----------- | --------- | ------------ | --------- | ----------- | ---------- | ------------- | ---------- | ------------------------------------------------------------------------ | +| ncclDevKernel_AllGather_RING_LL(ncclDevComm*, unsigned int*, unsigned int\*, int) | 0.00% | 0.000us | 0.00% | 0.000us | 0.000us | 88.038ms | 12.14% | 88.038ms | 830.547us | 106 | | +| | | | | | | | | | | | torch/distributed/distributed_c10d.py(2864): all_gather_into_tensor | +| | | | | | | | | | | | torch/distributed/c10d_logger.py(72): wrapper | +| | | | | | | | | | | | torch/distributed/fsdp/\_flat_param.py(1366): \_all_gather_flat_param | +| | | | | | | | | | | | torch/distributed/fsdp/\_flat_param.py(1285): unshard | +| FullyShardedDataParallel.forward | 0.00% | 0.000us | 0.00% | 0.000us | 0.000us | 59.050ms | 8.14% | 59.050ms | 59.050ms | 1 | | +| | | | | | | | | | | | torch/nn/functional.py(2154): embedding | +| | | | | | | | | | | | torch/nn/modules/sparse.py(162): forward | +| | | | | | | | | | | | torch/nn/modules/module.py(1534): \_call_impl | +| | | | | | | | | | | | nn.Module: Embedding_0 | - `{model_name}-{train_type}-memory-timeline.html` - Stacked time series plot of memory use broken down by `Parameter`, `Gradients`, `Activations`, etc. - `{model_name}-{train_type}-stacks.txt` - Stack trace. See [docs](https://pytorch.org/docs/stable/profiler.html#torch.profiler._KinetoProfile.export_stacks). -Detailed `CLI` options: +## Detailed Usage + +`CLI` options in full: - `profile` - whether to profile - `profiling_outputs` - output directory for `torch.profiler` artifacts @@ -48,19 +68,19 @@ Detailed `CLI` options: **Note**: Simplest to think of 2 ways of scheduling the profiler: 1. Set `repeat` to the number of total number of desired profiling cycles. For example if `wait=1`, `warmup=1`, `active=1`, and `repeat=1`, then the profiler will wait for 1 step, warmup for 1, and record for 1 then stop. - 2. Set `repeat` to `0` and `profiling_frequency` to the cycle length. E.g., with `repeat=0`, `profiling_frequency=10`, `warmup=2`, `active=1`, then `wait` will be automatically set to `profiling_frequency - (warmup + active) = 7`. The profiler will then continuously execute the following cycle: wait for 7 steps, warmup for 2, record for 1. + 2. Set `repeat` to `0` and `profiling_frequency` to the cycle length. E.g., with `repeat=0`, `profiling_frequency=10`, `warmup=2`, `active=1`, then `wait` will be automatically set to `profiling_frequency - (warmup + active) = 7`. The profiler will then continuously execute the following cycle: wait for 7 steps, warmup for 2, record for 1 for the entire training run. See [docs](https://pytorch.org/docs/stable/profiler.html#torch.profiler.schedule) for further details. - `max_steps` - maximum number of batches per epoch. E.g., with `num_epochs=1`, stops training after `max_steps` of batches. Note that this is automatically adjusted to accommodate the profiler schedule; for example, if `max_steps < wait_steps + warmup_steps + active_steps`, it will automatically be set to `wait_steps + warmup_steps + active_steps` such that the profiler can run for at least 1 cycle. -#### Additional Notes +## Additional Notes -The default schedule for the profiler is set such that to continuously execute a 10-step cycle: wait for 7, warmup for 2, record for 1. +The default schedule for the profiler is set to continuously execute a 10-step cycle: wait for 7, warmup for 2, record for 1. `with_stack` and `with_shapes` are overridden by `export_memory_timeline` since the memory profile requires these options to be `True`. -#### Examples +## Examples - Record every 5th step, exporting a `chrome` / `tensorboard` trace for each cycle: @@ -118,9 +138,9 @@ The default schedule for the profiler is set such that to continuously execute a --with_stack true \ --num_epochs 1 \ --max_steps 20 \ - --repeat 1 \ --warmup_steps 1 \ --active_steps 5 \ + --repeat 1 \ --profiling_output llama-test2 ``` The output will be a single trace at `iteration_6` which contains 5 training steps. From 1b3f1f547fb6e65241c8e9ccdd346409de8a73f4 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 4 Jun 2024 19:13:21 +0000 Subject: [PATCH 15/16] clean up profile.sh --- profile.sh | 142 ++++++++++++++++++----------------------------------- 1 file changed, 48 insertions(+), 94 deletions(-) diff --git a/profile.sh b/profile.sh index b2b1990..80264bf 100755 --- a/profile.sh +++ b/profile.sh @@ -1,97 +1,51 @@ -# Running below will result in a directory `Llama-2-7b_qlora-{local_rank}` with the following artifacts: -# - Llama-2-7b_qlora-chrome-trace.json.gz - interactive trace that can be viewed using `chrome::tracing` or `perfetto` -# - Llama-2-7b_qlora-key_averages.txt - sorted table of events, e.g.: -# | Name | Self CPU % | Self CPU | CPU total % | CPU total | CPU time avg | Self CUDA | Self CUDA % | CUDA total | CUDA time avg | CPU Mem | Self CPU Mem | CUDA Mem | Self CUDA Mem | # of Calls | Source Location | -# |---------------------------------------|------------|------------|-------------|------------|--------------|-------------|-------------|-------------|---------------|---------|--------------|----------|---------------|------------|----------------------------------------------------------------------------------| -# | ProfilerStep* | 0.00% | 0.000us | 0.00% | 0.000us | 0.000us | 4.816s | 44.60% | 4.816s | 963.233ms | 0 b | 0 b | 0 b | 0 b | 5 | | -# | | | | | | | | | | | | | | | | train.py(962): fsdp_main | -# | | | | | | | | | | | | | | | | torch/multiprocessing/spawn.py(75): _wrap | -# | | | | | | | | | | | | | | | | multiprocessing/process.py(108): run | -# | | | | | | | | | | | | | | | | multiprocessing/process.py(314): _bootstrap | -# | FullyShardedDataParallel.forward | 0.00% | 0.000us | 0.00% | 0.000us | 0.000us | 2.208s | 20.45% | 2.208s | 441.555ms | 0 b | 0 b | 0 b | 0 b | 5 | | -# | | | | | | | | | | | | | | | | torch/nn/functional.py(2154): embedding | -# | | | | | | | | | | | | | | | | torch/nn/modules/sparse.py(162): forward | -# | | | | | | | | | | | | | | | | torch/nn/modules/module.py(1534): _call_impl | -# | | | | | | | | | | | | | | | | nn.Module: Embedding_0 | -# | aten::mm | 0.44% | 31.314ms | 0.69% | 48.739ms | 43.517us | 332.421ms | 3.08% | 337.208ms | 301.079us | 0 b | 0 b | 3.26 Gb | 3.26 Gb | 1120 | | -# | | | | | | | | | | | | | | | | bitsandbytes/autograd/_functions.py(492): forward | -# | | | | | | | | | | | | | | | | | -# | | | | | | | | | | | | | | | | torch/autograd/function.py(582): apply | -# | | | | | | | | | | | | | | | | bitsandbytes/autograd/_functions.py(559): matmul_4bit | -# | MatMul4Bit | 2.81% | 198.511ms | 4.93% | 347.437ms | 310.212us | 284.169ms | 2.63% | 630.417ms | 562.872us | 0 b | 0 b | 3.26 Gb | -62.31 Gb | 1120 | | -# | | | | | | | | | | | | | | | | torch/autograd/function.py(582): apply | -# | | | | | | | | | | | | | | | | bitsandbytes/autograd/_functions.py(559): matmul_4bit | -# | | | | | | | | | | | | | | | | bitsandbytes/nn/modules.py(442): forward | -# | | | | | | | | | | | | | | | | torch/nn/modules/module.py(1534): _call_impl | +#See PROFILING.md for documentation -# - Llama-2-7b_qlora-memory-timeline.html - Stacked time series plot of memory use broken down by `Parameter`, `Gradients`, `Activations`, etc. -# - Llama-2-7b_qlora-stacks.txt - Stack trace. See [docs](https://pytorch.org/docs/stable/profiler.html#torch.profiler._KinetoProfile.export_stacks). - -# Detailed `CLI` options: -# - `profile` - whether to profile -# - `profiling_outputs` - output directory for `torch.profiler` artifacts -# - `export_trace` - enables exporting of interactive trace that can be viewed and analyzed using `chrome::tracing` -# - `export_memory_timeline` - exports an HTML memory timeline which shows memory use by category (`parameters`, `activations`, `gradients`, etc.) -# - `with_stack` - exports stack trace -# - `with_shapes` - adds shapes of operators to the trace -# - `{wait, warmup, active}_steps` - controls how many profiling steps are recorded: -# - `wait_steps` - number of steps for the profiler to wait before starting to profile -# - `warmup_steps` - number of steps for profiler to profile without recording -# - `active_steps` - number of steps to record -# See [docs](https://pytorch.org/docs/stable/profiler.html#torch.profiler.schedule) for further details. - -# The default schedule for the profiler is set such that only 2 steps of the each epoch are recorded (not counting `wait` and `warmup` steps which are not recorded). - -# Note that `with_stack` and `with_shapes` are overridden by `export_memory_timeline` since the memory profile requires these options to be `True`. - -#**IMPORTANT** There are issues with recording stack traces and exporting traces simultaneously (see this [issue](https://github.com/pytorch/pytorch/issues/113564)) depending on `python` version. The only combination I was able to get both to work at the same time was with `python=3.11.9` and `torch=2.3.0`. -#Tested on `python=3.11.9 and torch=2.3.0`` - -#"meta-llama/Llama-2-7b-hf" +# Run profiler contiguously on a 5-step cycle: 4 warmup steps and 1 active (recording) step. +python train.py \ +--model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ +--gradient_accumulation_steps 2 \ +--batch_size 1 \ +--context_length 256 \ +--num_epochs 1 \ +--sharding_strategy full_shard \ +--precision bf16 \ +--train_type qlora \ +--use_gradient_checkpointing false \ +--use_cpu_offload false \ +--log_to stdout \ +--dataset dummy \ +--profile true \ +--export_trace true \ +--export_memory_timeline false \ +--with_stack true \ +--max_steps 20 \ +--repeat 0 \ +--warmup_steps 4 \ +--active_steps 1 \ +--profiling_frequency 5 \ +--profiling_output llama-test +# Run for 1 cycle then stop profiling # python train.py \ -# --model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ -# --gradient_accumulation_steps 2 \ -# --batch_size 1 \ -# --context_length 256 \ -# --num_epochs 1 \ -# --sharding_strategy full_shard \ -# --precision bf16 \ -# --train_type qlora \ -# --use_gradient_checkpointing false \ -# --use_cpu_offload false \ -# --log_to stdout \ -# --dataset dummy \ -# --profile true \ -# --export_trace true \ -# --export_memory_timeline false \ -# --with_stack true \ -# --max_steps 20 \ -# --repeat 0 \ -# --warmup_steps 4 \ -# --active_steps 1 \ -# --profiling_frequency 5 \ -# --profiling_output llama-test -python train.py \ - --model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ - --gradient_accumulation_steps 2 \ - --batch_size 1 \ - --context_length 256 \ - --num_epochs 1 \ - --sharding_strategy full_shard \ - --precision bf16 \ - --train_type qlora \ - --use_gradient_checkpointing false \ - --use_cpu_offload false \ - --log_to stdout \ - --dataset dummy \ - --profile true \ - --export_trace true \ - --export_memory_timeline true \ - --with_stack true \ - --num_epochs 1 \ - --max_steps 20 \ - --repeat 1 \ - --warmup_steps 1 \ - --active_steps 4 \ - --profiling_output llama-test2 \ No newline at end of file +# --model_name "hf-internal-testing/tiny-random-LlamaForCausalLM" \ +# --gradient_accumulation_steps 2 \ +# --batch_size 1 \ +# --context_length 256 \ +# --num_epochs 1 \ +# --sharding_strategy full_shard \ +# --precision bf16 \ +# --train_type qlora \ +# --use_gradient_checkpointing false \ +# --use_cpu_offload false \ +# --log_to stdout \ +# --dataset dummy \ +# --profile true \ +# --export_trace true \ +# --export_memory_timeline true \ +# --with_stack true \ +# --num_epochs 1 \ +# --max_steps 20 \ +# --repeat 1 \ +# --warmup_steps 1 \ +# --active_steps 4 \ +# --profiling_output llama-test2 \ No newline at end of file From cf04925e16b7176b6a627551f30c634f726c3389 Mon Sep 17 00:00:00 2001 From: jeromeku Date: Tue, 4 Jun 2024 19:17:39 +0000 Subject: [PATCH 16/16] minor edits --- train.py | 101 ++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 63 insertions(+), 38 deletions(-) diff --git a/train.py b/train.py index 12d6451..6d303b4 100644 --- a/train.py +++ b/train.py @@ -15,63 +15,84 @@ # Imports # General -import torch, os, gc, time, safetensors, copy, math, types, sys +import copy import functools -import torch.optim as optim -from torch.optim.lr_scheduler import LambdaLR -from transformers.optimization import get_linear_schedule_with_warmup +import gc +import math +import os +import sys +import time +import types +from contextlib import nullcontext +from glob import glob +from pathlib import Path +from typing import Dict, List + import bitsandbytes as bnb +import safetensors +import torch import torch.distributed as dist import torch.multiprocessing as mp -from torch.profiler import profile, record_function, ProfilerActivity -from contextlib import nullcontext -from safetensors.torch import save_file -from tqdm.auto import tqdm -from typing import List, Dict -from pathlib import Path -from glob import glob -from packaging.version import parse +import torch.optim as optim +from accelerate import init_empty_weights +from accelerate.utils import set_seed + +# Model loading +from bitsandbytes.nn import Linear4bit, Params4bit +from fastcore.parallel import parallel # Argument parsing -from fastcore.script import call_parse, bool_arg, Param +from fastcore.script import Param, bool_arg, call_parse +from packaging.version import parse +from safetensors.torch import save_file # Torch + distributed training -from torch import nn, Tensor -from torch.nn.utils.rnn import pad_sequence -from torch.utils.data import Dataset, DataLoader, DistributedSampler - -# FSDP -from torch.distributed.fsdp import MixedPrecision, FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy -from torch.distributed.fsdp.api import BackwardPrefetch, CPUOffload, ShardingStrategy -from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler -from torch.distributed.fsdp import StateDictType, FullStateDictConfig +from torch import Tensor, nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper, - offload_wrapper, CheckpointImpl, apply_activation_checkpointing, + checkpoint_wrapper, + offload_wrapper, ) -# Model loading -from bitsandbytes.nn import Linear4bit, Params4bit -from accelerate import init_empty_weights -from accelerate.utils import set_seed -from transformers.utils import hub, SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME -from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig +# FSDP +from torch.distributed.fsdp import FullStateDictConfig, MixedPrecision, StateDictType +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp.api import BackwardPrefetch, CPUOffload, ShardingStrategy +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import ( + _or_policy, + lambda_auto_wrap_policy, + transformer_auto_wrap_policy, +) +from torch.nn.utils.rnn import pad_sequence +from torch.optim.lr_scheduler import LambdaLR +from torch.profiler import ProfilerActivity, profile, record_function +from torch.utils.data import DataLoader, Dataset, DistributedSampler +from tqdm.auto import tqdm +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers.optimization import get_linear_schedule_with_warmup from transformers.tokenization_utils_fast import PreTrainedTokenizerFast -from fastcore.parallel import parallel +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, hub try: - from hqq.core.quantize import HQQLinear, HQQBackend, BaseQuantizeConfig + from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear except ImportError: HQQLinear = None pass # To add a new model, import the transformer, attention, & MLP layers # for the wrapping policy and `check_fn` in activation checkpointing -from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LLAMA_ATTENTION_CLASSES, LlamaMLP -from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MISTRAL_ATTENTION_CLASSES, MistralMLP +from transformers.models.llama.modeling_llama import ( + LLAMA_ATTENTION_CLASSES, + LlamaDecoderLayer, + LlamaMLP, +) +from transformers.models.mistral.modeling_mistral import ( + MISTRAL_ATTENTION_CLASSES, + MistralDecoderLayer, + MistralMLP, +) # To get rid of tokenizers warnings for now os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -84,10 +105,12 @@ # LoRA and DORA modules sys.path.append("./scripts") -from lora import LORA from dora import BNBDORA, HQQDORA, DORALayer, MagnitudeLayer +from lora import LORA + from profiling_utils import profiling_context + class Logger: def __init__(self, args, log_to="stdout", project_name="fsdp_qlora", entity=None, group=None, name=None, rank=0): # self.log_every_n_steps = log_every_n_steps TODO: add this back as an option @@ -443,7 +466,7 @@ def get_optimizer(model:nn.Module, args:Dict): # Wrap the model using LoRA policy from llama-recipes or custom policy: # This checks for lora layers (has weight and requires_grad) def get_wrapping_policy(custom_policy:bool=False, vanilla_policy:bool=False): - from peft.tuners import PromptEncoder, PromptEmbedding, PrefixEncoder + from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder if custom_policy: def lambda_policy_fn(module): @@ -644,7 +667,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs): # PEFT setup (LoRA and QLoRA) if args["train_type"] in ["lora", "qlora"]: - from peft import get_peft_model, LoraConfig, TaskType + from peft import LoraConfig, TaskType, get_peft_model peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, @@ -926,6 +949,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs): prof.step() + #Primarily for debugging if args["max_steps"] > 0 and batch_idx > args["max_steps"]: if rank == 0: print("Max steps reached, skipping rest of epoch") @@ -1053,6 +1077,7 @@ def fsdp_qlora( group: str = None, # For wandb logging entity: str = None, # For wandb logging n_bits: int = 4, # passed to hqq + #Profiling args profile: bool_arg = False, # Whether to profile with torch.profiler profiling_output: str = "profiles", # Output file prefix for profiling overwrite_profiling_output: bool = True, # Overwrite output directory