Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
mreso committed Jan 22, 2025
1 parent 39a40b2 commit c052e42
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 30 deletions.
13 changes: 12 additions & 1 deletion torchft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Callable, Generic, TypeVar

import torch
import torch.distributed as dist

from torchft.http import _IPv6HTTPServer

Expand Down Expand Up @@ -76,6 +77,14 @@ def do_GET(self):

sd = state_dict()

def func(obj):
if isinstance(obj, dist.tensor.DTensor) and hasattr(obj, "device_mesh") and hasattr(obj.device_mesh, "replicate_pg"):
obj.device_mesh.replicate_pg = None

from torch.utils._pytree import tree_map

tree_map(func, sd["user"])

torch.save(sd, self.wfile)
except Exception as e:
logger.exception(
Expand Down Expand Up @@ -113,7 +122,9 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
data = f.read()

reader = io.BytesIO(data)
return torch.load(reader, weights_only=True)
print(f"{reader.read(100)=}")
reader.seek(0)
return torch.load(reader, weights_only=False)

def address(self) -> str:
"""
Expand Down
9 changes: 9 additions & 0 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,15 @@ def _async_quorum(
self._pending_state_dict = CheckpointServer.load_from_address(
checkpoint_server_address, timeout=self._timeout
)

def func(obj):
if isinstance(obj, dist.tensor.DTensor) and hasattr(obj, "device_mesh") and hasattr(obj.device_mesh, "replicate_pg"):
obj.device_mesh.replicate_pg = self._pg

from torch.utils._pytree import tree_map

tree_map(func, self._pending_state_dict["user"])

self.load_state_dict(self._pending_state_dict["torchft"])
# we apply the user state dict only when safe from the main thread

Expand Down
62 changes: 33 additions & 29 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def __init__(
raise ValueError(
"ManagedDeviceMesh doesn't support both mesh and parent are None."
)
self.mesh = mesh
self._mesh = mesh
self.mesh_dim_names = mesh_dim_names
self.replicate_pg = replicate_pg
self.replicate_dim = replicate_dim
Expand Down Expand Up @@ -893,17 +893,17 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
elif mesh_dim_names in self.flatten_meshes:
return self.flatten_meshes[mesh_dim_names]
else:
assert self.mesh is not None
return self.mesh[mesh_dim_names]
assert self._mesh is not None
return self._mesh[mesh_dim_names]
else:
assert isinstance(mesh_dim_names, tuple)
if self.replicate_dim_name in mesh_dim_names:
assert self.mesh is not None
return self.mesh[mesh_dim_names]
assert self._mesh is not None
return self._mesh[mesh_dim_names]
else:
assert self.mesh is not None
assert self._mesh is not None
return ManagedDeviceMesh(
self.mesh[mesh_dim_names],
self._mesh[mesh_dim_names],
mesh_dim_names,
self.replicate_pg,
mesh_dim_names.index(self.replicate_dim_name),
Expand All @@ -924,8 +924,8 @@ def get_group(self, mesh_dim: Optional[Union[int, str]] = None) -> BaseProcessGr
elif dim == self.replicate_dim:
return self.replicate_pg
else:
assert self.mesh is not None
return self.mesh.get_group(self._real_mesh_dim(dim))
assert self._mesh is not None
return self._mesh.get_group(self._real_mesh_dim(dim))

def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
flatten_mesh = _FlattenDeviceMesh(self)
Expand All @@ -939,32 +939,32 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":

def size(self, mesh_dim: Optional[int] = None) -> int:
if mesh_dim is None:
if self.mesh is None:
if self._mesh is None:
return self.replicate_pg.size()
else:
assert self.mesh is not None
return self.mesh.size() * self.replicate_pg.size()
assert self._mesh is not None
return self._mesh.size() * self.replicate_pg.size()
elif mesh_dim == self.replicate_dim:
return self.replicate_pg.size()
else:
assert self.mesh is not None
return self.mesh.size(self._real_mesh_dim(mesh_dim))
assert self._mesh is not None
return self._mesh.size(self._real_mesh_dim(mesh_dim))

@property
def ndim(self) -> int:
assert self.mesh is not None
return self.mesh.ndim + 1
assert self._mesh is not None
return self._mesh.ndim + 1

@property
def shape(self) -> Tuple[int, ...]:
assert self.mesh is not None
ret: List[int] = list(self.mesh.shape)
assert self._mesh is not None
ret: List[int] = list(self._mesh.shape)
ret.insert(self.replicate_dim, self.replicate_pg.size())
return tuple(ret)

def get_rank(self) -> int:
assert self.mesh is not None
return self.mesh.get_rank()
assert self._mesh is not None
return self._mesh.get_rank()

def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
if isinstance(mesh_dim, str):
Expand All @@ -973,33 +973,37 @@ def get_local_rank(self, mesh_dim: Optional[Union[int, str]] = None) -> int:
dim = 0 if mesh_dim is None else int(mesh_dim)

if mesh_dim is None:
if self.mesh is None:
if self._mesh is None:
return get_rank(self.replicate_pg)

assert self.replicate_dim == 0, "replicate_dim must be the first one"
assert self.mesh is not None
other_dim_size = self.mesh.size()
assert self.mesh is not None
other_dim_rank = self.mesh.get_local_rank()
assert self._mesh is not None
other_dim_size = self._mesh.size()
assert self._mesh is not None
other_dim_rank = self._mesh.get_local_rank()
replicate_pg_rank = get_rank(self.replicate_pg)
return other_dim_size * replicate_pg_rank + other_dim_rank
elif dim == self.replicate_dim:
return get_rank(self.replicate_pg)
else:
assert self.mesh is not None
return self.mesh.get_local_rank(self._real_mesh_dim(dim))
assert self._mesh is not None
return self._mesh.get_local_rank(self._real_mesh_dim(dim))

def get_coordinate(self) -> Optional[List[int]]:
"""
Return the relative indices of this rank relative to all
dimensions of the mesh. If this rank is not part of the mesh, return None.
"""
assert self.mesh is not None
return self.mesh._coordinate_on_dim if self.mesh._coordinate_on_dim else None
assert self._mesh is not None
return self._mesh._coordinate_on_dim if self._mesh._coordinate_on_dim else None

def get_all_groups(self) -> List[BaseProcessGroup]:
raise NotImplementedError

@property
def mesh(self):
return self._mesh.mesh


class _FlattenDeviceMesh(DeviceMesh):
def __init__(self, managed_mesh: ManagedDeviceMesh) -> None:
Expand Down
158 changes: 158 additions & 0 deletions train_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
from datasets import load_dataset

import torch
from transformers import LlamaForCausalLM, AutoTokenizer
from torch.distributed._composable.fsdp import fully_shard
import torch.distributed as dist
from tqdm import tqdm
from transformers.data import DataCollatorForSeq2Seq
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict

from torchdata.stateful_dataloader import StatefulDataLoader

from torchft import (
DistributedSampler,
Manager,
Optimizer,
ProcessGroupBabyNCCL,
ProcessGroupGloo,
)
from torchft.process_group import ft_init_device_mesh

def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None, manager=None):

if replica_group_size is None or sharding_group_size is None:
raise ValueError("Both replica_group_size and sharding_group_size must be provided.")

device = device or f"cuda"

device_mesh = ft_init_device_mesh(
device_type=device,
mesh_shape=(replica_group_size, sharding_group_size),
mesh_dim_names=("dp_replicate", "dp_shard"),
replicate_dim=0,
manager=manager,
)
if device_mesh is None:
raise RuntimeError("Failed to create a valid device mesh.")

return device_mesh

def parallelize_llama(model, mesh):
sharding_conditions = [lambda m: isinstance(m, LlamaDecoderLayer)]

for m in reversed(list(model.modules())):
if any(c(m) for c in sharding_conditions):
# fully_shard(m, mesh=mesh, reshard_after_forward=True)
fully_shard(m, mesh=mesh)
# fully_shard([model.model.embed_tokens, model.lm_head], mesh=mesh)
fully_shard(model, mesh=mesh)

def main():
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
NUM_REPLICAS = int(os.environ.get("NUM_REPLICAS", 2))

rank = int(os.environ.get("RANK", 0))

model_name = "Meta-Llama/Llama-3.2-1B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = LlamaForCausalLM.from_pretrained(model_name)

if not tokenizer.pad_token_id:
tokenizer.pad_token_id = tokenizer.eos_token_id

# If there is a mismatch between tokenizer vocab size and embedding matrix,
# throw a warning and then expand the embedding matrix
assert len(tokenizer) == model.get_input_embeddings().weight.shape[0]

train_data = load_dataset("samsum", split="train")

class SAMSumDataset(torch.utils.data.Dataset):
def __init__(self, data, tokenizer):
self.data = data
self.tokenizer = tokenizer
def __getitem__(self, idx):
text = self.data[idx]
prompt = self.tokenizer.encode(tokenizer.bos_token + f"Summarize this dialog: {text['dialogue']}\n---\nSummary: ", add_special_tokens=False)
summary = self.tokenizer.encode(text["summary"] + self.tokenizer.eos_token, add_special_tokens=False)
input_ids = prompt + summary
labels = len(prompt) * [-100] + summary
return {"input_ids": input_ids, "labels": labels}
def __len__(self):
return len(self.data)


train_dataset = SAMSumDataset(train_data, tokenizer)

batch_size = 8

sampler = DistributedSampler(
train_dataset,
replica_group=REPLICA_GROUP_ID,
num_replica_groups=NUM_REPLICA_GROUPS,
rank=rank,
shuffle=True,
num_replicas=NUM_REPLICAS,
)

train_dataloader = StatefulDataLoader(train_dataset, batch_size=batch_size, collate_fn=DataCollatorForSeq2Seq(tokenizer), sampler=sampler)

def load_state_dict(state_dict):
set_state_dict(
model,
optimizer.optim,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
)


def state_dict():
model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer.optim)
return {
"model": model_state_dict,
"optim": optimizer_state_dict,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pg = ProcessGroupBabyNCCL() if torch.cuda.is_available() else ProcessGroupGloo()

manager = Manager(
pg=pg,
min_replica_size=1,
load_state_dict=load_state_dict,
state_dict=state_dict,
replica_id=f"train_fsdp_{REPLICA_GROUP_ID}",
)

mesh = hsdp_device_mesh(NUM_REPLICA_GROUPS, NUM_REPLICAS, "cuda" if torch.cuda.is_available() else "cpu", manager=manager)

parallelize_llama(model, mesh)

model.to(device)

optimizer = Optimizer(manager, torch.optim.Adam(model.parameters(), lr=1e-5))

optimizer.zero_grad()

while manager.current_step() < 500:
model.train()
for batch in tqdm(train_dataloader):
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad()

outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()

if manager.current_step() % 100 == 0:
print(f"[{manager.current_step()}] loss = {loss.item()}")


if __name__ == "__main__":
main()

0 comments on commit c052e42

Please sign in to comment.