Skip to content

Commit d4736ef

Browse files
mrshenlifacebook-github-bot
authored andcommittedJul 24, 2020
Add done() API to Future (pytorch#42013)
Summary: Pull Request resolved: pytorch#42013 Test Plan: Imported from OSS Reviewed By: rohan-varma Differential Revision: D22729596 Pulled By: mrshenli fbshipit-source-id: ed31021a35af6e2c3393b9b14e4572cf51013bc0
1 parent 890b52e commit d4736ef

File tree

6 files changed

+56
-0
lines changed

6 files changed

+56
-0
lines changed
 

‎test/test_futures.py

+24
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,30 @@ def add_one(fut):
1111

1212

1313
class TestFuture(TestCase):
14+
15+
def test_done(self) -> None:
16+
f = Future[torch.Tensor]()
17+
self.assertFalse(f.done())
18+
19+
f.set_result(torch.ones(2, 2))
20+
self.assertTrue(f.done())
21+
22+
def test_done_exception(self) -> None:
23+
err_msg = "Intentional Value Error"
24+
25+
def raise_exception(unused_future):
26+
raise RuntimeError(err_msg)
27+
28+
f1 = Future[torch.Tensor]()
29+
self.assertFalse(f1.done())
30+
f1.set_result(torch.ones(2, 2))
31+
self.assertTrue(f1.done())
32+
33+
f2 = f1.then(raise_exception)
34+
self.assertTrue(f2.done())
35+
with self.assertRaisesRegex(RuntimeError, err_msg):
36+
f2.wait()
37+
1438
def test_wait(self) -> None:
1539
f = Future[torch.Tensor]()
1640
f.set_result(torch.ones(2, 2))

‎torch/_C/__init__.pyi.in

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class _LegacyVariableBase(object):
120120
# Defined in torch/csrc/jit/python/init.cpp
121121
class Future(object):
122122
def __init__(self) -> None: ...
123+
def done(self) -> _bool: ...
123124
def wait(self) -> Any: ...
124125
def then(self, callback: Callable) -> Future: ...
125126
def set_result(self, result: Any) -> None: ...

‎torch/csrc/jit/python/init.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -942,6 +942,10 @@ void initJITBindings(PyObject* module) {
942942
return std::make_shared<PythonFutureWrapper>(
943943
c10::make_intrusive<c10::ivalue::Future>(PyObjectType::get()));
944944
}))
945+
.def(
946+
"done",
947+
// Intentionally not releasing GIL
948+
&PythonFutureWrapper::done)
945949
.def(
946950
"wait",
947951
&PythonFutureWrapper::wait,

‎torch/csrc/jit/python/pybind_utils.h

+4
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ struct VISIBILITY_HIDDEN PythonFutureWrapper
7474
explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
7575
PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;
7676

77+
bool done() {
78+
return fut->completed();
79+
}
80+
7781
py::object wait() {
7882
fut->wait();
7983
if (jit::tracer::isTracing()) {

‎torch/futures/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta):
2222
execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
2323
also exposes a set of APIs to add callback functions and set results.
2424
"""
25+
26+
def done(self) -> bool:
27+
r"""
28+
Return ``True`` if this ``Future`` is done. A ``Future`` is done if it
29+
has a result or an exception.
30+
"""
31+
return super().done()
32+
2533
def wait(self) -> T:
2634
r"""
2735
Block until the value of this ``Future`` is ready.

‎torch/testing/_internal/distributed/rpc/rpc_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -3105,6 +3105,21 @@ def test_pickle_future(self):
31053105
with self.assertRaisesRegex(RuntimeError, errMsg):
31063106
rpc.remote(dst, fail_on_fut, args=(fut,))
31073107

3108+
@dist_init
3109+
def test_future_done(self):
3110+
dst = worker_name((self.rank + 1) % self.world_size)
3111+
fut = rpc.rpc_async(dst, torch.add, args=(torch.zeros(2), 1))
3112+
fut.wait()
3113+
self.assertTrue(fut.done())
3114+
3115+
@dist_init
3116+
def test_future_done_exception(self):
3117+
dst = worker_name((self.rank + 1) % self.world_size)
3118+
fut = rpc.rpc_async(dst, raise_func)
3119+
with self.assertRaisesRegex(ValueError, "Expected error"):
3120+
fut.wait()
3121+
self.assertTrue(fut.done())
3122+
31083123
def _test_future_cb(self, func):
31093124
dst1 = worker_name((self.rank + 1) % self.world_size)
31103125
dst2 = worker_name((self.rank + 2) % self.world_size)

0 commit comments

Comments
 (0)
Please sign in to comment.