1
1
#!/usr/bin/python3
2
2
import types
3
- from typing import Any , Callable , Dict , Iterator , Optional , Set , Tuple , TypeVar , Union , List
3
+ from typing import (
4
+ Any ,
5
+ Callable ,
6
+ Dict ,
7
+ Iterator ,
8
+ List ,
9
+ Optional ,
10
+ Set ,
11
+ Tuple ,
12
+ TypeVar ,
13
+ Union ,
14
+ )
4
15
5
16
import torch
6
17
import torch .distributed .rpc as rpc
@@ -26,7 +37,7 @@ def _instantiate_template(module_interface_cls):
26
37
instantiator .instantiate_scriptable_remote_module_template (module_interface_cls )
27
38
28
39
29
- def _create_module (module_cls , args , kwargs , module_interface_cls = None ):
40
+ def _create_module (module_cls , args , kwargs , device = "cpu" , module_interface_cls = None ):
30
41
module = module_cls (* args , ** kwargs )
31
42
if not isinstance (module , nn .Module ):
32
43
raise ValueError (
@@ -35,6 +46,7 @@ def _create_module(module_cls, args, kwargs, module_interface_cls=None):
35
46
)
36
47
if module_interface_cls is not None :
37
48
module = torch .jit .script (module )
49
+ module .to (device )
38
50
return rpc .RRef (module , module_interface_cls )
39
51
40
52
@@ -53,6 +65,7 @@ class _RemoteModule(nn.Module):
53
65
def __init__ (
54
66
self ,
55
67
on : str ,
68
+ device : torch .device ,
56
69
module_cls : nn .Module ,
57
70
args : Tuple = None ,
58
71
kwargs : Dict [str , Any ] = None ,
@@ -88,6 +101,7 @@ def __init__(
88
101
89
102
Arguments:
90
103
on (str or WorkerInfo): id or name of the destination worker.
104
+ device (torch.device): Device on the destination worker where we‘d like to place this module.
91
105
module_cls (nn.Module): For example,
92
106
>>> class MyModule(nn.Module):
93
107
>>> def forward(input):
@@ -118,7 +132,7 @@ def __init__(
118
132
>>>
119
133
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
120
134
>>> remote_linear_module = RemoteModule(
121
- >>> "worker1", nn.Linear, args=(20, 30),
135
+ >>> "worker1", "cpu", nn.Linear, args=(20, 30),
122
136
>>> )
123
137
>>> input = torch.randn(128, 20)
124
138
>>> ret_fut = remote_linear_module.forward_async(input)
@@ -164,7 +178,9 @@ def __init__(
164
178
165
179
# Create the module on the remote side.
166
180
self .module_rref = rpc .rpc_sync (
167
- on , _create_module , (module_cls , args , kwargs , _module_interface_cls )
181
+ on ,
182
+ _create_module ,
183
+ (module_cls , args , kwargs , device , _module_interface_cls ),
168
184
)
169
185
170
186
# Install generated methods.
@@ -314,6 +330,7 @@ class RemoteModule(_RemoteModule):
314
330
315
331
Arguments:
316
332
to (str or WorkerInfo): id or name of the destination worker.
333
+ device (torch.device): Device on the destination worker where we‘d like to place this module.
317
334
module_cls (nn.Module): For example,
318
335
>>> class MyModule(nn.Module):
319
336
>>> def forward(input):
@@ -358,8 +375,9 @@ class RemoteModule(_RemoteModule):
358
375
def __init__ (
359
376
self ,
360
377
on : str ,
378
+ device : torch .device ,
361
379
module_cls : nn .Module ,
362
380
args : Tuple = None ,
363
381
kwargs : Dict [str , Any ] = None ,
364
382
):
365
- super ().__init__ (on , module_cls , args , kwargs )
383
+ super ().__init__ (on , device , module_cls , args , kwargs )
0 commit comments