@@ -1448,6 +1448,64 @@ static void MAG_HOTPROC mag_blas_matmul_f32(const mag_compute_payload_t* payload
1448
1448
}
1449
1449
}
1450
1450
1451
+ #else
1452
+
1453
+ static void MAG_HOTPROC mag_blas_matmul_f32 (const mag_compute_payload_t * payload, mag_kernel_context_t * ctx) {
1454
+ mag_tensor_t * r = payload->node ;
1455
+ const mag_tensor_t * x = r->op_inputs [0 ];
1456
+ const mag_tensor_t * y = r->op_inputs [1 ];
1457
+ mag_f32_t * br = mag_f32p_mut (r);
1458
+ const mag_f32_t * bx = mag_f32p (x);
1459
+ const mag_f32_t * by = mag_f32p (y);
1460
+ mag_load_local_storage_group (r, rd, shape);
1461
+ mag_load_local_storage_group (r, rs, strides);
1462
+ mag_load_local_storage_group (x, xd, shape);
1463
+ mag_load_local_storage_group (x, xs, strides);
1464
+ mag_load_local_storage_group (y, yd, shape);
1465
+ mag_load_local_storage_group (y, ys, strides);
1466
+ int64_t ti = payload->thread_idx ;
1467
+ if (ti != 0 ) return ;
1468
+ mag_assert2 (mag_tensor_is_contiguous (x) && mag_tensor_is_contiguous (y) && mag_tensor_is_contiguous (r));
1469
+ bool trans_a = mag_tensor_is_transposed (x);
1470
+ if (x->op == MAG_OP_CLONE && x->op_inputs [0 ]) trans_a |= mag_tensor_is_transposed (x->op_inputs [0 ]);
1471
+ memset (br, 0 , mag_tensor_data_size (r));
1472
+ int64_t b2 = yd2/xd2;
1473
+ int64_t b3 = yd3/xd3;
1474
+ int64_t b4 = yd4/xd4;
1475
+ int64_t b5 = yd5/xd5;
1476
+ for (int64_t i5=0 ; i5 < xd5; ++i5) {
1477
+ for (int64_t i4=0 ; i4 < xd4; ++i4) {
1478
+ for (int64_t i3=0 ; i3 < xd3; ++i3) {
1479
+ for (int64_t i2=0 ; i2 < xd2; ++i2) {
1480
+ int64_t xi5 = i5/b5;
1481
+ int64_t xi4 = i4/b4;
1482
+ int64_t xi3 = i3/b3;
1483
+ int64_t xi2 = i2/b2;
1484
+ const mag_f32_t * px = bx + xi5*xs5 + xi4*xs4 + xi3*xs3 + xi2*xs2;
1485
+ const mag_f32_t * py = by + i5*ys5 + i4*ys4 + i3*ys3 + i2*ys2;
1486
+ mag_f32_t * pr = br + i5*rs5 + i4*rs4 + i3*rs3 + i2*rs2;
1487
+ mag_bnd_chk (pr, br, mag_tensor_data_size (r));
1488
+ mag_bnd_chk (px, bx, mag_tensor_data_size (x));
1489
+ mag_bnd_chk (py, by, mag_tensor_data_size (y));
1490
+ for (int64_t i = 0 ; i < x->numel ; ++i) { /* Rows */
1491
+ for (int64_t k = 0 ; k < xd1; ++k) { /* Inner dim */
1492
+ const mag_f32_t * ppx = px + (trans_a ? k*xd0 + i : xd1*i + k);
1493
+ mag_bnd_chk (px, bx, mag_tensor_data_size (x));
1494
+ for (int64_t j = 0 ; j < yd1; ++j) { /* Columns */
1495
+ mag_f32_t * ppr = br + rd1*i + j;
1496
+ const mag_f32_t * ppy = py + yd1*k + j;
1497
+ mag_bnd_chk (pr, br, mag_tensor_data_size (r));
1498
+ mag_bnd_chk (py, by, mag_tensor_data_size (y));
1499
+ *ppr += (*ppx) * (*ppy);
1500
+ }
1501
+ }
1502
+ }
1503
+ }
1504
+ }
1505
+ }
1506
+ }
1507
+ }
1508
+
1451
1509
#endif
1452
1510
1453
1511
#ifndef MAG_BLAS_SPECIALIZATION
0 commit comments