Skip to content

Commit 10f2e81

Browse files
CUDA/HIP: refractor mmqv to unify the calculation of nwarps and rows per block between host and device code. (#12177)
refactor mmqv to unify the calculation of nwarps and rows per block between host and device code. --------- Co-authored-by: Johannes Gäßler <[email protected]>
1 parent ba76543 commit 10f2e81

File tree

2 files changed

+142
-59
lines changed

2 files changed

+142
-59
lines changed

ggml/src/ggml-cuda/common.cuh

+2-2
Original file line numberDiff line numberDiff line change
@@ -395,11 +395,11 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
395395

396396
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
397397
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
398-
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
398+
#if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
399399
c = __builtin_amdgcn_sdot4(a, b, c, false);
400400
#elif defined(RDNA3)
401401
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
402-
#elif defined(__gfx1010__) || defined(__gfx900__)
402+
#elif defined(RDNA1) || defined(__gfx900__)
403403
int tmp1;
404404
int tmp2;
405405
asm("\n \

ggml/src/ggml-cuda/mmvq.cu

+140-57
Original file line numberDiff line numberDiff line change
@@ -47,36 +47,110 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
4747
1;
4848
}
4949

50+
enum mmvq_parameter_table_id {
51+
MMVQ_PARAMETERS_GENERIC = 0,
52+
MMVQ_PARAMETERS_GCN,
53+
MMVQ_PARAMETERS_RDNA2
54+
};
55+
56+
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
57+
#if defined(RDNA2) || defined(RDNA3)
58+
return MMVQ_PARAMETERS_RDNA2;
59+
#elif defined(GCN) || defined(CDNA)
60+
return MMVQ_PARAMETERS_GCN;
61+
#else
62+
return MMVQ_PARAMETERS_GENERIC;
63+
#endif
64+
}
65+
66+
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
67+
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
68+
return MMVQ_PARAMETERS_RDNA2;
69+
}
70+
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
71+
return MMVQ_PARAMETERS_GCN;
72+
}
73+
return MMVQ_PARAMETERS_GENERIC;
74+
}
75+
76+
static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) {
77+
if (table_id == MMVQ_PARAMETERS_GENERIC) {
78+
switch (ncols_y) {
79+
case 1:
80+
case 2:
81+
case 3:
82+
case 4:
83+
return 4;
84+
case 5:
85+
case 6:
86+
case 7:
87+
case 8:
88+
return 2;
89+
default:
90+
return 1;
91+
}
92+
} else if (table_id == MMVQ_PARAMETERS_GCN) {
93+
switch (ncols_y) {
94+
case 1:
95+
case 2:
96+
case 3:
97+
case 4:
98+
return 2;
99+
case 5:
100+
case 6:
101+
case 7:
102+
case 8:
103+
default:
104+
return 1;
105+
}
106+
}
107+
return 1;
108+
}
109+
110+
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
111+
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
112+
switch (ncols_y) {
113+
case 1:
114+
return 1;
115+
case 2:
116+
case 3:
117+
case 4:
118+
case 5:
119+
case 6:
120+
case 7:
121+
case 8:
122+
return 2;
123+
default:
124+
return 1;
125+
}
126+
}
127+
return 1;
128+
}
129+
50130
template <ggml_type type, int ncols_y>
51-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52131
// tell the compiler to use as many registers as it wants, see nwarps definition below
53-
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
54-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
132+
__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
55133
static __global__ void mul_mat_vec_q(
56134
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57135
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
58136

59137
constexpr int qk = ggml_cuda_type_traits<type>::qk;
60138
constexpr int qi = ggml_cuda_type_traits<type>::qi;
61139
constexpr int vdr = get_vdr_mmvq(type);
140+
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
141+
constexpr int nwarps = calc_nwarps(ncols_y, table_id);
142+
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
143+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
62144

63145
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
64146

65-
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66-
constexpr int nwarps = 1;
67-
constexpr int rows_per_cuda_block = 1;
68-
#else
69-
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
70-
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
71-
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72-
73-
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
147+
const int tid = warp_size*threadIdx.y + threadIdx.x;
74148
const int row0 = rows_per_cuda_block*blockIdx.x;
75149
const int blocks_per_row_x = ncols_x / qk;
76150
const int blocks_per_col_y = nrows_y / QK8_1;
77-
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
151+
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
78152

79-
// partial sum for each thread
153+
// partial sum for each thread
80154
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
81155

82156
const block_q8_1 * y = (const block_q8_1 *) vy;
@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
96170
}
97171
}
98172

