Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 574f9af

Browse files
mingzhe09088facebook-github-bot
authored andcommittedSep 16, 2020
[NCCL] Add option to run NCCL on high priority cuda stream (pytorch#43796)
Summary: Pull Request resolved: pytorch#43796 This diff adds an option for the process group NCCL backend to pick high priority cuda streams. Test Plan: waitforsandcastle Reviewed By: jiayisuse Differential Revision: D23404286 fbshipit-source-id: b79ae097b7cd945a26e8ba1dd13ad3147ac790eb
1 parent 161490d commit 574f9af

File tree

5 files changed

+92
-26
lines changed

5 files changed

+92
-26
lines changed
 

‎torch/csrc/distributed/c10d/init.cpp

+21-6
Original file line numberDiff line numberDiff line change
@@ -685,19 +685,34 @@ They are used in specifying strategies for reduction collectives, e.g.,
685685
#endif
686686

687687
#ifdef USE_C10D_NCCL
688-
shared_ptr_class_<::c10d::ProcessGroupNCCL>(
688+
auto processGroupNCCL = shared_ptr_class_<::c10d::ProcessGroupNCCL>(
689689
module, "ProcessGroupNCCL", processGroup)
690+
.def(py::init<
691+
const std::shared_ptr<::c10d::Store>&,
692+
int,
693+
int,
694+
::c10d::ProcessGroupNCCL::Options>())
690695
.def(
691-
py::init<
692-
const std::shared_ptr<::c10d::Store>&,
693-
int,
694-
int,
695-
const std::chrono::milliseconds&>(),
696+
py::init([](const std::shared_ptr<::c10d::Store>& store,
697+
int rank,
698+
int size,
699+
const std::chrono::milliseconds& timeout){
700+
::c10d::ProcessGroupNCCL::Options options;
701+
options.isHighPriorityStream = false;
702+
options.opTimeout = timeout;
703+
return std::make_shared<::c10d::ProcessGroupNCCL>(
704+
store, rank, size, options);
705+
}),
696706
py::arg("store"),
697707
py::arg("rank"),
698708
py::arg("size"),
699709
py::arg("timeout") = std::chrono::milliseconds(
700710
::c10d::ProcessGroupNCCL::kProcessGroupNCCLOpTimeoutMillis));
711+
712+
py::class_<::c10d::ProcessGroupNCCL::Options>(processGroupNCCL, "Options")
713+
.def(py::init<>())
714+
.def_readwrite("is_high_priority", &::c10d::ProcessGroupNCCL::Options::isHighPriorityStream)
715+
.def_readwrite("op_timeout", &::c10d::ProcessGroupNCCL::Options::opTimeout);
701716
#endif
702717

703718
#ifdef USE_C10D_MPI

‎torch/lib/c10d/ProcessGroupNCCL.cpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -430,13 +430,14 @@ ProcessGroupNCCL::ProcessGroupNCCL(
430430
const std::shared_ptr<Store>& store,
431431
int rank,
432432
int size,
433-
const std::chrono::milliseconds& opTimeout)
433+
Options options)
434434
: ProcessGroup(rank, size),
435435
store_(store),
436436
ncclCommCounter_(0),
437437
terminateProcessGroup_(false),
438-
opTimeout_(opTimeout),
439-
futureNCCLCallbackStreams_(c10::cuda::device_count()) {
438+
opTimeout_(options.opTimeout),
439+
futureNCCLCallbackStreams_(c10::cuda::device_count()),
440+
isHighPriorityStream_(options.isHighPriorityStream) {
440441
try {
441442
parseNcclBlockingWait();
442443
} catch (std::exception& e) {
@@ -769,14 +770,14 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
769770
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
770771

771772
// Creates the NCCL streams
772-
streamVal.push_back(at::cuda::getStreamFromPool());
773+
streamVal.push_back(at::cuda::getStreamFromPool(isHighPriorityStream_));
773774

774775
// If not set before, get a dedicated stream for the device to run
775776
// FutureNCCL then callbacks.
776777
std::lock_guard<std::mutex> lock(mutex_);
777778
if (futureNCCLCallbackStreams_[deviceIndex] == nullptr) {
778779
futureNCCLCallbackStreams_[deviceIndex] =
779-
std::make_shared<at::cuda::CUDAStream>(at::cuda::getStreamFromPool());
780+
std::make_shared<at::cuda::CUDAStream>(at::cuda::getStreamFromPool(isHighPriorityStream_));
780781
}
781782
}
782783

@@ -931,6 +932,8 @@ void ProcessGroupNCCL::workEnqueue(
931932
workList_.emplace_back(std::move(work));
932933
}
933934
}
935+
ProcessGroupNCCL::Options::Options()
936+
: opTimeout(kProcessGroupNCCLOpTimeoutMillis), isHighPriorityStream(false) {}
934937

935938
template <typename Fn, typename PreProcess, typename PostProcess>
936939
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(

‎torch/lib/c10d/ProcessGroupNCCL.hpp

+13-5
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ class ProcessGroupNCCL : public ProcessGroup {
161161
friend class ProcessGroupNCCL;
162162
};
163163

164+
struct Options {
165+
explicit Options();
166+
167+
std::chrono::milliseconds opTimeout;
168+
bool isHighPriorityStream;
169+
};
170+
164171
// FutureNCCL is a subclass of ivalue's Future. The goal is to use
165172
// this class in getFuture API of WorkNCCL. This Future is mostly a
166173
// wrapper to synchronize streams appropriately and it mostly enables
@@ -341,8 +348,7 @@ class ProcessGroupNCCL : public ProcessGroup {
341348
const std::shared_ptr<Store>& store,
342349
int rank,
343350
int size,
344-
const std::chrono::milliseconds& opTimeout =
345-
std::chrono::milliseconds(kProcessGroupNCCLOpTimeoutMillis));
351+
Options options = Options());
346352

347353
// This constructor includes the deprecated `groupName` argument.
348354
// If you have existing code that uses the `groupName`, you can replace
@@ -352,9 +358,8 @@ class ProcessGroupNCCL : public ProcessGroup {
352358
int rank,
353359
int size,
354360
const std::string& groupName,
355-
const std::chrono::milliseconds& opTimeout =
356-
std::chrono::milliseconds(kProcessGroupNCCLOpTimeoutMillis))
357-
: ProcessGroupNCCL(store, rank, size, opTimeout) {}
361+
Options options = Options())
362+
: ProcessGroupNCCL(store, rank, size, options) {}
358363

359364
virtual ~ProcessGroupNCCL();
360365

@@ -626,6 +631,9 @@ class ProcessGroupNCCL : public ProcessGroup {
626631
// of the corresponding device inside ProcessGroupNCCL::getNCCLComm if not set
627632
// before.
628633
std::vector<std::shared_ptr<at::cuda::CUDAStream>> futureNCCLCallbackStreams_;
634+
635+
// Schedule NCCL operations on high priority CUDA streams.
636+
bool isHighPriorityStream_ = false;
629637
};
630638

631639
} // namespace c10d

