Skip to content

Commit eefc79d

Browse files
YibinLiu666warrentdrew
authored andcommitted
【Hackathon 5th No.6】 为 Paddle 增强put_along_axis API -part (PaddlePaddle#59674)
1 parent f025385 commit eefc79d

21 files changed

+2730
-135
lines changed

paddle/phi/api/yaml/backward.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1760,8 +1760,8 @@
17601760
optional : boxes_num
17611761

17621762
- backward_op : put_along_axis_grad
1763-
forward : put_along_axis (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign") -> Tensor(out)
1764-
args : (Tensor arr, Tensor indices, Tensor out_grad, int axis, str reduce)
1763+
forward : put_along_axis (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign", bool include_self = true) -> Tensor(out)
1764+
args : (Tensor arr, Tensor indices, Tensor values, Tensor out, Tensor out_grad, int axis, str reduce, bool include_self)
17651765
output : Tensor(arr_grad), Tensor(values_grad)
17661766
infer_meta :
17671767
func : GeneralBinaryGradInferMeta

paddle/phi/api/yaml/op_compat.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2432,7 +2432,7 @@
24322432
outputs :
24332433
out : Result
24342434
attrs :
2435-
{axis : Axis, reduce : Reduce}
2435+
{axis : Axis, reduce : Reduce, include_self: Include_self}
24362436

24372437
- op : pylayer
24382438
backward : pylayer_grad

paddle/phi/api/yaml/ops.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2032,7 +2032,7 @@
20322032
backward : psroi_pool_grad
20332033

20342034
- op : put_along_axis
2035-
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign")
2035+
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign", bool include_self = true)
20362036
output : Tensor(out)
20372037
infer_meta :
20382038
func : UnchangedInferMeta

paddle/phi/backends/gpu/gpu_primitives.h

+176
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,182 @@ CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
395395
CudaAtomicAdd(imag, val.imag));
396396
}
397397

398+
// For atomicMul.
399+
CUDA_ATOMIC_WRAPPER(Mul, int) {
400+
int res = *address, old = res; // NOLINT
401+
do {
402+
old = res;
403+
res = atomicCAS(address, // NOLINT
404+
old, // NOLINT
405+
val * old); // NOLINT
406+
} while (old != res);
407+
return res;
408+
}
409+
410+
CUDA_ATOMIC_WRAPPER(Mul, unsigned int) {
411+
unsigned int res = *address, old = res; // NOLINT
412+
do {
413+
old = res;
414+
res = atomicCAS(address, // NOLINT
415+
old, // NOLINT
416+
val * old); // NOLINT
417+
} while (old != res);
418+
return res;
419+
}
420+
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
421+
// It because unsigned long long int is not necessarily uint64_t
422+
CUDA_ATOMIC_WRAPPER(Mul, unsigned long long int) { // NOLINT
423+
unsigned long long int old = *address, assumed; // NOLINT
424+
425+
do {
426+
assumed = old;
427+
old = atomicCAS(address, assumed, val * assumed);
428+
} while (assumed != old);
429+
return old;
430+
}
431+
432+
CUDA_ATOMIC_WRAPPER(Mul, int64_t) {
433+
// Here, we check long long int must be int64_t.
434+
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
435+
"long long should be int64");
436+
long long int res = *address, old = res; // NOLINT
437+
do {
438+
old = res;
439+
res = (long long int)atomicCAS( // NOLINT
440+
(unsigned long long int *)address, // NOLINT
441+
(unsigned long long int)old, // NOLINT
442+
(unsigned long long int)val * (unsigned long long int)old); // NOLINT
443+
} while (old != res);
444+
return res;
445+
}
446+
447+
CUDA_ATOMIC_WRAPPER(Mul, float) {
448+
int *const address_as_i = reinterpret_cast<int *>(address);
449+
int old = *address_as_i, assumed;
450+
451+
do {
452+
assumed = old;
453+
old = atomicCAS(
454+
address_as_i, assumed, __float_as_int(val * __int_as_float(assumed)));
455+
} while (assumed != old);
456+
457+
return __int_as_float(old);
458+
}
459+
460+
CUDA_ATOMIC_WRAPPER(Mul, double) {
461+
unsigned long long int *const address_as_ull = // NOLINT
462+
reinterpret_cast<unsigned long long int *>(address); // NOLINT
463+
unsigned long long int old = *address_as_ull, assumed; // NOLINT
464+
465+
do {
466+
assumed = old;
467+
468+
old = atomicCAS(address_as_ull,
469+
assumed,
470+
__double_as_longlong(val * __longlong_as_double(assumed)));
471+
} while (assumed != old);
472+
473+
return __longlong_as_double(old);
474+
}
475+
476+
#ifdef PADDLE_CUDA_FP16
477+
inline static __device__ uint32_t mul_to_low_half(uint32_t val, float x) {
478+
phi::dtype::float16 low_half;
479+
// The float16 in lower 16bits
480+
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
481+
low_half = static_cast<phi::dtype::float16>(static_cast<float>(low_half) * x);
482+
return (val & 0xFFFF0000u) | low_half.x;
483+
}
484+
485+
inline static __device__ uint32_t mul_to_high_half(uint32_t val, float x) {
486+
phi::dtype::float16 high_half;
487+
// The float16 in higher 16bits
488+
high_half.x = static_cast<uint16_t>(val >> 16);
489+
high_half =
490+
static_cast<phi::dtype::float16>(static_cast<float>(high_half) * x);
491+
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
492+
}
493+
494+
CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::float16) {
495+
if (*address >= val) {
496+
return *address;
497+
}
498+
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
499+
reinterpret_cast<char *>(address) -
500+
(reinterpret_cast<uintptr_t>(address) & 0x02));
501+
float val_f = static_cast<float>(val);
502+
uint32_t old = *address_as_ui;
503+
uint32_t assumed;
504+
if (((uintptr_t)address & 0x02) == 0) {
505+
// The float16 value stay at lower 16 bits of the address.
506+
do {
507+
assumed = old;
508+
old = atomicCAS(address_as_ui, assumed, mul_to_low_half(assumed, val_f));
509+
} while (old != assumed);
510+
phi::dtype::float16 ret;
511+
ret.x = old & 0xFFFFu;
512+
return ret;
513+
} else {
514+
// The float16 value stay at higher 16 bits of the address.
515+
do {
516+
assumed = old;
517+
old = atomicCAS(address_as_ui, assumed, mul_to_high_half(assumed, val_f));
518+
} while (old != assumed);
519+
phi::dtype::float16 ret;
520+
ret.x = old >> 16;
521+
return ret;
522+
}
523+
}
524+
#endif
525+
526+
inline static __device__ uint32_t bf16_mul_to_low_half(uint32_t val, float x) {
527+
phi::dtype::bfloat16 low_half;
528+
// The bfloat16 in lower 16bits
529+
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
530+
low_half =
531+
static_cast<phi::dtype::bfloat16>(static_cast<float>(low_half) * x);
532+
return (val & 0xFFFF0000u) | low_half.x;
533+
}
534+
535+
inline static __device__ uint32_t bf16_mul_to_high_half(uint32_t val, float x) {
536+
phi::dtype::bfloat16 high_half;
537+
// The bfloat16 in higher 16bits
538+
high_half.x = static_cast<uint16_t>(val >> 16);
539+
high_half =
540+
static_cast<phi::dtype::bfloat16>(static_cast<float>(high_half) * x);
541+
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
542+
}
543+
544+
CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::bfloat16) {
545+
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
546+
reinterpret_cast<char *>(address) -
547+
(reinterpret_cast<uintptr_t>(address) & 0x02));
548+
float val_f = static_cast<float>(val);
549+
uint32_t old = *address_as_ui;
550+
uint32_t assumed;
551+
if (((uintptr_t)address & 0x02) == 0) {
552+
// The bfloat16 value stay at lower 16 bits of the address.
553+
do {
554+
assumed = old;
555+
old = atomicCAS(
556+
address_as_ui, assumed, bf16_mul_to_low_half(assumed, val_f));
557+
} while (old != assumed);
558+
phi::dtype::bfloat16 ret;
559+
ret.x = old & 0xFFFFu;
560+
return ret;
561+
} else {
562+
// The bfloat16 value stay at higher 16 bits of the address.
563+
do {
564+
assumed = old;
565+
old = atomicCAS(
566+
address_as_ui, assumed, bf16_mul_to_high_half(assumed, val_f));
567+
} while (old != assumed);
568+
phi::dtype::bfloat16 ret;
569+
ret.x = old >> 16;
570+
return ret;
571+
}
572+
}
573+
398574
// For atomicMax
399575
USE_CUDA_ATOMIC(Max, int);
400576
USE_CUDA_ATOMIC(Max, unsigned int);

paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ void CummaxGradKernel(const Context& dev_ctx,
3838
}
3939
if (dtype == DataType::INT32) {
4040
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
41-
*x_grad, axis, indices, out_grad, dev_ctx);
41+
*x_grad, axis, indices, out_grad, true, dev_ctx);
4242
} else if (dtype == DataType::INT64) {
4343
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
44-
*x_grad, axis, indices, out_grad, dev_ctx);
44+
*x_grad, axis, indices, out_grad, true, dev_ctx);
4545
}
4646
}
4747

@@ -61,10 +61,10 @@ void CumminGradKernel(const Context& dev_ctx,
6161
}
6262
if (dtype == DataType::INT32) {
6363
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
64-
*x_grad, axis, indices, out_grad, dev_ctx);
64+
*x_grad, axis, indices, out_grad, true, dev_ctx);
6565
} else if (dtype == DataType::INT64) {
6666
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
67-
*x_grad, axis, indices, out_grad, dev_ctx);
67+
*x_grad, axis, indices, out_grad, true, dev_ctx);
6868
}
6969
}
7070

0 commit comments

Comments
 (0)