13
13
14
14
#include < thrust/execution_policy.h>
15
15
#include < thrust/unique.h>
16
- # include < thrust/device_vector.h >
16
+
17
17
18
18
namespace at {
19
19
namespace native {
@@ -82,7 +82,8 @@ __global__ void compute_grad_weight_bags(
82
82
int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
83
83
int64_t stride, int mode_mean, const int64_t *bag_size,
84
84
scalar_t * per_sample_weights, int64_t per_sample_weights_stride,
85
- int64_t * segment_offsets, int64_t num_of_segments, scalar_t *grad_weight_per_segment,
85
+ int64_t * segment_offsets, int64_t num_of_segments,
86
+ acc_type<scalar_t , true > *grad_weight_per_segment,
86
87
const int64_t stride_warped) {
87
88
88
89
const int gid = blockIdx .x * blockDim .x + threadIdx .x ;
@@ -126,7 +127,7 @@ __global__ void compute_grad_weight(
126
127
int64_t stride,
127
128
int64_t * segment_offsets,
128
129
int64_t num_of_segments,
129
- scalar_t *grad_weight_per_segment,
130
+ acc_type< scalar_t , true > *grad_weight_per_segment,
130
131
int padding_idx,
131
132
const int64_t stride_warped) {
132
133
@@ -142,9 +143,6 @@ __global__ void compute_grad_weight(
142
143
}
143
144
const int idx_begin = segment_offsets[id];
144
145
const int idx_end = (id == num_of_segments-1 )?numel:segment_offsets[id+1 ];
145
- if (idx_begin == padding_idx) {
146
- return ;
147
- }
148
146
149
147
accscalar_t weight = 0 ;
150
148
for (int idx=idx_begin; idx < idx_end; ++idx) {
@@ -161,7 +159,8 @@ __global__ void compute_grad_weight(
161
159
template <typename scalar_t >
162
160
__global__ void sum_and_scatter (
163
161
int64_t *input, scalar_t *gradWeight, int64_t stride,
164
- int64_t * segment_offsets, int64_t num_of_segments, const scalar_t *grad_weight_per_segment,
162
+ int64_t * segment_offsets, int64_t num_of_segments,
163
+ const acc_type<scalar_t , true > *grad_weight_per_segment,
165
164
const int64_t *segment_sizes_offsets, int64_t num_of_partial_segments,
166
165
const int64_t stride_warped) {
167
166
@@ -212,7 +211,7 @@ Tensor embedding_backward_cuda_kernel(
212
211
// spawn a warp per index. In this context, a segment is a number of rows that should
213
212
// be summarized.
214
213
// Unit: index in `sorted_indices` and `orig_indices`
215
- thrust::device_vector< int64_t > segment_offsets ( numel);
214
+ auto segment_offsets = at::empty ({ numel}, orig_indices. options () );
216
215
int64_t num_of_segments;
217
216
{
218
217
auto sorted_indices_dev = thrust::device_ptr<int64_t >(sorted_indices.data <int64_t >());
@@ -224,18 +223,18 @@ Tensor embedding_backward_cuda_kernel(
224
223
sorted_indices_dev + numel,
225
224
thrust::make_counting_iterator (0 ),
226
225
dummy_dev,
227
- thrust::raw_pointer_cast (segment_offsets.data ()));
226
+ thrust::device_ptr< int64_t > (segment_offsets.data < int64_t > ()));
228
227
num_of_segments = thrust::get<0 >(ends) - dummy_dev;
229
228
}
230
229
231
230
// We split the segments up into sizes of `NROWS_PER_THREAD`
232
231
// Compute the number partial-segments per segment (some partial-segments
233
232
// may not be the full `NROWS_PER_THREAD` number of rows)
234
- thrust::device_vector< int64_t > partials_per_segment ( num_of_segments);
233
+ auto partials_per_segment = at::empty ({ num_of_segments}, orig_indices. options () );
235
234
{
236
235
krn_partials_per_segment<<<ceil_div(num_of_segments, 32 ), 32 , 0 , stream>>> (
237
- thrust::raw_pointer_cast ( partials_per_segment.data () ),
238
- thrust::raw_pointer_cast ( segment_offsets.data () ),
236
+ partials_per_segment.data < int64_t >( ),
237
+ segment_offsets.data < int64_t >( ),
239
238
num_of_segments,
240
239
numel);
241
240
}
@@ -244,82 +243,85 @@ Tensor embedding_backward_cuda_kernel(
244
243
// of each partial-segment in `sorted_indices`, we need to compute the
245
244
// start position of each _segment_ in `partial_segment_offset`.
246
245
// Unit: index in `partial_segment_offset`
247
- thrust::device_vector< int64_t > partials_per_segment_offset ( num_of_segments);
246
+ auto partials_per_segment_offset = at::empty ({ num_of_segments}, orig_indices. options ());
248
247
thrust::exclusive_scan (
249
248
policy,
250
- partials_per_segment.begin ( ),
251
- partials_per_segment.end ( ),
252
- partials_per_segment_offset.begin ( ));
249
+ thrust::device_ptr< int64_t >( partials_per_segment.data < int64_t >() ),
250
+ thrust::device_ptr< int64_t >( partials_per_segment.data < int64_t >()+num_of_segments ),
251
+ thrust::device_ptr< int64_t >( partials_per_segment_offset.data < int64_t >() ));
253
252
254
253
// The total number of partial-segments is the sum of `partials_per_segment_offset`
255
- const int num_of_partial_segments = partials_per_segment[num_of_segments-1 ] +
256
- partials_per_segment_offset[num_of_segments-1 ];
254
+ const int num_of_partial_segments = partials_per_segment[num_of_segments-1 ]. item < int64_t >() +
255
+ partials_per_segment_offset[num_of_segments-1 ]. item < int64_t >() ;
257
256
258
257
// Now we can compute the start position of each partial-segment
259
258
// Unit: index in `sorted_indices` and `orig_indices`
260
- thrust::device_vector< int64_t > partial_segment_offset ( num_of_partial_segments);
259
+ auto partial_segment_offset = at::empty ({ num_of_partial_segments}, orig_indices. options () );
261
260
{
262
261
krn_partial_segment_offset<<<ceil_div(num_of_segments, 32 ), 32 , 0 , stream>>> (
263
- thrust::raw_pointer_cast ( partial_segment_offset.data () ),
264
- thrust::raw_pointer_cast ( partials_per_segment.data () ),
265
- thrust::raw_pointer_cast ( partials_per_segment_offset.data () ),
266
- thrust::raw_pointer_cast ( segment_offsets.data () ),
262
+ partial_segment_offset.data < int64_t >( ),
263
+ partials_per_segment.data < int64_t >( ),
264
+ partials_per_segment_offset.data < int64_t >( ),
265
+ segment_offsets.data < int64_t >( ),
267
266
num_of_segments);
268
267
}
269
268
270
- auto grad_weight_per_segment = at::empty ({num_of_partial_segments, stride}, grad.options ());
271
269
const int stride_warped = ceil_div (stride, WARP_SIZE)*WARP_SIZE;
272
270
const int block = std::min (stride_warped, MAX_BLOCK_SIZE);
273
271
const int grid = ceil_div (num_of_partial_segments*stride_warped, block);
274
272
275
- // Compute the sum of each partial-segment and handle bags
276
- if (offset2bag.defined ()) {
277
- AT_DISPATCH_FLOATING_TYPES_AND_HALF (
278
- grad.scalar_type (), " embedding_bag_backward_cuda_compute_grad_weight" , [&] {
279
- compute_grad_weight_bags<scalar_t ><<<grid, block, 0 , stream>>> (
280
- orig_indices.data <int64_t >(),
281
- grad.data <scalar_t >(),
282
- offset2bag.data <int64_t >(),
283
- count.defined () ? count.data <int64_t >() : nullptr , numel, stride,
284
- mode_mean, bag_size.data <int64_t >(),
285
- per_sample_weights.defined () ? per_sample_weights.data <scalar_t >() : NULL ,
286
- per_sample_weights.defined () ? per_sample_weights.stride (0 ) : 0 ,
287
- thrust::raw_pointer_cast (partial_segment_offset.data ()),
288
- num_of_partial_segments, grad_weight_per_segment.data <scalar_t >(),
289
- stride_warped);
290
- });
291
- } else {
292
- AT_DISPATCH_FLOATING_TYPES_AND_HALF (
293
- grad.scalar_type (), " embedding_bag_backward_cuda_compute_grad_weight" , [&] {
294
- compute_grad_weight<scalar_t ><<<grid, block, 0 , stream>>> (
295
- orig_indices.data <int64_t >(),
296
- grad.data <scalar_t >(),
297
- count.defined () ? count.data <int64_t >() : nullptr ,
298
- numel, stride,
299
- thrust::raw_pointer_cast (partial_segment_offset.data ()),
300
- num_of_partial_segments,
301
- grad_weight_per_segment.data <scalar_t >(),
302
- padding_idx,
303
- stride_warped);
304
- });
305
- }
306
- THCudaCheck (cudaGetLastError ());
307
-
308
- // Finally, we sum all the partial-sums and scatter them
309
- // into `grad_weight`.
310
- const int grid2 = ceil_div (num_of_segments*stride_warped, block);
311
273
AT_DISPATCH_FLOATING_TYPES_AND_HALF (
312
- grad.scalar_type (), " embedding_bag_backward_cuda_sum_and_scatter" , [&] {
313
- sum_and_scatter<scalar_t ><<<grid2, block, 0 , stream>>> (
314
- sorted_indices.data <int64_t >(),
315
- grad_weight.data <scalar_t >(),
316
- stride,
317
- thrust::raw_pointer_cast (segment_offsets.data ()),
318
- num_of_segments, grad_weight_per_segment.data <scalar_t >(),
319
- thrust::raw_pointer_cast (partials_per_segment_offset.data ()),
320
- num_of_partial_segments, stride_warped);
274
+ grad.scalar_type (), " embedding_bag_backward_cuda_compute_grad_weight" , [&] {
275
+ // For numerical stability, the dtype of `grad_weight_per_segment`
276
+ // should match `acc_type`
277
+ using partial_weight_t = acc_type<scalar_t , true >;
278
+ TensorOptions op;
279
+ if (grad.dtype () == at::kHalf ) {
280
+ op = grad.options ().dtype (at::kFloat );
281
+ } else {
282
+ op = grad.options ();
283
+ }
284
+ auto grad_weight_per_segment = at::empty ({num_of_partial_segments, stride}, op);
285
+ // Compute the sum of each partial-segment and handle bags
286
+ if (offset2bag.defined ()) {
287
+ compute_grad_weight_bags<scalar_t ><<<grid, block, 0 , stream>>> (
288
+ orig_indices.data <int64_t >(),
289
+ grad.data <scalar_t >(),
290
+ offset2bag.data <int64_t >(),
291
+ count.defined () ? count.data <int64_t >() : nullptr , numel, stride,
292
+ mode_mean, bag_size.data <int64_t >(),
293
+ per_sample_weights.defined () ? per_sample_weights.data <scalar_t >() : NULL ,
294
+ per_sample_weights.defined () ? per_sample_weights.stride (0 ) : 0 ,
295
+ partial_segment_offset.data <int64_t >(),
296
+ num_of_partial_segments, grad_weight_per_segment.data <partial_weight_t >(),
297
+ stride_warped);
298
+ } else {
299
+ compute_grad_weight<scalar_t ><<<grid, block, 0 , stream>>> (
300
+ orig_indices.data <int64_t >(),
301
+ grad.data <scalar_t >(),
302
+ count.defined () ? count.data <int64_t >() : nullptr ,
303
+ numel, stride,
304
+ partial_segment_offset.data <int64_t >(),
305
+ num_of_partial_segments,
306
+ grad_weight_per_segment.data <partial_weight_t >(),
307
+ padding_idx,
308
+ stride_warped);
309
+ }
310
+ THCudaCheck (cudaGetLastError ());
311
+
312
+ // Finally, we sum all the partial-sums and scatter them
313
+ // into `grad_weight`.
314
+ const int grid2 = ceil_div (num_of_segments*stride_warped, block);
315
+ sum_and_scatter<scalar_t ><<<grid2, block, 0 , stream>>> (
316
+ sorted_indices.data <int64_t >(),
317
+ grad_weight.data <scalar_t >(),
318
+ stride,
319
+ segment_offsets.data <int64_t >(),
320
+ num_of_segments, grad_weight_per_segment.data <partial_weight_t >(),
321
+ partials_per_segment_offset.data <int64_t >(),
322
+ num_of_partial_segments, stride_warped);
323
+ THCudaCheck (cudaGetLastError ());
321
324
});
322
- THCudaCheck (cudaGetLastError ());
323
325
return grad_weight;
324
326
}
325
327
0 commit comments