‎torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp

+13-7
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
3838
const std::shared_ptr<c10d::Store>& store,
3939
int rank,
4040
int size,
41-
std::chrono::milliseconds timeout)
42-
: ProcessGroupNCCL(store, rank, size, timeout), simulate_error_(false) {}
41+
c10d::ProcessGroupNCCL::Options opts)
42+
: ProcessGroupNCCL(store, rank, size, opts), simulate_error_(false) {}
4343

4444
std::exception_ptr checkForNCCLErrors(
4545
const std::vector<std::shared_ptr<c10d::NCCLComm>>& ncclComms) override {
@@ -100,8 +100,8 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
100100
const std::shared_ptr<c10d::Store>& store,
101101
int rank,
102102
int size,
103-
std::chrono::milliseconds timeout)
104-
: ProcessGroupNCCLSimulateErrors(store, rank, size, timeout),
103+
c10d::ProcessGroupNCCL::Options opts)
104+
: ProcessGroupNCCLSimulateErrors(store, rank, size, opts),
105105
set_timedout_error_(false) {}
106106

107107
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
@@ -165,8 +165,10 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) {
165165
}
166166

167167
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
168+
c10d::ProcessGroupNCCL::Options options;
169+
options.opTimeout = std::chrono::milliseconds(1000);
168170
ProcessGroupNCCLSimulateErrors pg(
169-
store_, 0, 1, std::chrono::milliseconds(1000));
171+
store_, 0, 1, options);
170172

