Skip to content

Commit ae77736

Browse files
AWQ Triton kernels. Make autoawq-kernels optional. (#608)
1 parent 8d903b2 commit ae77736

15 files changed

+484
-232
lines changed

README.md

+9-31
Original file line numberDiff line numberDiff line change
@@ -46,41 +46,19 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up
4646
- Your NVIDIA GPU(s) must be of Compute Capability 7.5. Turing and later architectures are supported.
4747
- Your CUDA version must be CUDA 11.8 or later.
4848
- AMD:
49-
- Your ROCm version must be ROCm 5.6 or later.
49+
- Your ROCm version must be compatible with Triton.
5050

5151
### Install from PyPi
5252

53-
To install the newest AutoAWQ from PyPi, you need CUDA 12.1 installed.
53+
There are a few ways to install AutoAWQ:
5454

55-
```
56-
pip install autoawq
57-
```
58-
59-
### Build from source
60-
61-
For CUDA 11.8, ROCm 5.6, and ROCm 5.7, you can install wheels from the [release page](https://github.com/casper-hansen/AutoAWQ/releases/latest):
62-
63-
```
64-
pip install autoawq@https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl
65-
```
66-
67-
Or from the main branch directly:
68-
69-
```
70-
pip install autoawq@https://github.com/casper-hansen/AutoAWQ.git
71-
```
72-
73-
Or by cloning the repository and installing from source:
74-
75-
```
76-
git clone https://github.com/casper-hansen/AutoAWQ
77-
cd AutoAWQ
78-
pip install -e .
79-
```
80-
81-
All three methods will install the latest and correct kernels for your system from [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases).
82-
83-
If your system is not supported (i.e. not on the release page), you can build the kernels yourself by following the instructions in [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases) and then install AutoAWQ from source.
55+
1. Default:
56+
- `pip install autoawq`
57+
- NOTE: The default installation includes no external kernels and relies on Triton for inference.
58+
59+
2. From main branch with kernels:
60+
- `INSTALL_KERNELS=1 pip install git+https://github.com/casper-hansen/AutoAWQ.git`
61+
- NOTE: This installs https://github.com/casper-hansen/AutoAWQ_kernels
8462

8563
## Usage
8664

awq/models/base.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import gc
33
import json
4+
import warnings
45
import logging
56
import torch
67
import transformers
@@ -30,6 +31,7 @@
3031
get_named_linears,
3132
set_op_by_name,
3233
exclude_layers_to_not_quantize,
34+
try_import,
3335
)
3436
from awq.utils.utils import get_best_device, qbits_available
3537
from transformers import (
@@ -530,8 +532,12 @@ def from_quantized(
530532
)
531533

532534
# Dispath to devices
535+
awq_ext, msg = try_import("awq_ext")
533536
if fuse_layers:
534-
self.fuse_layers(model)
537+
if awq_ext is None:
538+
warnings.warn("Skipping fusing modules because AWQ extension is not installed." + msg)
539+
else:
540+
self.fuse_layers(model)
535541

536542
if use_cpu_qbits:
537543
dtype = torch.bfloat16 if check_isa_supported("AMX") else torch.float32

awq/modules/linear/exllama.py

+4-17
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
import torch
2-
import warnings
32
import torch.nn as nn
3+
from awq.utils.module import try_import
44
from awq.utils.packing_utils import unpack_reorder_pack
55

6-
try:
7-
import exl_ext # with CUDA kernels (AutoAWQ_kernels)
8-
9-
EXL_INSTALLED = True
10-
except Exception as ex:
11-
EXL_INSTALLED = False
12-
warnings.warn(f"AutoAWQ could not load ExLlama kernels extension. Details: {ex}")
6+
exl_ext, msg = try_import("exl_ext")
137

148
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
159
none_tensor = torch.empty((1, 1), device="meta")
@@ -106,15 +100,8 @@ def forward(self, x):
106100
"module.post_init() must be called before module.forward(). "
107101
"Use exllama_post_init() on the whole model."
108102
)
109-
assert EXL_INSTALLED, (
110-
"Exllama kernels could not be loaded. "
111-
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
112-
)
113-
114-
assert EXL_INSTALLED, (
115-
"ExllamaV2 kernels are not installed. "
116-
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
117-
)
103+
if exl_ext is None:
104+
raise ModuleNotFoundError("External ExLlama kernels are not properly installed." + msg)
118105

119106
input_dtype = x.dtype
120107
out_shape = x.shape[:-1] + (self.out_features,)

awq/modules/linear/exllamav2.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,10 @@
22
import warnings
33
import torch.nn as nn
44
from typing import Dict
5+
from awq.utils.module import try_import
56
from awq.utils.packing_utils import unpack_reorder_pack
67

7-
try:
8-
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels)
9-
10-
EXLV2_INSTALLED = True
11-
except Exception as ex:
12-
EXLV2_INSTALLED = False
13-
warnings.warn(f"AutoAWQ could not load ExLlamaV2 kernels extension. Details: {ex}")
8+
exlv2_ext, msg = try_import("exlv2_ext")
149

