Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Profiling Improvements #67

Merged
merged 17 commits into from
Jun 11, 2024
90 changes: 64 additions & 26 deletions profile.sh
Original file line number Diff line number Diff line change
@@ -1,28 +1,66 @@
# This generates uses export_stacks to generate profiling output
# /tmp/profile_0.txt, /tmp/profile_1.txt, etc. (1 file per process.)
#
# Output files are generated using export_stacks(), note there are some
# outstanding issues to be aware of:
# https://github.com/pytorch/pytorch/issues/100253
#
# Profiling output files can be used with speedscope or other tools.
#
# For additional information, see:
# https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
# Running below will result in a directory `Llama-2-7b_qlora-{local_rank}` with the following artifacts:
# - Llama-2-7b_qlora-chrome-trace.json.gz - interactive trace that can be viewed using `chrome::tracing` or `perfetto`
# - Llama-2-7b_qlora-key_averages.txt - sorted table of events, e.g.:
# | Name | Self CPU % | Self CPU | CPU total % | CPU total | CPU time avg | Self CUDA | Self CUDA % | CUDA total | CUDA time avg | CPU Mem | Self CPU Mem | CUDA Mem | Self CUDA Mem | # of Calls | Source Location |
# |---------------------------------------|------------|------------|-------------|------------|--------------|-------------|-------------|-------------|---------------|---------|--------------|----------|---------------|------------|----------------------------------------------------------------------------------|
# | ProfilerStep* | 0.00% | 0.000us | 0.00% | 0.000us | 0.000us | 4.816s | 44.60% | 4.816s | 963.233ms | 0 b | 0 b | 0 b | 0 b | 5 | <built-in method to of Tensor object at 0x7f20bf709310> |
# | | | | | | | | | | | | | | | | train.py(962): fsdp_main |
# | | | | | | | | | | | | | | | | torch/multiprocessing/spawn.py(75): _wrap |
# | | | | | | | | | | | | | | | | multiprocessing/process.py(108): run |
# | | | | | | | | | | | | | | | | multiprocessing/process.py(314): _bootstrap |
# | FullyShardedDataParallel.forward | 0.00% | 0.000us | 0.00% | 0.000us | 0.000us | 2.208s | 20.45% | 2.208s | 441.555ms | 0 b | 0 b | 0 b | 0 b | 5 | <built-in method embedding of type object at 0x7f21e21797c0> |
# | | | | | | | | | | | | | | | | torch/nn/functional.py(2154): embedding |
# | | | | | | | | | | | | | | | | torch/nn/modules/sparse.py(162): forward |
# | | | | | | | | | | | | | | | | torch/nn/modules/module.py(1534): _call_impl |
# | | | | | | | | | | | | | | | | nn.Module: Embedding_0 |
# | aten::mm | 0.44% | 31.314ms | 0.69% | 48.739ms | 43.517us | 332.421ms | 3.08% | 337.208ms | 301.079us | 0 b | 0 b | 3.26 Gb | 3.26 Gb | 1120 | <built-in function linear> |
# | | | | | | | | | | | | | | | | bitsandbytes/autograd/_functions.py(492): forward |
# | | | | | | | | | | | | | | | | <built-in method apply of FunctionMeta object at 0x827a410> |
# | | | | | | | | | | | | | | | | torch/autograd/function.py(582): apply |
# | | | | | | | | | | | | | | | | bitsandbytes/autograd/_functions.py(559): matmul_4bit |
# | MatMul4Bit | 2.81% | 198.511ms | 4.93% | 347.437ms | 310.212us | 284.169ms | 2.63% | 630.417ms | 562.872us | 0 b | 0 b | 3.26 Gb | -62.31 Gb | 1120 | <built-in method apply of FunctionMeta object at 0x827a410> |
# | | | | | | | | | | | | | | | | torch/autograd/function.py(582): apply |
# | | | | | | | | | | | | | | | | bitsandbytes/autograd/_functions.py(559): matmul_4bit |
# | | | | | | | | | | | | | | | | bitsandbytes/nn/modules.py(442): forward |
# | | | | | | | | | | | | | | | | torch/nn/modules/module.py(1534): _call_impl |

