Skip to content

Commit 0aecbbb

Browse files
Mike Ruberryfacebook-github-bot
Mike Ruberry
authored andcommittedJun 10, 2020
Changes TensorIterator computation to not consider out kwarg, lets UnaryOps safe cast to out (pytorch#39655)
Summary: **BC breaking note:** In PyTorch 1.5 passing the out= kwarg to some functions, like torch.add, could affect the computation. That is, ``` out = torch.add(a, b) ``` could produce a different tensor than ``` torch.add(a, b, out=out) ``` This is because previously the out argument participated in the type promotion rules. For greater consistency with NumPy, Python, and C++, in PyTorch 1.6 the out argument no longer participates in type promotion, and has no effect on the computation performed. **ORIGINAL PR NOTE** This PR effectively rewrites Tensor Iterator's "compute_types" function to both clarify its behavior and change how our type promotion works to never consider the out argument when determining the iterator's "common dtype," AKA its "computation type." That is, ``` a = op(b, c) ``` should always produce the same result as ``` op(b, c, out=a) ``` This is consistent with NumPy and programming languages like Python and C++. The conceptual model for this change is that a TensorIterator may have a "common computation type" that all inputs are cast to and its computation performed in. This common computation type, if it exists, is determined by applying our type promotion rules to the inputs. A common computation type is natural for some classes of functions, like many binary elementwise functions (e.g. add, sub, mul, div...). (NumPy describes these as "universal functions.") Many functions, however, like indexing operations, don't have a natural common computation type. In the future we'll likely want to support setting the TensorIterator's common computation type explicitly to enable "floating ufuncs" like the sin function that promote integer types to the default scalar type. Logic like that is beyond the type promotion system, which can only review inputs. Implementing this change in a readable and maintainable manner was challenging because compute_types() has had many small modifications from many authors over ~2 year period, and the existing logic was in some places outdated and in other places unnecessarily complicated. The existing "strategies" approach also painted with a broad brush, and two of them no longer made conceptual sense after this change. As a result, the new version of this function has a small set of flags to control its behavior. This has the positive effect of disentangling checks like all operands having the same device and their having the same dtype. Additional changes in this PR: - Unary operations now support out arguments with different dtypes. Like binary ops they check canCast(computation type, out dtype). - The dtype checking for lerp was outdated and its error message included the wrong variable. It has been fixed. - The check for whether all tensors are on the same device has been separated from other checks. TensorIterators used by copy disable this check. - As a result of this change, the output dtype can be computed if only the input types are available. - The "fast path" for checking if a common dtype computation is necessary has been updated and simplified to also handle zero-dim tensors. - A couple helper functions for compute_types() have been inlined to improve readability. - The confusingly named and no longer used promote_gpu_output_dtypes_ has been removed. This variable was intended to support casting fp16 reductions on GPU, but it has become a nullop. That logic is now implemented here: https://github.com/pytorch/pytorch/blob/856215509d89c935cd1768ce4b496d4fc0e919a6/aten/src/ATen/native/ReduceOpsUtils.h#L207. Pull Request resolved: pytorch#39655 Differential Revision: D21970878 Pulled By: mruberry fbshipit-source-id: 5e6354c78240877ab5d6b1f7cfb351bd89049012
1 parent acc13ac commit 0aecbbb

29 files changed

+368
-282
lines changed
 

‎aten/src/ATen/native/BinaryOps.cpp

+18-7
Original file line numberDiff line numberDiff line change
@@ -111,23 +111,34 @@ Tensor& remainder_(Tensor& self, const Tensor& other) {
111111
}
112112

113113
Tensor& true_divide_out(Tensor& result, const Tensor& self, const Tensor& divisor) {
114-
TORCH_CHECK(!isIntegralType(result.scalar_type(), /*includeBool=*/ true),
115-
"True division requires a floating output type, but got ",
116-
result.scalar_type());
114+
// If both inputs have integral (or bool) types, creates
115+
// temporary float copies as new inputs.
116+
if (isIntegralType(self.scalar_type(), /*includeBool=*/ true)
117+
&& isIntegralType(divisor.scalar_type(), /*includeBool=*/ true)) {
118+
const auto scalar_type = typeMetaToScalarType(c10::get_default_dtype());
119+
auto iter = TensorIterator::binary_op(result,
120+
self.to(scalar_type),
121+
divisor.to(scalar_type),
122+
/*check_mem_overlap=*/ true);
123+
div_stub(iter.device_type(), iter);
124+
return result;
125+
}
117126
auto iter = TensorIterator::binary_op(result, self, divisor, /*check_mem_overlap=*/ true);
118127
div_stub(iter.device_type(), iter);
119128
return result;
120129
}
121130

122131
Tensor true_divide(const Tensor& self, const Tensor& divisor) {
123-
// If both inputs have integral (or bool) types, sets the output to have
124-
// the default (floating) scalar type
132+
// If both inputs have integral (or bool) types, creates
133+
// temporary float copies as new inputs and sets the result's type to
134+
// the default scalar type
125135
if (isIntegralType(self.scalar_type(), /*includeBool=*/ true)
126136
&& isIntegralType(divisor.scalar_type(), /*includeBool=*/ true)) {
127137
const auto scalar_type = typeMetaToScalarType(c10::get_default_dtype());
128138
Tensor result = at::empty({0}, self.options().dtype(scalar_type));
129-
130-
auto iter = TensorIterator::binary_op(result, self, divisor);
139+
auto iter = TensorIterator::binary_op(result,
140+
self.to(scalar_type),
141+
divisor.to(scalar_type));
131142
div_stub(iter.device_type(), iter);
132143
return result;
133144
}

‎aten/src/ATen/native/Copy.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
140140
iter.add_output(self);
141141
iter.add_input(src);
142142
iter.dont_resize_outputs();
143-
iter.dont_compute_common_dtype();
143+
iter.check_all_same_dtype(false);
144+
iter.check_all_same_device(false);
144145
iter.build();
145146

146147
if (iter.numel() == 0) {

‎aten/src/ATen/native/ReduceOpsUtils.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ static TensorIterator make_reduction(
212212

213213
static TensorIterator make_reduction(
214214
const char* name, Tensor& result1, Tensor& result2, const Tensor& self, IntArrayRef dim,
215-
bool keepdim, ScalarType dtype1, ScalarType dtype2, bool promote_gpu_output_dtypes=true)
215+
bool keepdim, ScalarType dtype1, ScalarType dtype2)
216216
{
217217
// check that result type and dtype match if provided
218218
TORCH_CHECK(
@@ -240,9 +240,9 @@ static TensorIterator make_reduction(
240240
// product of templated kernel launches.
241241
if (self.scalar_type() == dtype1 ||
242242
(self.is_cuda() && self.scalar_type() == kHalf && dtype1 == kFloat)) {
243-
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self, promote_gpu_output_dtypes);
243+
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self);
244244
}
245-
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1), promote_gpu_output_dtypes);
245+
return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1));
246246
}
247247