1510
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
1611
none_tensor = torch.empty((1, 1), device="meta")
@@ -133,10 +128,8 @@ def forward(self, x):
133128
"module.post_init() must be called before module.forward(). "
134129
"Use exllamav2_post_init() on the whole model."
135130
)
136-
assert EXLV2_INSTALLED, (
137-
"ExllamaV2 kernels are not installed. "
138-
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
139-
)
131+
if exlv2_ext is None:
132+
raise ModuleNotFoundError("External ExLlamaV2 kernels are not properly installed." + msg)
140133

141134
input_dtype = x.dtype
142135
out_shape = x.shape[:-1] + (self.out_features,)

awq/modules/linear/gemm.py

+40-12
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,24 @@
22
import warnings
33
import torch.nn as nn
44
from torch.autograd import Function
5+
from awq.utils.module import try_import
56
from awq.utils.utils import get_best_device
67
from awq.utils.packing_utils import dequantize_gemm
78

9+
# NOTE: We check if awq_ext or triton is available. awq_ext will be preferred if both are installed.
10+
11+
awq_ext, msg = try_import("awq_ext")
12+
user_has_been_warned = False
13+
814
try:
9-
import awq_ext # with CUDA kernels (AutoAWQ_kernels)
15+
from awq.modules.triton.gemm import awq_gemm_triton, awq_dequantize_triton
1016

11-
AWQ_INSTALLED = True
12-
except Exception as ex:
13-
AWQ_INSTALLED = False
14-
warnings.warn(f"AutoAWQ could not load GEMM kernels extension. Details: {ex}")
17+
# covers both CUDA and ROCm
18+
if torch.cuda.is_available():
19+
TRITON_AVAILABLE = True
20+
21+
except ImportError:
22+
TRITON_AVAILABLE = False
1523

