From d3a5306459262515dd4a3f6eb34cf826a840e166 Mon Sep 17 00:00:00 2001 From: Casper Date: Tue, 23 Jan 2024 23:04:19 +0100 Subject: [PATCH 1/6] Faster context processing [INITIAL] --- csrc/ops.h | 8 ++ csrc/pybind.cpp | 1 + csrc/quantization/awq/dequantize_kernels.cu | 120 ++++++++++++++++++ setup.py | 5 +- .../model_executor/layers/quantization/awq.py | 10 +- 5 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 csrc/quantization/awq/dequantize_kernels.cu diff --git a/csrc/ops.h b/csrc/ops.h index 9340a60da1417..d49619644b182 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -70,6 +70,14 @@ torch::Tensor awq_gemm( torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); + +torch::Tensor awq_dequantize( + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters, + int thx, + int thy); #endif void squeezellm_gemm( diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 95f557686f337..7f922ab0a0626 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -51,6 +51,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { #ifndef USE_ROCM // Quantization ops ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); + ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); diff --git a/csrc/quantization/awq/dequantize_kernels.cu b/csrc/quantization/awq/dequantize_kernels.cu new file mode 100644 index 0000000000000..d80c91b3ed867 --- /dev/null +++ b/csrc/quantization/awq/dequantize_kernels.cu @@ -0,0 +1,120 @@ +#include +#include +#include "dequantize.cuh" +#include +#include +#include + + +namespace vllm { +namespace awq { + +__global__ void __launch_bounds__(64) dequantize_weights( + int* __restrict__ B, // 4096x64 4096 rows 64 cols + half* __restrict__ scaling_factors, // 32x512 32 rows 512 cols + int* __restrict__ zeros, // 32x64 32 rows 64 cols + half* __restrict__ C, // 4096x512 4096 rows 512 cols + int G +) +{ + int j_factors1 = 4; + int row_stride2 = 4; + int split_k_iters = 1; + static constexpr uint32_t ZERO = 0x0; + half B_shared[32 * (128 + 8)]; + + half* B_shared_ptr2 = B_shared; + + half B_shared_warp[32]; + int OC = 512; + + int N = blockDim.x * gridDim.x; // 2 + int col = (blockIdx.x * blockDim.x + threadIdx.x); + int row = blockIdx.y * blockDim.y + threadIdx.y; + int index1 = 8 * col + 8 * row * N; // + i (<8) + half* C_ptr2 = C + index1; + + int index2 = col + row * N; + int* B_ptr2 = B + index2; + + int index3 = col + (int)(row / G) * N; + int* zeros_ptr2 = zeros + index3; + int index4 = 8 * col + (int)(row / G) * N * 8; // + i (<8) + half* scaling_factors_ptr2 = scaling_factors + index4; + + + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); +int j=0; + + uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + + *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16; + + for (int i=0; i<8; ++i) { + *(C_ptr2 + i) = B_shared[i]; + } +} + +} // namespace awq +} // namespace vllm + +// Dequantization to fp16 +torch::Tensor awq_dequantize( + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters, + int thx, + int thy) +{ + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + int G = in_c / _scaling_factors.size(0); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx==0) { + x_thread = qout_c; + } + if (thy==0) { + y_thread = in_c; + } + if (thx==0 && thy==0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } + + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + + auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); + at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); // row, col 4096x512 + + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_thread, y_thread); // col, row 64x4096 + + dequantize_weights<<>>(kernel, scaling_factors, zeros, de_kernel, G); + + return _de_kernel; +} diff --git a/setup.py b/setup.py index fb37a8d952314..b66c995793089 100644 --- a/setup.py +++ b/setup.py @@ -255,7 +255,10 @@ def get_torch_arch_list() -> Set[str]: ] if _is_cuda(): - vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") + vllm_extension_sources.extend([ + "csrc/quantization/awq/gemm_kernels.cu", + "csrc/quantization/awq/dequantize_kernels.cu", + ]) if not _is_neuron(): vllm_extension = CUDAExtension( diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 831576b1d7cd7..af767d0c13a2c 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -153,7 +153,15 @@ def apply_weights(self, pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) - out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) + + # batch_size*seq_len >= threshold + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + out = torch.matmul(reshaped_x, out) + else: + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: out = out + bias return out.reshape(out_shape) From bf62aed960fcaf4b465c9aac0187ce73d539cea2 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Wed, 24 Jan 2024 21:01:25 +0000 Subject: [PATCH 2/6] Consolidate and fix workspace --- csrc/quantization/awq/dequantize_kernels.cu | 120 -------------------- csrc/quantization/awq/gemm_kernels.cu | 108 ++++++++++++++++++ setup.py | 1 - 3 files changed, 108 insertions(+), 121 deletions(-) delete mode 100644 csrc/quantization/awq/dequantize_kernels.cu diff --git a/csrc/quantization/awq/dequantize_kernels.cu b/csrc/quantization/awq/dequantize_kernels.cu deleted file mode 100644 index d80c91b3ed867..0000000000000 --- a/csrc/quantization/awq/dequantize_kernels.cu +++ /dev/null @@ -1,120 +0,0 @@ -#include -#include -#include "dequantize.cuh" -#include -#include -#include - - -namespace vllm { -namespace awq { - -__global__ void __launch_bounds__(64) dequantize_weights( - int* __restrict__ B, // 4096x64 4096 rows 64 cols - half* __restrict__ scaling_factors, // 32x512 32 rows 512 cols - int* __restrict__ zeros, // 32x64 32 rows 64 cols - half* __restrict__ C, // 4096x512 4096 rows 512 cols - int G -) -{ - int j_factors1 = 4; - int row_stride2 = 4; - int split_k_iters = 1; - static constexpr uint32_t ZERO = 0x0; - half B_shared[32 * (128 + 8)]; - - half* B_shared_ptr2 = B_shared; - - half B_shared_warp[32]; - int OC = 512; - - int N = blockDim.x * gridDim.x; // 2 - int col = (blockIdx.x * blockDim.x + threadIdx.x); - int row = blockIdx.y * blockDim.y + threadIdx.y; - int index1 = 8 * col + 8 * row * N; // + i (<8) - half* C_ptr2 = C + index1; - - int index2 = col + row * N; - int* B_ptr2 = B + index2; - - int index3 = col + (int)(row / G) * N; - int* zeros_ptr2 = zeros + index3; - int index4 = 8 * col + (int)(row / G) * N * 8; // + i (<8) - half* scaling_factors_ptr2 = scaling_factors + index4; - - - uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); - uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); -int j=0; - - uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - - *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16; - - for (int i=0; i<8; ++i) { - *(C_ptr2 + i) = B_shared[i]; - } -} - -} // namespace awq -} // namespace vllm - -// Dequantization to fp16 -torch::Tensor awq_dequantize( - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters, - int thx, - int thy) -{ - int in_c = _kernel.size(0); - int qout_c = _kernel.size(1); - int out_c = qout_c * 8; - int G = in_c / _scaling_factors.size(0); - - int x_thread = thx; - int y_thread = thy; - - int x_blocks = 1; - int y_blocks = 1; - if (thx==0) { - x_thread = qout_c; - } - if (thy==0) { - y_thread = in_c; - } - if (thx==0 && thy==0) { - x_thread = 8; - y_thread = 8; - x_blocks = (int)(qout_c / 8); - y_blocks = (int)(in_c / 8); - } - - const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - - auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); // row, col 4096x512 - - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - - dim3 num_blocks(x_blocks, y_blocks); - dim3 threads_per_block(x_thread, y_thread); // col, row 64x4096 - - dequantize_weights<<>>(kernel, scaling_factors, zeros, de_kernel, G); - - return _de_kernel; -} diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu index 04dfe8fe9b889..376c8ebfb9b7a 100644 --- a/csrc/quantization/awq/gemm_kernels.cu +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in #endif } +__global__ void __launch_bounds__(64) dequantize_weights( + int* __restrict__ B, + half* __restrict__ scaling_factors, + int* __restrict__ zeros, + half* __restrict__ C, + int G +) +{ + int j_factors1 = 4; + int row_stride2 = 4; + int split_k_iters = 1; + static constexpr uint32_t ZERO = 0x0; + half B_shared[32 * (128 + 8)]; + + half* B_shared_ptr2 = B_shared; + + half B_shared_warp[32]; + int OC = 512; + + int N = blockDim.x * gridDim.x; // 2 + int col = (blockIdx.x * blockDim.x + threadIdx.x); + int row = blockIdx.y * blockDim.y + threadIdx.y; + int index1 = 8 * col + 8 * row * N; + half* C_ptr2 = C + index1; + + int index2 = col + row * N; + int* B_ptr2 = B + index2; + + int index3 = col + (int)(row / G) * N; + int* zeros_ptr2 = zeros + index3; + int index4 = 8 * col + (int)(row / G) * N * 8; + half* scaling_factors_ptr2 = scaling_factors + index4; + + + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); +int j=0; + + uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + + *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16; + + for (int i=0; i<8; ++i) { + *(C_ptr2 + i) = B_shared[i]; + } +} + } // namespace awq } // namespace vllm +torch::Tensor awq_dequantize( + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters, + int thx, + int thy) +{ + int in_c = _kernel.size(0); + int qout_c = _kernel.size(1); + int out_c = qout_c * 8; + int G = in_c / _scaling_factors.size(0); + + int x_thread = thx; + int y_thread = thy; + + int x_blocks = 1; + int y_blocks = 1; + if (thx==0) { + x_thread = qout_c; + } + if (thy==0) { + y_thread = in_c; + } + if (thx==0 && thy==0) { + x_thread = 8; + y_thread = 8; + x_blocks = (int)(qout_c / 8); + y_blocks = (int)(in_c / 8); + } + + const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); + + auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); + at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); + + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + + dim3 num_blocks(x_blocks, y_blocks); + dim3 threads_per_block(x_thread, y_thread); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + vllm::awq::dequantize_weights<<>>( + kernel, scaling_factors, zeros, de_kernel, G); + + return _de_kernel; +} + // in_feats: M, IC [float16] // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] // scaling_factors: IC // G, OC [float16] diff --git a/setup.py b/setup.py index b66c995793089..267dc970e12e7 100644 --- a/setup.py +++ b/setup.py @@ -257,7 +257,6 @@ def get_torch_arch_list() -> Set[str]: if _is_cuda(): vllm_extension_sources.extend([ "csrc/quantization/awq/gemm_kernels.cu", - "csrc/quantization/awq/dequantize_kernels.cu", ]) if not _is_neuron(): From d10313b664976d53f5634e0dd4a9522347ae14f6 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Wed, 24 Jan 2024 21:13:22 +0000 Subject: [PATCH 3/6] Adjust heuristic to 256 --- vllm/model_executor/layers/quantization/awq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index af767d0c13a2c..e2a7c4390877c 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -155,7 +155,7 @@ def apply_weights(self, reshaped_x = x.reshape(-1, x.shape[-1]) # batch_size*seq_len >= threshold - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024 + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 256 if FP16_MATMUL_HEURISTIC_CONDITION: out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) From 1b6521f7d4522896321c0332f92f3aa096e44751 Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 24 Jan 2024 22:18:24 +0100 Subject: [PATCH 4/6] Formatting --- vllm/model_executor/layers/quantization/awq.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index e2a7c4390877c..5571c7e63a3d3 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -153,7 +153,7 @@ def apply_weights(self, pack_factor = self.quant_config.pack_factor out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) - + # batch_size*seq_len >= threshold FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 256 @@ -161,7 +161,8 @@ def apply_weights(self, out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) out = torch.matmul(reshaped_x, out) else: - out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, + pack_factor) if bias is not None: out = out + bias return out.reshape(out_shape) From 414908615ebe23b73e3c07df7208cffa749d0ebb Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 24 Jan 2024 22:22:27 +0100 Subject: [PATCH 5/6] Formatting (again) --- vllm/model_executor/layers/quantization/awq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 5571c7e63a3d3..4d80bea676a67 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -155,7 +155,7 @@ def apply_weights(self, reshaped_x = x.reshape(-1, x.shape[-1]) # batch_size*seq_len >= threshold - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 256 + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 256 if FP16_MATMUL_HEURISTIC_CONDITION: out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) From 77811c60cf1617298fdeda01c1d9d176ed2056ef Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Fri, 26 Jan 2024 10:53:17 +0000 Subject: [PATCH 6/6] Apply code suggestions --- setup.py | 4 +--- vllm/model_executor/layers/quantization/awq.py | 4 ++-- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/setup.py b/setup.py index 267dc970e12e7..fb37a8d952314 100644 --- a/setup.py +++ b/setup.py @@ -255,9 +255,7 @@ def get_torch_arch_list() -> Set[str]: ] if _is_cuda(): - vllm_extension_sources.extend([ - "csrc/quantization/awq/gemm_kernels.cu", - ]) + vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") if not _is_neuron(): vllm_extension = CUDAExtension( diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 4d80bea676a67..4d3fd3ec0cc71 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -154,8 +154,8 @@ def apply_weights(self, out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) - # batch_size*seq_len >= threshold - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 256 + # num_tokens >= threshold + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 if FP16_MATMUL_HEURISTIC_CONDITION: out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)