Skip to content

Commit

Permalink
Update ROCm float multiplication in sparse Marlin MMA
Browse files Browse the repository at this point in the history
Replace __builtin_amdgcn_fmul_f32 with __ocml_fmul_f32 for more accurate and consistent float multiplication in the scale_floats function on AMD GPU platforms.
  • Loading branch information
petrex committed Mar 11, 2025
1 parent 66691c3 commit 04014e7
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions torchao/csrc/cuda/sparse_marlin/mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ namespace torchao {
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction
// is not supported. On later versions of CUDA the version without ordered
// metadata results in the following warning:
// | Advisory: Modifier ‘.sp::ordered_metadata should be used on instruction
// | mma instead of modifier ‘.sp’ as it is expected to have substantially
// | Advisory: Modifier 'sp::ordered_metadata' should be used on instruction
// | 'mma' instead of modifier 'sp' as it is expected to have substantially
// | reduced performance on some future architectures

#if defined(USE_ROCM)
Expand Down Expand Up @@ -281,15 +281,15 @@ __device__ inline void scale_floats(float* c0, float* c1, float* c2, float* c3,
float* c7, FragS& s1) {
#ifdef USE_ROCM
// AMD implementation - fixed
*c0 = __builtin_amdgcn_fmul_f32(*c0, __half2float(s0[0].x));
*c1 = __builtin_amdgcn_fmul_f32(*c1, __half2float(s0[0].y));
*c2 = __builtin_amdgcn_fmul_f32(*c2, __half2float(s0[1].x));
*c3 = __builtin_amdgcn_fmul_f32(*c3, __half2float(s0[1].y));
*c0 = __ocml_fmul_f32(*c0, __half2float(s0[0].x));
*c1 = __ocml_fmul_f32(*c1, __half2float(s0[0].y));
*c2 = __ocml_fmul_f32(*c2, __half2float(s0[1].x));
*c3 = __ocml_fmul_f32(*c3, __half2float(s0[1].y));

*c4 = __builtin_amdgcn_fmul_f32(*c4, __half2float(s1[0].x));
*c5 = __builtin_amdgcn_fmul_f32(*c5, __half2float(s1[0].y));
*c6 = __builtin_amdgcn_fmul_f32(*c6, __half2float(s1[1].x));
*c7 = __builtin_amdgcn_fmul_f32(*c7, __half2float(s1[1].y));
*c4 = __ocml_fmul_f32(*c4, __half2float(s1[0].x));
*c5 = __ocml_fmul_f32(*c5, __half2float(s1[0].y));
*c6 = __ocml_fmul_f32(*c6, __half2float(s1[1].x));
*c7 = __ocml_fmul_f32(*c7, __half2float(s1[1].y));
#else
// NVIDIA implementation
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
Expand Down

0 comments on commit 04014e7

Please sign in to comment.