python train.py \
--model_name meta-llama/Meta-Llama-3-8B \
--train_type hqq_dora \
--n_bits 4 \
--precision bf16 \
--dataset orca_math \
--dataset_samples 8 \
--batch_size 2 \
--context_length 512 \
--gradient_accumulation_steps 2 \
--use_gradient_checkpointing False \
--use_cpu_offload False \
--use_activation_cpu_offload False \
--save_model False \
--profiling_output /tmp/profile
# - Llama-2-7b_qlora-memory-timeline.html - Stacked time series plot of memory use broken down by `Parameter`, `Gradients`, `Activations`, etc.
# - Llama-2-7b_qlora-stacks.txt - Stack trace. See [docs](https://pytorch.org/docs/stable/profiler.html#torch.profiler._KinetoProfile.export_stacks).

# Detailed `CLI` options:
# - `profile` - whether to profile
# - `profiling_outputs` - output directory for `torch.profiler` artifacts
# - `export_trace` - enables exporting of interactive trace that can be viewed and analyzed using `chrome::tracing`
# - `export_memory_timeline` - exports an HTML memory timeline which shows memory use by category (`parameters`, `activations`, `gradients`, etc.)
# - `with_stack` - exports stack trace
# - `with_shapes` - adds shapes of operators to the trace
# - `{wait, warmup, active}_steps` - controls how many profiling steps are recorded:
# - `wait_steps` - number of steps for the profiler to wait before starting to profile
# - `warmup_steps` - number of steps for profiler to profile without recording
# - `active_steps` - number of steps to record
# See [docs](https://pytorch.org/docs/stable/profiler.html#torch.profiler.schedule) for further details.

# The default schedule for the profiler is set such that only 2 steps of the each epoch are recorded (not counting `wait` and `warmup` steps which are not recorded).

# Note that `with_stack` and `with_shapes` are overridden by `export_memory_timeline` since the memory profile requires these options to be `True`.

#**IMPORTANT** There are issues with recording stack traces and exporting traces simultaneously (see this [issue](https://github.com/pytorch/pytorch/issues/113564)) depending on `python` version. The only combination I was able to get both to work at the same time was with `python=3.11.9` and `torch=2.3.0`.
#Tested on `python=3.11.9 and torch=2.3.0``

python train.py \
--model_name "meta-llama/Llama-2-7b-hf" \
--gradient_accumulation_steps 2 \
--batch_size 1 \
--context_length 256 \
--num_epochs 1 \
--sharding_strategy full_shard \
--precision bf16 \
--train_type qlora \
--use_gradient_checkpointing false \
--use_cpu_offload false \
--log_to stdout \
--dataset dummy \
--profile true \
--export_trace true \
--export_memory_timeline true \
--max_steps 10
62 changes: 62 additions & 0 deletions profiling_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch
import os
from datetime import datetime
from functools import partial
def trace_handler(
prof: torch.profiler.profile,
rank: int,
export_trace=True,
export_memory_timeline=False,
with_stack: bool = True,
group_by_stack: int = 0,
group_by_input_shapes: bool = False,
prefix="",
out_dir="./profiles",
time_fmt_str: str = "%m_%d_%H",
metric="self_cuda_time_total",
row_limit=25,
verbose=False,
):
# Prefix for file names.
timestamp = datetime.now().strftime(time_fmt_str)
file_prefix = os.path.join(out_dir, f"{prefix}-{timestamp}")
if not os.path.exists(out_dir):
os.makedirs(out_dir)

# Construct the trace file.
if export_trace:
prof.export_chrome_trace(f"{file_prefix}-chrome-trace.json.gz")

# Construct the memory timeline file.
if export_memory_timeline:
prof.export_memory_timeline(
f"{file_prefix}-memory-timeline.html"
)

