-
Notifications
You must be signed in to change notification settings - Fork 228
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] Rocm sparse fix #1868
Draft
petrex
wants to merge
23
commits into
pytorch:main
Choose a base branch
from
petrex:rocm_sparse_fix
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
[wip] Rocm sparse fix #1868
+163
−85
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Update GPU architecture check to use gcnArchName and improve detection of gfx942 support
Reorganize source file selection logic for CUDA and ROCm builds, improving conditional handling of GPU sources and CUTLASS kernels. Simplify the source file selection process and improve readability of the build configuration.
Modify CUTLASS kernel configuration to explicitly check for non-ROCm platforms when enabling support, ensuring more precise build configuration for different GPU environments.
Move source file collection logic to maintain consistent code organization and improve readability of the build configuration. No functional changes were made to the source file selection process.
Remove the `-t=0` flag from NVCC compilation options, which appears to be unnecessary. This simplifies the compilation configuration without impacting build behavior.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1868
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 3e5a411 with merge base ce05b3f ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Add conditional compilation for ROCm platforms in the sparse Marlin matrix multiply accumulate (MMA) function. This ensures proper inline assembly implementation for both CUDA and ROCm environments, using platform-specific register and instruction handling.
Use __builtin_bit_cast to correctly convert float pairs to half-precision uint32_t values for AMD GPU platforms, ensuring proper type handling in the sparse Marlin matrix multiply accumulate (MMA) implementation.
Update CUDA half-precision operations using __hsub2 and __hfma2 intrinsics to improve performance and precision in sparse matrix multiply-accumulate (MMA) computations.
Update AMD GPU implementation to use __hsub2 and __hmul2 intrinsics for improved performance and precision in half-precision sparse matrix multiply-accumulate computations.
Update AMD GPU implementation to use __builtin_amdgcn_fmul_f32 instead of __builtin_amdgcn_fmul_legacy for more accurate float multiplication in the scale_floats function.
Include necessary ROCm-specific headers for HIP runtime and half-precision operations, with comments addressing potential compiler and architecture considerations for AMD GPU platforms.
Replace __builtin_amdgcn_fmul_f32 with __ocml_fmul_f32 for more accurate and consistent float multiplication in the scale_floats function on AMD GPU platforms.
Replace __builtin_amdgcn_global_load_lds with inline assembly using ds_load_b instruction for more precise and direct global to local data store (LDS) transfer on MI300X AMD GPUs.
Replace __ocml_fmul_f32 with standard C++ multiplication for more readable and straightforward float scaling on AMD MI300X GPUs.
Update cudaFuncSetAttribute call to use reinterpret_cast for correct function pointer handling in the Marlin_24 CUDA kernel, ensuring proper dynamic shared memory configuration.
Refactor cp_async4 functions for ROCm to use explicit ds_load instructions for 4, 8, and 16-byte transfers. Add a fallback mechanism using __builtin_memcpy for unsupported sizes, improving the precision and flexibility of global to local data store (LDS) transfers on MI300X AMD GPUs.
Add missing closing braces in cp_async4_pred_zfill, cp_async4_pred, and cp_async4 functions to ensure proper code structure and prevent potential compilation issues in the ROCm sparse Marlin MMA implementation.
…functions Simplify ROCm global to LDS transfer by removing fallback __builtin_memcpy in cp_async4_pred_zfill, cp_async4_pred, and cp_async4 functions, reducing code complexity while maintaining the primary ds_load_b128 transfer mechanism.
…functions Simplify ROCm global to LDS transfer by removing the 16-byte ds_load_b128 instruction from cp_async4_pred_zfill, cp_async4_pred, and cp_async4 functions, further reducing code complexity and maintaining the core transfer mechanism.
…tion Replace global_load_dwordx4 with multiple ds_read_b32 instructions for better compatibility and support across different ROCm platforms. Modify ldsm4 and ldsm4_t functions to use more widely supported memory load techniques.
Update ldsm4_m device function to use separate ds_read_b32 instructions instead of a single ds_read_b64, improving compatibility and load behavior on ROCm platforms.
Modify the MFMA instruction assembly for AMD GPUs to use correct syntax and operand handling. Replace register constraints with vector register constraints and simplify the instruction format to improve compatibility and readability on ROCm platforms.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
ciflow/rocm
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
module: rocm
topic: bug fix
Use this tag for PRs that fix bugs
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
TLDR: fix sparse marlin kernel for rocm
This pull request includes several updates to the
setup.py
file and modifications to CUDA and ROCm specific code in thetorchao/csrc/cuda/sparse_marlin
directory. The most important changes focus on improving compatibility with ROCm, updating assertions, and refining the build process for CUDA and ROCm extensions.Updates to
setup.py
:-t=0
flag from thenvcc
compile arguments.Modifications to CUDA and ROCm specific code:
cp_async4_pred_zfill
,cp_async4_pred
, andcp_async4
functions to use theLDS.G
instruction for global to LDS transfers on MI300X. [1] [2] [3]mma.h
to support ROCm architecture and improve performance on MI300X. [1] [2] [3] [4]to_half4
,dequant_4bit
,dequant_8bit
,scale
, andscale_floats
to use appropriate ROCm intrinsics and improve compatibility. [1] [2] [3] [4] [5]