Skip to content

Commit 52e44ce

Browse files
authored
[bug] Improve error message with GlobalPtrStmt indexing (#5841)
* [bug] Improve error message with GlobalPtrStmt indexing * Fix minor bug * Bug fix * Fixed interface issues * Addressed review comments
1 parent 2b6e753 commit 52e44ce

14 files changed

+79
-32
lines changed

python/taichi/lang/any_array.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from taichi._lib import core as _ti_core
2+
from taichi.lang import impl
23
from taichi.lang.enums import Layout
34
from taichi.lang.expr import Expr, make_expr_group
45
from taichi.lang.util import taichi_scope
@@ -74,8 +75,9 @@ def subscript(self, i, j):
7475
indices = indices_second + self.indices_first
7576
else:
7677
indices = self.indices_first + indices_second
77-
return Expr(_ti_core.subscript(self.arr.ptr,
78-
make_expr_group(*indices)))
78+
return Expr(
79+
_ti_core.subscript(self.arr.ptr, make_expr_group(*indices),
80+
impl.get_runtime().get_current_src_info()))
7981

8082

8183
__all__ = []

python/taichi/lang/impl.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
@taichi_scope
3333
def expr_init_local_tensor(shape, element_type, elements):
3434
return get_runtime().prog.current_ast_builder().expr_alloca_local_tensor(
35-
shape, element_type, elements)
35+
shape, element_type, elements,
36+
get_runtime().get_current_src_info())
3637

3738

3839
@taichi_scope
@@ -72,7 +73,8 @@ def expr_init(rhs):
7273
if hasattr(rhs, '_data_oriented'):
7374
return rhs
7475
return Expr(get_runtime().prog.current_ast_builder().expr_var(
75-
Expr(rhs).ptr))
76+
Expr(rhs).ptr,
77+
get_runtime().get_current_src_info()))
7678

7779

7880
@taichi_scope
@@ -182,7 +184,9 @@ def subscript(value, *_indices, skip_reordered=False, get_ref=False):
182184
entries = {k: subscript(v, *_indices) for k, v in value._items}
183185
entries['__struct_methods'] = value.struct_methods
184186
return _IntermediateStruct(entries)
185-
return Expr(_ti_core.subscript(_var, indices_expr_group))
187+
return Expr(
188+
_ti_core.subscript(_var, indices_expr_group,
189+
get_runtime().get_current_src_info()))
186190
if isinstance(value, AnyArray):
187191
# TODO: deprecate using get_attribute to get dim
188192
field_dim = int(value.ptr.get_attribute("dim"))
@@ -192,7 +196,9 @@ def subscript(value, *_indices, skip_reordered=False, get_ref=False):
192196
f'Field with dim {field_dim - element_dim} accessed with indices of dim {index_dim}'
193197
)
194198
if element_dim == 0:
195-
return Expr(_ti_core.subscript(value.ptr, indices_expr_group))
199+
return Expr(
200+
_ti_core.subscript(value.ptr, indices_expr_group,
201+
get_runtime().get_current_src_info()))
196202
n = value.element_shape[0]
197203
m = 1 if element_dim == 1 else value.element_shape[1]
198204
any_array_access = AnyArrayAccess(value, _indices)
@@ -217,7 +223,9 @@ def make_stride_expr(_var, _indices, shape, stride):
217223

218224
@taichi_scope
219225
def make_index_expr(_var, _indices):
220-
return Expr(_ti_core.make_index_expr(_var, make_expr_group(*_indices)))
226+
return Expr(
227+
_ti_core.make_index_expr(_var, make_expr_group(*_indices),
228+
get_runtime().get_current_src_info()))
221229

222230

223231
class SrcInfoGuard:

python/taichi/lang/matrix.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,10 @@ def __init__(self, field, indices):
15001500
super().__init__(
15011501
field.n,
15021502
field.m, [
1503-
expr.Expr(ti_python_core.subscript(e.ptr, indices))
1503+
expr.Expr(
1504+
ti_python_core.subscript(
1505+
e.ptr, indices,
1506+
impl.get_runtime().get_current_src_info()))
15041507
for e in field._get_field_members()
15051508
],
15061509
ndim=getattr(field, "ndim", 2))

