Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Improve error message with GlobalPtrStmt indexing #5841

Merged
merged 5 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from taichi._lib import core as _ti_core
from taichi.lang import impl
from taichi.lang.enums import Layout
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.util import taichi_scope
Expand Down Expand Up @@ -67,8 +68,9 @@ def subscript(self, i, j):
indices = indices_second + self.indices_first
else:
indices = self.indices_first + indices_second
return Expr(_ti_core.subscript(self.arr.ptr,
make_expr_group(*indices)))
return Expr(
_ti_core.subscript(self.arr.ptr, make_expr_group(*indices),
impl.get_runtime().get_current_src_info()))


__all__ = []
18 changes: 13 additions & 5 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
@taichi_scope
def expr_init_local_tensor(shape, element_type, elements):
return get_runtime().prog.current_ast_builder().expr_alloca_local_tensor(
shape, element_type, elements)
shape, element_type, elements,
get_runtime().get_current_src_info())


@taichi_scope
Expand Down Expand Up @@ -72,7 +73,8 @@ def expr_init(rhs):
if hasattr(rhs, '_data_oriented'):
return rhs
return Expr(get_runtime().prog.current_ast_builder().expr_var(
Expr(rhs).ptr))
Expr(rhs).ptr,
get_runtime().get_current_src_info()))


@taichi_scope
Expand Down Expand Up @@ -182,7 +184,9 @@ def subscript(value, *_indices, skip_reordered=False, get_ref=False):
entries = {k: subscript(v, *_indices) for k, v in value._items}
entries['__struct_methods'] = value.struct_methods
return _IntermediateStruct(entries)
return Expr(_ti_core.subscript(_var, indices_expr_group))
return Expr(
_ti_core.subscript(_var, indices_expr_group,
get_runtime().get_current_src_info()))
if isinstance(value, AnyArray):
# TODO: deprecate using get_attribute to get dim
field_dim = int(value.ptr.get_attribute("dim"))
Expand All @@ -192,7 +196,9 @@ def subscript(value, *_indices, skip_reordered=False, get_ref=False):
f'Field with dim {field_dim - element_dim} accessed with indices of dim {index_dim}'
)
if element_dim == 0:
return Expr(_ti_core.subscript(value.ptr, indices_expr_group))
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
n = value.element_shape[0]
m = 1 if element_dim == 1 else value.element_shape[1]
any_array_access = AnyArrayAccess(value, _indices)
Expand All @@ -217,7 +223,9 @@ def make_stride_expr(_var, _indices, shape, stride):

@taichi_scope
def make_index_expr(_var, _indices):
return Expr(_ti_core.make_index_expr(_var, make_expr_group(*_indices)))
return Expr(
_ti_core.make_index_expr(_var, make_expr_group(*_indices),
get_runtime().get_current_src_info()))


class SrcInfoGuard:
Expand Down
5 changes: 4 additions & 1 deletion python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,10 @@ def __init__(self, field, indices):
super().__init__(
field.n,
field.m, [
expr.Expr(ti_python_core.subscript(e.ptr, indices))
expr.Expr(
ti_python_core.subscript(
e.ptr, indices,
impl.get_runtime().get_current_src_info()))
for e in field._get_field_members()
],
ndim=getattr(field, "ndim", 2))
Expand Down
6 changes: 4 additions & 2 deletions python/taichi/lang/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,10 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
var = attr._get_field_members()[0].ptr
setattr(
self, key,
impl.Expr(_ti_core.subscript(var,
global_entry_expr_group)))
impl.Expr(
_ti_core.subscript(
var, global_entry_expr_group,
impl.get_runtime().get_current_src_info())))

for element_type in self.mesh._type.elements:
setattr(self, element_type_name(element_type),
Expand Down
23 changes: 17 additions & 6 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,12 @@ void InternalFuncCallExpression::flatten(FlattenContext *ctx) {
ctx->push_back<InternalFuncStmt>(func_name, args_stmts, nullptr,
with_runtime_context);
stmt = ctx->back_stmt();
stmt->tb = tb;
}

void ExternalTensorExpression::flatten(FlattenContext *ctx) {
auto ptr = Stmt::make<ArgLoadStmt>(arg_id, dt, /*is_ptr=*/true);
ptr->tb = tb;
ctx->push_back(std::move(ptr));
stmt = ctx->back_stmt();
}
Expand All @@ -370,6 +372,7 @@ void GlobalVariableExpression::flatten(FlattenContext *ctx) {
TI_ASSERT(snode->num_active_indices == 0);
auto ptr = Stmt::make<GlobalPtrStmt>(LaneAttribute<SNode *>(snode),
std::vector<Stmt *>());
ptr->tb = tb;
ctx->push_back(std::move(ptr));
}

Expand Down Expand Up @@ -483,6 +486,7 @@ void IndexExpression::flatten(FlattenContext *ctx) {
stmt = make_tensor_access(
ctx, var, indices, var->ret_type->cast<TensorType>()->get_shape(), 1);
}
stmt->tb = tb;
}

