You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Changes TensorIterator computation to not consider out kwarg, lets UnaryOps safe cast to out (pytorch#39655)
Summary:
**BC breaking note:**
In PyTorch 1.5 passing the out= kwarg to some functions, like torch.add, could affect the computation. That is,
```
out = torch.add(a, b)
```
could produce a different tensor than
```
torch.add(a, b, out=out)
```
This is because previously the out argument participated in the type promotion rules. For greater consistency with NumPy, Python, and C++, in PyTorch 1.6 the out argument no longer participates in type promotion, and has no effect on the computation performed.
**ORIGINAL PR NOTE**
This PR effectively rewrites Tensor Iterator's "compute_types" function to both clarify its behavior and change how our type promotion works to never consider the out argument when determining the iterator's "common dtype," AKA its "computation type." That is,
```
a = op(b, c)
```
should always produce the same result as
```
op(b, c, out=a)
```
This is consistent with NumPy and programming languages like Python and C++.
The conceptual model for this change is that a TensorIterator may have a "common computation type" that all inputs are cast to and its computation performed in. This common computation type, if it exists, is determined by applying our type promotion rules to the inputs.
A common computation type is natural for some classes of functions, like many binary elementwise functions (e.g. add, sub, mul, div...). (NumPy describes these as "universal functions.") Many functions, however, like indexing operations, don't have a natural common computation type. In the future we'll likely want to support setting the TensorIterator's common computation type explicitly to enable "floating ufuncs" like the sin function that promote integer types to the default scalar type. Logic like that is beyond the type promotion system, which can only review inputs.
Implementing this change in a readable and maintainable manner was challenging because compute_types() has had many small modifications from many authors over ~2 year period, and the existing logic was in some places outdated and in other places unnecessarily complicated. The existing "strategies" approach also painted with a broad brush, and two of them no longer made conceptual sense after this change. As a result, the new version of this function has a small set of flags to control its behavior. This has the positive effect of disentangling checks like all operands having the same device and their having the same dtype.
Additional changes in this PR:
- Unary operations now support out arguments with different dtypes. Like binary ops they check canCast(computation type, out dtype).
- The dtype checking for lerp was outdated and its error message included the wrong variable. It has been fixed.
- The check for whether all tensors are on the same device has been separated from other checks. TensorIterators used by copy disable this check.
- As a result of this change, the output dtype can be computed if only the input types are available.
- The "fast path" for checking if a common dtype computation is necessary has been updated and simplified to also handle zero-dim tensors.
- A couple helper functions for compute_types() have been inlined to improve readability.
- The confusingly named and no longer used promote_gpu_output_dtypes_ has been removed. This variable was intended to support casting fp16 reductions on GPU, but it has become a nullop. That logic is now implemented here: https://github.com/pytorch/pytorch/blob/856215509d89c935cd1768ce4b496d4fc0e919a6/aten/src/ATen/native/ReduceOpsUtils.h#L207.
Pull Request resolved: pytorch#39655
Differential Revision: D21970878
Pulled By: mruberry
fbshipit-source-id: 5e6354c78240877ab5d6b1f7cfb351bd89049012
0 commit comments