You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add torch.__future__._overwrite_module_params_on_conversion global flag, and check it in nn.Module._apply() (pytorch#21613)
Summary:
pytorch#17072 breaks `model.to(xla_device)`, because moving `model` to XLA device involves changing its parameters' TensorImpl type, and the current implementation of `nn.Module.to()` doesn't support changing module parameters' TensorImpl type:
```python
# https://github.com/pytorch/pytorch/blob/6dc445e1a84dc5d093d640de54f038f021d13227/torch/nn/modules/module.py#L192-L208
def _apply(self, fn):
...
for param in self._parameters.values():
if param is not None:
# Tensors stored in modules are graph leaves, and we don't
# want to create copy nodes, so we have to unpack the data.
param.data = fn(param.data) # NOTE: this doesn't allow changing `param.data`'s TensorImpl type
if param._grad is not None:
param._grad.data = fn(param._grad.data) # NOTE: this doesn't allow changing `param._grad.data`'s TensorImpl type
...
```
yf225 TODO: fix the description here when we finish the implementation
To fix this problem, we introduce a new API `model.to_()` that always assign new tensors to the parameters (thus supporting changing the parameters to any TensorImpl type), and also bump the version counter of the original parameters correctly so that they are invalidated in any autograd graph they participate in.
We also add warning to the current `model.to()` API to inform users about the upcoming behavior change of `model.to()`: in future releases, it would create and return a new model instead of in-place updating the current model.
This unblocks adding XLA to our CI test suite, which also allows XLA to catch up with other changes in our codebase, notably the c10 dispatcher.
[xla ci]
cc. resistor ailzhang
Pull Request resolved: pytorch#21613
Differential Revision: D15895387
Pulled By: yf225
fbshipit-source-id: b79f230fb06019122a37fdf0711bf2130a016fe6
0 commit comments