@@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)(
52
52
typedef float (*vec_dot_KQ_f32_t)(
53
53
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
54
54
55
- template <typename T, int D>
55
+ template <typename T, int D, int warp_size >
56
56
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0 (
57
57
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
58
58
59
59
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60
- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
61
60
GGML_UNUSED (Q_v);
62
61
63
62
T sum = 0 .0f ;
@@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
93
92
return sum;
94
93
}
95
94
96
- template <typename T, int D>
95
+ template <typename T, int D, int warp_size >
97
96
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1 (
98
97
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
99
98
100
99
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
101
- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
102
100
GGML_UNUSED (Q_v);
103
101
104
102
T sum = 0 .0f ;
@@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
138
136
return sum;
139
137
}
140
138
141
- template <typename T, int D>
139
+ template <typename T, int D, int warp_size >
142
140
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0 (
143
141
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
144
142
145
143
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
146
- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
147
144
GGML_UNUSED (Q_v);
148
145
149
146
T sum = 0 .0f ;
@@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
186
183
return sum;
187
184
}
188
185
189
- template <typename T, int D>
186
+ template <typename T, int D, int warp_size >
190
187
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1 (
191
188
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
192
189
193
190
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
194
- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
195
191
GGML_UNUSED (Q_v);
196
192
197
193
T sum = 0 .0f ;
@@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
238
234
return sum;
239
235
}
240
236
241
- template <typename T, int D>
237
+ template <typename T, int D, int warp_size >
242
238
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0 (
243
239
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
244
240
245
241
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
246
- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
247
242
GGML_UNUSED (Q_v);
248
243
249
244
T sum = 0 .0f ;
@@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
272
267
return sum;
273
268
}
274
269
275
- template <typename T, int D>
270
+ template <typename T, int D, int warp_size >
276
271
static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16 (
277
272
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
278
273
279
274
const half2 * K_h2 = (const half2 *) K_c;
280
- constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
281
275
GGML_UNUSED (Q_q8);
282
276
GGML_UNUSED (Q_ds_v);
283
277
@@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
480
474
return x[i];
481
475
}
482
476
483
- template <int D>
477
+ template <int D, int warp_size = WARP_SIZE >
484
478
constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16 (ggml_type type_K) {
485
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
486
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
487
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
488
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
489
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
490
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
479
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size > :
480
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size > :
481
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size > :
482
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size > :
483
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size > :
484
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size > :
491
485
nullptr ;
492
486
}
493
487
494
- template <int D>
488
+ template <int D, int warp_size = WARP_SIZE >
495
489
constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32 (ggml_type type_K) {
496
- return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D> :
497
- type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D> :
498
- type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D> :
499
- type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D> :
500
- type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D> :
501
- type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D> :
490
+ return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float , D, warp_size > :
491
+ type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float , D, warp_size > :
492
+ type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float , D, warp_size > :
493
+ type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float , D, warp_size > :
494
+ type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float , D, warp_size > :
495
+ type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float , D, warp_size > :
502
496
nullptr ;
503
497
}
504
498
@@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) {
681
675
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
682
676
void launch_fattn (
683
677
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
684
- const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
678
+ const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
679
+ const int warp_size = WARP_SIZE
685
680
) {
686
681
constexpr int ncols = ncols1 * ncols2;
687
682
@@ -704,8 +699,6 @@ void launch_fattn(
704
699
705
700
GGML_ASSERT (Q->ne [3 ] == 1 );
706
701
707
- const int warp_size = ggml_cuda_info ().devices [ctx.device ].warp_size ;
708
-
709
702
ggml_cuda_pool & pool = ctx.pool ();
710
703
cudaStream_t main_stream = ctx.stream ();
711
704
const int id = ggml_cuda_get_device ();
@@ -805,7 +798,6 @@ void launch_fattn(
805
798
const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
806
799
807
800
GGML_ASSERT (block_dim.x % warp_size == 0 );
808
- GGML_ASSERT (!GGML_CUDA_CC_IS_AMD (cc) || block_dim.x * block_dim.y <= 4 * (unsigned int )warp_size);
809
801
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>> (
810
802
(const char *) Q->data ,
811
803
K_data,
0 commit comments