Skip to content

Commit 6b97279

Browse files
Will Fengfacebook-github-bot
Will Feng
authored andcommittedJun 19, 2019
Add torch.__future__._overwrite_module_params_on_conversion global flag, and check it in nn.Module._apply() (pytorch#21613)
Summary: pytorch#17072 breaks `model.to(xla_device)`, because moving `model` to XLA device involves changing its parameters' TensorImpl type, and the current implementation of `nn.Module.to()` doesn't support changing module parameters' TensorImpl type: ```python # https://github.com/pytorch/pytorch/blob/6dc445e1a84dc5d093d640de54f038f021d13227/torch/nn/modules/module.py#L192-L208 def _apply(self, fn): ... for param in self._parameters.values(): if param is not None: # Tensors stored in modules are graph leaves, and we don't # want to create copy nodes, so we have to unpack the data. param.data = fn(param.data) # NOTE: this doesn't allow changing `param.data`'s TensorImpl type if param._grad is not None: param._grad.data = fn(param._grad.data) # NOTE: this doesn't allow changing `param._grad.data`'s TensorImpl type ... ``` yf225 TODO: fix the description here when we finish the implementation To fix this problem, we introduce a new API `model.to_()` that always assign new tensors to the parameters (thus supporting changing the parameters to any TensorImpl type), and also bump the version counter of the original parameters correctly so that they are invalidated in any autograd graph they participate in. We also add warning to the current `model.to()` API to inform users about the upcoming behavior change of `model.to()`: in future releases, it would create and return a new model instead of in-place updating the current model. This unblocks adding XLA to our CI test suite, which also allows XLA to catch up with other changes in our codebase, notably the c10 dispatcher. [xla ci] cc. resistor ailzhang Pull Request resolved: pytorch#21613 Differential Revision: D15895387 Pulled By: yf225 fbshipit-source-id: b79f230fb06019122a37fdf0711bf2130a016fe6
1 parent 056a033 commit 6b97279

File tree

8 files changed

+204
-10
lines changed

8 files changed

+204
-10
lines changed
 

‎aten/src/ATen/native/TypeProperties.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ bool is_quantized(const Tensor& self) {
3838
return self.is_quantized();
3939
}
4040

41+
// True if `self` has the same derived type of TensorImpl as `other`.
42+
bool _has_same_tensorimpl_type(const Tensor& self, const Tensor& other) {
43+
return typeid(*(self.unsafeGetTensorImpl())) == typeid(*(other.unsafeGetTensorImpl()));
44+
}
45+
4146
Tensor type_as(const Tensor& self, const Tensor& other) {
4247
return self.toType(other.type());
4348
}

‎aten/src/ATen/native/native_functions.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -2036,6 +2036,9 @@
20362036
- func: type_as(Tensor self, Tensor other) -> Tensor
20372037
variants: method
20382038

2039+
- func: _has_same_tensorimpl_type(Tensor self, Tensor other) -> bool
2040+
variants: function
2041+
20392042
- func: _unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)
20402043
variants: function
20412044
dispatch:

‎test/test_nn.py

+138
Original file line numberDiff line numberDiff line change
@@ -1602,6 +1602,144 @@ def add_one_inplace(t):
16021602
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
16031603
pgm.backward(torch.randn(10, 20))
16041604

