Skip to content

Commit 8ab2ad3

Browse files
malfetfacebook-github-bot
authored andcommittedSep 26, 2020
Enable torch.cuda.nccl typechecking (pytorch#45344)
Summary: Fixes pytorch#45336 Pull Request resolved: pytorch#45344 Reviewed By: walterddr Differential Revision: D23935306 Pulled By: malfet fbshipit-source-id: dd09d4f8ff7a327131764487158675027a13bf69
1 parent 5211fb9 commit 8ab2ad3

File tree

4 files changed

+56
-19
lines changed

4 files changed

+56
-19
lines changed
 

‎mypy.ini

-6
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,6 @@ ignore_errors = True
186186
[mypy-torch.cuda.amp.*]
187187
ignore_errors = True
188188

189-
#[mypy-torch.cuda.comm]
190-
#ignore_errors = True
191-
192-
[mypy-torch.cuda.nccl]
193-
ignore_errors = True
194-
195189
[mypy-torch._lobpcg]
196190
ignore_errors = True
197191

‎torch/_C/__init__.pyi.in

+26
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,32 @@ def _cuda_lock_mutex() -> None: ...
511511
def _cuda_unlock_mutex() -> None: ...
512512
def _nccl_version() -> _int: ...
513513
def _nccl_unique_id() -> bytes: ...
514+
def _nccl_init_rank(nranks: _int, comm_id: bytes, rank: _int) -> object: ...
515+
def _nccl_reduce(input: Sequence[Tensor],
516+
output: Tensor,
517+
root: _int,
518+
op: _int,
519+
streams: Optional[Sequence[_CudaStreamBase]],
520+
comms: Optional[Sequence[object]]) -> None: ...
521+
def _nccl_all_reduce(input: Sequence[Tensor],
522+
output: Sequence[Tensor],
523+
op: _int,
524+
streams: Optional[Sequence[_CudaStreamBase]],
525+
comms: Optional[Sequence[object]]) -> None: ...
526+
def _nccl_broadcast(input: Sequence[Tensor],
527+
root: _int,
528+
streams: Optional[Sequence[_CudaStreamBase]],
529+
comms: Optional[Sequence[object]]) -> None: ...
530+
def _nccl_all_gather(input: Sequence[Tensor],
531+
output: Sequence[Tensor],
532+
streams: Optional[Sequence[_CudaStreamBase]],
533+
comms: Optional[Sequence[object]]) -> None: ...
534+
def _nccl_reduce_scatter(input: Sequence[Tensor],
535+
output: Sequence[Tensor],
536+
op: _int,
537+
streams: Optional[Sequence[_CudaStreamBase]],
538+
comms: Optional[Sequence[object]]) -> None: ...
539+
514540

515541
class _CudaDeviceProperties:
516542
name: str

‎torch/csrc/cuda/python_nccl.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ PyObject* THCPModule_nccl_broadcast(PyObject* self, PyObject* args) {
199199
nullptr,
200200
"nccl_broadcast",
201201
1,
202-
"(sequence[Tensor] inputs, int root)");
202+
"(sequence[Tensor] inputs, int root"
203+
" sequence[torch.cuda.Stream] streams,"
204+
" sequence[torch.cuda.nccl.Communicator] comms)");
203205
return nullptr;
204206
}
205207

@@ -228,7 +230,9 @@ PyObject* THCPModule_nccl_all_gather(PyObject* self, PyObject* args) {
228230
nullptr,
229231
"nccl_all_gather",
230232
1,
231-
"(sequence[Tensor] inputs, sequence[Tensor] outputs");
233+
"(sequence[Tensor] inputs, sequence[Tensor] outputs"
234+
" sequence[torch.cuda.Stream] streams,"
235+
" sequence[torch.cuda.nccl.Communicator] comms)");
232236
return nullptr;
233237
}
234238

@@ -258,7 +262,9 @@ PyObject* THCPModule_nccl_reduce_scatter(PyObject* self, PyObject* args) {
258262
nullptr,
259263
"nccl_reduce_scatter",
260264
1,
261-
"(sequence[Tensor] inputs, sequence[Tensor] outputs, int op");
265+
"(sequence[Tensor] inputs, sequence[Tensor] outputs, int op"
266+
" sequence[torch.cuda.Stream] streams,"
267+
" sequence[torch.cuda.nccl.Communicator] comms)");
262268
return nullptr;
263269
}
264270

‎torch/cuda/nccl.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch._six
55
import torch.cuda
6+
from typing import Optional, Sequence, Union
67

78

89
__all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter']
@@ -43,7 +44,7 @@ def init_rank(num_ranks, uid, rank):
4344
return torch._C._nccl_init_rank(num_ranks, uid, rank)
4445

4546

46-
def _check_sequence_type(inputs):
47+
def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None:
4748
if not isinstance(inputs, collections.Container) or isinstance(inputs, torch.Tensor):
4849
raise TypeError("Inputs should be a collection of tensors")
4950

@@ -58,8 +59,15 @@ def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
5859

5960
# `output` used to be `outputs`, taking in a list of tensors. So we have two
6061
# arguments for BC reasons.
61-
def reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None, *, outputs=None):
62+
def reduce(inputs: Sequence[torch.Tensor],
63+
output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
64+
root: int = 0,
65+
op: int = SUM,
66+
streams: Optional[Sequence[torch.cuda.Stream]] = None,
67+
comms=None, *,
68+
outputs: Optional[Sequence[torch.Tensor]] = None) -> None:
6269
_check_sequence_type(inputs)
70+
_output: torch.Tensor
6371
if outputs is not None:
6472
if output is not None:
6573
raise ValueError(
@@ -70,30 +78,33 @@ def reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None, *, out
7078
warnings.warn(
7179
"nccl.reduce with an output tensor list is deprecated. "
7280
"Please specify a single output tensor with argument 'output' instead instead.")
73-
output = outputs[root]
81+
_output = outputs[root]
7482
elif not isinstance(output, torch.Tensor) and isinstance(output, torch._six.container_abcs.Sequence):
7583
# User called old API with positional arguments of list of output tensors.
7684
warnings.warn(
7785
"nccl.reduce with an output tensor list is deprecated. "
7886
"Please specify a single output tensor.")
79-
output = output[root]
80-
elif output is None:
81-
output = inputs[root]
82-
torch._C._nccl_reduce(inputs, output, root, op, streams, comms)
87+
_output = output[root]
88+
else:
89+
_output = inputs[root] if output is None else output
90+
torch._C._nccl_reduce(inputs, _output, root, op, streams, comms)
8391

8492

85-
def broadcast(inputs, root=0, streams=None, comms=None):
93+
def broadcast(inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None) -> None:
8694
_check_sequence_type(inputs)
8795
torch._C._nccl_broadcast(inputs, root, streams, comms)
8896

8997

90-
def all_gather(inputs, outputs, streams=None, comms=None):
98+
def all_gather(inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor], streams=None, comms=None) -> None:
9199
_check_sequence_type(inputs)
92100
_check_sequence_type(outputs)
93101
torch._C._nccl_all_gather(inputs, outputs, streams, comms)
94102

95103

96-
def reduce_scatter(inputs, outputs, op=SUM, streams=None, comms=None):
104+
def reduce_scatter(inputs: Sequence[torch.Tensor],
105+
outputs: Sequence[torch.Tensor],
106+
op: int = SUM,
107+
streams=None, comms=None) -> None:
97108
_check_sequence_type(inputs)
98109
_check_sequence_type(outputs)
99110
torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms)

0 commit comments

Comments
 (0)
Please sign in to comment.