Skip to content

Commit c9a8413

Browse files
madsbkfacebook-github-bot
authored andcommittedJul 2, 2019
Numerical stability of embedding kernels (pytorch#22401)
Summary: Address the issue raised in pytorch#22377. The PR pytorch#22016 introduces a temporary tensor of weights `grad_weight_per_segment` of the same dtype as the end result, which can be a problem when using `float16`. In this PR, it now use a `float32` temporary tensor when the input is `float16`. ngimel, can I get you to review? I think I have fixed the issues you have pointed out. Pull Request resolved: pytorch#22401 Differential Revision: D16077319 Pulled By: mrshenli fbshipit-source-id: 7cfad7f40b4d41a244052baa2982ab51bbbd7309
1 parent b768777 commit c9a8413

File tree

1 file changed

+72
-70
lines changed

1 file changed

+72
-70
lines changed
 

‎aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu

+72-70
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
#include <thrust/execution_policy.h>
1515
#include <thrust/unique.h>
16-
#include <thrust/device_vector.h>
16+
1717

1818
namespace at {
1919
namespace native {
@@ -82,7 +82,8 @@ __global__ void compute_grad_weight_bags(
8282
int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
8383
int64_t stride, int mode_mean, const int64_t *bag_size,
8484
scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
85-
int64_t* segment_offsets, int64_t num_of_segments, scalar_t *grad_weight_per_segment,
85+
int64_t* segment_offsets, int64_t num_of_segments,
86+
acc_type<scalar_t, true> *grad_weight_per_segment,
8687
const int64_t stride_warped) {
8788

8889
const int gid = blockIdx.x * blockDim.x + threadIdx.x;
@@ -126,7 +127,7 @@ __global__ void compute_grad_weight(
126127
int64_t stride,
127128
int64_t* segment_offsets,
128129
int64_t num_of_segments,
129-
scalar_t *grad_weight_per_segment,
130+
acc_type<scalar_t, true> *grad_weight_per_segment,
130131
int padding_idx,
131132
const int64_t stride_warped) {
132133

@@ -142,9 +143,6 @@ __global__ void compute_grad_weight(
142143
}
143144
const int idx_begin = segment_offsets[id];
144145
const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
145-
if (idx_begin == padding_idx) {
146-
return;
147-
}
148146

149147
accscalar_t weight = 0;
150148
for (int idx=idx_begin; idx < idx_end; ++idx) {
@@ -161,7 +159,8 @@ __global__ void compute_grad_weight(
161159
template <typename scalar_t>
162160
__global__ void sum_and_scatter(
163161
int64_t *input, scalar_t *gradWeight, int64_t stride,
164-
int64_t* segment_offsets, int64_t num_of_segments, const scalar_t *grad_weight_per_segment,
162+
int64_t* segment_offsets, int64_t num_of_segments,
163+
const acc_type<scalar_t, true> *grad_weight_per_segment,
165164
const int64_t *segment_sizes_offsets, int64_t num_of_partial_segments,
166165
const int64_t stride_warped) {
167166

@@ -212,7 +211,7 @@ Tensor embedding_backward_cuda_kernel(
212211
// spawn a warp per index. In this context, a segment is a number of rows that should
213212
// be summarized.
214213
// Unit: index in `sorted_indices` and `orig_indices`
215-
thrust::device_vector<int64_t> segment_offsets(numel);
214+
auto segment_offsets = at::empty({numel}, orig_indices.options());
216215
int64_t num_of_segments;
217216
{
218217
auto sorted_indices_dev = thrust::device_ptr<int64_t>(sorted_indices.data<int64_t>());
@@ -224,18 +223,18 @@ Tensor embedding_backward_cuda_kernel(
224223
sorted_indices_dev + numel,
225224
thrust::make_counting_iterator(0),
226225
dummy_dev,
227-
thrust::raw_pointer_cast(segment_offsets.data()));
226+
thrust::device_ptr<int64_t>(segment_offsets.data<int64_t>()));
228227
num_of_segments = thrust::get<0>(ends) - dummy_dev;
229228
}
230229

231230
// We split the segments up into sizes of `NROWS_PER_THREAD`
232231
// Compute the number partial-segments per segment (some partial-segments
233232
// may not be the full `NROWS_PER_THREAD` number of rows)
234-
thrust::device_vector<int64_t> partials_per_segment(num_of_segments);
233+
auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options());
235234
{
236235
krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
237-
thrust::raw_pointer_cast(partials_per_segment.data()),
238-
thrust::raw_pointer_cast(segment_offsets.data()),
236+
partials_per_segment.data<int64_t>(),
237+
segment_offsets.data<int64_t>(),
239238
num_of_segments,
240239
numel);
241240
}
@@ -244,82 +243,85 @@ Tensor embedding_backward_cuda_kernel(
244243
// of each partial-segment in `sorted_indices`, we need to compute the
245244
// start position of each _segment_ in `partial_segment_offset`.
246245
// Unit: index in `partial_segment_offset`
247-
thrust::device_vector<int64_t> partials_per_segment_offset(num_of_segments);
246+
auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options());
248247
thrust::exclusive_scan(
249248
policy,
250-
partials_per_segment.begin(),
251-
partials_per_segment.end(),
252-
partials_per_segment_offset.begin());
249+
thrust::device_ptr<int64_t>(partials_per_segment.data<int64_t>()),
250+
thrust::device_ptr<int64_t>(partials_per_segment.data<int64_t>()+num_of_segments),
251+
thrust::device_ptr<int64_t>(partials_per_segment_offset.data<int64_t>()));
253252

254253
// The total number of partial-segments is the sum of `partials_per_segment_offset`
255-
const int num_of_partial_segments = partials_per_segment[num_of_segments-1] +
256-
partials_per_segment_offset[num_of_segments-1];
254+
const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item<int64_t>() +
255+
partials_per_segment_offset[num_of_segments-1].item<int64_t>();
257256

258257
// Now we can compute the start position of each partial-segment
259258
// Unit: index in `sorted_indices` and `orig_indices`
260-
thrust::device_vector<int64_t> partial_segment_offset(num_of_partial_segments);
259+
auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options());
261260
{
262261
krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
263-
thrust::raw_pointer_cast(partial_segment_offset.data()),
264-
thrust::raw_pointer_cast(partials_per_segment.data()),
265-
thrust::raw_pointer_cast(partials_per_segment_offset.data()),
266-
thrust::raw_pointer_cast(segment_offsets.data()),
262+
partial_segment_offset.data<int64_t>(),
263+
partials_per_segment.data<int64_t>(),
264+
partials_per_segment_offset.data<int64_t>(),
265+
segment_offsets.data<int64_t>(),
267266
num_of_segments);
268267
}
269268

270-
auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, grad.options());
271269
const int stride_warped = ceil_div(stride, WARP_SIZE)*WARP_SIZE;
272270
const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
273271
const int grid = ceil_div(num_of_partial_segments*stride_warped, block);
274272

275-
// Compute the sum of each partial-segment and handle bags
276-
if (offset2bag.defined()) {
277-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
278-
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
279-
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
280-
orig_indices.data<int64_t>(),
281-
grad.data<scalar_t>(),
282-
offset2bag.data<int64_t>(),
283-
count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
284-
mode_mean, bag_size.data<int64_t>(),
285-
per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
286-
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
287-
thrust::raw_pointer_cast(partial_segment_offset.data()),
288-
num_of_partial_segments, grad_weight_per_segment.data<scalar_t>(),
289-
stride_warped);
290-
});
291-
} else {
292-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
293-
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
294-
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
295-
orig_indices.data<int64_t>(),
296-
grad.data<scalar_t>(),
297-
count.defined() ? count.data<int64_t>() : nullptr,
298-
numel, stride,
299-
thrust::raw_pointer_cast(partial_segment_offset.data()),
300-
num_of_partial_segments,
301-
grad_weight_per_segment.data<scalar_t>(),
302-
padding_idx,
303-
stride_warped);
304-
});
305-
}
306-
THCudaCheck(cudaGetLastError());
307-
308-
// Finally, we sum all the partial-sums and scatter them
309-
// into `grad_weight`.
310-
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
311273
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
312-
grad.scalar_type(), "embedding_bag_backward_cuda_sum_and_scatter", [&] {
313-
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
314-
sorted_indices.data<int64_t>(),
315-
grad_weight.data<scalar_t>(),
316-
stride,
317-
thrust::raw_pointer_cast(segment_offsets.data()),
318-
num_of_segments, grad_weight_per_segment.data<scalar_t>(),
319-
thrust::raw_pointer_cast(partials_per_segment_offset.data()),
320-
num_of_partial_segments, stride_warped);
274+
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
275+
// For numerical stability, the dtype of `grad_weight_per_segment`
276+
// should match `acc_type`
277+
using partial_weight_t = acc_type<scalar_t, true>;
278+
TensorOptions op;
279+
if(grad.dtype() == at::kHalf) {
280+
op = grad.options().dtype(at::kFloat);
281+
} else {
282+
op = grad.options();
283+
}
284+
auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op);
285+
// Compute the sum of each partial-segment and handle bags
286+
if (offset2bag.defined()) {
287+
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
288+
orig_indices.data<int64_t>(),
289+
grad.data<scalar_t>(),
290+
offset2bag.data<int64_t>(),
291+
count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
292+
mode_mean, bag_size.data<int64_t>(),
293+
per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
294+
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
295+
partial_segment_offset.data<int64_t>(),
296+
num_of_partial_segments, grad_weight_per_segment.data<partial_weight_t>(),
297+
stride_warped);
298+
} else {
299+
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
300+
orig_indices.data<int64_t>(),
301+
grad.data<scalar_t>(),
302+
count.defined() ? count.data<int64_t>() : nullptr,
303+
numel, stride,
304+
partial_segment_offset.data<int64_t>(),
305+
num_of_partial_segments,
306+
grad_weight_per_segment.data<partial_weight_t>(),
307+
padding_idx,
308+
stride_warped);
309+
}
310+
THCudaCheck(cudaGetLastError());
311+
312+
// Finally, we sum all the partial-sums and scatter them
313+
// into `grad_weight`.
314+
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
315+
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
316+
sorted_indices.data<int64_t>(),
317+
grad_weight.data<scalar_t>(),
318+
stride,
319+
segment_offsets.data<int64_t>(),
320+
num_of_segments, grad_weight_per_segment.data<partial_weight_t>(),
321+
partials_per_segment_offset.data<int64_t>(),
322+
num_of_partial_segments, stride_warped);
323+
THCudaCheck(cudaGetLastError());
321324
});
322-
THCudaCheck(cudaGetLastError());
323325
return grad_weight;
324326
}
325327

0 commit comments

Comments
 (0)
Please sign in to comment.