@@ -2873,15 +2873,16 @@ static void mag_collect_topo_iterative(mag_tensor_t* root, mag_tensor_array_t* o
2873
2873
(* mag_alloc )(visited , 0 );
2874
2874
}
2875
2875
2876
- static void mag_tensor_patch_grad (mag_tensor_t * dst , mag_tensor_t * new_grad , const char * stage ) {
2876
+ static void mag_tensor_patch_grad (const mag_tensor_t * root , mag_tensor_t * dst , mag_tensor_t * new_grad , const char * stage ) {
2877
2877
if (!mag_tensor_is_shape_eq (dst , new_grad )) {
2878
2878
char shape_dst [MAG_FMT_DIM_BUF_SIZE ];
2879
2879
char shape_grad [MAG_FMT_DIM_BUF_SIZE ];
2880
2880
mag_fmt_dims (& shape_dst , & dst -> shape , dst -> rank );
2881
2881
mag_fmt_dims (& shape_grad , & new_grad -> shape , new_grad -> rank );
2882
2882
const char * dst_op = mag_op_meta_of (dst -> op )-> mnemonic ;
2883
2883
const char * grad_op = mag_op_meta_of (new_grad -> op )-> mnemonic ;
2884
- mag_panic ("Shape mismatch: %s (%s) != %s (%s) !%s" , shape_dst , dst_op , shape_grad , grad_op , stage );
2884
+ const char * root_op = mag_op_meta_of (root -> op )-> mnemonic ;
2885
+ mag_panic ("Shape mismatch: %s (%s) != %s (%s) Stage: %s, Root Op: %s\n" , shape_dst , dst_op , shape_grad , grad_op , stage , root_op );
2885
2886
}
2886
2887
mag_tensor_fmt_name (new_grad , "%s (grad)" , dst -> name );
2887
2888
new_grad -> flags = (new_grad -> flags | MAG_TFLAG_IS_GRAD ) & ~MAG_TFLAG_REQUIRES_GRAD ;
@@ -2906,7 +2907,7 @@ void mag_tensor_backward(mag_tensor_t* root) {
2906
2907
if (!child -> grad ) {
2907
2908
mag_tensor_t * grad = mag_tensor_create (child -> ctx , child -> dtype , child -> shape , child -> rank , NULL , 0 );
2908
2909
mag_tensor_fill (grad , 1.0f );
2909
- mag_tensor_patch_grad (child , grad , "init" );
2910
+ mag_tensor_patch_grad (child , child , grad , "init" );
2910
2911
}
2911
2912
if (mag_unlikely (child -> op == MAG_OP_NOP )) continue ;
2912
2913
mag_tensor_t * grads [MAG_MAX_OP_PARAMS ] = {0 };
@@ -2917,15 +2918,15 @@ void mag_tensor_backward(mag_tensor_t* root) {
2917
2918
for (uint32_t i = 0 ; i < numin ; ++ i ) {
2918
2919
mag_tensor_t * input = child -> op_inputs [i ];
2919
2920
mag_assert2 (input );
2920
- if (!(input -> flags & MAG_TFLAG_REQUIRES_GRAD ) || input -> op == MAG_OP_NOP ) continue ;
2921
+ if (!(input -> flags & MAG_TFLAG_REQUIRES_GRAD )) continue ;
2921
2922
mag_tensor_t * gri = grads [i ];
2922
2923
mag_assert (gri , "Gradient for op %s, input #%d is not computed" , meta -> mnemonic , i );
2923
2924
if (!input -> grad ) {
2924
- mag_tensor_patch_grad (input , gri , "patch" );
2925
+ mag_tensor_patch_grad (child , input , gri , "patch" );
2925
2926
} else {
2926
2927
mag_tensor_t * acc = mag_add_ (gri , input -> grad );
2927
2928
mag_tensor_decref (input -> grad );
2928
- mag_tensor_patch_grad (input , acc , "acc" );
2929
+ mag_tensor_patch_grad (child , input , acc , "acc" );
2929
2930
mag_tensor_decref (gri );
2930
2931
}
2931
2932
}
0 commit comments