Skip to content

Commit 4c19a1e

Browse files
rgommersfacebook-github-bot
authored andcommittedAug 31, 2020
Move torch/autograd/grad_mode.pyi stubs inline (pytorch#43415)
Summary: - Add `torch._C` bindings from `torch/csrc/autograd/init.cpp` - Renamed `torch._C.set_grad_enabled` to `torch._C._set_grad_enabled` so it doesn't conflict with torch.set_grad_enabled anymore This is a continuation of pytorchgh-38201. All I did was resolve merge conflicts and finish the annotation of `_DecoratorContextManager.__call__` that ezyang started in the first commit. ~Reverts commit b5cd3a8, which was only motivated by not having `typing_extensions` available.~ (JIT can't be made to understand `Literal[False]`, so keep as is). Pull Request resolved: pytorch#43415 Reviewed By: ngimel Differential Revision: D23301168 Pulled By: malfet fbshipit-source-id: cb5290f2e556b4036592655b9fe54564cbb036f6
1 parent e941a46 commit 4c19a1e

File tree

4 files changed

+41
-35
lines changed

4 files changed

+41
-35
lines changed
 

‎torch/_C/__init__.pyi.in

+11
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,17 @@ has_cudnn: _bool
217217
_GLIBCXX_USE_CXX11_ABI: _bool
218218
default_generator: Generator
219219

220+
# Defined in torch/csrc/autograd/init.cpp
221+
def _set_grad_enabled(enabled: _bool) -> None: ...
222+
def is_grad_enabled() -> _bool: ...
223+
def set_autocast_enabled(enabled: _bool) -> None: ...
224+
def is_autocast_enabled() -> _bool: ...
225+
def clear_autocast_cache() -> None: ...
226+
def autocast_increment_nesting() -> None: ...
227+
def autocast_decrement_nesting() -> None: ...
228+
def set_anomaly_enabled(enabled: _bool) -> None: ...
229+
def is_anomaly_enabled() -> _bool: ...
230+
220231
# Defined in torch/csrc/jit/python/script_init.cpp
221232
class FileCheck(object):
222233
# TODO

‎torch/autograd/grad_mode.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
11
import torch
22
import functools
33
import inspect
4-
from typing import Any
4+
from typing import Any, Callable, TypeVar, cast
5+
6+
7+
__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled']
8+
9+
10+
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
11+
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
12+
FuncType = Callable[..., Any]
13+
F = TypeVar('F', bound=FuncType)
14+
515

616
class _DecoratorContextManager:
717
"""Allow a context manager to be used as a decorator"""
818

9-
def __call__(self, func):
19+
def __call__(self, func: F) -> F:
1020
if inspect.isgeneratorfunction(func):
1121
return self._wrap_generator(func)
1222

1323
@functools.wraps(func)
1424
def decorate_context(*args, **kwargs):
1525
with self:
1626
return func(*args, **kwargs)
17-
return decorate_context
27+
return cast(F, decorate_context)
1828

1929
def _wrap_generator(self, func):
2030
"""Wrap each generator invocation with the context manager"""
@@ -30,6 +40,12 @@ def generator_context(*args, **kwargs):
3040
break
3141
return generator_context
3242

43+
def __enter__(self) -> None:
44+
raise NotImplementedError
45+
46+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
47+
raise NotImplementedError
48+
3349

3450
class no_grad(_DecoratorContextManager):
3551
r"""Context-manager that disabled gradient calculation.
@@ -70,7 +86,7 @@ def __enter__(self):
7086
self.prev = torch.is_grad_enabled()
7187
torch.set_grad_enabled(False)
7288

73-
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
89+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
7490
torch.set_grad_enabled(self.prev)
7591

7692

@@ -105,12 +121,12 @@ class enable_grad(_DecoratorContextManager):
105121
True
106122
107123
"""
108-
def __enter__(self):
124+
def __enter__(self) -> None:
109125
self.prev = torch.is_grad_enabled()
110-
torch._C.set_grad_enabled(True)
126+
torch._C._set_grad_enabled(True)
111127

112-
def __exit__(self, *args):
113-
torch.set_grad_enabled(self.prev)
128+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
129+
torch._C._set_grad_enabled(self.prev)
114130

115131

116132
class set_grad_enabled(object):
@@ -147,12 +163,12 @@ class set_grad_enabled(object):
147163
148164
"""
149165

150-
def __init__(self, mode):
166+
def __init__(self, mode: bool) -> None:
151167
self.prev = torch.is_grad_enabled()
152-
torch._C.set_grad_enabled(mode)
168+
torch._C._set_grad_enabled(mode)
153169

154-
def __enter__(self):
170+
def __enter__(self) -> None:
155171
pass
156172

157-
def __exit__(self, *args):
158-
torch.set_grad_enabled(self.prev)
173+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
174+
torch._C._set_grad_enabled(self.prev)

‎torch/autograd/grad_mode.pyi

-21
This file was deleted.

‎torch/csrc/autograd/init.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ static PyObject * is_anomaly_mode_enabled(PyObject* _unused, PyObject *arg) {
163163

164164
// autograd methods on torch._C
165165
static PyMethodDef methods[] = { // NOLINT
166-
{"set_grad_enabled", (PyCFunction)set_grad_enabled, METH_O, nullptr},
166+
{"_set_grad_enabled", (PyCFunction)set_grad_enabled, METH_O, nullptr},
167167
{"is_grad_enabled", (PyCFunction)is_grad_enabled, METH_NOARGS, nullptr},
168168
{"set_autocast_enabled", (PyCFunction)set_autocast_enabled, METH_O, nullptr},
169169
{"is_autocast_enabled", (PyCFunction)is_autocast_enabled, METH_NOARGS, nullptr},

0 commit comments

Comments
 (0)
Please sign in to comment.