Skip to content

Commit d2537f1

Browse files
authored
Enable triton on XPU devices (#695)
1 parent 9affc3e commit d2537f1

File tree

5 files changed

+20
-7
lines changed

5 files changed

+20
-7
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up
4949
- Your ROCm version must be compatible with Triton.
5050
- Intel CPU and Intel GPU:
5151
- Your torch and intel_extension_for_pytorch package version should at least 2.4 for optimized performance.
52+
- Alternatively, you can rely on triton kernels for GPU, then you'll need to install [intel-xpu-backend-for-triton](https://github.com/intel/intel-xpu-backend-for-triton) along with compatible torch and transformers. Easiest way is to use [pre-built wheels](https://github.com/intel/intel-xpu-backend-for-triton?tab=readme-ov-file#install-pytorch-and-triton-from-nightly-wheels).
5253

5354
### Install from PyPi
5455

awq/models/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
exclude_layers_to_not_quantize,
3030
try_import,
3131
)
32-
from awq.utils.utils import get_best_device, ipex_available
32+
from awq.utils.utils import get_best_device, ipex_available, triton_available
3333
from transformers import (
3434
AutoConfig,
3535
PreTrainedModel,
@@ -499,7 +499,8 @@ def from_quantized(
499499
)
500500

501501
best_device = get_best_device()
502-
use_ipex = use_ipex or best_device in ["cpu", "xpu:0"]
502+
if best_device == "cpu" or (best_device == "xpu:0" and not triton_available):
503+
use_ipex = True
503504
if use_ipex and not ipex_available:
504505
raise ImportError(
505506
"Please install intel_extension_for_pytorch with "

awq/modules/linear/gemm.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
try:
1515
from awq.modules.triton.gemm import awq_gemm_triton, awq_dequantize_triton
1616

17-
# covers both CUDA and ROCm
18-
if torch.cuda.is_available():
19-
TRITON_AVAILABLE = True
17+
# covers CUDA, ROCm and XPU. If we can import triton, then we can use it.
18+
TRITON_AVAILABLE = True
2019

2120
except ImportError:
2221
TRITON_AVAILABLE = False

awq/modules/triton/gemm.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020

2121
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
2222

23+
def get_same_device_cm(t):
24+
if t.device.type == 'xpu':
25+
return torch.xpu.device(t.device.index)
26+
else:
27+
return torch.cuda.device(t.device.index)
28+
2329

2430
@triton.jit
2531
def awq_dequantize_kernel(
@@ -280,7 +286,7 @@ def awq_dequantize_triton(
280286
triton.cdiv(X, META["BLOCK_SIZE_X"]),
281287
triton.cdiv(Y, META["BLOCK_SIZE_Y"]),
282288
)
283-
with torch.cuda.device(qweight.device.index):
289+
with get_same_device_cm(qweight):
284290
awq_dequantize_kernel[grid](
285291
qweight,
286292
scales,
@@ -333,7 +339,7 @@ def awq_gemm_triton(
333339

334340
# A = input, B = qweight, C = result
335341
# A = M x K, B = K x N, C = M x N
336-
with torch.cuda.device(qweight.device.index):
342+
with get_same_device_cm(qweight):
337343
awq_gemm_kernel[grid](
338344
input,
339345
qweight,

awq/utils/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55

66

77
ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None
8+
try:
9+
import triton as tl
10+
triton_available = True
11+
except ImportError:
12+
triton_available = False
13+
814

915

1016
def get_module_by_name_suffix(model, module_name: str):

0 commit comments

Comments
 (0)