Skip to content

Commit 34c961b

Browse files
authored
CUDA/HIP: Fix fattn-vec-* when device warp size is not 32 (#12315)
When fattn-wmma was ported over to warp64 various bits that also touch fattn-vec where converted to selectable warp size, however the fattn-vec kernels dont work with 64 wide warps for now, so we need to avoid launching them with parameters for warp64
1 parent 7841fc7 commit 34c961b

File tree

2 files changed

+26
-33
lines changed

2 files changed

+26
-33
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

+22-30
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)(
5252
typedef float (*vec_dot_KQ_f32_t)(
5353
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
5454

55-
template<typename T, int D>
55+
template<typename T, int D, int warp_size>
5656
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
5757
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
5858

5959
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
6160
GGML_UNUSED(Q_v);
6261

6362
T sum = 0.0f;
@@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
9392
return sum;
9493
}
9594

96-
template<typename T, int D>
95+
template<typename T, int D, int warp_size>
9796
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
9897
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
9998

10099
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
101-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
102100
GGML_UNUSED(Q_v);
103101

104102
T sum = 0.0f;
@@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
138136
return sum;
139137
}
140138

141-
template<typename T, int D>
139+
template<typename T, int D, int warp_size>
142140
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
143141
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
144142

145143
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
146-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
147144
GGML_UNUSED(Q_v);
148145

149146
T sum = 0.0f;
@@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
186183
return sum;
187184
}
188185

189-
template<typename T, int D>
186+
template<typename T, int D, int warp_size>
190187
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
191188
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
192189

193190
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
194-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
195191
GGML_UNUSED(Q_v);
196192

197193
T sum = 0.0f;
@@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
238234
return sum;
239235
}
240236

241-
template <typename T, int D>
237+
template <typename T, int D, int warp_size>
242238
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
243239
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
244240

245241
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
246-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
247242
GGML_UNUSED(Q_v);
248243

249244
T sum = 0.0f;
@@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
272267
return sum;
273268
}
274269

275-
template <typename T, int D>
270+
template <typename T, int D, int warp_size>
276271
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
277272
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
278273

279274
const half2 * K_h2 = (const half2 *) K_c;
280-
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
281275
GGML_UNUSED(Q_q8);
282276
GGML_UNUSED(Q_ds_v);
283277

@@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
480474
return x[i];
481475
}
482476

483-
template <int D>
477+
template <int D, int warp_size = WARP_SIZE>
484478
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
485-
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
486-
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
487-
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
488-
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
489-
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
490-
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
479+
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> :
480+
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> :
481+
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> :
482+
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
483+
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
484+
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
491485
nullptr;
492486
}
493487

494-
template <int D>
488+
template <int D, int warp_size = WARP_SIZE>
495489
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
496-
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
497-
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
498-
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
499-
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
500-
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
501-
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
490+
return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> :
491+
type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> :
492+
type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> :
493+
type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
494+
type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
495+
type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
502496
nullptr;
503497
}
504498

@@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) {
681675
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
682676
void launch_fattn(
683677
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
684-
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
678+
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
679+
const int warp_size = WARP_SIZE
685680
) {
686681
constexpr int ncols = ncols1 * ncols2;
687682

@@ -704,8 +699,6 @@ void launch_fattn(
704699

705700
GGML_ASSERT(Q->ne[3] == 1);
706701

707-
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
708-
709702
ggml_cuda_pool & pool = ctx.pool();
710703
cudaStream_t main_stream = ctx.stream();
711704
const int id = ggml_cuda_get_device();
@@ -805,7 +798,6 @@ void launch_fattn(
805798
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
806799

807800
GGML_ASSERT(block_dim.x % warp_size == 0);
808-
GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
809801
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
810802
(const char *) Q->data,
811803
K_data,

ggml/src/ggml-cuda/fattn-wmma-f16.cu

+4-3
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
469469
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
470470
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
471471
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
472+
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
472473

473474
float logit_softcap;
474475
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
@@ -485,7 +486,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
485486
fattn_kernel = flash_attn_ext_f16<
486487
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
487488
}
488-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
489+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
489490
return;
490491
}
491492
if (2*blocks_num_pb1 < 2*nsm) {
@@ -500,7 +501,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
500501
fattn_kernel = flash_attn_ext_f16<
501502
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
502503
}
503-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
504+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
504505
return;
505506
}
506507
constexpr int parallel_blocks = 1;
@@ -514,7 +515,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
514515
fattn_kernel = flash_attn_ext_f16<
515516
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
516517
}
517-
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
518+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
518519
}
519520

520521
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

0 commit comments

Comments
 (0)