diff --git a/README.md b/README.md index 084ea465..56d02a0f 100644 --- a/README.md +++ b/README.md @@ -46,41 +46,19 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up - Your NVIDIA GPU(s) must be of Compute Capability 7.5. Turing and later architectures are supported. - Your CUDA version must be CUDA 11.8 or later. - AMD: - - Your ROCm version must be ROCm 5.6 or later. + - Your ROCm version must be compatible with Triton. ### Install from PyPi -To install the newest AutoAWQ from PyPi, you need CUDA 12.1 installed. +There are a few ways to install AutoAWQ: -``` -pip install autoawq -``` - -### Build from source - -For CUDA 11.8, ROCm 5.6, and ROCm 5.7, you can install wheels from the [release page](https://github.com/casper-hansen/AutoAWQ/releases/latest): - -``` -pip install autoawq@https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl -``` - -Or from the main branch directly: - -``` -pip install autoawq@https://github.com/casper-hansen/AutoAWQ.git -``` - -Or by cloning the repository and installing from source: - -``` -git clone https://github.com/casper-hansen/AutoAWQ -cd AutoAWQ -pip install -e . -``` - -All three methods will install the latest and correct kernels for your system from [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases). - -If your system is not supported (i.e. not on the release page), you can build the kernels yourself by following the instructions in [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases) and then install AutoAWQ from source. +1. Default: + - `pip install autoawq` + - NOTE: The default installation includes no external kernels and relies on Triton for inference. + +2. From main branch with kernels: + - `INSTALL_KERNELS=1 pip install git+https://github.com/casper-hansen/AutoAWQ.git` + - NOTE: This installs https://github.com/casper-hansen/AutoAWQ_kernels ## Usage diff --git a/awq/models/base.py b/awq/models/base.py index 1d376fc0..475e48bb 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -1,6 +1,7 @@ import os import gc import json +import warnings import logging import torch import transformers @@ -30,6 +31,7 @@ get_named_linears, set_op_by_name, exclude_layers_to_not_quantize, + try_import, ) from awq.utils.utils import get_best_device, qbits_available from transformers import ( @@ -530,8 +532,12 @@ def from_quantized( ) # Dispath to devices + awq_ext, msg = try_import("awq_ext") if fuse_layers: - self.fuse_layers(model) + if awq_ext is None: + warnings.warn("Skipping fusing modules because AWQ extension is not installed." + msg) + else: + self.fuse_layers(model) if use_cpu_qbits: dtype = torch.bfloat16 if check_isa_supported("AMX") else torch.float32 diff --git a/awq/modules/linear/exllama.py b/awq/modules/linear/exllama.py index cfdf93aa..350b3264 100644 --- a/awq/modules/linear/exllama.py +++ b/awq/modules/linear/exllama.py @@ -1,15 +1,9 @@ import torch -import warnings import torch.nn as nn +from awq.utils.module import try_import from awq.utils.packing_utils import unpack_reorder_pack -try: - import exl_ext # with CUDA kernels (AutoAWQ_kernels) - - EXL_INSTALLED = True -except Exception as ex: - EXL_INSTALLED = False - warnings.warn(f"AutoAWQ could not load ExLlama kernels extension. Details: {ex}") +exl_ext, msg = try_import("exl_ext") # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension none_tensor = torch.empty((1, 1), device="meta") @@ -106,15 +100,8 @@ def forward(self, x): "module.post_init() must be called before module.forward(). " "Use exllama_post_init() on the whole model." ) - assert EXL_INSTALLED, ( - "Exllama kernels could not be loaded. " - "Please install them from https://github.com/casper-hansen/AutoAWQ_kernels" - ) - - assert EXL_INSTALLED, ( - "ExllamaV2 kernels are not installed. " - "Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels." - ) + if exl_ext is None: + raise ModuleNotFoundError("External ExLlama kernels are not properly installed." + msg) input_dtype = x.dtype out_shape = x.shape[:-1] + (self.out_features,) diff --git a/awq/modules/linear/exllamav2.py b/awq/modules/linear/exllamav2.py index c560f549..8e5d0585 100644 --- a/awq/modules/linear/exllamav2.py +++ b/awq/modules/linear/exllamav2.py @@ -2,15 +2,10 @@ import warnings import torch.nn as nn from typing import Dict +from awq.utils.module import try_import from awq.utils.packing_utils import unpack_reorder_pack -try: - import exlv2_ext # with CUDA kernels (AutoAWQ_kernels) - - EXLV2_INSTALLED = True -except Exception as ex: - EXLV2_INSTALLED = False - warnings.warn(f"AutoAWQ could not load ExLlamaV2 kernels extension. Details: {ex}") +exlv2_ext, msg = try_import("exlv2_ext") # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension none_tensor = torch.empty((1, 1), device="meta") @@ -133,10 +128,8 @@ def forward(self, x): "module.post_init() must be called before module.forward(). " "Use exllamav2_post_init() on the whole model." ) - assert EXLV2_INSTALLED, ( - "ExllamaV2 kernels are not installed. " - "Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels." - ) + if exlv2_ext is None: + raise ModuleNotFoundError("External ExLlamaV2 kernels are not properly installed." + msg) input_dtype = x.dtype out_shape = x.shape[:-1] + (self.out_features,) diff --git a/awq/modules/linear/gemm.py b/awq/modules/linear/gemm.py index 63b1c8f3..7ee89cc8 100644 --- a/awq/modules/linear/gemm.py +++ b/awq/modules/linear/gemm.py @@ -2,16 +2,24 @@ import warnings import torch.nn as nn from torch.autograd import Function +from awq.utils.module import try_import from awq.utils.utils import get_best_device from awq.utils.packing_utils import dequantize_gemm +# NOTE: We check if awq_ext or triton is available. awq_ext will be preferred if both are installed. + +awq_ext, msg = try_import("awq_ext") +user_has_been_warned = False + try: - import awq_ext # with CUDA kernels (AutoAWQ_kernels) + from awq.modules.triton.gemm import awq_gemm_triton, awq_dequantize_triton - AWQ_INSTALLED = True -except Exception as ex: - AWQ_INSTALLED = False - warnings.warn(f"AutoAWQ could not load GEMM kernels extension. Details: {ex}") + # covers both CUDA and ROCm + if torch.cuda.is_available(): + TRITON_AVAILABLE = True + +except ImportError: + TRITON_AVAILABLE = False # Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev class WQLinearMMFunction(Function): @@ -35,7 +43,7 @@ def forward( out_shape = x.shape[:-1] + (out_features,) x = x.to(torch.float16) - if AWQ_INSTALLED: + if awq_ext is not None: FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 if FP16_MATMUL_HEURISTIC_CONDITION: @@ -47,7 +55,22 @@ def forward( out = awq_ext.gemm_forward_cuda( x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 ) + + elif TRITON_AVAILABLE: + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_dequantize_triton(qweight, scales, qzeros) + out = torch.matmul(x, out) + else: + out = awq_gemm_triton( + x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8, + ) + else: + if not user_has_been_warned: + warnings.warn("Using naive (slow) implementation." + msg) + user_has_been_warned = True out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) out = torch.matmul(x, out) @@ -64,16 +87,21 @@ def forward( def backward(ctx, grad_output): input, qweight, qzeros, scales, bias = ctx.saved_tensors - if not AWQ_INSTALLED: + if awq_ext is None and not TRITON_AVAILABLE: raise ValueError( - "auto-awq kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels" + "either triton or autoawq-kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels" " by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels" ) - + # Cast to correct dtype for mixed precision training - weights = awq_ext.dequantize_weights_cuda( - qweight, scales, qzeros, 1, 0, 0, False - ).to(grad_output.dtype) + if awq_ext is not None: + weights = awq_ext.dequantize_weights_cuda( + qweight, scales, qzeros, 1, 0, 0, False + ).to(grad_output.dtype) + else: + weights = awq_dequantize_triton( + qweight, scales, qzeros + ).to(grad_output.dtype) if ctx.needs_input_grad[0]: # 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm diff --git a/awq/modules/linear/gemm_qbits.py b/awq/modules/linear/gemm_qbits.py index a75f4dd2..126aad29 100644 --- a/awq/modules/linear/gemm_qbits.py +++ b/awq/modules/linear/gemm_qbits.py @@ -1,13 +1,11 @@ import torch import torch.nn as nn +from awq.utils.module import try_import from ...utils.packing_utils import reverse_awq_order, unpack_awq -try: - from intel_extension_for_transformers import qbits # with QBits kernels () - - QBITS_INSTALLED = True -except: - QBITS_INSTALLED = False +intel_extension_for_transformers, msg = try_import("intel_extension_for_transformers") +if intel_extension_for_transformers is not None: + qbits = getattr(intel_extension_for_transformers, 'qbits') BITS_DTYPE_MAPPING = { 4: "int4_clip", @@ -34,8 +32,8 @@ class WQLinear_QBits(nn.Module): def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev): super().__init__() - assert QBITS_INSTALLED, \ - "Please install ITREX qbits package with `pip install intel-extension-for-transformers`." + if intel_extension_for_transformers is None: + raise ModuleNotFoundError("Please install ITREX qbits package with `pip install intel-extension-for-transformers`." + msg) self.use_bf16 = qbits.check_isa_supported("AMX") @@ -118,10 +116,12 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze @torch.no_grad() def forward(self, x): - assert QBITS_INSTALLED, ( - "QBits kernels could not be loaded. " - "Please install with `pip install intel-extension-for-transformers` and " - "refer to the detial https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md") + if intel_extension_for_transformers is None: + raise ModuleNotFoundError( + "QBits kernels could not be loaded. " + "Please install with `pip install intel-extension-for-transformers` and " + "refer to the detial https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md" + ) input_dtype = x.dtype out_shape = x.shape[:-1] + (self.out_features,) diff --git a/awq/modules/linear/gemv.py b/awq/modules/linear/gemv.py index 4ecb04dc..8da98e19 100644 --- a/awq/modules/linear/gemv.py +++ b/awq/modules/linear/gemv.py @@ -1,14 +1,9 @@ import torch import warnings import torch.nn as nn +from awq.utils.module import try_import -try: - import awq_ext # with CUDA kernels - - AWQ_INSTALLED = True -except Exception as ex: - AWQ_INSTALLED = False - warnings.warn(f"AutoAWQ could not load GEMV kernels extension. Details: {ex}") +awq_ext, msg = try_import("awq_ext") def make_divisible(c, divisor): return (c + divisor - 1) // divisor @@ -160,10 +155,8 @@ def from_linear( @torch.no_grad() def forward(self, x): - assert AWQ_INSTALLED, ( - "AWQ kernels could not be loaded. " - "Please install them from https://github.com/casper-hansen/AutoAWQ_kernels" - ) + if awq_ext is None: + raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg) out_shape = x.shape[:-1] + (self.out_features,) inputs = x.reshape(-1, x.shape[-1]) diff --git a/awq/modules/linear/gemv_fast.py b/awq/modules/linear/gemv_fast.py index e227fb1b..56b789a0 100644 --- a/awq/modules/linear/gemv_fast.py +++ b/awq/modules/linear/gemv_fast.py @@ -1,13 +1,8 @@ import torch import warnings +from awq.utils.module import try_import -try: - import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels) - - AWQ_INSTALLED = True -except Exception as ex: - AWQ_INSTALLED = False - warnings.warn(f"AutoAWQ could not load GEMVFast kernels extension. Details: {ex}") +awq_v2_ext, msg = try_import("awq_v2_ext") def make_divisible(c, divisor): return (c + divisor - 1) // divisor @@ -189,6 +184,8 @@ def from_linear( @torch.no_grad() def forward(self, x): + if awq_v2_ext is None: + raise ModuleNotFoundError("External AWQ V2 kernels are not properly installed." + msg) inputs = x batch_size, n_tokens, _ = inputs.shape if batch_size < 8 and n_tokens == 1: diff --git a/awq/modules/linear/marlin.py b/awq/modules/linear/marlin.py index 2db8b7ee..82f24f9b 100644 --- a/awq/modules/linear/marlin.py +++ b/awq/modules/linear/marlin.py @@ -1,14 +1,9 @@ import torch import torch.nn as nn import numpy as np +from awq.utils.module import try_import -try: - import marlin_cuda # with CUDA kernels (AutoAWQ_kernels) - - MARLIN_INSTALLED = True -except: - MARLIN_INSTALLED = False - +marlin_cuda, msg = try_import("marlin_cuda") def _get_perms(): perm = [] @@ -179,10 +174,8 @@ def forward(self, x): "module.post_init() must be called before module.forward(). " "Use marlin_post_init() on the whole model." ) - assert MARLIN_INSTALLED, ( - "Marlin kernels are not installed. " - "Please install AWQ compatible Marlin kernels from AutoAWQ_kernels." - ) + if marlin_cuda is None: + raise ModuleNotFoundError("External Marlin kernels are not properly installed." + msg) out_shape = x.shape[:-1] + (self.out_features,) diff --git a/awq/modules/triton/__init__.py b/awq/modules/triton/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/awq/modules/triton/gemm.py b/awq/modules/triton/gemm.py new file mode 100644 index 00000000..9138e8c4 --- /dev/null +++ b/awq/modules/triton/gemm.py @@ -0,0 +1,351 @@ +# Copied from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py + +# Copyright 2024 The vLLM team. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import triton +import triton.language as tl + +AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +@triton.jit +def awq_dequantize_kernel( + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr, +): + # Setup the pids. + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + # Compute offsets and masks for qweight_ptr. + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + + masks = masks_y[:, None] & masks_x[None, :] + + # Compute offsets and masks for result output ptr. + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + result_offsets = ( + 8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :] + ) + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + # Load the weights. + iweights = tl.load(qweight_ptr + offsets, masks) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ( + (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] + ).reshape(8) + + # Use this to compute a set of shifts that can be used to unpack and + # reorder the values in iweights and zeros. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + iweights = (iweights >> shifts) & 0xF + + # Compute zero offsets and masks. + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + # Load the zeros. + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + zeros = (zeros >> shifts) & 0xF + + # Compute scale offsets and masks. + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :] + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + # Load the scales. + scales = tl.load(scales_ptr + scale_offsets, scale_masks) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Dequantize. + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + + # Finally, store. + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def awq_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + zeros_ptr, + scales_ptr, + M, + N, + K, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = c_ptr.type.element_ty + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # accumulator = tl.arange(0, BLOCK_SIZE_N) + # accumulator = tl.broadcast_to(accumulator[None, :], + # (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # accumulator = accumulator & 0x0 + # accumulator = accumulator.to(accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ( + (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None] + ).reshape(8) + + # Create the necessary shifts to use to unpack. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_am = offsets_am < M + + offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_bn = offsets_bn < N // 8 + + offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_zn = offsets_zn < N // 8 + + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + masks_sn = offsets_sn < N + + offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offsets_a = K * offsets_am[:, None] + offsets_k[None, :] + offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv + # block_offset = BLOCK_SIZE_K * SPLIT_K + # for k in range(0, (K + block_offset - 1) // (block_offset)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + + # Dequantize b. + offsets_szk = ( + BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K + ) // group_size + tl.arange(0, 1) + offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] + masks_zk = offsets_szk < K // group_size + masks_z = masks_zk[:, None] & masks_zn[None, :] + zeros_ptrs = zeros_ptr + offsets_z + zeros = tl.load(zeros_ptrs, mask=masks_z) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_sk = offsets_szk < K // group_size + masks_s = masks_sk[:, None] & masks_sn[None, :] + scales_ptrs = scales_ptr + offsets_s + scales = tl.load(scales_ptrs, mask=masks_s) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = (b - zeros) * scales + b = b.to(c_ptr.type.element_ty) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K * SPLIT_K + a_ptrs += BLOCK_SIZE_K * SPLIT_K + b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +# qweights - [K , M // 8], int32 +# scales - [K // G, M ], float16 +# zeros - [K // G, M // 8], int32 +def awq_dequantize_triton( + qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32, +) -> torch.Tensor: + K = qweight.shape[0] + M = scales.shape[1] + group_size = qweight.shape[0] // scales.shape[0] + + assert K > 0 and M > 0 + assert scales.shape[0] == K // group_size and scales.shape[1] == M + assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + # Result tensor: + # number of rows = same as input tensor + # number of cols = 8 x input tensor num cols + result = torch.empty( + qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype, + ) + + Y = qweight.shape[0] # num rows + X = qweight.shape[1] # num cols + + grid = lambda META: ( + triton.cdiv(X, META["BLOCK_SIZE_X"]), + triton.cdiv(Y, META["BLOCK_SIZE_Y"]), + ) + awq_dequantize_kernel[grid]( + qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y, + ) + + return result + + +# input - [M, K] +# qweight - [K, N // 8] +# qzeros - [K // G, N // 8] +# scales - [K // G, N] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_triton( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32, +) -> torch.Tensor: + M, K = input.shape + N = qweight.shape[1] * 8 + group_size = qweight.shape[0] // qzeros.shape[0] + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 + assert split_k_iters <= 32 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + split_k_iters, + ) + + result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) + + # A = input, B = qweight, C = result + # A = M x K, B = K x N, C = M x N + awq_gemm_kernel[grid]( + input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters, + ) + + return result diff --git a/awq/utils/module.py b/awq/utils/module.py index 12fff5c0..b2191f65 100644 --- a/awq/utils/module.py +++ b/awq/utils/module.py @@ -1,5 +1,12 @@ import torch.nn as nn +import importlib +def try_import(module_name): + try: + module = importlib.import_module(module_name) + return module, "" + except Exception as ex: + return None, str(ex) def get_named_linears(module): return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} diff --git a/docs/index.md b/docs/index.md index fa675a31..dedeb807 100644 --- a/docs/index.md +++ b/docs/index.md @@ -33,26 +33,4 @@ Example inference speed (RTX 4090, Ryzen 9 7950X, 64 tokens): ## Supported models -The detailed support list: - -| Models | Sizes | -| -------- | --------------------------- | -| LLaMA-2 | 7B/13B/70B | -| LLaMA | 7B/13B/30B/65B | -| Mistral | 7B | -| Vicuna | 7B/13B | -| MPT | 7B/30B | -| Falcon | 7B/40B | -| OPT | 125m/1.3B/2.7B/6.7B/13B/30B | -| Bloom | 560m/3B/7B/ | -| GPTJ | 6.7B | -| Aquila | 7B | -| Aquila2 | 7B/34B | -| Yi | 6B/34B | -| Qwen | 1.8B/7B/14B/72B | -| BigCode | 1B/7B/15B | -| GPT NeoX | 20B | -| GPT-J | 6B | -| LLaVa | 7B/13B | -| Mixtral | 8x7B | -| Baichuan | 7B/13B | +We support modern LLMs. You can find a list of supported Huggingface `model_types` in `awq/models`. \ No newline at end of file diff --git a/examples/generate.py b/examples/generate.py index 27efe221..803e5b9d 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -1,42 +1,37 @@ +import torch from awq import AutoAWQForCausalLM -from awq.utils.utils import get_best_device from transformers import AutoTokenizer, TextStreamer - -quant_path = "casperhansen/llama-3-8b-instruct-awq" - -# Load model -if get_best_device() == "cpu": - model = AutoAWQForCausalLM.from_quantized(quant_path, use_qbits=True, fuse_layers=False) -else: - model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True) -tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) +model_id = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4" +tokenizer = AutoTokenizer.from_pretrained(model_id) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) -prompt = "You're standing on the surface of the Earth. "\ - "You walk one mile south, one mile west and one mile north. "\ - "You end up exactly where you started. Where are you?" - -chat = [ - {"role": "system", "content": "You are a concise assistant that helps answer questions."}, - {"role": "user", "content": prompt}, -] - -terminators = [ - tokenizer.eos_token_id, - tokenizer.convert_tokens_to_ids("<|eot_id|>") -] - -tokens = tokenizer.apply_chat_template( - chat, - return_tensors="pt" +model = AutoAWQForCausalLM.from_quantized( + model_id, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + device_map="auto", ) -tokens = tokens.to(get_best_device()) -# Generate output -generation_output = model.generate( - tokens, +prompt = [ + {"role": "system", "content": "You are a helpful assistant, that responds as a pirate."}, + {"role": "user", "content": \ + "You're standing on the surface of the Earth. "\ + "You walk one mile south, one mile west and one mile north. "\ + "You end up exactly where you started. Where are you?"}, +] +inputs = tokenizer.apply_chat_template( + prompt, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, +).to("cuda") + +outputs = model.generate( + **inputs, + do_sample=True, + max_new_tokens=256, streamer=streamer, - max_new_tokens=64, - eos_token_id=terminators + eos_token_id=[tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")] ) diff --git a/setup.py b/setup.py index ffecd001..12a8ce27 100644 --- a/setup.py +++ b/setup.py @@ -1,39 +1,12 @@ import os import torch -import platform -import requests from pathlib import Path from setuptools import setup, find_packages from torch.utils.cpp_extension import CUDAExtension - -def get_latest_kernels_version(repo): - """ - Get the latest version of the kernels from the github repo. - """ - response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest") - data = response.json() - tag_name = data["tag_name"] - version = tag_name.replace("v", "") - return version - - -def get_kernels_whl_url( - gpu_system_version, - release_version, - python_version, - platform, - architecture, -): - """ - Get the url for the kernels wheel file. - """ - return f"https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v{release_version}/autoawq_kernels-{release_version}+{gpu_system_version}-cp{python_version}-cp{python_version}-{platform}_{architecture}.whl" - - AUTOAWQ_VERSION = "0.2.6" PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1" -NO_KERNELS = int(os.getenv("NO_KERNELS", "0")) +INSTALL_KERNELS = os.getenv("INSTALL_KERNELS", "0") == "1" IS_CPU_ONLY = not torch.backends.mps.is_available() and not torch.cuda.is_available() CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda @@ -42,10 +15,8 @@ def get_kernels_whl_url( ROCM_VERSION = os.getenv("ROCM_VERSION", None) or torch.version.hip if ROCM_VERSION: - if ROCM_VERSION.startswith("5.7"): - ROCM_VERSION = "5.7.1" - - ROCM_VERSION = "".join(ROCM_VERSION.split("."))[:3] + ROCM_VERSION_LEN = min(len(ROCM_VERSION.split(".")), 3) + ROCM_VERSION = "".join(ROCM_VERSION.split("."))[:ROCM_VERSION_LEN] if not PYPI_BUILD: if IS_CPU_ONLY: @@ -88,6 +59,7 @@ def get_kernels_whl_url( requirements = [ "torch>=2.3.1", + "triton", "transformers>=4.35.0", "tokenizers>=0.12.1", "typing_extensions>=4.8.0", @@ -97,41 +69,15 @@ def get_kernels_whl_url( ] try: - if ROCM_VERSION: - import exlv2_ext - else: - import awq_ext + import awq_ext KERNELS_INSTALLED = True except ImportError: KERNELS_INSTALLED = False -# kernels can be downloaded from pypi for cuda+121 only -# for everything else, we need to download the wheels from github -if not KERNELS_INSTALLED and (CUDA_VERSION or ROCM_VERSION) and not NO_KERNELS: - if CUDA_VERSION and CUDA_VERSION.startswith("12"): - requirements.append("autoawq-kernels") - elif CUDA_VERSION and CUDA_VERSION.startswith("11") or ROCM_VERSION in ["571"]: - gpu_system_version = ( - f"cu{CUDA_VERSION}" if CUDA_VERSION else f"rocm{ROCM_VERSION}" - ) - kernels_version = get_latest_kernels_version("casper-hansen/AutoAWQ_kernels") - python_version = "".join(platform.python_version_tuple()[:2]) - platform_name = platform.system().lower() - architecture = platform.machine().lower() - latest_rocm_kernels_wheels = get_kernels_whl_url( - gpu_system_version, - kernels_version, - python_version, - platform_name, - architecture, - ) - requirements.append(f"autoawq-kernels@{latest_rocm_kernels_wheels}") - else: - raise RuntimeError( - "Your system have a GPU with an unsupported CUDA or ROCm version. " - "Please install the kernels manually from https://github.com/casper-hansen/AutoAWQ_kernels" - ) +if not KERNELS_INSTALLED and CUDA_VERSION and INSTALL_KERNELS and CUDA_VERSION.startswith("12"): + requirements.append("autoawq-kernels") + elif IS_CPU_ONLY: requirements.append("intel-extension-for-transformers>=1.4.2") @@ -152,7 +98,7 @@ def get_kernels_whl_url( install_requires=requirements, extras_require={ "eval": ["lm_eval==0.4.1", "tabulate", "protobuf", "evaluate", "scipy"], - "dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"] + "dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"], }, **common_setup_kwargs, )