Skip to content

Commit d5748d9

Browse files
Iurii Zdebskyifacebook-github-bot
Iurii Zdebskyi
authored andcommittedSep 25, 2020
Enable binary ops with Scalar Lists with for foreach APIs (pytorch#45298)
Summary: Pull Request resolved: pytorch#45298 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D23931986 Pulled By: izdeby fbshipit-source-id: 281267cd6f90d57a169af89f9f10b0f4fcab47e3
1 parent f07ac6a commit d5748d9

14 files changed

+837
-115
lines changed
 

‎aten/src/ATen/native/ForeachOpsKernels.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,26 @@ std::vector<Tensor> foreach_tensor_##NAME##_scalar_kernel_slow(TensorList tensor
2424
return result; \
2525
}
2626

27+
#define FOREACH_BINARY_OP_SCALARLIST(NAME) \
28+
void foreach_tensor_##NAME##_scalarlist_kernel_slow_(TensorList tensors, at::ArrayRef<double> scalars) { \
29+
check_foreach_api_restrictions(tensors, scalars); \
30+
\
31+
for (int i = 0; i < tensors.size(); i++) { \
32+
tensors[i].NAME##_(scalars[i]); \
33+
} \
34+
} \
35+
\
36+
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_slow(TensorList tensors, at::ArrayRef<double> scalars) { \
37+
check_foreach_api_restrictions(tensors, scalars); \
38+
std::vector<Tensor> result; \
39+
result.reserve(tensors.size()); \
40+
for (int i = 0; i < tensors.size(); i++) { \
41+
result.emplace_back(tensors[i].NAME(scalars[i])); \
42+
} \
43+
\
44+
return result; \
45+
}
46+
2747
#define FOREACH_BINARY_OP_LIST(NAME) \
2848
std::vector<Tensor> foreach_tensor_##NAME##_list_kernel_slow(TensorList tensors1, TensorList tensors2) { \
2949
check_foreach_api_restrictions(tensors1, tensors2); \
@@ -117,6 +137,10 @@ FOREACH_BINARY_OP_SCALAR(add);
117137
FOREACH_BINARY_OP_SCALAR(sub);
118138
FOREACH_BINARY_OP_SCALAR(mul);
119139
FOREACH_BINARY_OP_SCALAR(div);
140+
FOREACH_BINARY_OP_SCALARLIST(add);
141+
FOREACH_BINARY_OP_SCALARLIST(sub);
142+
FOREACH_BINARY_OP_SCALARLIST(mul);
143+
FOREACH_BINARY_OP_SCALARLIST(div);
120144
FOREACH_BINARY_OP_LIST(mul);
121145
FOREACH_BINARY_OP_LIST(div);
122146
FOREACH_UNARY_OP(sqrt);

‎aten/src/ATen/native/ForeachUtils.h

+14
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ void check_foreach_api_restrictions(TensorList tensors1, TensorList tensors2) {
3131
}
3232
}
3333

34+
void check_foreach_api_restrictions(TensorList tensors, ArrayRef<double> scalars) {
35+
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
36+
TORCH_CHECK(scalars.size() > 0, "Scalars list must have at least one value.");
37+
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
38+
}
39+
3440
// To go via 'fast' path, several conditions must be satisfied
3541
// - All tensors must be on the same device
3642
// - All tensors must have strided layout
@@ -132,5 +138,13 @@ bool can_use_fast_route(TensorList tensors) {
132138
return true;
133139
}
134140