python/taichi/lang/mesh.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,10 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
660660
var = attr._get_field_members()[0].ptr
661661
setattr(
662662
self, key,
663-
impl.Expr(_ti_core.subscript(var,
664-
global_entry_expr_group)))
663+
impl.Expr(
664+
_ti_core.subscript(
665+
var, global_entry_expr_group,
666+
impl.get_runtime().get_current_src_info())))
665667

666668
for element_type in self.mesh._type.elements:
667669
setattr(self, element_type_name(element_type),

taichi/ir/frontend_ir.cpp

+17-6
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,12 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) {
358358
ctx->push_back<InternalFuncStmt>(func_name, args_stmts, nullptr,
359359
with_runtime_context);
360360
stmt = ctx->back_stmt();
361+
stmt->tb = tb;
361362
}
362363

363364
void ExternalTensorExpression::flatten(FlattenContext *ctx) {
364365
auto ptr = Stmt::make<ArgLoadStmt>(arg_id, dt, /*is_ptr=*/true);
366+
ptr->tb = tb;
365367
ctx->push_back(std::move(ptr));
366368
stmt = ctx->back_stmt();
367369
}
@@ -370,6 +372,7 @@ void GlobalVariableExpression::flatten(FlattenContext *ctx) {
370372
TI_ASSERT(snode->num_active_indices == 0);
371373
auto ptr = Stmt::make<GlobalPtrStmt>(LaneAttribute<SNode *>(snode),
372374
std::vector<Stmt *>());
375+
ptr->tb = tb;
373376
ctx->push_back(std::move(ptr));
374377
}
375378

@@ -483,6 +486,7 @@ void IndexExpression::flatten(FlattenContext *ctx) {
483486
stmt = make_tensor_access(
484487
ctx, var, indices, var->ret_type->cast<TensorType>()->get_shape(), 1);
485488
}
489+
stmt->tb = tb;
486490
}
487491

