Skip to content

Commit cce7680

Browse files
mrshenlifacebook-github-bot
authored andcommittedSep 16, 2020
Add bound method tests for async_execution with RRef helper (pytorch#44716)
Summary: Pull Request resolved: pytorch#44716 Test Plan: Imported from OSS Reviewed By: rohan-varma Differential Revision: D23707326 Pulled By: mrshenli fbshipit-source-id: a2f8db17447e9f82c9f6ed941ff1f8cb9090ad74
1 parent 257c6d0 commit cce7680

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed
 

‎torch/distributed/rpc/functions.py

+26
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ def async_execution(fn):
115115
>>> )
116116
>>> return ret_fut
117117
>>>
118+
>>> @rpc.functions.async_execution
119+
>>> def bound_async_add(self, to, x, y, z):
120+
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
121+
>>> lambda fut: fut.wait() + z
122+
>>> )
123+
>>>
118124
>>> # On worker0
119125
>>> ret = rpc.rpc_sync(
120126
>>> "worker1",
@@ -129,6 +135,26 @@ def async_execution(fn):
129135
>>> args=("worker2", torch.ones(2), 1, 2)
130136
>>> )
131137
>>> print(ret) # prints tensor([4., 4.])
138+
139+
This decorator also works with RRef helpers, i.e., .
140+
:meth:`torch.distributed.rpc.RRef.rpc_sync`,
141+
:meth:`torch.distributed.rpc.RRef.rpc_async`, and
142+
:meth:`torch.distributed.rpc.RRef.remote`.
143+
144+
>>> from torch.distributed import rpc
145+
>>>
146+
>>> # reuse the AsyncExecutionClass class above
147+
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
148+
>>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
149+
>>> print(ret) # prints tensor([4., 4.])
150+
>>>
151+
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
152+
>>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
153+
>>> print(ret) # prints tensor([4., 4.])
154+
>>>
155+
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
156+
>>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
157+
>>> print(ret) # prints tensor([4., 4.])
132158
"""
133159
@functools.wraps(fn)
134160
def wrapper(*args, **kwargs):

‎torch/testing/_internal/distributed/rpc/rpc_test.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@ def class_async_add(cls, to, x, y, z):
448448
)
449449
return ret_fut
450450

451+
@rpc.functions.async_execution
452+
def bound_async_add(self, to, x, y, z):
453+
return rpc.rpc_async(to, torch.add, args=(x, y)).then(
454+
lambda fut: fut.wait() + z
455+
)
456+
451457

452458
def return_future():
453459
return torch.futures.Future()
@@ -3051,14 +3057,17 @@ def _test_test_async_class_rref_proxy(self, mode=RPCExecMode.SYNC):
30513057
if mode == RPCExecMode.SYNC:
30523058
ret = rref.rpc_sync().static_async_add(dst2, x, x, y)
30533059
ret += rref.rpc_sync().class_async_add(dst2, x, x, y)
3060+
ret += rref.rpc_sync().bound_async_add(dst2, x, x, y)
30543061
elif mode == RPCExecMode.ASYNC:
30553062
ret = rref.rpc_async().static_async_add(dst2, x, x, y).wait()
30563063
ret += rref.rpc_async().class_async_add(dst2, x, x, y).wait()
3064+
ret += rref.rpc_async().bound_async_add(dst2, x, x, y).wait()
30573065
elif mode == RPCExecMode.REMOTE:
30583066
ret = rref.remote().static_async_add(dst2, x, x, y).to_here()
30593067
ret += rref.remote().class_async_add(dst2, x, x, y).to_here()
3068+
ret += rref.remote().bound_async_add(dst2, x, x, y).to_here()
30603069

3061-
self.assertEqual(ret, 2 * 4 * x)
3070+
self.assertEqual(ret, 3 * 4 * x)
30623071

30633072
@dist_init
30643073
def test_async_class_rref_proxy(self):

0 commit comments

Comments
 (0)