Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit c68cc78

Browse files
Yi Wangfacebook-github-bot
Yi Wang
authored andcommittedSep 18, 2020
Add a device parameter to RemoteModule (pytorch#44254)
Summary: Pull Request resolved: pytorch#44254 Add a device parameter to RemoteModule, so it can be placed on any device and not just CPU. Original PR issue: RemoteModule enhancements pytorch#40550 Test Plan: buck test test/distributed/rpc:process_group_agent -- RemoteModule Reviewed By: pritamdamania87 Differential Revision: D23483803 fbshipit-source-id: 4918583c15c6a38a255ccbf12c9168660ab7f6db
1 parent cff0e57 commit c68cc78

File tree

3 files changed

+161
-73
lines changed

3 files changed

+161
-73
lines changed
 

‎torch/distributed/nn/api/remote_module.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
#!/usr/bin/python3
22
import types
3-
from typing import Any, Callable, Dict, Iterator, Optional, Set, Tuple, TypeVar, Union, List
3+
from typing import (
4+
Any,
5+
Callable,
6+
Dict,
7+
Iterator,
8+
List,
9+
Optional,
10+
Set,
11+
Tuple,
12+
TypeVar,
13+
Union,
14+
)
415

516
import torch
617
import torch.distributed.rpc as rpc
@@ -26,7 +37,7 @@ def _instantiate_template(module_interface_cls):
2637
instantiator.instantiate_scriptable_remote_module_template(module_interface_cls)
2738

2839

29-
def _create_module(module_cls, args, kwargs, module_interface_cls=None):
40+
def _create_module(module_cls, args, kwargs, device="cpu", module_interface_cls=None):
3041
module = module_cls(*args, **kwargs)
3142
if not isinstance(module, nn.Module):
3243
raise ValueError(
@@ -35,6 +46,7 @@ def _create_module(module_cls, args, kwargs, module_interface_cls=None):
3546
)
3647
if module_interface_cls is not None:
3748
module = torch.jit.script(module)
49+
module.to(device)
3850
return rpc.RRef(module, module_interface_cls)
3951

4052

@@ -53,6 +65,7 @@ class _RemoteModule(nn.Module):
5365
def __init__(
5466
self,
5567
on: str,
68+
device: torch.device,
5669
module_cls: nn.Module,
5770
args: Tuple = None,
5871
kwargs: Dict[str, Any] = None,
@@ -88,6 +101,7 @@ def __init__(
88101
89102
Arguments:
90103
on (str or WorkerInfo): id or name of the destination worker.
104+
device (torch.device): Device on the destination worker where we‘d like to place this module.
91105
module_cls (nn.Module): For example,
92106
>>> class MyModule(nn.Module):
93107
>>> def forward(input):
@@ -118,7 +132,7 @@ def __init__(
118132
>>>
119133
>>> rpc.init_rpc("worker0", rank=0, world_size=2)
120134
>>> remote_linear_module = RemoteModule(
121-
>>> "worker1", nn.Linear, args=(20, 30),
135+
>>> "worker1", "cpu", nn.Linear, args=(20, 30),
122136
>>> )
123137
>>> input = torch.randn(128, 20)
124138
>>> ret_fut = remote_linear_module.forward_async(input)
@@ -164,7 +178,9 @@ def __init__(
164178

165179
# Create the module on the remote side.
166180
self.module_rref = rpc.rpc_sync(
167-
on, _create_module, (module_cls, args, kwargs, _module_interface_cls)
181+
on,
182+
_create_module,
183+
(module_cls, args, kwargs, device, _module_interface_cls),
168184
)
169185

170186
# Install generated methods.
@@ -314,6 +330,7 @@ class RemoteModule(_RemoteModule):
314330
315331
Arguments:
316332
to (str or WorkerInfo): id or name of the destination worker.
333+
device (torch.device): Device on the destination worker where we‘d like to place this module.
317334
module_cls (nn.Module): For example,
318335
>>> class MyModule(nn.Module):
319336
>>> def forward(input):
@@ -358,8 +375,9 @@ class RemoteModule(_RemoteModule):
358375
def __init__(
359376
self,
360377
on: str,
378+
device: torch.device,
361379
module_cls: nn.Module,
362380
args: Tuple = None,
363381
kwargs: Dict[str, Any] = None,
364382
):
365-
super().__init__(on, module_cls, args, kwargs)
383+
super().__init__(on, device, module_cls, args, kwargs)

0 commit comments

Comments
 (0)
Please sign in to comment.