1624
# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
1725
class WQLinearMMFunction(Function):
@@ -35,7 +43,7 @@ def forward(
3543
out_shape = x.shape[:-1] + (out_features,)
3644
x = x.to(torch.float16)
3745

38-
if AWQ_INSTALLED:
46+
if awq_ext is not None:
3947
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024
4048

4149
if FP16_MATMUL_HEURISTIC_CONDITION:
@@ -47,7 +55,22 @@ def forward(
4755
out = awq_ext.gemm_forward_cuda(
4856
x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8
4957
)
58+
59+
elif TRITON_AVAILABLE:
60+
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024
61+
62+
if FP16_MATMUL_HEURISTIC_CONDITION:
63+
out = awq_dequantize_triton(qweight, scales, qzeros)
64+
out = torch.matmul(x, out)
65+
else:
66+
out = awq_gemm_triton(
67+
x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8,
68+
)
69+
5070
else:
71+
if not user_has_been_warned:
72+
warnings.warn("Using naive (slow) implementation." + msg)
73+
user_has_been_warned = True
5174
out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
5275
out = torch.matmul(x, out)
5376

@@ -64,16 +87,21 @@ def forward(
6487
def backward(ctx, grad_output):
6588
input, qweight, qzeros, scales, bias = ctx.saved_tensors
6689

67-
if not AWQ_INSTALLED:
90+
if awq_ext is None and not TRITON_AVAILABLE:
6891
raise ValueError(
69-
"auto-awq kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
92+
"either triton or autoawq-kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
7093
" by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels"
7194
)
72-
95+
7396
# Cast to correct dtype for mixed precision training
74-
weights = awq_ext.dequantize_weights_cuda(
75-
qweight, scales, qzeros, 1, 0, 0, False
76-
).to(grad_output.dtype)
97+
if awq_ext is not None:
98+
weights = awq_ext.dequantize_weights_cuda(
99+
qweight, scales, qzeros, 1, 0, 0, False
100+
).to(grad_output.dtype)
101+
else:
102+
weights = awq_dequantize_triton(
103+
qweight, scales, qzeros
104+
).to(grad_output.dtype)
77105

78106
if ctx.needs_input_grad[0]:
79107
# 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm

awq/modules/linear/gemm_qbits.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import torch
22
import torch.nn as nn
3+
from awq.utils.module import try_import
34
from ...utils.packing_utils import reverse_awq_order, unpack_awq
45

5-
try:
6-
from intel_extension_for_transformers import qbits # with QBits kernels ()
7-
8-
QBITS_INSTALLED = True
9-
except:
10-
QBITS_INSTALLED = False
6+
intel_extension_for_transformers, msg = try_import("intel_extension_for_transformers")
7+
if intel_extension_for_transformers is not None:
8+
qbits = getattr(intel_extension_for_transformers, 'qbits')
119

1210
BITS_DTYPE_MAPPING = {
1311
4: "int4_clip",
@@ -34,8 +32,8 @@ class WQLinear_QBits(nn.Module):
3432

3533
def __init__(self, w_bit, group_size, in_features, out_features, bias, zero_point, dev):
3634
super().__init__()
37-
assert QBITS_INSTALLED, \
38-
"Please install ITREX qbits package with `pip install intel-extension-for-transformers`."
35+
if intel_extension_for_transformers is None:
36+
raise ModuleNotFoundError("Please install ITREX qbits package with `pip install intel-extension-for-transformers`." + msg)
3937

4038
self.use_bf16 = qbits.check_isa_supported("AMX")
4139

@@ -118,10 +116,12 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze
118116

119117
@torch.no_grad()
120118
def forward(self, x):
121-
assert QBITS_INSTALLED, (
122-
"QBits kernels could not be loaded. "
123-
"Please install with `pip install intel-extension-for-transformers` and "
124-
"refer to the detial https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md")
119+
if intel_extension_for_transformers is None:
120+
raise ModuleNotFoundError(
121+
"QBits kernels could not be loaded. "
122+
"Please install with `pip install intel-extension-for-transformers` and "
123+
"refer to the detial https://github.com/intel/intel-extension-for-transformers/blob/main/docs/qbits.md"
124+
)
125125

126126
input_dtype = x.dtype
127127
out_shape = x.shape[:-1] + (self.out_features,)

awq/modules/linear/gemv.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
import torch
22
import warnings
33
import torch.nn as nn
4+
from awq.utils.module import try_import
45

5-
try:
6-
import awq_ext # with CUDA kernels
7-
8-
AWQ_INSTALLED = True
9-
except Exception as ex:
10-
AWQ_INSTALLED = False
11-
warnings.warn(f"AutoAWQ could not load GEMV kernels extension. Details: {ex}")
6+
awq_ext, msg = try_import("awq_ext")
127

138
def make_divisible(c, divisor):
149
return (c + divisor - 1) // divisor
@@ -160,10 +155,8 @@ def from_linear(
160155

161156
@torch.no_grad()
162157
def forward(self, x):
163-
assert AWQ_INSTALLED, (
164-
"AWQ kernels could not be loaded. "
165-
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
166-
)
158+
if awq_ext is None:
159+
raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg)
167160

168161
out_shape = x.shape[:-1] + (self.out_features,)
169162
inputs = x.reshape(-1, x.shape[-1])

awq/modules/linear/gemv_fast.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import torch
22
import warnings
3+
from awq.utils.module import try_import
34

4-
try:
5-
import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels)
6-
7-
AWQ_INSTALLED = True
8-
except Exception as ex:
9-
AWQ_INSTALLED = False
10-
warnings.warn(f"AutoAWQ could not load GEMVFast kernels extension. Details: {ex}")
5+
awq_v2_ext, msg = try_import("awq_v2_ext")
116

127
def make_divisible(c, divisor):
138
return (c + divisor - 1) // divisor
@@ -189,6 +184,8 @@ def from_linear(
189184

190185
@torch.no_grad()
191186
def forward(self, x):
187+
if awq_v2_ext is None:
188+
raise ModuleNotFoundError("External AWQ V2 kernels are not properly installed." + msg)
192189
inputs = x
193190
batch_size, n_tokens, _ = inputs.shape
194191
if batch_size < 8 and n_tokens == 1:

awq/modules/linear/marlin.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
import torch
22
import torch.nn as nn
33
import numpy as np
4+
from awq.utils.module import try_import
45

5-
try:
6-
import marlin_cuda # with CUDA kernels (AutoAWQ_kernels)
7-
8-
MARLIN_INSTALLED = True
9-
except:
10-
MARLIN_INSTALLED = False
11-
6+
marlin_cuda, msg = try_import("marlin_cuda")
127

138
def _get_perms():
149
perm = []
@@ -179,10 +174,8 @@ def forward(self, x):
179174
"module.post_init() must be called before module.forward(). "
180175
"Use marlin_post_init() on the whole model."
181176
)
182-
assert MARLIN_INSTALLED, (
183-
"Marlin kernels are not installed. "
184-
"Please install AWQ compatible Marlin kernels from AutoAWQ_kernels."
185-
)
177+
if marlin_cuda is None:
178+
raise ModuleNotFoundError("External Marlin kernels are not properly installed." + msg)
186179

187180
out_shape = x.shape[:-1] + (self.out_features,)
188181

awq/modules/triton/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)