1605+
def test_overwrite_module_params_on_conversion(self):
1606+
torch.__future__.set_overwrite_module_params_on_conversion(False)
1607+
1608+
# Test that if the conversion function passed to `module._apply()`
1609+
# changes the TensorImpl type of `module`'s parameters, the `module`'s
1610+
# parameters are always overwritten, regardless of the value of
1611+
# `torch.__future__.get_overwrite_module_params_on_conversion()`.
1612+
m = nn.Linear(20, 10)
1613+
m.weight.grad = torch.randn(10, 20)
1614+
weight_ref = m.weight
1615+
weight_grad_ref = m.weight.grad
1616+
m = m._apply(lambda t: torch.sparse_coo_tensor(torch.zeros([2, 1]), torch.ones([1]), torch.Size([10, 20])))
1617+
self.assertNotEqual(weight_ref.layout, m.weight.layout)
1618+
self.assertNotEqual(weight_grad_ref.layout, m.weight.grad.layout)
1619+
1620+
# Test that under the current default settings
1621+
# (`torch.__future__.get_overwrite_module_params_on_conversion() == False`),
1622+
# a view to a module's parameters is not pointing to the same storage as
1623+
# its base variable after converting the module to a different dtype.
1624+
m = nn.Linear(20, 10).float()
1625+
mw = m.weight[:]
1626+
m.double()
1627+
mw[0][0] = 5
1628+
with self.assertRaisesRegex(RuntimeError, "Expected object of scalar type Float but got scalar type Double"):
1629+
mw[0][0] == mw._base[0][0]
1630+
1631+
torch.__future__.set_overwrite_module_params_on_conversion(True)
1632+
1633+
# Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1634+
# a view to a module's parameters is still pointing to the same storage as
1635+
# its base variable after converting the module to a different dtype.
1636+
m = nn.Linear(20, 10).float()
1637+
mw = m.weight[:]
1638+
m.double()
1639+
mw[0][0] = 5
1640+
self.assertTrue(mw[0][0] == mw._base[0][0])
1641+
1642+
# Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1643+
# `float_module.double()` doesn't preserve previous references to
1644+
# `float_module`'s parameters or gradients.
1645+
m = nn.Linear(20, 10).float()
1646+
m.weight.grad = torch.randn(10, 20).float()
1647+
weight_ref = m.weight
1648+
weight_grad_ref = m.weight.grad
1649+
m.double()
1650+
self.assertNotEqual(weight_ref.dtype, m.weight.dtype)
1651+
self.assertNotEqual(weight_grad_ref.dtype, m.weight.grad.dtype)
1652+
1653+
def add_one_inplace(t):
1654+
return t.add_(1.0)
1655+
1656+
# Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1657+
# applying an in-place operation to a module would bump the module's
1658+
# original parameters' version counter.
1659+
m = nn.Linear(20, 10)
1660+
pvm = m.weight.mul(m.weight)
1661+
weight_ref = m.weight
1662+
m_weight_version_saved = weight_ref._version
1663+
m = m._apply(add_one_inplace)
1664+
# Test that the in-place operation bumps the original parameter's version counter
1665+
self.assertGreater(weight_ref._version, m_weight_version_saved)
1666+
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1667+
pvm.backward(torch.randn(10, 20))
1668+
1669+
# Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1670+
# applying an in-place operation to a module would bump the module's
1671+
# original parameters' gradients' version counter.
1672+
m = nn.Linear(20, 10)
1673+
m.weight.grad = torch.randn(10, 20).requires_grad_()
1674+
pgm = m.weight.grad.mul(m.weight.grad)
1675+
weight_grad_ref = m.weight.grad
1676+
m_weight_grad_version_saved = weight_grad_ref._version
1677+
m = m._apply(add_one_inplace)
1678+
self.assertGreater(weight_grad_ref._version, m_weight_grad_version_saved)
1679+
with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
1680+
pgm.backward(torch.randn(10, 20))
1681+
1682+
# Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1683+
# applying an out-of-place operation to a module doesn't bump
1684+
# the module's original parameters' version counter.
1685+
m = nn.Linear(20, 10)
1686+
weight_ref = m.weight
1687+
m_weight_version_saved = weight_ref._version
1688+
m = m._apply(lambda t: torch.randn(t.shape))
1689+
self.assertEqual(weight_ref._version, m_weight_version_saved)
1690+
1691+
# Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1692+
# applying an out-of-place operation to a module doesn't bump
1693+
# the module's original parameters' gradients' version counter.
1694+
m = nn.Linear(20, 10)
1695+
m.weight.grad = torch.randn(10, 20).requires_grad_()
1696+
weight_grad_ref = m.weight.grad
1697+
m_weight_grad_version_saved = weight_grad_ref._version
1698+
m = m._apply(lambda t: torch.randn(t.shape))
1699+
self.assertEqual(weight_grad_ref._version, m_weight_grad_version_saved)
1700+
1701+
torch.__future__.set_overwrite_module_params_on_conversion(False)
1702+
1703+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1704+
def test_overwrite_module_params_on_conversion_cpu_cuda(self):
1705+
torch.__future__.set_overwrite_module_params_on_conversion(False)
1706+
1707+
# Test that under the current default settings
1708+
# (`torch.__future__.get_overwrite_module_params_on_conversion() == False`),
1709+
# a view to a module's parameters is not pointing to the same storage as
1710+
# its base variable after converting the module to a different device.
1711+
m = nn.Linear(20, 10)
1712+
mw = m.weight[:]
1713+
m.to('cuda')
1714+
with torch.no_grad():
1715+
# Without using `torch.no_grad()`, this will leak CUDA memory.
1716+
# (Issue is filed at https://github.com/pytorch/pytorch/issues/21875)
1717+
mw[0][0] = 5
1718+
with self.assertRaisesRegex(RuntimeError, "Expected object of backend CPU but got backend CUDA"):
1719+
mw[0][0] == mw._base[0][0]
1720+
1721+
torch.__future__.set_overwrite_module_params_on_conversion(True)
1722+
1723+
# Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1724+
# a view to a module's parameters is still pointing to the same storage as
1725+
# its base variable after converting the module to a different device.
1726+
m = nn.Linear(20, 10)
1727+
mw = m.weight[:]
1728+
m.to('cuda')
1729+
mw[0][0] = 5
1730+
self.assertTrue(mw[0][0] == mw._base[0][0])
1731+
1732+
# Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`,
1733+
# `cpu_module.to("cuda")` doesn't preserve previous references to
1734+
# `cpu_module`'s parameters or gradients.
1735+
m = nn.Linear(20, 10)
1736+
m.weight.grad = torch.randn(10, 20)
1737+
weight_ref = m.weight
1738+
weight_grad_ref = m.weight.grad
1739+
m.to('cuda')
1740+
self.assertNotEqual(weight_ref.device, m.weight.device)
1741+
self.assertNotEqual(weight_grad_ref.device, m.weight.grad.device)
1742+
16051743
def test_type(self):
16061744
l = nn.Linear(10, 20)
16071745
net = nn.Module()

