@@ -395,6 +395,182 @@ CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
395
395
CudaAtomicAdd (imag, val.imag ));
396
396
}
397
397
398
+ // For atomicMul.
399
+ CUDA_ATOMIC_WRAPPER (Mul, int ) {
400
+ int res = *address, old = res; // NOLINT
401
+ do {
402
+ old = res;
403
+ res = atomicCAS (address, // NOLINT
404
+ old, // NOLINT
405
+ val * old); // NOLINT
406
+ } while (old != res);
407
+ return res;
408
+ }
409
+
410
+ CUDA_ATOMIC_WRAPPER (Mul, unsigned int ) {
411
+ unsigned int res = *address, old = res; // NOLINT
412
+ do {
413
+ old = res;
414
+ res = atomicCAS (address, // NOLINT
415
+ old, // NOLINT
416
+ val * old); // NOLINT
417
+ } while (old != res);
418
+ return res;
419
+ }
420
+ // CUDA API uses unsigned long long int, we cannot use uint64_t here.
421
+ // It because unsigned long long int is not necessarily uint64_t
422
+ CUDA_ATOMIC_WRAPPER (Mul, unsigned long long int ) { // NOLINT
423
+ unsigned long long int old = *address, assumed; // NOLINT
424
+
425
+ do {
426
+ assumed = old;
427
+ old = atomicCAS (address, assumed, val * assumed);
428
+ } while (assumed != old);
429
+ return old;
430
+ }
431
+
432
+ CUDA_ATOMIC_WRAPPER (Mul, int64_t ) {
433
+ // Here, we check long long int must be int64_t.
434
+ static_assert (sizeof (int64_t ) == sizeof (long long int ), // NOLINT
435
+ " long long should be int64" );
436
+ long long int res = *address, old = res; // NOLINT
437
+ do {
438
+ old = res;
439
+ res = (long long int )atomicCAS ( // NOLINT
440
+ (unsigned long long int *)address, // NOLINT
441
+ (unsigned long long int )old, // NOLINT
442
+ (unsigned long long int )val * (unsigned long long int )old); // NOLINT
443
+ } while (old != res);
444
+ return res;
445
+ }
446
+
447
+ CUDA_ATOMIC_WRAPPER (Mul, float ) {
448
+ int *const address_as_i = reinterpret_cast <int *>(address);
449
+ int old = *address_as_i, assumed;
450
+
451
+ do {
452
+ assumed = old;
453
+ old = atomicCAS (
454
+ address_as_i, assumed, __float_as_int (val * __int_as_float (assumed)));
455
+ } while (assumed != old);
456
+
457
+ return __int_as_float (old);
458
+ }
459
+
460
+ CUDA_ATOMIC_WRAPPER (Mul, double ) {
461
+ unsigned long long int *const address_as_ull = // NOLINT
462
+ reinterpret_cast <unsigned long long int *>(address); // NOLINT
463
+ unsigned long long int old = *address_as_ull, assumed; // NOLINT
464
+
465
+ do {
466
+ assumed = old;
467
+
468
+ old = atomicCAS (address_as_ull,
469
+ assumed,
470
+ __double_as_longlong (val * __longlong_as_double (assumed)));
471
+ } while (assumed != old);
472
+
473
+ return __longlong_as_double (old);
474
+ }
475
+
476
+ #ifdef PADDLE_CUDA_FP16
477
+ inline static __device__ uint32_t mul_to_low_half (uint32_t val, float x) {
478
+ phi::dtype::float16 low_half;
479
+ // The float16 in lower 16bits
480
+ low_half.x = static_cast <uint16_t >(val & 0xFFFFu );
481
+ low_half = static_cast <phi::dtype::float16>(static_cast <float >(low_half) * x);
482
+ return (val & 0xFFFF0000u ) | low_half.x ;
483
+ }
484
+
485
+ inline static __device__ uint32_t mul_to_high_half (uint32_t val, float x) {
486
+ phi::dtype::float16 high_half;
487
+ // The float16 in higher 16bits
488
+ high_half.x = static_cast <uint16_t >(val >> 16 );
489
+ high_half =
490
+ static_cast <phi::dtype::float16>(static_cast <float >(high_half) * x);
491
+ return (val & 0xFFFFu ) | (static_cast <uint32_t >(high_half.x ) << 16 );
492
+ }
493
+
494
+ CUDA_ATOMIC_WRAPPER (Mul, phi::dtype::float16) {
495
+ if (*address >= val) {
496
+ return *address;
497
+ }
498
+ uint32_t *address_as_ui = reinterpret_cast <uint32_t *>(
499
+ reinterpret_cast <char *>(address) -
500
+ (reinterpret_cast <uintptr_t >(address) & 0x02 ));
501
+ float val_f = static_cast <float >(val);
502
+ uint32_t old = *address_as_ui;
503
+ uint32_t assumed;
504
+ if (((uintptr_t )address & 0x02 ) == 0 ) {
505
+ // The float16 value stay at lower 16 bits of the address.
506
+ do {
507
+ assumed = old;
508
+ old = atomicCAS (address_as_ui, assumed, mul_to_low_half (assumed, val_f));
509
+ } while (old != assumed);
510
+ phi::dtype::float16 ret;
511
+ ret.x = old & 0xFFFFu ;
512
+ return ret;
513
+ } else {
514
+ // The float16 value stay at higher 16 bits of the address.
515
+ do {
516
+ assumed = old;
517
+ old = atomicCAS (address_as_ui, assumed, mul_to_high_half (assumed, val_f));
518
+ } while (old != assumed);
519
+ phi::dtype::float16 ret;
520
+ ret.x = old >> 16 ;
521
+ return ret;
522
+ }
523
+ }
524
+ #endif
525
+
526
+ inline static __device__ uint32_t bf16_mul_to_low_half (uint32_t val, float x) {
527
+ phi::dtype::bfloat16 low_half;
528
+ // The bfloat16 in lower 16bits
529
+ low_half.x = static_cast <uint16_t >(val & 0xFFFFu );
530
+ low_half =
531
+ static_cast <phi::dtype::bfloat16>(static_cast <float >(low_half) * x);
532
+ return (val & 0xFFFF0000u ) | low_half.x ;
533
+ }
534
+
535
+ inline static __device__ uint32_t bf16_mul_to_high_half (uint32_t val, float x) {
536
+ phi::dtype::bfloat16 high_half;
537
+ // The bfloat16 in higher 16bits
538
+ high_half.x = static_cast <uint16_t >(val >> 16 );
539
+ high_half =
540
+ static_cast <phi::dtype::bfloat16>(static_cast <float >(high_half) * x);
541
+ return (val & 0xFFFFu ) | (static_cast <uint32_t >(high_half.x ) << 16 );
542
+ }
543
+
544
+ CUDA_ATOMIC_WRAPPER (Mul, phi::dtype::bfloat16) {
545
+ uint32_t *address_as_ui = reinterpret_cast <uint32_t *>(
546
+ reinterpret_cast <char *>(address) -
547
+ (reinterpret_cast <uintptr_t >(address) & 0x02 ));
548
+ float val_f = static_cast <float >(val);
549
+ uint32_t old = *address_as_ui;
550
+ uint32_t assumed;
551
+ if (((uintptr_t )address & 0x02 ) == 0 ) {
552
+ // The bfloat16 value stay at lower 16 bits of the address.
553
+ do {
554
+ assumed = old;
555
+ old = atomicCAS (
556
+ address_as_ui, assumed, bf16_mul_to_low_half (assumed, val_f));
557
+ } while (old != assumed);
558
+ phi::dtype::bfloat16 ret;
559
+ ret.x = old & 0xFFFFu ;
560
+ return ret;
561
+ } else {
562
+ // The bfloat16 value stay at higher 16 bits of the address.
563
+ do {
564
+ assumed = old;
565
+ old = atomicCAS (
566
+ address_as_ui, assumed, bf16_mul_to_high_half (assumed, val_f));
567
+ } while (old != assumed);
568
+ phi::dtype::bfloat16 ret;
569
+ ret.x = old >> 16 ;
570
+ return ret;
571
+ }
572
+ }
573
+
398
574
// For atomicMax
399
575
USE_CUDA_ATOMIC (Max, int );
400
576
USE_CUDA_ATOMIC (Max, unsigned int );
0 commit comments