Skip to content

Commit

Permalink
Add ROCm support for Marlin kernel function attribute setting
Browse files Browse the repository at this point in the history
Modify the Marlin kernel to conditionally set function attributes using HIP/ROCm-specific API when running on AMD platforms, while maintaining CUDA compatibility. This ensures proper dynamic shared memory configuration across different GPU architectures.
  • Loading branch information
petrex committed Mar 11, 2025
1 parent 6f43e01 commit bf81763
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions torchao/csrc/cuda/sparse_marlin/marlin_kernel_nm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -861,10 +861,19 @@ __global__ void Marlin_24(
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS) { \
cudaFuncSetAttribute( \
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
#ifdef __HIP_PLATFORM_AMD__
// For ROCm/HIP, cast the kernel function to const void* for hipFuncSetAttribute
hipFuncSetAttribute(
reinterpret_cast<const void*>(&Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS,
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>),
hipFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);
#else
// For CUDA, use the template function directly
cudaFuncSetAttribute(
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS,
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS>,
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem);
#endif
Marlin_24<NUM_BITS, THREADS, THREAD_N_BLOCKS, THREAD_M_BLOCKS, \
THREAD_K_BLOCKS, STAGES, GROUP_BLOCKS> \
<<<blocks, THREADS, max_shared_mem, stream>>>(A_ptr, B_ptr, meta_ptr, \
Expand Down

0 comments on commit bf81763

Please sign in to comment.