Skip to content

Commit 6abd34d

Browse files
committed
matmul fallback kernel
1 parent da0203e commit 6abd34d

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

magnetron/magnetron.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -3998,7 +3998,7 @@ static MAG_COLDPROC void mag_graphviz_dump(const mag_tensor_t* node, FILE *fp, m
39983998
for (unsigned i=0; i < MAG_MAX_INPUT_TENSORS; ++i) {
39993999
mag_tensor_t* input = node->op_inputs[i];
40004000
if (!input) continue;
4001-
char name[64];
4001+
char name[128];
40024002
if (*input->name) snprintf(name, sizeof(name), " in %u (%s)", i, input->name);
40034003
else snprintf(name, sizeof(name), " in %u", i);
40044004
fprintf(fp, " \"%p\" -> \"%p\" [label=\"%s\"];\n", (void*)input, (void*)node, name);

magnetron/magnetron_cpu_blas.inl

+58
Original file line numberDiff line numberDiff line change
@@ -1448,6 +1448,64 @@ static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload
14481448
}
14491449
}
14501450

1451+
#else
1452+
1453+
static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload, mag_kernel_context_t* ctx) {
1454+
mag_tensor_t* r = payload->node;
1455+
const mag_tensor_t* x = r->op_inputs[0];
1456+
const mag_tensor_t* y = r->op_inputs[1];
1457+
mag_f32_t* br = mag_f32p_mut(r);
1458+
const mag_f32_t* bx = mag_f32p(x);
1459+
const mag_f32_t* by = mag_f32p(y);
1460+
mag_load_local_storage_group(r, rd, shape);
1461+
mag_load_local_storage_group(r, rs, strides);
1462+
mag_load_local_storage_group(x, xd, shape);
1463+
mag_load_local_storage_group(x, xs, strides);
1464+
mag_load_local_storage_group(y, yd, shape);
1465+
mag_load_local_storage_group(y, ys, strides);
1466+
int64_t ti = payload->thread_idx;
1467+
if (ti != 0) return;
1468+
mag_assert2(mag_tensor_is_contiguous(x) && mag_tensor_is_contiguous(y) && mag_tensor_is_contiguous(r));
1469+
bool trans_a = mag_tensor_is_transposed(x);
1470+
if (x->op == MAG_OP_CLONE && x->op_inputs[0]) trans_a |= mag_tensor_is_transposed(x->op_inputs[0]);
1471+
memset(br, 0, mag_tensor_data_size(r));
1472+
int64_t b2 = yd2/xd2;
1473+
int64_t b3 = yd3/xd3;
1474+
int64_t b4 = yd4/xd4;
1475+
int64_t b5 = yd5/xd5;
1476+
for (int64_t i5=0; i5 < xd5; ++i5) {
1477+
for (int64_t i4=0; i4 < xd4; ++i4) {
1478+
for (int64_t i3=0; i3 < xd3; ++i3) {
1479+
for (int64_t i2=0; i2 < xd2; ++i2) {
1480+
int64_t xi5 = i5/b5;
1481+
int64_t xi4 = i4/b4;
1482+
int64_t xi3 = i3/b3;
1483+
int64_t xi2 = i2/b2;
1484+
const mag_f32_t* px = bx + xi5*xs5 + xi4*xs4 + xi3*xs3 + xi2*xs2;
1485+
const mag_f32_t* py = by + i5*ys5 + i4*ys4 + i3*ys3 + i2*ys2;
1486+
mag_f32_t* pr = br + i5*rs5 + i4*rs4 + i3*rs3 + i2*rs2;
1487+
mag_bnd_chk(pr, br, mag_tensor_data_size(r));
1488+
mag_bnd_chk(px, bx, mag_tensor_data_size(x));
1489+
mag_bnd_chk(py, by, mag_tensor_data_size(y));
1490+
for (int64_t i = 0; i < x->numel; ++i) { /* Rows */
1491+
for (int64_t k = 0; k < xd1; ++k) { /* Inner dim */
1492+
const mag_f32_t* ppx = px + (trans_a ? k*xd0 + i : xd1*i + k);
1493+
mag_bnd_chk(px, bx, mag_tensor_data_size(x));
1494+
for (int64_t j = 0; j < yd1; ++j) { /* Columns */
1495+
mag_f32_t* ppr = br + rd1*i + j;
1496+
const mag_f32_t* ppy = py + yd1*k + j;
1497+
mag_bnd_chk(pr, br, mag_tensor_data_size(r));
1498+
mag_bnd_chk(py, by, mag_tensor_data_size(y));
1499+
*ppr += (*ppx) * (*ppy);
1500+
}
1501+
}
1502+
}
1503+
}
1504+
}
1505+
}
1506+
}
1507+
}
1508+
14511509
#endif
14521510

14531511
#ifndef MAG_BLAS_SPECIALIZATION

0 commit comments

Comments
 (0)