Skip to content

Commit a2b4177

Browse files
pritamdamaniafacebook-github-bot
authored andcommittedSep 25, 2020
Add barrier() at the end of init_process_group and new_group. (pytorch#45181)
Summary: Pull Request resolved: pytorch#45181 `init_process_group` and `new_group` update a bunch of global variables after initializing the actual process group. As a result, there is a race that after initializing the process group on say rank 0, if we immediately check the default process group on rank 1 (say via RPC), we might actually get an error since rank 1 hasn't yet updated its _default_pg variable. To resolve this issue, I've added barrier() at the end of both of these calls. This ensures that once these calls return we are guaranteed about correct initialization on all ranks. Since these calls are usually done mostly during initialization, it should be fine to add the overhead of a barrier() here. #Closes: pytorch#40434, pytorch#40378 ghstack-source-id: 112923112 Test Plan: Reproduced the failures in pytorch#40434 and pytorch#40378 and verified that this PR fixes the issue. Reviewed By: mrshenli Differential Revision: D23858025 fbshipit-source-id: c4d5e46c2157981caf3ba1525dec5310dcbc1830
1 parent 3b7e4f8 commit a2b4177

File tree

5 files changed

+32
-9
lines changed

5 files changed

+32
-9
lines changed
 

‎test/cpp_extensions/cpp_c10d_extension.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::allgather_base(
6363

6464
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::barrier(
6565
const BarrierOptions& opts) {
66-
throw std::runtime_error("ProcessGroupTest does not support barrier");
66+
return std::make_shared<ProcessGroupTest::WorkTest>();
6767
}
6868

6969
std::shared_ptr<ProcessGroup::Work> ProcessGroupTest::gather(

‎test/distributed/test_c10d.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -334,11 +334,11 @@ def test_unknown_handler(self):
334334
@skip_if_win32()
335335
class RendezvousEnvTest(TestCase):
336336
@retry_on_connect_failures
337+
@requires_nccl()
337338
def test_common_errors(self):
338-
# TODO remove this hack
339-
if not hasattr(c10d, "ProcessGroupNCCL"):
340-
raise unittest.SkipTest("C10D is not built with NCCL process group,"
341-
" skipping test")
339+
if torch.cuda.device_count() == 0:
340+
raise unittest.SkipTest("No GPUs available, skipping test")
341+
342342
vars = {
343343
"WORLD_SIZE": "1",
344344
"RANK": "0",
@@ -579,6 +579,8 @@ def _test_default_store_timeout(self, backend):
579579
@requires_nccl()
580580
@retry_on_connect_failures
581581
def test_default_store_timeout_nccl(self):
582+
if torch.cuda.device_count() == 0:
583+
raise unittest.SkipTest("No GPUs available, skipping test")
582584
self._test_default_store_timeout('nccl')
583585

584586
@requires_gloo()

‎torch/distributed/distributed_c10d.py

+9
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,10 @@ def init_process_group(backend,
436436
_backend = _pg_map[_default_pg][0]
437437
_default_pg_init_method = init_method
438438

439+
# barrier at the end to ensure that once we return from this method, all
440+
# process groups including global variables are updated correctly on all
441+
# ranks.
442+
barrier()
439443

440444
def _new_process_group_helper(world_size,
441445
rank,
@@ -2025,4 +2029,9 @@ def new_group(ranks=None, timeout=default_pg_timeout, backend=None):
20252029
for group_rank, global_rank in enumerate(ranks)
20262030
}
20272031

2032+
# barrier at the end to ensure that once we return from this method, all
2033+
# process groups including global variables are updated correctly on all
2034+
# ranks.
2035+
barrier()
2036+
20282037
return pg

‎torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,19 @@ def trainer_name(self, rank):
325325
# The name has to be consistent with that in 'dist_init' decorator.
326326
return f"worker{rank}"
327327

328-
def _remote_worker_process(self):
328+
def _remote_worker_process(self, ddp_mode):
329329
gLogger.info("The remote worker is running.")
330330
dist.init_process_group(
331331
backend="gloo",
332332
init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
333333
world_size=self.world_size,
334334
rank=self.rank,
335335
)
336+
337+
if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
338+
# new_group needs to be called on ranks.
339+
dist.new_group(TRAINER_RANKS)
340+
336341
global shutdown_signal
337342
with shutdown_signal:
338343
shutdown_signal.wait()
@@ -367,6 +372,7 @@ def _master_process(self, ddp_mode: DdpMode, simulate_uneven_inputs: bool):
367372
world_size=self.world_size,
368373
rank=self.rank,
369374
)
375+
370376
remote_em_rref = rpc.remote(
371377
self.remote_worker_name(), RemoteEM, args=(NUM_EM_ROW, D_SPARSE)
372378
)
@@ -401,6 +407,10 @@ def do_test_on_master(
401407
)
402408
)
403409

410+
if ddp_mode in (DdpMode.INSIDE, DdpMode.OUTSIDE):
411+
# new_group needs to be called on ranks.
412+
dist.new_group(TRAINER_RANKS)
413+
404414
training_examples = get_training_examples()
405415
for _ in range(3):
406416
futures = []
@@ -455,7 +465,7 @@ def _do_test(self, ddp_mode, simulate_uneven_inputs=False):
455465
if self.rank == MASTER_RANK:
456466
self._master_process(ddp_mode, simulate_uneven_inputs)
457467
elif self.rank == REMOTE_WORKER_RANK:
458-
self._remote_worker_process()
468+
self._remote_worker_process(ddp_mode)
459469
elif self.rank in TRAINER_RANKS:
460470
self._trainer_process(self.rank)
461471
else:

‎torch/testing/_internal/distributed/distributed_test.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,8 @@ def init_method(self):
285285

286286
@classmethod
287287
def _run(cls, rank, test_name, file_name):
288+
if BACKEND == 'nccl' and not torch.cuda.is_available():
289+
sys.exit(TEST_SKIPS['no_cuda'].exit_code)
288290
self = cls(test_name)
289291
self.rank = rank
290292
self.file_name = file_name
@@ -2283,7 +2285,7 @@ def test_DistributedDataParallel_requires_grad(self):
22832285
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
22842286
@skip_if_rocm
22852287
def test_DistributedDataParallel_non_default_stream(self):
2286-
stream = torch.cuda.Stream()
2288+
stream = torch.cuda.Stream(self.rank)
22872289
rank = self.rank
22882290
with torch.cuda.stream(stream):
22892291
net = torch.nn.parallel.DistributedDataParallel(
@@ -3020,7 +3022,7 @@ def _run_uneven_inputs_test(
30203022
rank = self.rank
30213023
sync_interval = test_case.sync_interval
30223024
# Ensure all outsanding GPU work is comlete so this test runs independently.
3023-
torch.cuda.synchronize()
3025+
dist.barrier()
30243026
# Bucket_cap_mb is intentionally low to test allreduce scheduling when
30253027
# there are many buckets.
30263028
net = torch.nn.parallel.DistributedDataParallel(

0 commit comments

Comments
 (0)
Please sign in to comment.