if with_stack:
prof.export_stacks(f"{file_prefix}-stacks.txt", metric=metric)

key_avgs = prof.key_averages(
group_by_input_shape=group_by_input_shapes, group_by_stack_n=group_by_stack
).table(sort_by=metric, row_limit=row_limit)
with open(f"{file_prefix}-key_averages.txt", "w") as f:
print(
key_avgs, file=f
)
if rank == 0:
print(f"Saving profiling results to {out_dir}")
if verbose:
print(key_avgs)

class FakeContext:
"""
Fake context when not using profiler with profiling script.

"""
def __enter__(self):
return self
def __exit__(self, *args, **kwargs):
pass

def step(self):
pass
92 changes: 78 additions & 14 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
sys.path.append("./scripts")
from lora import LORA
from dora import BNBDORA, HQQDORA, DORALayer, MagnitudeLayer
from profiling_tools import trace_handler, FakeContext

class Logger:
def __init__(self, args, log_to="stdout", project_name="fsdp_qlora", entity=None, group=None, name=None, rank=0):
Expand Down Expand Up @@ -481,14 +482,44 @@ def mlp_policy_fn(module):

# Main function, run on each process
def fsdp_main(local_rank:int, world_size:int, args:Dict):
profiler_context = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
with_stack=True,
use_cuda=True,
record_shapes=False,
experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) # https://github.com/pytorch/pytorch/issues/100253
) if args["profiling_output"] else nullcontext()

#Profiler args
if args["profile"]:
#Profiler args
profile_memory = args["export_memory_timeline"]
export_trace = args["export_trace"]
export_memory_timeline = args["export_memory_timeline"]
with_stack = args["with_stack"] or args["export_memory_timeline"]
with_shapes = args["with_shapes"] or export_memory_timeline
model_name = args["model_name"].split("/")[-1]
train_type = args["train_type"]
prefix = f"{model_name}_{train_type}"
output_dir = args["profiling_output"] if args["profiling_output"] else f"./{model_name}_{train_type}-{local_rank}"
schedule = torch.profiler.schedule(wait=args["wait_steps"], warmup=args["warmup_steps"], active=args["active_steps"], repeat=args["repeat"])
callback = functools.partial(trace_handler,
export_trace=export_trace,
export_memory_timeline=export_memory_timeline,
with_stack=with_stack,
group_by_stack=5 if with_stack else 0,
prefix=prefix,
out_dir=output_dir,
rank=local_rank,
verbose=args["verbose"])
#Instantiate profiler
profiler_context = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
with_stack=with_stack,
profile_memory=profile_memory,
use_cuda=True,
record_shapes=with_shapes,
schedule=schedule,
on_trace_ready=callback,
experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True) if with_stack else None, # https://github.com/pytorch/pytorch/issues/100253
)
if args["max_steps"] > 0:
#Ensure enough steps to accommodate profiler schedule
args["max_steps"] = max(args["max_steps"], args["wait_steps"] + args["warmup_steps"] + args["active_steps"])
else:
profiler_context = FakeContext()
with profiler_context as prof:
# Setup and initialize the process group
os.environ['MASTER_ADDR'] = args["master_addr"]
Expand Down Expand Up @@ -826,6 +857,7 @@ def load_and_quantize_parallel(name_param, model, **kwargs):
ddp_loss = torch.zeros(2).to(local_rank)

for batch_idx, batch in enumerate(dataloader):

accumulate_grads = (batch_idx+1) % gradient_accumulation_steps == 0

# Prevent gradient syncing until update step if using no_sync option.
Expand Down Expand Up @@ -925,6 +957,17 @@ def load_and_quantize_parallel(name_param, model, **kwargs):
if args["log_to"] == 'wandb':
logger.log({"loss": log_loss, "lr": log_lr}, rank)
ddp_loss = torch.zeros(2).to(local_rank)

