Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix nccl future execution #126

Merged
merged 1 commit into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 102 additions & 7 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,26 @@
import logging
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field
from datetime import timedelta
from typing import Any, Dict, Generator, List, Protocol, Set, Tuple
from typing import Any, Dict, Generator, List, Optional, Protocol, Set, Tuple, TypeVar
from unittest import TestCase

import torch
import torch.distributed as dist
from parameterized import parameterized
from torch import nn, optim
from torch._dynamo.utils import timed

from torchft._torchft import LighthouseServer
from torchft.ddp import DistributedDataParallel
from torchft.local_sgd import DiLoCo, LocalSGD
from torchft.manager import Manager
from torchft.optim import OptimizerWrapper
from torchft.process_group import ProcessGroupGloo
from torchft.process_group import ProcessGroupBabyNCCL, ProcessGroupGloo

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -69,10 +71,14 @@ def check(self, rank: int, step: int) -> None:
raise InjectedFailure(f"injected failure {rank=} {step=}")


class TrainLoop(Protocol):
# R for an arbitrary return type
R = TypeVar("R", covariant=True)


class TrainLoop(Protocol[R]):
def __call__(
self, rank: int, store_port: int, device: torch.device, runner: "Runner"
) -> Dict[str, Dict[str, object]]: ...
) -> R: ...


@dataclass
Expand All @@ -81,15 +87,15 @@ class Runner:
num_replicas: int
lighthouse_address: str
failure_injector: FailureInjector
train_loop: TrainLoop
train_loop: TrainLoop[object]

use_cuda: bool = False
world_size: int = 1
attempts: int = 3
manager_args: Dict[str, object] = field(default_factory=dict)
train_loop_args: Dict[str, Any] = field(default_factory=dict)

def _replica_main(self) -> List[Dict[str, Dict[str, object]]]:
def _replica_main(self) -> List[object]:
store = dist.TCPStore(
host_name="localhost",
port=0,
Expand Down Expand Up @@ -131,7 +137,7 @@ def _replica_main(self) -> List[Dict[str, Dict[str, object]]]:

return [fut.result() for fut in futures]

def run_replica(self) -> List[Dict[str, Dict[str, object]]]:
def run_replica(self) -> List[object]:
for i in range(self.attempts):
try:
print(
Expand Down Expand Up @@ -391,3 +397,92 @@ def test_quorum_timeout(self) -> None:
"status: Cancelled, message.*Timeout expired",
):
manager.should_commit(timeout=timedelta(seconds=0.01))

@parameterized.expand(
[
(True,), # Test with CUDA
(False,), # Test without CUDA (CPU)
]
)
def test_manager_allreduce(self, use_cuda: bool) -> None:
# Skip the test if use_cuda is True and there are not enough GPUs
if use_cuda and torch.cuda.device_count() < 2:
self.skipTest("Not enough GPUs for CUDA test")

# manager supports allreduce but we found an issue where the future callback is getting called
# before the allreduce is complete. This test is to ensure that the callback has stream synchronization
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
)
num_replicas = 2
futures = []

with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id in range(num_replicas):
failure_injector = FailureInjector()
runner = Runner(
replica_id=replica_id,
num_replicas=num_replicas,
lighthouse_address=lighthouse.address(),
failure_injector=failure_injector,
train_loop=all_reduce_callback,
use_cuda=use_cuda,
)
futures.append(executor.submit(runner.run_replica))

results = []
for fut in as_completed(futures):
try:
results.append(fut.result()[0])
except Exception as e:
print(e, flush=True)
traceback.print_exc()
raise

lighthouse.shutdown()

print(results)
r0, r1 = results
torch.testing.assert_close(r0, r1, check_device=False)


def all_reduce_callback(
rank: int,
store_port: int,
device: torch.device,
runner: Runner,
) -> Optional[torch.Tensor]:
with ExitStack() as stack:
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")

if device.type == "cuda":
pg = ProcessGroupBabyNCCL()
else:
pg = ProcessGroupGloo()
manager = Manager(
pg=pg,
min_replica_size=2,
use_async_quorum=False,
load_state_dict=lambda x: None,
state_dict=lambda: None,
replica_id=str(runner.replica_id),
store_addr="localhost",
store_port=store_port,
rank=rank,
world_size=runner.world_size,
lighthouse_addr=runner.lighthouse_address,
port=19530 + runner.replica_id,
timeout=timedelta(seconds=10),
quorum_timeout=timedelta(seconds=10),
# pyre-fixme[6]: Incompatible parameter type
**runner.manager_args,
)
stack.callback(lambda: manager.shutdown(wait=False))

manager.start_quorum()
t1 = torch.ones((1, 3), device=device)
fut = manager.allreduce(t1)
fut.wait()
return t1
return None
32 changes: 26 additions & 6 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,10 +1093,12 @@ def _worker(

args = _PickleSafeOptions.unsafe_args(args)
fn = getattr(pg, func_name)

work[op_id] = _OpMetadata(
work=fn(*args, **kwargs),
stream=stream,
)

elif cmd == "wait":
op_id, timeout = cast(tuple[int, timedelta], op[1:])

Expand Down Expand Up @@ -1126,15 +1128,29 @@ def _worker(
del work[op_id]
elif cmd == "future":
op_id: int = cast(int, op[1])
metadata: _OpMetadata = work[op_id]

def callback(fut: Future[object]) -> None:
def callback(fut: Future[object], metadata: _OpMetadata) -> None:
try:
fut.wait()
future_pipe.send((op_id, _FUTURE_RESULT, None))
# create an event after the collective has been issued
# to wait on this before we call "future"
with metadata.set_stream():
fut.wait()
event = (
torch.cuda.current_stream().record_event(
torch.cuda.Event(interprocess=True)
)
if metadata.stream is not None
else None
)

future_pipe.send((op_id, _FUTURE_RESULT, None, event))
except Exception as e:
future_pipe.send((op_id, _FUTURE_EXCEPTION, e))
future_pipe.send((op_id, _FUTURE_EXCEPTION, e, None))

work[op_id].work.get_future().add_done_callback(callback)
metadata.work.get_future().add_done_callback(
lambda fut: callback(fut, metadata)
)
elif cmd == "num_active_work":
req_pipe.send(len(work))
else:
Expand All @@ -1153,11 +1169,15 @@ def _future_handler(self, future_pipe: _MonitoredPipe) -> None:
except TimeoutError:
continue

op_id, mode, data = cast(Tuple[int, str, object], cmd)
op_id, mode, data, event = cast(
Tuple[int, str, object, Optional[torch.cuda.Event]], cmd
)
with self._futures_lock:
fut = self._futures[op_id]
del self._futures[op_id]
if mode == _FUTURE_RESULT:
if event is not None:
event.wait()
fut.set_result(data)
elif mode == _FUTURE_EXCEPTION:
fut.set_exception(data)
Expand Down