488492
void StrideExpression::type_check(CompileConfig *) {
@@ -603,6 +607,7 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) {
603607
indices_stmt.push_back(indices[i]->stmt);
604608
}
605609
auto ptr = ctx->push_back<GlobalPtrStmt>(snode, indices_stmt);
610+
ptr->tb = tb;
606611
if (op_type == SNodeOpType::is_active) {
607612
TI_ERROR_IF(snode->type != SNodeType::pointer &&
608613
snode->type != SNodeType::hash &&
@@ -835,22 +840,27 @@ void ASTBuilder::stop_gradient(SNode *snode) {
835840
stack_.back()->stop_gradients.push_back(snode);
836841
}
837842

838-
void ASTBuilder::insert_assignment(Expr &lhs, const Expr &rhs) {
843+
void ASTBuilder::insert_assignment(Expr &lhs,
844+
const Expr &rhs,
845+
const std::string &tb) {
839846
// Inside a kernel or a function
840847
// Create an assignment in the IR
841848
if (lhs.expr == nullptr) {
842849
lhs.set(rhs);
843850
} else if (lhs.expr->is_lvalue()) {
844-
this->insert(std::make_unique<FrontendAssignStmt>(lhs, rhs));
851+
auto stmt = std::make_unique<FrontendAssignStmt>(lhs, rhs);
852+
stmt->tb = tb;
853+
this->insert(std::move(stmt));
854+
845855
} else {
846856
TI_ERROR("Cannot assign to non-lvalue: {}",
847857
ExpressionHumanFriendlyPrinter::expr_to_string(lhs));
848858
}
849859
}
850860

851-
Expr ASTBuilder::make_var(const Expr &x) {
861+
Expr ASTBuilder::make_var(const Expr &x, std::string tb) {
852862
auto var = this->expr_alloca();
853-
this->insert_assignment(var, x);
863+
this->insert_assignment(var, x, tb);
854864
return var;
855865
}
856866

@@ -962,7 +972,8 @@ Expr ASTBuilder::expr_alloca() {
962972

963973
Expr ASTBuilder::expr_alloca_local_tensor(const std::vector<int> &shape,
964974
const DataType &element_type,
965-
const ExprGroup &elements) {
975+
const ExprGroup &elements,
976+
std::string tb) {
966977
auto var = Expr(std::make_shared<IdExpression>(get_next_id()));
967978
this->insert(std::make_unique<FrontendAllocaStmt>(
968979
std::static_pointer_cast<IdExpression>(var.expr)->id, shape,
@@ -980,7 +991,7 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector<int> &shape,
980991
for (int d = 0; d < (int)shape.size(); ++d)
981992
indices.push_back(reversed_indices[(int)shape.size() - 1 - d]);
982993
this->insert(std::make_unique<FrontendAssignStmt>(
983-
Expr::make<IndexExpression>(var, indices), elements.exprs[i]));
994+
Expr::make<IndexExpression>(var, indices, tb), elements.exprs[i]));
984995
}
985996
return var;
986997
}

taichi/ir/frontend_ir.h

+10-4
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,11 @@ class IndexExpression : public Expression {
511511
Expr var;
512512
ExprGroup indices;
513513

514-
IndexExpression(const Expr &var, const ExprGroup &indices)
514+
IndexExpression(const Expr &var,
515+
const ExprGroup &indices,
516+
std::string tb = "")
515517
: var(var), indices(indices) {
518+
this->tb = tb;
516519
}
517520

518521
void type_check(CompileConfig *config) override;
@@ -853,8 +856,10 @@ class ASTBuilder {
853856
Block *current_block();
854857
Stmt *get_last_stmt();
855858
void stop_gradient(SNode *);
856-
void insert_assignment(Expr &lhs, const Expr &rhs);
857-
Expr make_var(const Expr &x);
859+
void insert_assignment(Expr &lhs,
860+
const Expr &rhs,
861+
const std::string &tb = "");
862+
Expr make_var(const Expr &x, std::string tb);
858863
void insert_for(const Expr &s,
859864
const Expr &e,
860865
const std::function<void(Expr)> &func);
@@ -878,7 +883,8 @@ class ASTBuilder {
878883
Expr expr_alloca();
879884
Expr expr_alloca_local_tensor(const std::vector<int> &shape,
880885
const DataType &element_type,
881-
const ExprGroup &elements);
886+
const ExprGroup &elements,
887+
std::string tb);
882888
Expr expr_alloca_shared_array(const std::vector<int> &shape,
883889
const DataType &element_type);
884890
void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb);

taichi/math/svd.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,9 @@ sifakis_svd_export(ASTBuilder *ast_builder,
5858
constexpr Tf Sine_Pi_Over_Eight = 0.3826834323650897f;
5959
constexpr Tf Cosine_Pi_Over_Eight = 0.9238795325112867f;
6060

61+
std::string tb = "";
6162
auto Var =
62-
std::bind(&ASTBuilder::make_var, ast_builder, std::placeholders::_1);
63+
std::bind(&ASTBuilder::make_var, ast_builder, std::placeholders::_1, tb);
6364

6465
auto Sfour_gamma_squared = Var(Expr(Tf(0.0)));
6566
auto Ssine_pi_over_eight = Var(Expr(Tf(0.0)));

taichi/program/program.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,10 @@ Kernel &Program::get_snode_writer(SNode *snode) {
366366
}
367367
auto expr = Expr(snode_to_glb_var_exprs_.at(snode))[indices];
368368
this->current_ast_builder()->insert_assignment(
369-
expr, Expr::make<ArgLoadExpression>(snode->num_active_indices,
370-
snode->dt->get_compute_type()));
369+
expr,
370+
Expr::make<ArgLoadExpression>(snode->num_active_indices,
371+
snode->dt->get_compute_type()),
372+
expr->tb);
371373
});
372374
ker.set_arch(get_accessor_arch());
373375
ker.name = kernel_name;

taichi/python/export_lang.cpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -964,12 +964,15 @@ void export_lang(py::module &m) {
964964

965965
m.def("data_type_name", data_type_name);
966966

967-
m.def("subscript", [](const Expr &expr, const ExprGroup &expr_group) {
968-
return expr[expr_group];
969-
});
967+
m.def("subscript",
968+
[](const Expr &expr, const ExprGroup &expr_group, std::string tb) {
969+
Expr idx_expr = expr[expr_group];
970+
idx_expr.set_tb(tb);
971+
return idx_expr;
972+
});
970973

971-
m.def("make_index_expr",
972-
Expr::make<IndexExpression, const Expr &, const ExprGroup &>);
974+
m.def("make_index_expr", Expr::make<IndexExpression, const Expr &,
975+
const ExprGroup &, std::string>);
973976

974977
m.def("make_stride_expr",
975978
Expr::make<StrideExpression, const Expr &, const ExprGroup &,

taichi/transforms/auto_diff.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "taichi/ir/statements.h"
44
#include "taichi/ir/transforms.h"
55
#include "taichi/ir/visitors.h"
6+
#include "taichi/transforms/utils.h"
67

78
#include <typeinfo>
89
#include <algorithm>
@@ -1601,6 +1602,8 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor {
16011602
"(kernel={}) Breaks the global data access rule. Snode {} is "
16021603
"overwritten unexpectedly.",
16031604
kernel_name_, dest->snodes[0]->get_node_type_name());
1605+
msg += "\n" + stmt->tb;
1606+
16041607
stmt->insert_before_me(
16051608
Stmt::make<AssertStmt>(check_equal, msg, std::vector<Stmt *>()));
16061609
}

taichi/transforms/check_out_of_bound.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "taichi/ir/transforms.h"
44
#include "taichi/ir/visitors.h"
55
#include "taichi/transforms/check_out_of_bound.h"
6+
#include "taichi/transforms/utils.h"
67
#include <set>
78

89
TLANG_NAMESPACE_BEGIN
@@ -98,6 +99,7 @@ class CheckOutOfBound : public BasicStmtVisitor {
9899
msg += "%d";
99100
}
100101
msg += ")";
102+
msg += "\n" + stmt->tb;
101103

102104
new_stmts.push_back<AssertStmt>(result, msg, args);
103105
modifier.insert_before(stmt, std::move(new_stmts));
@@ -117,6 +119,7 @@ class CheckOutOfBound : public BasicStmtVisitor {
117119
BinaryOpType::cmp_ge, stmt->rhs, compare_rhs.get());
118120
compare->ret_type = PrimitiveType::i32;
119121
std::string msg = "Negative exponent for integer pows are not allowed";
122+
msg += "\n" + stmt->tb;
120123
auto assert_stmt = std::make_unique<AssertStmt>(compare.get(), msg,
121124
std::vector<Stmt *>());
122125
assert_stmt->accept(this);

taichi/transforms/simplify.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "taichi/transforms/simplify.h"
77
#include "taichi/program/kernel.h"
88
#include "taichi/program/program.h"
9+
#include "taichi/transforms/utils.h"
910
#include <set>
1011
#include <unordered_set>
1112
#include <utility>
@@ -301,9 +302,9 @@ class BasicBlockSimplify : public IRVisitor {
301302
auto zero = Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(0));
302303
auto check_sum =
303304
Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ge, sum.get(), zero.get());
304-
auto assert = Stmt::make<AssertStmt>(check_sum.get(),
305-
"The indices provided are too big!",
306-
std::vector<Stmt *>());
305+
auto assert = Stmt::make<AssertStmt>(
306+
check_sum.get(), "The indices provided are too big!\n" + stmt->tb,
307+
std::vector<Stmt *>());
307308
// Because Taichi's assertion is checked only after the execution of the
308309
// kernel, when the linear index overflows and goes negative, we have to
309310
// replace that with 0 to make sure that the rest of the kernel can still

taichi/transforms/type_check.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "taichi/ir/transforms.h"
66
#include "taichi/ir/analysis.h"
77
#include "taichi/ir/frontend_ir.h"
8+
#include "taichi/transforms/utils.h"
89

910
TLANG_NAMESPACE_BEGIN
1011

@@ -252,6 +253,7 @@ class TypeCheck : public IRVisitor {
252253
std::string msg =
253254
"Detected overflow for bit_shift_op with rhs = %d, exceeding limit of "
254255
"%d.";
256+
msg += "\n" + stmt->tb;
255257
std::vector<Stmt *> args = {rhs, const_stmt.get()};
256258
auto assert_stmt =
257259
Stmt::make<AssertStmt>(cond_stmt.get(), msg, std::move(args));

tests/cpp/ir/frontend_type_inference_test.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ TEST(FrontendTypeInference, Id) {
3232
Callable::CurrentCallableGuard _(kernel->program, kernel.get());
3333
auto const_i32 = value<int32>(-(1 << 20));
3434
const_i32->type_check(nullptr);
35-
auto id_i32 = prog->current_ast_builder()->make_var(const_i32);
35+
auto id_i32 = prog->current_ast_builder()->make_var(const_i32, const_i32->tb);
3636
EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32);
3737
}
3838

0 commit comments

Comments
 (0)