171173
auto work = pg.allreduce(tensors_);
172174
work->wait();
@@ -192,8 +194,10 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) {
192194
}
193195

194196
ASSERT_TRUE(setenv(c10d::NCCL_BLOCKING_WAIT, "1", 1) == 0);
197+
c10d::ProcessGroupNCCL::Options options;
198+
options.opTimeout = std::chrono::milliseconds(3000);
195199
ProcessGroupNCCLTimedOutErrors pg(
196-
store_, 0, 1, std::chrono::milliseconds(3000));
200+
store_, 0, 1, options);
197201

198202
auto work = pg.allreduce(tensors_);
199203
work->wait();
@@ -213,8 +217,10 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) {
213217
return;
214218
}
215219

220+
c10d::ProcessGroupNCCL::Options options;
221+
options.opTimeout = std::chrono::milliseconds(3000);
216222
ProcessGroupNCCLSimulateErrors pg(
217-
store_, 0, 1, std::chrono::milliseconds(3000));
223+
store_, 0, 1, options);
218224

219225
auto work = pg.allreduce(tensors_);
220226
pg.barrier()->wait();

‎torch/testing/_internal/distributed/distributed_test.py

+37-3
Original file line numberDiff line numberDiff line change
@@ -689,7 +689,7 @@ def test_irecv(self):
689689

690690
# BROADCAST
691691
def _test_broadcast_helper(
692-
self, group, group_id, rank, cuda=False, rank_to_GPU=None
692+
self, group, group_id, rank, cuda=False, rank_to_GPU=None, with_options=False
693693
):
694694
for dtype, value, requires_cuda in [
695695
(torch.float, -1e-10, False),
@@ -707,12 +707,24 @@ def _test_broadcast_helper(
707707
if cuda:
708708
expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
709709
if rank == src:
710-
dist.broadcast(expected_tensor, src, group_id)
710+
if with_options:
711+
opts = dist.BroadcastOptions()
712+
opts.rootTensor = 0
713+
opts.rootRank = src
714+
group_id.broadcast([expected_tensor], opts).wait()
715+
else:
716+
dist.broadcast(expected_tensor, src, group_id)
711717
else:
712718
tensor = _build_tensor(src + 1, -1, dtype)
713719
if cuda:
714720
tensor = tensor.cuda(rank_to_GPU[rank][0])
715-
dist.broadcast(tensor, src, group_id)
721+
if with_options:
722+
opts = dist.BroadcastOptions()
723+
opts.rootTensor = 0
724+
opts.rootRank = src
725+
group_id.broadcast([tensor], opts).wait()
726+
else:
727+
dist.broadcast(tensor, src, group_id)
716728
self.assertEqual(tensor.size(), expected_tensor.size())
717729
self.assertEqual(tensor.ne(expected_tensor).max(), torch.tensor(False))
718730

@@ -744,6 +756,28 @@ def test_broadcast_full_group(self):
744756
group, group_id, rank = self._init_full_group_test()
745757
self._test_broadcast_helper(group, group_id, rank)
746758

759+
@unittest.skipIf(
760+
BACKEND != "nccl",
761+
"Only NCCL backend supports high priority stream",
762+
)
763+
@skip_if_no_gpu
764+
@skip_if_rocm
765+
def test_nccl_high_priority_stream(self):
766+
group, _, rank = self._init_global_test()
767+
rank_to_GPU = self._init_multigpu_helper()
768+
769+
new_port = str(MASTER_PORT + 1)
770+
os.environ['MASTER_PORT'] = new_port
771+
gen_iterator = dist.rendezvous('env://', rank, dist.get_world_size())
772+
store, rank, size = next(gen_iterator)
773+
store = dist.PrefixStore(new_port, store)
774+
775+
opts = dist.ProcessGroupNCCL.Options()
776+
opts.is_high_priority = False
777+
group_id = dist.ProcessGroupNCCL(store, rank, size, opts)
778+
779+
self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU, True)
780+
747781
# REDUCE
748782
def _test_reduce_helper(
749783
self,

0 commit comments

Comments
 (0)
Please sign in to comment.