141+
bool can_use_fast_route(TensorList tensors, ArrayRef<double> scalars) {
142+
TORCH_CHECK(tensors.size() > 0, "Tensor list must have at least one tensor.");
143+
TORCH_CHECK(scalars.size() > 0, "Scalars list must have at least one value.");
144+
TORCH_CHECK(tensors.size() == scalars.size(), "Tensor list must have same number of elements as scalar list.");
145+
146+
return can_use_fast_route(tensors);
147+
}
148+
135149
}
136150
}} // at::native
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/native/ForeachUtils.h>
3+
#include <ATen/native/cuda/ForeachFunctors.cuh>
4+
5+
namespace at { namespace native {
6+
7+
template<template<class> class Op>
8+
std::vector<Tensor> foreach_binary_op(TensorList tensors, at::ArrayRef<double> scalars) {
9+
std::vector<std::vector<at::Tensor>> tensor_lists;
10+
std::vector<at::Tensor> vec_res;
11+
for (const auto& t: tensors) {
12+
vec_res.emplace_back(at::native::empty_like(t));
13+
}
14+
15+
tensor_lists.emplace_back(tensors.vec());
16+
tensor_lists.emplace_back(vec_res);
17+
18+
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda", [&]() {
19+
multi_tensor_apply<2>(tensor_lists, scalars, BinaryOpScalarListFunctor<scalar_t, Op>());
20+
});
21+
return tensor_lists[1];
22+
}
23+
24+
template<template<class> class Op>
25+
void foreach_binary_op_(TensorList tensors, at::ArrayRef<double> scalars) {
26+
std::vector<std::vector<at::Tensor>> tensor_lists;
27+
tensor_lists.emplace_back(tensors.vec());
28+
29+
AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, tensors[0].scalar_type(), "foreach_binary_op_scalarlist_cuda_", [&]() {
30+
multi_tensor_apply<1>(tensor_lists, scalars, BinaryOpScalarListFunctor_<scalar_t, Op>());
31+
});
32+
}
33+
34+
#define FOREACH_BINARY_OP_SCALARLIST(NAME, OP) \
35+
void foreach_tensor_##NAME##_scalarlist_kernel_cuda_(TensorList tensors, at::ArrayRef<double> scalars) { \
36+
check_foreach_api_restrictions(tensors); \
37+
\
38+
if (!can_use_fast_route(tensors, scalars)) { \
39+
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow_(tensors, scalars); \
40+
} \
41+
\
42+
foreach_binary_op_<OP>(tensors, scalars); \
43+
} \
44+
\
45+
std::vector<Tensor> foreach_tensor_##NAME##_scalarlist_kernel_cuda(TensorList tensors, at::ArrayRef<double> scalars) { \
46+
check_foreach_api_restrictions(tensors); \
47+
\
48+
if (!can_use_fast_route(tensors, scalars)) { \
49+
return at::native::foreach_tensor_##NAME##_scalarlist_kernel_slow(tensors, scalars); \
50+
} \
51+
\
52+
return foreach_binary_op<OP>(tensors, scalars); \
53+
}
54+
55+
FOREACH_BINARY_OP_SCALARLIST(add, std::plus);
56+
FOREACH_BINARY_OP_SCALARLIST(sub, std::minus);
57+
FOREACH_BINARY_OP_SCALARLIST(mul, std::multiplies);
58+
FOREACH_BINARY_OP_SCALARLIST(div, std::divides);
59+
60+
}} // namespace at::native

‎aten/src/ATen/native/cuda/ForeachFunctors.cuh

+115
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,121 @@ struct BinaryOpScalarFunctor {
118118
}
119119
};
120120