99-
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
173+
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
100174
if (threadIdx.y > 0) {
101175
#pragma unroll
102176
for (int j = 0; j < ncols_y; ++j) {
@@ -120,7 +194,7 @@ static __global__ void mul_mat_vec_q(
120194
for (int l = 0; l < nwarps-1; ++l) {
121195
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
122196
}
123-
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
197+
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
124198
}
125199

126200
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
@@ -129,6 +203,13 @@ static __global__ void mul_mat_vec_q(
129203
}
130204
}
131205

206+
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
207+
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
208+
const dim3 block_nums(nblocks, 1, 1);
209+
const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
210+
return {block_nums, block_dims};
211+
}
212+
132213
template <ggml_type type>
133214
static void mul_mat_vec_q_cuda(
134215
const void * vx, const void * vy, float * dst,
@@ -137,65 +218,67 @@ static void mul_mat_vec_q_cuda(
137218
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
138219
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
139220

140-
int id = ggml_cuda_get_device();
141-
142-
int64_t nwarps = 1;
143-
int64_t rows_per_cuda_block = 1;
144-
145-
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
146-
switch(ncols_y) {
147-
case 1:
148-
nwarps = 4;
149-
rows_per_cuda_block = 1;
150-
break;
151-
case 2:
152-
case 3:
153-
case 4:
154-
nwarps = 4;
155-
rows_per_cuda_block = 2;
156-
break;
157-
case 5:
158-
case 6:
159-
case 7:
160-
case 8:
161-
nwarps = 2;
162-
rows_per_cuda_block = 2;
163-
break;
164-
default:
165-
GGML_ABORT("fatal error");
166-
break;
167-
}
168-
}
169-
170-
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
171-
const dim3 block_nums(nblocks, 1, 1);
172-
const dim3 block_dims(WARP_SIZE, nwarps, 1);
221+
const int device = ggml_cuda_get_device();
222+
const int warp_size = ggml_cuda_info().devices[device].warp_size;
223+
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
173224

174225
switch (ncols_y) {
175226
case 1:
176-
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
227+
{
228+
constexpr int c_ncols_y = 1;
229+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
230+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
177231
break;
232+
}
178233
case 2:
179-
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
234+
{
235+
constexpr int c_ncols_y = 2;
236+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
237+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
180238
break;
239+
}
181240
case 3:
182-
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
241+
{
242+
constexpr int c_ncols_y = 3;
243+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
244+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
183245
break;
246+
}
184247
case 4:
185-
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
248+
{
249+
constexpr int c_ncols_y = 4;
250+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
251+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
186252
break;
253+
}
187254
case 5:
188-
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
255+
{
256+
constexpr int c_ncols_y = 5;
257+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
258+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
189259
break;
260+
}
190261
case 6:
191-
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
262+
{
263+
constexpr int c_ncols_y = 6;
264+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
265+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
192266
break;
267+
}
193268
case 7:
194-
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
269+
{
270+
constexpr int c_ncols_y = 7;
271+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
272+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
195273
break;
274+
}
196275
case 8:
197-
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
276+
{
277+
constexpr int c_ncols_y = 8;
278+
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
279+
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
198280
break;
281+
}
199282
default:
200283
GGML_ABORT("fatal error");
201284
break;

0 commit comments

Comments
 (0)