@@ -47,36 +47,110 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
47
47
1 ;
48
48
}
49
49
50
+ enum mmvq_parameter_table_id {
51
+ MMVQ_PARAMETERS_GENERIC = 0 ,
52
+ MMVQ_PARAMETERS_GCN,
53
+ MMVQ_PARAMETERS_RDNA2
54
+ };
55
+
56
+ static constexpr __device__ mmvq_parameter_table_id get_device_table_id () {
57
+ #if defined(RDNA2) || defined(RDNA3)
58
+ return MMVQ_PARAMETERS_RDNA2;
59
+ #elif defined(GCN) || defined(CDNA)
60
+ return MMVQ_PARAMETERS_GCN;
61
+ #else
62
+ return MMVQ_PARAMETERS_GENERIC;
63
+ #endif
64
+ }
65
+
66
+ static __host__ mmvq_parameter_table_id get_device_table_id (int cc) {
67
+ if (GGML_CUDA_CC_IS_RDNA2 (cc) || GGML_CUDA_CC_IS_RDNA3 (cc)) {
68
+ return MMVQ_PARAMETERS_RDNA2;
69
+ }
70
+ if (GGML_CUDA_CC_IS_GCN (cc) || GGML_CUDA_CC_IS_CDNA (cc)) {
71
+ return MMVQ_PARAMETERS_GCN;
72
+ }
73
+ return MMVQ_PARAMETERS_GENERIC;
74
+ }
75
+
76
+ static constexpr __host__ __device__ int calc_nwarps (int ncols_y, mmvq_parameter_table_id table_id) {
77
+ if (table_id == MMVQ_PARAMETERS_GENERIC) {
78
+ switch (ncols_y) {
79
+ case 1 :
80
+ case 2 :
81
+ case 3 :
82
+ case 4 :
83
+ return 4 ;
84
+ case 5 :
85
+ case 6 :
86
+ case 7 :
87
+ case 8 :
88
+ return 2 ;
89
+ default :
90
+ return 1 ;
91
+ }
92
+ } else if (table_id == MMVQ_PARAMETERS_GCN) {
93
+ switch (ncols_y) {
94
+ case 1 :
95
+ case 2 :
96
+ case 3 :
97
+ case 4 :
98
+ return 2 ;
99
+ case 5 :
100
+ case 6 :
101
+ case 7 :
102
+ case 8 :
103
+ default :
104
+ return 1 ;
105
+ }
106
+ }
107
+ return 1 ;
108
+ }
109
+
110
+ static constexpr __host__ __device__ int calc_rows_per_block (int ncols_y, int table_id) {
111
+ if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
112
+ switch (ncols_y) {
113
+ case 1 :
114
+ return 1 ;
115
+ case 2 :
116
+ case 3 :
117
+ case 4 :
118
+ case 5 :
119
+ case 6 :
120
+ case 7 :
121
+ case 8 :
122
+ return 2 ;
123
+ default :
124
+ return 1 ;
125
+ }
126
+ }
127
+ return 1 ;
128
+ }
129
+
50
130
template <ggml_type type, int ncols_y>
51
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52
131
// tell the compiler to use as many registers as it wants, see nwarps definition below
53
- __launch_bounds__ ((ncols_y <= 4 ? 4 : 2 )*WARP_SIZE, 1)
54
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
132
+ __launch_bounds__ (calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
55
133
static __global__ void mul_mat_vec_q(
56
134
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57
135
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
58
136
59
137
constexpr int qk = ggml_cuda_type_traits<type>::qk;
60
138
constexpr int qi = ggml_cuda_type_traits<type>::qi;
61
139
constexpr int vdr = get_vdr_mmvq (type);
140
+ constexpr mmvq_parameter_table_id table_id = get_device_table_id ();
141
+ constexpr int nwarps = calc_nwarps (ncols_y, table_id);
142
+ constexpr int rows_per_cuda_block = calc_rows_per_block (ncols_y, table_id);
143
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size ();
62
144
63
145
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda (type);
64
146
65
- #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66
- constexpr int nwarps = 1 ;
67
- constexpr int rows_per_cuda_block = 1 ;
68
- #else
69
- constexpr int nwarps = ncols_y <= 4 ? 4 : 2 ;
70
- constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2 ;
71
- #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72
-
73
- const int tid = WARP_SIZE*threadIdx .y + threadIdx .x ;
147
+ const int tid = warp_size*threadIdx .y + threadIdx .x ;
74
148
const int row0 = rows_per_cuda_block*blockIdx .x ;
75
149
const int blocks_per_row_x = ncols_x / qk;
76
150
const int blocks_per_col_y = nrows_y / QK8_1;
77
- constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
151
+ constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
78
152
79
- // partial sum for each thread
153
+ // partial sum for each thread
80
154
float tmp[ncols_y][rows_per_cuda_block] = {0 .0f };
81
155
82
156
const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
96
170
}
97
171
}
98
172
99
- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][WARP_SIZE ];
173
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1 ][ncols_y][rows_per_cuda_block][warp_size ];
100
174
if (threadIdx .y > 0 ) {
101
175
#pragma unroll
102
176
for (int j = 0 ; j < ncols_y; ++j) {
@@ -120,7 +194,7 @@ static __global__ void mul_mat_vec_q(
120
194
for (int l = 0 ; l < nwarps-1 ; ++l) {
121
195
tmp[j][i] += tmp_shared[l][j][i][threadIdx .x ];
122
196
}
123
- tmp[j][i] = warp_reduce_sum (tmp[j][i]);
197
+ tmp[j][i] = warp_reduce_sum<warp_size> (tmp[j][i]);
124
198
}
125
199
126
200
if (threadIdx .x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx .x < nrows_dst)) {
@@ -129,6 +203,13 @@ static __global__ void mul_mat_vec_q(
129
203
}
130
204
}
131
205
206
+ static std::pair<dim3 , dim3 > calc_launch_params (const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
207
+ const int64_t nblocks = (nrows_x + calc_rows_per_block (ncols_y, table_id) - 1 ) / calc_rows_per_block (ncols_y, table_id);
208
+ const dim3 block_nums (nblocks, 1 , 1 );
209
+ const dim3 block_dims (warp_size, calc_nwarps (ncols_y, table_id), 1 );
210
+ return {block_nums, block_dims};
211
+ }
212
+
132
213
template <ggml_type type>
133
214
static void mul_mat_vec_q_cuda (
134
215
const void * vx, const void * vy, float * dst,
@@ -137,65 +218,67 @@ static void mul_mat_vec_q_cuda(
137
218
GGML_ASSERT (ncols_x % ggml_blck_size (type) == 0 );
138
219
GGML_ASSERT (ncols_y <= MMVQ_MAX_BATCH_SIZE);
139
220
140
- int id = ggml_cuda_get_device ();
141
-
142
- int64_t nwarps = 1 ;
143
- int64_t rows_per_cuda_block = 1 ;
144
-
145
- if (ggml_cuda_info ().devices [id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146
- switch (ncols_y) {
147
- case 1 :
148
- nwarps = 4 ;
149
- rows_per_cuda_block = 1 ;
150
- break ;
151
- case 2 :
152
- case 3 :
153
- case 4 :
154
- nwarps = 4 ;
155
- rows_per_cuda_block = 2 ;
156
- break ;
157
- case 5 :
158
- case 6 :
159
- case 7 :
160
- case 8 :
161
- nwarps = 2 ;
162
- rows_per_cuda_block = 2 ;
163
- break ;
164
- default :
165
- GGML_ABORT (" fatal error" );
166
- break ;
167
- }
168
- }
169
-
170
- const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1 ) / rows_per_cuda_block;
171
- const dim3 block_nums (nblocks, 1 , 1 );
172
- const dim3 block_dims (WARP_SIZE, nwarps, 1 );
221
+ const int device = ggml_cuda_get_device ();
222
+ const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
223
+ const mmvq_parameter_table_id table_id = get_device_table_id (ggml_cuda_info ().devices [device].cc );
173
224
174
225
switch (ncols_y) {
175
226
case 1 :
176
- mul_mat_vec_q<type, 1 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
227
+ {
228
+ constexpr int c_ncols_y = 1 ;
229
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
230
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
177
231
break ;
232
+ }
178
233
case 2 :
179
- mul_mat_vec_q<type, 2 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
234
+ {
235
+ constexpr int c_ncols_y = 2 ;
236
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
237
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
180
238
break ;
239
+ }
181
240
case 3 :
182
- mul_mat_vec_q<type, 3 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
241
+ {
242
+ constexpr int c_ncols_y = 3 ;
243
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
244
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
183
245
break ;
246
+ }
184
247
case 4 :
185
- mul_mat_vec_q<type, 4 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
248
+ {
249
+ constexpr int c_ncols_y = 4 ;
250
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
251
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
186
252
break ;
253
+ }
187
254
case 5 :
188
- mul_mat_vec_q<type, 5 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
255
+ {
256
+ constexpr int c_ncols_y = 5 ;
257
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
258
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
189
259
break ;
260
+ }
190
261
case 6 :
191
- mul_mat_vec_q<type, 6 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
262
+ {
263
+ constexpr int c_ncols_y = 6 ;
264
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
265
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
192
266
break ;
267
+ }
193
268
case 7 :
194
- mul_mat_vec_q<type, 7 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
269
+ {
270
+ constexpr int c_ncols_y = 7 ;
271
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
272
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
195
273
break ;
274
+ }
196
275
case 8 :
197
- mul_mat_vec_q<type, 8 ><<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
276
+ {
277
+ constexpr int c_ncols_y = 8 ;
278
+ std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_y, nrows_x, warp_size, table_id);
279
+ mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0 , stream>>> (vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
198
280
break ;
281
+ }
199
282
default :
200
283
GGML_ABORT (" fatal error" );
201
284
break ;
0 commit comments