Skip to content

Commit 8a8df44

Browse files
committed
fixes
1 parent 96cdb5b commit 8a8df44

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

magnetron/magnetron.c

+7-6
Original file line numberDiff line numberDiff line change
@@ -2873,15 +2873,16 @@ static void mag_collect_topo_iterative(mag_tensor_t* root, mag_tensor_array_t* o
28732873
(*mag_alloc)(visited, 0);
28742874
}
28752875

2876-
static void mag_tensor_patch_grad(mag_tensor_t* dst, mag_tensor_t* new_grad, const char* stage) {
2876+
static void mag_tensor_patch_grad(const mag_tensor_t* root, mag_tensor_t* dst, mag_tensor_t* new_grad, const char* stage) {
28772877
if (!mag_tensor_is_shape_eq(dst, new_grad)) {
28782878
char shape_dst[MAG_FMT_DIM_BUF_SIZE];
28792879
char shape_grad[MAG_FMT_DIM_BUF_SIZE];
28802880
mag_fmt_dims(&shape_dst, &dst->shape, dst->rank);
28812881
mag_fmt_dims(&shape_grad, &new_grad->shape, new_grad->rank);
28822882
const char* dst_op = mag_op_meta_of(dst->op)->mnemonic;
28832883
const char* grad_op = mag_op_meta_of(new_grad->op)->mnemonic;
2884-
mag_panic("Shape mismatch: %s (%s) != %s (%s) !%s", shape_dst, dst_op, shape_grad, grad_op, stage);
2884+
const char* root_op = mag_op_meta_of(root->op)->mnemonic;
2885+
mag_panic("Shape mismatch: %s (%s) != %s (%s) Stage: %s, Root Op: %s\n", shape_dst, dst_op, shape_grad, grad_op, stage, root_op);
28852886
}
28862887
mag_tensor_fmt_name(new_grad, "%s (grad)", dst->name);
28872888
new_grad->flags = (new_grad->flags | MAG_TFLAG_IS_GRAD) & ~MAG_TFLAG_REQUIRES_GRAD;
@@ -2906,7 +2907,7 @@ void mag_tensor_backward(mag_tensor_t* root) {
29062907
if (!child->grad) {
29072908
mag_tensor_t* grad = mag_tensor_create(child->ctx, child->dtype, child->shape, child->rank, NULL, 0);
29082909
mag_tensor_fill(grad, 1.0f);
2909-
mag_tensor_patch_grad(child, grad, "init");
2910+
mag_tensor_patch_grad(child, child, grad, "init");
29102911
}
29112912
if (mag_unlikely(child->op == MAG_OP_NOP)) continue;
29122913
mag_tensor_t* grads[MAG_MAX_OP_PARAMS] = {0};
@@ -2917,15 +2918,15 @@ void mag_tensor_backward(mag_tensor_t* root) {
29172918
for (uint32_t i = 0; i < numin; ++i) {
29182919
mag_tensor_t* input = child->op_inputs[i];
29192920
mag_assert2(input);
2920-
if (!(input->flags & MAG_TFLAG_REQUIRES_GRAD) || input->op == MAG_OP_NOP) continue;
2921+
if (!(input->flags & MAG_TFLAG_REQUIRES_GRAD)) continue;
29212922
mag_tensor_t* gri = grads[i];
29222923
mag_assert(gri, "Gradient for op %s, input #%d is not computed", meta->mnemonic, i);
29232924
if (!input->grad) {
2924-
mag_tensor_patch_grad(input, gri, "patch");
2925+
mag_tensor_patch_grad(child, input, gri, "patch");
29252926
} else {
29262927
mag_tensor_t* acc = mag_add_(gri, input->grad);
29272928
mag_tensor_decref(input->grad);
2928-
mag_tensor_patch_grad(input, acc, "acc");
2929+
mag_tensor_patch_grad(child, input, acc, "acc");
29292930
mag_tensor_decref(gri);
29302931
}
29312932
}

0 commit comments

Comments
 (0)