@@ -358,10 +358,12 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) {
358
358
ctx->push_back <InternalFuncStmt>(func_name, args_stmts, nullptr ,
359
359
with_runtime_context);
360
360
stmt = ctx->back_stmt ();
361
+ stmt->tb = tb;
361
362
}
362
363
363
364
void ExternalTensorExpression::flatten (FlattenContext *ctx) {
364
365
auto ptr = Stmt::make<ArgLoadStmt>(arg_id, dt, /* is_ptr=*/ true );
366
+ ptr->tb = tb;
365
367
ctx->push_back (std::move (ptr));
366
368
stmt = ctx->back_stmt ();
367
369
}
@@ -370,6 +372,7 @@ void GlobalVariableExpression::flatten(FlattenContext *ctx) {
370
372
TI_ASSERT (snode->num_active_indices == 0 );
371
373
auto ptr = Stmt::make<GlobalPtrStmt>(LaneAttribute<SNode *>(snode),
372
374
std::vector<Stmt *>());
375
+ ptr->tb = tb;
373
376
ctx->push_back (std::move (ptr));
374
377
}
375
378
@@ -483,6 +486,7 @@ void IndexExpression::flatten(FlattenContext *ctx) {
483
486
stmt = make_tensor_access (
484
487
ctx, var, indices, var->ret_type ->cast <TensorType>()->get_shape (), 1 );
485
488
}
489
+ stmt->tb = tb;
486
490
}
487
491
488
492
void StrideExpression::type_check (CompileConfig *) {
@@ -603,6 +607,7 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) {
603
607
indices_stmt.push_back (indices[i]->stmt );
604
608
}
605
609
auto ptr = ctx->push_back <GlobalPtrStmt>(snode, indices_stmt);
610
+ ptr->tb = tb;
606
611
if (op_type == SNodeOpType::is_active) {
607
612
TI_ERROR_IF (snode->type != SNodeType::pointer &&
608
613
snode->type != SNodeType::hash &&
@@ -835,22 +840,27 @@ void ASTBuilder::stop_gradient(SNode *snode) {
835
840
stack_.back ()->stop_gradients .push_back (snode);
836
841
}
837
842
838
- void ASTBuilder::insert_assignment (Expr &lhs, const Expr &rhs) {
843
+ void ASTBuilder::insert_assignment (Expr &lhs,
844
+ const Expr &rhs,
845
+ const std::string &tb) {
839
846
// Inside a kernel or a function
840
847
// Create an assignment in the IR
841
848
if (lhs.expr == nullptr ) {
842
849
lhs.set (rhs);
843
850
} else if (lhs.expr ->is_lvalue ()) {
844
- this ->insert (std::make_unique<FrontendAssignStmt>(lhs, rhs));
851
+ auto stmt = std::make_unique<FrontendAssignStmt>(lhs, rhs);
852
+ stmt->tb = tb;
853
+ this ->insert (std::move (stmt));
854
+
845
855
} else {
846
856
TI_ERROR (" Cannot assign to non-lvalue: {}" ,
847
857
ExpressionHumanFriendlyPrinter::expr_to_string (lhs));
848
858
}
849
859
}
850
860
851
- Expr ASTBuilder::make_var (const Expr &x) {
861
+ Expr ASTBuilder::make_var (const Expr &x, std::string tb ) {
852
862
auto var = this ->expr_alloca ();
853
- this ->insert_assignment (var, x);
863
+ this ->insert_assignment (var, x, tb );
854
864
return var;
855
865
}
856
866
@@ -962,7 +972,8 @@ Expr ASTBuilder::expr_alloca() {
962
972
963
973
Expr ASTBuilder::expr_alloca_local_tensor (const std::vector<int > &shape,
964
974
const DataType &element_type,
965
- const ExprGroup &elements) {
975
+ const ExprGroup &elements,
976
+ std::string tb) {
966
977
auto var = Expr (std::make_shared<IdExpression>(get_next_id ()));
967
978
this ->insert (std::make_unique<FrontendAllocaStmt>(
968
979
std::static_pointer_cast<IdExpression>(var.expr )->id , shape,
@@ -980,7 +991,7 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector<int> &shape,
980
991
for (int d = 0 ; d < (int )shape.size (); ++d)
981
992
indices.push_back (reversed_indices[(int )shape.size () - 1 - d]);
982
993
this ->insert (std::make_unique<FrontendAssignStmt>(
983
- Expr::make<IndexExpression>(var, indices), elements.exprs [i]));
994
+ Expr::make<IndexExpression>(var, indices, tb ), elements.exprs [i]));
984
995
}
985
996
return var;
986
997
}
0 commit comments