121+
template<typename T, template<class> class Op>
122+
struct BinaryOpScalarListFunctor_ {
123+
__device__ void operator() (
124+
int chunk_size,
125+
TensorListScalarListMetadata<1>& tl) {
126+
int tensor_loc = tl.block_to_tensor[blockIdx.x];
127+
int chunk_idx = tl.block_to_chunk[blockIdx.x];
128+
int n = tl.sizes[tensor_loc];
129+
130+
T* x = (T*)tl.addresses[0][tensor_loc];
131+
x += chunk_idx * chunk_size;
132+
133+
double y = tl.scalar_vals[tensor_loc];
134+
135+
n -= chunk_idx * chunk_size;
136+
137+
T r_x[kILP];
138+
139+
// to make things simple, we put aligned case in a different code path
140+
if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x)) {
141+
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
142+
// load
143+
load_store(r_x, x, 0 , i_start);
144+
#pragma unroll
145+
for(int ii = 0; ii < kILP; ii++) {
146+
r_x[ii] = Op<T>()(static_cast<T>(r_x[ii]), y);
147+
}
148+
// store
149+
load_store(x, r_x, i_start, 0);
150+
}
151+
}
152+
else {
153+
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
154+
#pragma unroll
155+
for(int ii = 0; ii < kILP; ii++) {
156+
r_x[ii] = 0;
157+
int i = i_start + threadIdx.x + ii * blockDim.x;
158+
if(i < n && i < chunk_size) {
159+
r_x[ii] = x[i];
160+
}
161+
}
162+
#pragma unroll
163+
for(int ii = 0; ii < kILP; ii++) {
164+
r_x[ii] = Op<T>()(static_cast<T>(r_x[ii]), y);
165+
}
166+
#pragma unroll
167+
for(int ii = 0; ii < kILP; ii++) {
168+
int i = i_start + threadIdx.x + ii * blockDim.x;
169+
if(i < n && i < chunk_size)
170+
x[i] = r_x[ii];
171+
}
172+
}
173+
}
174+
}
175+
};
176+
177+
template<typename T, template<class> class Op>
178+
struct BinaryOpScalarListFunctor {
179+
__device__ void operator() (
180+
int chunk_size,
181+
TensorListScalarListMetadata<2>& tl) {
182+
int tensor_loc = tl.block_to_tensor[blockIdx.x];
183+
int chunk_idx = tl.block_to_chunk[blockIdx.x];
184+
int n = tl.sizes[tensor_loc];
185+
186+
T* x = (T*)tl.addresses[0][tensor_loc];
187+
x += chunk_idx * chunk_size;
188+
189+
T* out = (T*)tl.addresses[1][tensor_loc];
190+
out += chunk_idx * chunk_size;
191+
192+
double y = tl.scalar_vals[tensor_loc];
193+
194+
n -= chunk_idx * chunk_size;
195+
196+
T r_x[kILP];
197+
198+
// to make things simple, we put aligned case in a different code path
199+
if(n % kILP == 0 && chunk_size % kILP == 0 && is_aligned(x) && is_aligned(out)) {
200+
for(int i_start = threadIdx.x; i_start * kILP < n && i_start * kILP < chunk_size; i_start += blockDim.x) {
201+
// load
202+
load_store(r_x, x, 0 , i_start);
203+
#pragma unroll
204+
for(int ii = 0; ii < kILP; ii++) {
205+
r_x[ii] = Op<T>()(static_cast<T>(r_x[ii]), y);
206+
}
207+
// store
208+
load_store(out, r_x, i_start, 0);
209+
}
210+
}
211+
else {
212+
for(int i_start = 0; i_start < n && i_start < chunk_size; i_start += blockDim.x * kILP) {
213+
#pragma unroll
214+
for(int ii = 0; ii < kILP; ii++) {
215+
r_x[ii] = 0;
216+
int i = i_start + threadIdx.x + ii * blockDim.x;
217+
if(i < n && i < chunk_size) {
218+
r_x[ii] = x[i];
219+
}
220+
}
221+
#pragma unroll
222+
for(int ii = 0; ii < kILP; ii++) {
223+
r_x[ii] = Op<T>()(static_cast<T>(r_x[ii]), y);
224+
}
225+
#pragma unroll
226+
for(int ii = 0; ii < kILP; ii++) {
227+
int i = i_start + threadIdx.x + ii * blockDim.x;
228+
if(i < n && i < chunk_size)
229+
out[i] = r_x[ii];
230+
}
231+
}
232+
}
233+
}
234+
};
235+
121236
template<typename T, template<class> class Op>
122237
struct BinaryOpListAlphaFunctor_ {
123238
__device__ void operator() (

‎aten/src/ATen/native/cuda/MultiTensorApply.cuh

+70
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s
2626
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
2727
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
2828
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
29+
static constexpr int depth_to_max_tensors_scalarlist[5] = {96, 64, 48, 36, 30};
2930

3031
template<int n> struct TensorListMetadata
3132
{
@@ -35,6 +36,15 @@ template<int n> struct TensorListMetadata
3536
int block_to_chunk[depth_to_max_blocks[n-1]];
3637
};
3738

39+
template<int n> struct TensorListScalarListMetadata
40+
{
41+
void* addresses[n][depth_to_max_tensors_scalarlist[n-1]];
42+
int sizes[depth_to_max_tensors_scalarlist[n-1]];
43+
double scalar_vals[depth_to_max_tensors_scalarlist[n-1]];
44+
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
45+
int block_to_chunk[depth_to_max_blocks[n-1]];
46+
};
47+
3848
template<typename T, typename U, typename... ArgTypes>
3949
C10_LAUNCH_BOUNDS_1(kBlockSize)
4050
__global__ void
@@ -49,11 +59,71 @@ multi_tensor_apply_kernel(
4959
template<int depth, typename T, typename... ArgTypes>
5060
void multi_tensor_apply(
5161
std::vector<std::vector<at::Tensor>>& tensor_lists,
62+
at::ArrayRef<double> scalars,
5263
T callable,
5364
ArgTypes... args) {
5465
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth.");
5566
const cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
67+
size_t n_tensors = tensor_lists[0].size();
68+
TensorListScalarListMetadata<depth> tensorListMeta;
69+
70+
int loc_block_info = 0;
71+
int loc_tensor_info = 0;
72+
for(size_t t = 0; t < n_tensors; t++) {
73+
74+
tensorListMeta.scalar_vals[loc_tensor_info] = scalars[t];
75+
76+
tensorListMeta.sizes[loc_tensor_info] = tensor_lists[0][t].numel();
77+
for (int d = 0; d < depth; d++) {
78+
tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][t].data_ptr();
79+
}
80+
loc_tensor_info++;
81+
82+
int chunks = (tensor_lists[0][t].numel() + kChunkSize - 1)/kChunkSize;
83+
for (int chunk = 0; chunk < chunks; chunk++) {
84+
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
85+
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
86+
loc_block_info++;
87+
88+
bool tensors_full = (loc_tensor_info == depth_to_max_tensors_scalarlist[depth-1] &&
89+
chunk == chunks - 1);
90+
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
91+
bool last_chunk = (t == n_tensors - 1 && chunk == chunks - 1);
92+
93+
if (tensors_full || blocks_full || last_chunk) {
94+
multi_tensor_apply_kernel<<<loc_block_info, kBlockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
95+
tensorListMeta,
96+
callable,
97+
args...);
98+
99+
AT_CUDA_CHECK(cudaGetLastError());
100+
101+
// Reset.
102+
loc_block_info = 0;
103+
if(chunk == chunks - 1) {
104+
loc_tensor_info = 0;
105+
}
106+
else {
107+
tensorListMeta.sizes[0] = tensorListMeta.sizes[loc_tensor_info-1];
108+
tensorListMeta.scalar_vals[0] = tensorListMeta.scalar_vals[loc_tensor_info-1];
109+
for(int d = 0; d < depth; d++) {
110+
tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info-1];
111+
}
112+
loc_tensor_info = 1;
113+
}
114+
}
115+
}
116+
}
117+
}
118+
56119

120+
template<int depth, typename T, typename... ArgTypes>
121+
void multi_tensor_apply(
122+
std::vector<std::vector<at::Tensor>>& tensor_lists,
123+
T callable,
124+
ArgTypes... args) {
125+
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth.");
126+
const cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
57127
size_t n_tensors = tensor_lists[0].size();
58128
TensorListMetadata<depth> tensorListMeta;
59129

0 commit comments

Comments
 (0)
Please sign in to comment.