Skip to content

Commit f835379

Browse files
Workaround: illegal memory access (#421)
1 parent b5db7fc commit f835379

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

awq/modules/linear/gemv_fast.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ def from_linear(
189189
@torch.no_grad()
190190
def forward(self, x):
191191
inputs = x
192-
if inputs.numel() / inputs.shape[-1] < 8:
192+
batch_size, n_tokens, _ = inputs.shape
193+
if batch_size < 8 and n_tokens == 1:
193194
out = awq_v2_ext.gemv_forward_cuda_decode(
194195
inputs,
195196
self.qweight,

0 commit comments

Comments
 (0)