3
3
4
4
import torch ._six
5
5
import torch .cuda
6
+ from typing import Optional , Sequence , Union
6
7
7
8
8
9
__all__ = ['all_reduce' , 'reduce' , 'broadcast' , 'all_gather' , 'reduce_scatter' ]
@@ -43,7 +44,7 @@ def init_rank(num_ranks, uid, rank):
43
44
return torch ._C ._nccl_init_rank (num_ranks , uid , rank )
44
45
45
46
46
- def _check_sequence_type (inputs ) :
47
+ def _check_sequence_type (inputs : Union [ torch . Tensor , Sequence [ torch . Tensor ]]) -> None :
47
48
if not isinstance (inputs , collections .Container ) or isinstance (inputs , torch .Tensor ):
48
49
raise TypeError ("Inputs should be a collection of tensors" )
49
50
@@ -58,8 +59,15 @@ def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None):
58
59
59
60
# `output` used to be `outputs`, taking in a list of tensors. So we have two
60
61
# arguments for BC reasons.
61
- def reduce (inputs , output = None , root = 0 , op = SUM , streams = None , comms = None , * , outputs = None ):
62
+ def reduce (inputs : Sequence [torch .Tensor ],
63
+ output : Optional [Union [torch .Tensor , Sequence [torch .Tensor ]]] = None ,
64
+ root : int = 0 ,
65
+ op : int = SUM ,
66
+ streams : Optional [Sequence [torch .cuda .Stream ]] = None ,
67
+ comms = None , * ,
68
+ outputs : Optional [Sequence [torch .Tensor ]] = None ) -> None :
62
69
_check_sequence_type (inputs )
70
+ _output : torch .Tensor
63
71
if outputs is not None :
64
72
if output is not None :
65
73
raise ValueError (
@@ -70,30 +78,33 @@ def reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None, *, out
70
78
warnings .warn (
71
79
"nccl.reduce with an output tensor list is deprecated. "
72
80
"Please specify a single output tensor with argument 'output' instead instead." )
73
- output = outputs [root ]
81
+ _output = outputs [root ]
74
82
elif not isinstance (output , torch .Tensor ) and isinstance (output , torch ._six .container_abcs .Sequence ):
75
83
# User called old API with positional arguments of list of output tensors.
76
84
warnings .warn (
77
85
"nccl.reduce with an output tensor list is deprecated. "
78
86
"Please specify a single output tensor." )
79
- output = output [root ]
80
- elif output is None :
81
- output = inputs [root ]
82
- torch ._C ._nccl_reduce (inputs , output , root , op , streams , comms )
87
+ _output = output [root ]
88
+ else :
89
+ _output = inputs [root ] if output is None else output
90
+ torch ._C ._nccl_reduce (inputs , _output , root , op , streams , comms )
83
91
84
92
85
- def broadcast (inputs , root = 0 , streams = None , comms = None ):
93
+ def broadcast (inputs : Sequence [ torch . Tensor ] , root : int = 0 , streams = None , comms = None ) -> None :
86
94
_check_sequence_type (inputs )
87
95
torch ._C ._nccl_broadcast (inputs , root , streams , comms )
88
96
89
97
90
- def all_gather (inputs , outputs , streams = None , comms = None ):
98
+ def all_gather (inputs : Sequence [ torch . Tensor ] , outputs : Sequence [ torch . Tensor ] , streams = None , comms = None ) -> None :
91
99
_check_sequence_type (inputs )
92
100
_check_sequence_type (outputs )
93
101
torch ._C ._nccl_all_gather (inputs , outputs , streams , comms )
94
102
95
103
96
- def reduce_scatter (inputs , outputs , op = SUM , streams = None , comms = None ):
104
+ def reduce_scatter (inputs : Sequence [torch .Tensor ],
105
+ outputs : Sequence [torch .Tensor ],
106
+ op : int = SUM ,
107
+ streams = None , comms = None ) -> None :
97
108
_check_sequence_type (inputs )
98
109
_check_sequence_type (outputs )
99
110
torch ._C ._nccl_reduce_scatter (inputs , outputs , op , streams , comms )
0 commit comments