‎torch/__future__.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
This global flag controls whether to assign new tensors to the parameters
3+
instead of changing the existing parameters in-place when converting an `nn.Module`
4+
using the following methods:
5+
1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
6+
2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
7+
3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
8+
4. `module._apply(fn)` (for generic functions applied to `module`)
9+
10+
Default: False
11+
"""
12+
_overwrite_module_params_on_conversion = False
13+
14+
def set_overwrite_module_params_on_conversion(value):
15+
global _overwrite_module_params_on_conversion
16+
_overwrite_module_params_on_conversion = value
17+
18+
def get_overwrite_module_params_on_conversion():
19+
return _overwrite_module_params_on_conversion

‎torch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def manager_path():
310310
import torch.backends.mkl
311311
import torch.backends.openmp
312312
import torch.__config__
313+
import torch.__future__
313314

314315
_C._init_names(list(torch._storage_classes))
315316

‎torch/csrc/autograd/variable.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ void Variable::set_data(const at::Tensor &new_data) {
8888
// from `new_data` to `var`. It requires that `new_data` has the same derived
8989
// type of TensorImpl as `var`.
9090
TORCH_CHECK(
91-
typeid(*(this->unsafeGetTensorImpl())) == typeid(*(new_data.unsafeGetTensorImpl())),
91+
_has_same_tensorimpl_type(*this, new_data),
9292
"Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have different types of TensorImpl.");
9393

9494
// Resets gradient accumulator if metadata is out of date

‎torch/csrc/autograd/variable.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,9 @@ struct TORCH_API Variable : public at::Tensor {
248248
bool keep_graph,
249249
bool create_graph) const;
250250

251-
/// Sets the `Tensor` held by this `Variable` to the one supplied.
252-
/// It is rarely necessary to call this; it's used, for example, when
253-
/// a non-sparse gradient gets added to a sparse gradient, requiring
254-
/// the type of the gradient `Variable` to become non-sparse.
251+
/// Sets the tensor data held by this `Variable` to be the same as `new_data`.
252+
/// It requires that `new_data` has the same derived type of TensorImpl as
253+
/// this `Variable`, by checking `_has_same_tensorimpl_type(this, new_data)`.
255254
void set_data(const at::Tensor &new_data);
256255

257256
/// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the

‎torch/nn/modules/module.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,44 @@ def _apply(self, fn):
193193
for module in self.children():
194194
module._apply(fn)
195195

196-
for param in self._parameters.values():
196+
def compute_should_use_set_data(tensor, tensor_applied):
197+
if torch._has_same_tensorimpl_type(tensor, tensor_applied):
198+
# If the new tensor has the same TensorImpl type as the existing tensor,
199+
# the current behavior is to change the tensor in-place using `.data =`,
200+
# and the future behavior is to overwrite the existing tensor. However,
201+
# changing the current behavior is a BC-breaking change, and we want it
202+
# to happen in future releases. So for now we introduce the
203+
# `torch.__future__.get_overwrite_module_params_on_conversion()`
204+
# global flag to let the user control whether they want the future
205+
# behavior of overwriting the existing tensor or not.
206+
return not torch.__future__.get_overwrite_module_params_on_conversion()
207+
else:
208+
return False
209+
210+
for key, param in self._parameters.items():
197211
if param is not None:
212+
# Tensors stored in modules are graph leaves, and we don't want to
213+
# track autograd history of `param_applied`, so we have to use
214+
# `with torch.no_grad():`
198215
with torch.no_grad():
199216
param_applied = fn(param)
200-
param.data = param_applied
201-
if param._grad is not None:
217+
should_use_set_data = compute_should_use_set_data(param, param_applied)
218+
if should_use_set_data:
219+
param.data = param_applied
220+
else:
221+
assert isinstance(param, Parameter)
222+
assert param.is_leaf
223+
self._parameters[key] = Parameter(param_applied, param.requires_grad)
224+
225+
if param.grad is not None:
202226
with torch.no_grad():
203-
grad_applied = fn(param._grad)
204-
param._grad.data = grad_applied
227+
grad_applied = fn(param.grad)
228+
should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
229+
if should_use_set_data:
230+
param.grad.data = grad_applied
231+
else:
232+
assert param.grad.is_leaf
233+
self._parameters[key].grad = grad_applied.requires_grad_(param.grad.requires_grad)
205234

206235
for key, buf in self._buffers.items():
207236
if buf is not None:

0 commit comments

Comments
 (0)
Please sign in to comment.