if rank == 0:
print(f"Batch idx {batch_idx}")


prof.step()

if args["max_steps"] > 0 and batch_idx > args["max_steps"]:
if rank == 0:
print("Max steps reached, skipping rest of epoch")
break

# Print + log peak memory usage for the whole fourth step of training
if epoch == 0 and (rank == 0 or args['verbose']):
Expand Down Expand Up @@ -999,9 +1042,9 @@ def load_and_quantize_parallel(name_param, model, **kwargs):

# Clean up
dist.destroy_process_group()
if args["profiling_output"]:
prof.export_stacks(path = f"{args['profiling_output']}_{local_rank}.txt",
metric = "self_cuda_time_total")
# if args["profiling_output"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can just delete this if its not needed anymore.

# prof.export_stacks(path = f"{args['profiling_output']}_{local_rank}.txt",
# metric = "self_cuda_time_total")

def validate_args(args):
if args["n_bits"] != 4 and args["train_type"] not in ["hqq_lora", "hqq_dora", "hqq_llama_pro"]:
Expand Down Expand Up @@ -1051,12 +1094,23 @@ def fsdp_qlora(
group: str = None, # For wandb logging
entity: str = None, # For wandb logging
n_bits: int = 4, # passed to hqq
profiling_output: str = None, # Output file for profiling
):
profile: bool = False, # Enable PyTorch profiler
profiling_output: str = None, # Output directory for torch.profiler artifacts
with_stack: bool = False, # Output stacks for profiling
with_shapes: bool = False, # Output shapes for profiling
export_trace: bool = True, # Output trace for profiling
export_memory_timeline: bool = False, # Output memory timelinefor profiling
wait_steps: int = 1, # Wait steps when running profiler
warmup_steps: int = 0, # Warmup steps when running profiler
active_steps: int = 5, # Active steps when running profiler
repeat: int = 1, #Number of profiler cycles (wait + warmup + active)
max_steps: int = -1, # Max number of training steps (in units of batches) per epoch. -1 means no max_steps, otherwise training loop breaks after `max_steps` each epoch.
):
"""
Train a model with FSDP and QLoRA/QDoRA.

Args:

world_size: Number of GPUs to use. -1 = all available GPUs.
train_type: "full", "lora", "qlora", or "custom_qlora"
llama_pro_path: Path to the quantized llama pro model
Expand Down Expand Up @@ -1186,6 +1240,16 @@ def main(
group: str = None, # For wandb logging
entity: str = None, # For wandb logging
n_bits: int = 4, # passed to hqq
profiling_output: str = "", # Output file prefix for profiling
profile: bool_arg = False, # Whether to profile with torch.profiler
profiling_output: str = "profiles", # Output file prefix for profiling
with_stack: bool_arg = False, # Output stacks for profiling. Note that setting export_memory_timeline will automatically export traces since `with_stack` must be true to profile memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we weren't refactoring, it would probably be good to make profiler-specific args eaisly identifiable (eg with a prof_ prefix). For now we can hold off until the refactoring to address this question of how to organize the arg list though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok if we leave as is? All the profiling args are demarcated from rest of CLI args and are off by default. Also, documentation clearly lays out the meaning of each of these args.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes fine to leave as is.

with_shapes: bool_arg = False, # Output shapes for profiling. Note that setting export_memory_timeline will automatically export traces since `with_shapes` must be true to profile memory.
export_trace: bool_arg = True, # Output trace for profiling
export_memory_timeline: bool_arg = False, # Output memory timelinefor profiling
wait_steps: int = 1, # Wait steps when running profiler
warmup_steps: int = 1, # Warmup steps when running profiler
active_steps: int = 2, # Active steps when running profiler
repeat: int = 1, #Number of profiler cycles (wait + warmup + active)
max_steps: int = -1, # Max number of training steps (in units of batches) per epoch. -1 means no max_steps, otherwise training loop breaks after `max_steps` each epoch.
):
fsdp_qlora(**locals())