void StrideExpression::type_check(CompileConfig *) {
Expand Down Expand Up @@ -603,6 +607,7 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) {
indices_stmt.push_back(indices[i]->stmt);
}
auto ptr = ctx->push_back<GlobalPtrStmt>(snode, indices_stmt);
ptr->tb = tb;
if (op_type == SNodeOpType::is_active) {
TI_ERROR_IF(snode->type != SNodeType::pointer &&
snode->type != SNodeType::hash &&
Expand Down Expand Up @@ -835,22 +840,27 @@ void ASTBuilder::stop_gradient(SNode *snode) {
stack_.back()->stop_gradients.push_back(snode);
}

void ASTBuilder::insert_assignment(Expr &lhs, const Expr &rhs) {
void ASTBuilder::insert_assignment(Expr &lhs,
const Expr &rhs,
const std::string &tb) {
// Inside a kernel or a function
// Create an assignment in the IR
if (lhs.expr == nullptr) {
lhs.set(rhs);
} else if (lhs.expr->is_lvalue()) {
this->insert(std::make_unique<FrontendAssignStmt>(lhs, rhs));
auto stmt = std::make_unique<FrontendAssignStmt>(lhs, rhs);
stmt->tb = tb;
this->insert(std::move(stmt));

} else {
TI_ERROR("Cannot assign to non-lvalue: {}",
ExpressionHumanFriendlyPrinter::expr_to_string(lhs));
}
}

Expr ASTBuilder::make_var(const Expr &x) {
Expr ASTBuilder::make_var(const Expr &x, std::string tb) {
auto var = this->expr_alloca();
this->insert_assignment(var, x);
this->insert_assignment(var, x, tb);
return var;
}

Expand Down Expand Up @@ -962,7 +972,8 @@ Expr ASTBuilder::expr_alloca() {

Expr ASTBuilder::expr_alloca_local_tensor(const std::vector<int> &shape,
const DataType &element_type,
const ExprGroup &elements) {
const ExprGroup &elements,
std::string tb) {
auto var = Expr(std::make_shared<IdExpression>(get_next_id()));
this->insert(std::make_unique<FrontendAllocaStmt>(
std::static_pointer_cast<IdExpression>(var.expr)->id, shape,
Expand All @@ -980,7 +991,7 @@ Expr ASTBuilder::expr_alloca_local_tensor(const std::vector<int> &shape,
for (int d = 0; d < (int)shape.size(); ++d)
indices.push_back(reversed_indices[(int)shape.size() - 1 - d]);
this->insert(std::make_unique<FrontendAssignStmt>(
Expr::make<IndexExpression>(var, indices), elements.exprs[i]));
Expr::make<IndexExpression>(var, indices, tb), elements.exprs[i]));
}
return var;
}
Expand Down
14 changes: 10 additions & 4 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,11 @@ class IndexExpression : public Expression {
Expr var;
ExprGroup indices;

IndexExpression(const Expr &var, const ExprGroup &indices)
IndexExpression(const Expr &var,
const ExprGroup &indices,
std::string tb = "")
: var(var), indices(indices) {
this->tb = tb;
}

void type_check(CompileConfig *config) override;
Expand Down Expand Up @@ -853,8 +856,10 @@ class ASTBuilder {
Block *current_block();
Stmt *get_last_stmt();
void stop_gradient(SNode *);
void insert_assignment(Expr &lhs, const Expr &rhs);
Expr make_var(const Expr &x);
void insert_assignment(Expr &lhs,
const Expr &rhs,
const std::string &tb = "");
Expr make_var(const Expr &x, std::string tb);
void insert_for(const Expr &s,
const Expr &e,
const std::function<void(Expr)> &func);
Expand All @@ -878,7 +883,8 @@ class ASTBuilder {
Expr expr_alloca();
Expr expr_alloca_local_tensor(const std::vector<int> &shape,
const DataType &element_type,
const ExprGroup &elements);
const ExprGroup &elements,
std::string tb);
Expr expr_alloca_shared_array(const std::vector<int> &shape,
const DataType &element_type);
void expr_assign(const Expr &lhs, const Expr &rhs, std::string tb);
Expand Down
3 changes: 2 additions & 1 deletion taichi/math/svd.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ sifakis_svd_export(ASTBuilder *ast_builder,
constexpr Tf Sine_Pi_Over_Eight = 0.3826834323650897f;
constexpr Tf Cosine_Pi_Over_Eight = 0.9238795325112867f;

std::string tb = "";
auto Var =
std::bind(&ASTBuilder::make_var, ast_builder, std::placeholders::_1);
std::bind(&ASTBuilder::make_var, ast_builder, std::placeholders::_1, tb);

auto Sfour_gamma_squared = Var(Expr(Tf(0.0)));
auto Ssine_pi_over_eight = Var(Expr(Tf(0.0)));
Expand Down
6 changes: 4 additions & 2 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,10 @@ Kernel &Program::get_snode_writer(SNode *snode) {
}
auto expr = Expr(snode_to_glb_var_exprs_.at(snode))[indices];
this->current_ast_builder()->insert_assignment(
expr, Expr::make<ArgLoadExpression>(snode->num_active_indices,
snode->dt->get_compute_type()));
expr,
Expr::make<ArgLoadExpression>(snode->num_active_indices,
snode->dt->get_compute_type()),
expr->tb);
});
ker.set_arch(get_accessor_arch());
ker.name = kernel_name;
Expand Down
13 changes: 8 additions & 5 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -958,12 +958,15 @@ void export_lang(py::module &m) {
});
m.def("data_type_name", data_type_name);

m.def("subscript", [](const Expr &expr, const ExprGroup &expr_group) {
return expr[expr_group];
});
m.def("subscript",
[](const Expr &expr, const ExprGroup &expr_group, std::string tb) {
Expr idx_expr = expr[expr_group];
idx_expr.set_tb(tb);
return idx_expr;
});

m.def("make_index_expr",
Expr::make<IndexExpression, const Expr &, const ExprGroup &>);
m.def("make_index_expr", Expr::make<IndexExpression, const Expr &,
const ExprGroup &, std::string>);

m.def("make_stride_expr",
Expr::make<StrideExpression, const Expr &, const ExprGroup &,
Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "taichi/ir/statements.h"
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/transforms/utils.h"

#include <typeinfo>
#include <algorithm>
Expand Down Expand Up @@ -1601,6 +1602,8 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor {
"(kernel={}) Breaks the global data access rule. Snode {} is "
"overwritten unexpectedly.",
kernel_name_, dest->snodes[0]->get_node_type_name());
msg += "\n" + stmt->tb;

stmt->insert_before_me(
Stmt::make<AssertStmt>(check_equal, msg, std::vector<Stmt *>()));
}
Expand Down
3 changes: 3 additions & 0 deletions taichi/transforms/check_out_of_bound.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "taichi/ir/transforms.h"
#include "taichi/ir/visitors.h"
#include "taichi/transforms/check_out_of_bound.h"
#include "taichi/transforms/utils.h"
#include <set>

TLANG_NAMESPACE_BEGIN
Expand Down Expand Up @@ -98,6 +99,7 @@ class CheckOutOfBound : public BasicStmtVisitor {
msg += "%d";
}
msg += ")";
msg += "\n" + stmt->tb;

new_stmts.push_back<AssertStmt>(result, msg, args);
modifier.insert_before(stmt, std::move(new_stmts));
Expand All @@ -117,6 +119,7 @@ class CheckOutOfBound : public BasicStmtVisitor {
BinaryOpType::cmp_ge, stmt->rhs, compare_rhs.get());
compare->ret_type = PrimitiveType::i32;
std::string msg = "Negative exponent for integer pows are not allowed";
msg += "\n" + stmt->tb;
auto assert_stmt = std::make_unique<AssertStmt>(compare.get(), msg,
std::vector<Stmt *>());
assert_stmt->accept(this);
Expand Down
7 changes: 4 additions & 3 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "taichi/transforms/simplify.h"
#include "taichi/program/kernel.h"
#include "taichi/program/program.h"
#include "taichi/transforms/utils.h"
#include <set>
#include <unordered_set>
#include <utility>
Expand Down Expand Up @@ -301,9 +302,9 @@ class BasicBlockSimplify : public IRVisitor {
auto zero = Stmt::make<ConstStmt>(LaneAttribute<TypedConstant>(0));
auto check_sum =
Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ge, sum.get(), zero.get());
auto assert = Stmt::make<AssertStmt>(check_sum.get(),
"The indices provided are too big!",
std::vector<Stmt *>());
auto assert = Stmt::make<AssertStmt>(
check_sum.get(), "The indices provided are too big!\n" + stmt->tb,
std::vector<Stmt *>());
// Because Taichi's assertion is checked only after the execution of the
// kernel, when the linear index overflows and goes negative, we have to
// replace that with 0 to make sure that the rest of the kernel can still
Expand Down
2 changes: 2 additions & 0 deletions taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "taichi/ir/transforms.h"
#include "taichi/ir/analysis.h"
#include "taichi/ir/frontend_ir.h"
#include "taichi/transforms/utils.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -252,6 +253,7 @@ class TypeCheck : public IRVisitor {
std::string msg =
"Detected overflow for bit_shift_op with rhs = %d, exceeding limit of "
"%d.";
msg += "\n" + stmt->tb;
std::vector<Stmt *> args = {rhs, const_stmt.get()};
auto assert_stmt =
Stmt::make<AssertStmt>(cond_stmt.get(), msg, std::move(args));
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/ir/frontend_type_inference_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ TEST(FrontendTypeInference, Id) {
Callable::CurrentCallableGuard _(kernel->program, kernel.get());
auto const_i32 = value<int32>(-(1 << 20));
const_i32->type_check(nullptr);
auto id_i32 = prog->current_ast_builder()->make_var(const_i32);
auto id_i32 = prog->current_ast_builder()->make_var(const_i32, const_i32->tb);
EXPECT_EQ(id_i32->ret_type, PrimitiveType::i32);
}

Expand Down