Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AWQ Triton kernels. Make autoawq-kernels optional. #608

Merged
merged 7 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 9 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,41 +46,19 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up
- Your NVIDIA GPU(s) must be of Compute Capability 7.5. Turing and later architectures are supported.
- Your CUDA version must be CUDA 11.8 or later.
- AMD:
- Your ROCm version must be ROCm 5.6 or later.
- Your ROCm version must be compatible with Triton.

### Install from PyPi

To install the newest AutoAWQ from PyPi, you need CUDA 12.1 installed.
There are a few ways to install AutoAWQ:

```
pip install autoawq
```

### Build from source

For CUDA 11.8, ROCm 5.6, and ROCm 5.7, you can install wheels from the [release page](https://github.com/casper-hansen/AutoAWQ/releases/latest):

```
pip install autoawq@https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl
```

Or from the main branch directly:

```
pip install autoawq@https://github.com/casper-hansen/AutoAWQ.git
```

Or by cloning the repository and installing from source:

```
git clone https://github.com/casper-hansen/AutoAWQ
cd AutoAWQ
pip install -e .
```

All three methods will install the latest and correct kernels for your system from [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases).

If your system is not supported (i.e. not on the release page), you can build the kernels yourself by following the instructions in [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases) and then install AutoAWQ from source.
1. Default:
- `pip install autoawq`
- NOTE: The default installation includes no external kernels and relies on Triton for inference.

2. From main branch with kernels:
- `INSTALL_KERNELS=1 pip install git+https://github.com/casper-hansen/AutoAWQ.git`
- NOTE: This installs https://github.com/casper-hansen/AutoAWQ_kernels

## Usage

Expand Down
8 changes: 7 additions & 1 deletion awq/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import gc
import json
import warnings
import logging
import torch
import transformers
Expand Down Expand Up @@ -30,6 +31,7 @@
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize,
try_import,
)
from awq.utils.utils import get_best_device, qbits_available
from transformers import (
Expand Down Expand Up @@ -530,8 +532,12 @@ def from_quantized(
)

# Dispath to devices
awq_ext, msg = try_import("awq_ext")
if fuse_layers:
self.fuse_layers(model)
if awq_ext is None:
warnings.warn("Skipping fusing modules because AWQ extension is not installed." + msg)
else:
self.fuse_layers(model)

if use_cpu_qbits:
dtype = torch.bfloat16 if check_isa_supported("AMX") else torch.float32
Expand Down
21 changes: 4 additions & 17 deletions awq/modules/linear/exllama.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
import torch
import warnings
import torch.nn as nn
from awq.utils.module import try_import
from awq.utils.packing_utils import unpack_reorder_pack

try:
import exl_ext # with CUDA kernels (AutoAWQ_kernels)

EXL_INSTALLED = True
except Exception as ex:
EXL_INSTALLED = False
warnings.warn(f"AutoAWQ could not load ExLlama kernels extension. Details: {ex}")
exl_ext, msg = try_import("exl_ext")

# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
Expand Down Expand Up @@ -106,15 +100,8 @@ def forward(self, x):
"module.post_init() must be called before module.forward(). "
"Use exllama_post_init() on the whole model."
)
assert EXL_INSTALLED, (
"Exllama kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)

assert EXL_INSTALLED, (
"ExllamaV2 kernels are not installed. "
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
)
if exl_ext is None:
raise ModuleNotFoundError("External ExLlama kernels are not properly installed." + msg)

input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
Expand Down
15 changes: 4 additions & 11 deletions awq/modules/linear/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
import warnings
import torch.nn as nn
from typing import Dict
from awq.utils.module import try_import
from awq.utils.packing_utils import unpack_reorder_pack

try:
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels)

EXLV2_INSTALLED = True
except Exception as ex:
EXLV2_INSTALLED = False
warnings.warn(f"AutoAWQ could not load ExLlamaV2 kernels extension. Details: {ex}")
exlv2_ext, msg = try_import("exlv2_ext")

# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
Expand Down Expand Up @@ -133,10 +128,8 @@ def forward(self, x):
"module.post_init() must be called before module.forward(). "
"Use exllamav2_post_init() on the whole model."
)
assert EXLV2_INSTALLED, (
"ExllamaV2 kernels are not installed. "
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
)
if exlv2_ext is None:
raise ModuleNotFoundError("External ExLlamaV2 kernels are not properly installed." + msg)

input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
Expand Down
52 changes: 40 additions & 12 deletions awq/modules/linear/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,24 @@
import warnings
import torch.nn as nn
from torch.autograd import Function
from awq.utils.module import try_import
from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm

# NOTE: We check if awq_ext or triton is available. awq_ext will be preferred if both are installed.

awq_ext, msg = try_import("awq_ext")
user_has_been_warned = False

try:
import awq_ext # with CUDA kernels (AutoAWQ_kernels)
from awq.modules.triton.gemm import awq_gemm_triton, awq_dequantize_triton

AWQ_INSTALLED = True
except Exception as ex:
AWQ_INSTALLED = False
warnings.warn(f"AutoAWQ could not load GEMM kernels extension. Details: {ex}")
# covers both CUDA and ROCm
if torch.cuda.is_available():
TRITON_AVAILABLE = True

except ImportError:
TRITON_AVAILABLE = False

# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class WQLinearMMFunction(Function):
Expand All @@ -35,7 +43,7 @@ def forward(
out_shape = x.shape[:-1] + (out_features,)
x = x.to(torch.float16)

if AWQ_INSTALLED:
if awq_ext is not None:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024

if FP16_MATMUL_HEURISTIC_CONDITION:
Expand All @@ -47,7 +55,22 @@ def forward(
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8
)

elif TRITON_AVAILABLE:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024

if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_dequantize_triton(qweight, scales, qzeros)
out = torch.matmul(x, out)
else:
out = awq_gemm_triton(
x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8,
)

else:
if not user_has_been_warned:
warnings.warn("Using naive (slow) implementation." + msg)
user_has_been_warned = True
out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
out = torch.matmul(x, out)

Expand All @@ -64,16 +87,21 @@ def forward(
def backward(ctx, grad_output):
input, qweight, qzeros, scales, bias = ctx.saved_tensors

if not AWQ_INSTALLED:
if awq_ext is None and not TRITON_AVAILABLE:
raise ValueError(
"auto-awq kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
"either triton or autoawq-kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
" by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels"
)

# Cast to correct dtype for mixed precision training
weights = awq_ext.dequantize_weights_cuda(
qweight, scales, qzeros, 1, 0, 0, False
).to(grad_output.dtype)
if awq_ext is not None:
weights = awq_ext.dequantize_weights_cuda(
qweight, scales, qzeros, 1, 0, 0, False
).to(grad_output.dtype)
else:
weights = awq_dequantize_triton(
qweight, scales, qzeros
).to(grad_output.dtype)

if ctx.needs_input_grad[0]:
# 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm
Expand Down
24 changes: 12 additions & 12 deletions awq/modules/linear/gemm_qbits.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch
import torch.nn as nn
from awq.utils.module import try_import
from ...utils.packing_utils import reverse_awq_order, unpack_awq

try:
from intel_extension_for_transformers import qbits # with QBits kernels ()

QBITS_INSTALLED = True
except:
QBITS_INSTALLED = False
intel_extension_for_transformers, msg = try_import("intel_extension_for_transformers")
if intel_extension_for_transformers is not None:
qbits = getattr(intel_extension_for_transformers, 'qbits')

BITS_DTYPE_MAPPING = {
4: "int4_clip",
Expand All @@ -34,8 +32,8 @@ class WQLinear_QBits(nn.Module):

def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev):
super().__init__()
assert QBITS_INSTALLED, \
"Please install ITREX qbits package with `pip install intel-extension-for-transformers`."
if intel_extension_for_transformers is None:
raise ModuleNotFoundError("Please install ITREX qbits package with `pip install intel-extension-for-transformers`." + msg)

self.use_bf16 = qbits.check_isa_supported("AMX")

Expand Down Expand Up @@ -118,10 +116,12 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze

@torch.no_grad()
def forward(self, x):
assert QBITS_INSTALLED, (
"QBits kernels could not be loaded. "
"Please install with `pip install intel-extension-for-transformers` and "
"refer to the detial https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md")
if intel_extension_for_transformers is None:
raise ModuleNotFoundError(
"QBits kernels could not be loaded. "
"Please install with `pip install intel-extension-for-transformers` and "
"refer to the detial https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md"
)

input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
Expand Down
15 changes: 4 additions & 11 deletions awq/modules/linear/gemv.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import torch
import warnings
import torch.nn as nn
from awq.utils.module import try_import

try:
import awq_ext # with CUDA kernels

AWQ_INSTALLED = True
except Exception as ex:
AWQ_INSTALLED = False
warnings.warn(f"AutoAWQ could not load GEMV kernels extension. Details: {ex}")
awq_ext, msg = try_import("awq_ext")

def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
Expand Down Expand Up @@ -160,10 +155,8 @@ def from_linear(

@torch.no_grad()
def forward(self, x):
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
if awq_ext is None:
raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg)

out_shape = x.shape[:-1] + (self.out_features,)
inputs = x.reshape(-1, x.shape[-1])
Expand Down
11 changes: 4 additions & 7 deletions awq/modules/linear/gemv_fast.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import torch
import warnings
from awq.utils.module import try_import

try:
import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels)

AWQ_INSTALLED = True
except Exception as ex:
AWQ_INSTALLED = False
warnings.warn(f"AutoAWQ could not load GEMVFast kernels extension. Details: {ex}")
awq_v2_ext, msg = try_import("awq_v2_ext")

def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
Expand Down Expand Up @@ -189,6 +184,8 @@ def from_linear(

@torch.no_grad()
def forward(self, x):
if awq_v2_ext is None:
raise ModuleNotFoundError("External AWQ V2 kernels are not properly installed." + msg)
inputs = x
batch_size, n_tokens, _ = inputs.shape
if batch_size < 8 and n_tokens == 1:
Expand Down
15 changes: 4 additions & 11 deletions awq/modules/linear/marlin.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import torch
import torch.nn as nn
import numpy as np
from awq.utils.module import try_import

try:
import marlin_cuda # with CUDA kernels (AutoAWQ_kernels)

MARLIN_INSTALLED = True
except:
MARLIN_INSTALLED = False

marlin_cuda, msg = try_import("marlin_cuda")

def _get_perms():
perm = []
Expand Down Expand Up @@ -179,10 +174,8 @@ def forward(self, x):
"module.post_init() must be called before module.forward(). "
"Use marlin_post_init() on the whole model."
)
assert MARLIN_INSTALLED, (
"Marlin kernels are not installed. "
"Please install AWQ compatible Marlin kernels from AutoAWQ_kernels."
)
if marlin_cuda is None:
raise ModuleNotFoundError("External Marlin kernels are not properly installed." + msg)

out_shape = x.shape[:-1] + (self.out_features,)

Expand Down
Empty file added awq/modules/triton/__init__.py
Empty file.
Loading