248248
static TensorIterator make_reduction(

‎aten/src/ATen/native/TensorAdvancedIndexing.cpp

+12-8
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,15 @@ static AdvancedIndex make_info(Tensor self, TensorList orig) {
215215
static TensorIterator make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) {
216216
TORCH_CHECK(is_expandable_to(value.sizes(), info.src.sizes()), "shape mismatch: value tensor of shape ", value.sizes(),
217217
" cannot be broadcast to indexing result of shape ", info.src.sizes());
218+
TORCH_CHECK(value.scalar_type() == info.src.scalar_type(),
219+
"Index put requires the source and destination dtypes match, "
220+
"got ", info.src.scalar_type(), " for the destination "
221+
"and ", value.scalar_type(), " for the source.");
218222
auto iter = TensorIterator();
219-
iter.dont_compute_common_dtype();
220223
iter.dont_resize_outputs();
224+
iter.check_all_same_dtype(false);
221225
iter.add_output(info.src);
222-
iter.add_input(value, info.src.device(), info.src.scalar_type());
226+
iter.add_input(value);
223227
for (auto& index : info.indices) {
224228
iter.add_input(index);
225229
}
@@ -229,7 +233,7 @@ static TensorIterator make_index_put_iterator(const AdvancedIndex& info, const T
229233

230234
static TensorIterator make_index_iterator(const AdvancedIndex& info) {
231235
auto iter = TensorIterator();
232-
iter.dont_compute_common_dtype();
236+
iter.check_all_same_dtype(false);
233237
iter.add_output(Tensor(), info.src.device(), info.src.scalar_type());
234238
iter.add_input(info.src);
235239
for (auto& index : info.indices) {
@@ -241,7 +245,7 @@ static TensorIterator make_index_iterator(const AdvancedIndex& info) {
241245

242246
static TensorIterator make_index_out_iterator(const AdvancedIndex& info, Tensor& result) {
243247
auto iter = TensorIterator();
244-
iter.dont_compute_common_dtype();
248+
iter.check_all_same_dtype(false);
245249
iter.add_output(result, info.src.device(), info.src.scalar_type());
246250
iter.add_input(info.src);
247251
for (auto& index : info.indices) {
@@ -437,7 +441,7 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
437441
auto slice_size = selfSlice.numel();
438442

439443
auto iter = TensorIterator();
440-
iter.dont_compute_common_dtype();
444+
iter.check_all_same_dtype(false);
441445
iter.dont_resize_outputs();
442446
iter.add_output(resultSlice);
443447
iter.add_input(selfSlice);
@@ -571,7 +575,7 @@ static Tensor & masked_fill_impl_cpu(Tensor & self, const Tensor & mask, Scalar
571575
}
572576

573577
auto iter = TensorIterator();
574-
iter.dont_compute_common_dtype();
578+
iter.check_all_same_dtype(false);
575579
iter.dont_resize_outputs();
576580
iter.add_output(self);
577581
iter.add_input(mask);
@@ -658,7 +662,7 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self,
658662
bool use_serial_kernel = self.numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1;
659663
if (use_serial_kernel) {
660664
auto iter = TensorIterator();
661-
iter.dont_compute_common_dtype();
665+
iter.check_all_same_dtype(false);
662666
iter.dont_resize_outputs();
663667
iter.add_output(result_strided);
664668
iter.add_input(_self);
@@ -681,7 +685,7 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self,
681685
std::partial_sum(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data);
682686

683687
auto iter = TensorIterator();
684-
iter.dont_compute_common_dtype();
688+
iter.check_all_same_dtype(false);
685689
iter.dont_resize_outputs();
686690
iter.add_output(result_strided);
687691
iter.add_input(_self);

‎aten/src/ATen/native/TensorCompare.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,12 @@ Tensor _s_where(const Tensor& condition, const Tensor& self, const Tensor& other
151151
TORCH_CHECK(self.dtype() == other.dtype(), "expected scalar type ", self.dtype(), " but found ", other.dtype());
152152
Tensor ret = at::empty(self.sizes(), self.options());
153153
auto iter = at::TensorIterator();
154+
iter.check_all_same_dtype(false);
154155
iter.set_check_mem_overlap(true);
155156
iter.add_output(ret);
156157
iter.add_input(condition);
157158
iter.add_input(self);
158159
iter.add_input(other);
159-
iter.dont_compute_common_dtype();
160160
iter.build();
161161
where_kernel(iter.device_type(), iter, condition.scalar_type());
162162
return ret;

0 commit comments

Comments
 (0)
Please sign in to comment.