Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] [ir] Remove legacy LocalAddress / VectorElement / create_vector_or_scalar_type() #5918

Merged
merged 4 commits into from
Aug 30, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion taichi/analysis/data_source_analysis.cpp
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ namespace irpass::analysis {
std::vector<Stmt *> get_load_pointers(Stmt *load_stmt) {
// If load_stmt loads some variables or a stack, return the pointers of them.
if (auto local_load = load_stmt->cast<LocalLoadStmt>()) {
return std::vector<Stmt *>(1, local_load->src.var);
return std::vector<Stmt *>(1, local_load->src);
} else if (auto global_load = load_stmt->cast<GlobalLoadStmt>()) {
return std::vector<Stmt *>(1, global_load->src);
} else if (auto atomic = load_stmt->cast<AtomicOpStmt>()) {
3 changes: 1 addition & 2 deletions taichi/analysis/verify.cpp
Original file line number Diff line number Diff line change
@@ -104,8 +104,7 @@ class IRVerifier : public BasicStmtVisitor {

void visit(LocalLoadStmt *stmt) override {
basic_verify(stmt);
TI_ASSERT(stmt->src.var->is<AllocaStmt>() ||
stmt->src.var->is<PtrOffsetStmt>());
TI_ASSERT(stmt->src->is<AllocaStmt>() || stmt->src->is<PtrOffsetStmt>());
}

void visit(LocalStoreStmt *stmt) override {
2 changes: 1 addition & 1 deletion taichi/codegen/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
@@ -217,7 +217,7 @@ class CCTransformer : public IRVisitor {
void visit(LocalLoadStmt *stmt) override {
auto var =
define_var(cc_data_type_name(stmt->element_type()), stmt->raw_name());
emit("{} = {};", var, stmt->src.var->raw_name());
emit("{} = {};", var, stmt->src->raw_name());
}

void visit(LocalStoreStmt *stmt) override {
20 changes: 9 additions & 11 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
@@ -1191,16 +1191,16 @@ void TaskCodeGenLLVM::visit(LocalLoadStmt *stmt) {
#ifdef TI_LLVM_15
// FIXME: get ptr_ty from taichi instead of llvm.
llvm::Type *ptr_ty = nullptr;
auto *val = llvm_val[stmt->src.var];
auto *val = llvm_val[stmt->src];
if (auto *alloc = llvm::dyn_cast<llvm::AllocaInst>(val))
ptr_ty = alloc->getAllocatedType();
if (!ptr_ty && stmt->src.var->element_type().is_pointer()) {
ptr_ty = llvm_type(stmt->src.var->element_type().ptr_removed());
if (!ptr_ty && stmt->src->element_type().is_pointer()) {
ptr_ty = llvm_type(stmt->src->element_type().ptr_removed());
}
TI_ASSERT(ptr_ty);
llvm_val[stmt] = builder->CreateLoad(ptr_ty, llvm_val[stmt->src.var]);
llvm_val[stmt] = builder->CreateLoad(ptr_ty, llvm_val[stmt->src]);
#else
llvm_val[stmt] = builder->CreateLoad(llvm_val[stmt->src.var]);
llvm_val[stmt] = builder->CreateLoad(llvm_val[stmt->src]);
#endif
}

@@ -1887,9 +1887,8 @@ std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::get_range_for_bounds(
if (stmt->const_begin) {
begin = tlctx->get_constant(stmt->begin_value);
} else {
auto begin_stmt = Stmt::make<GlobalTemporaryStmt>(
stmt->begin_offset,
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
auto begin_stmt =
Stmt::make<GlobalTemporaryStmt>(stmt->begin_offset, PrimitiveType::i32);
begin_stmt->accept(this);
begin = builder->CreateLoad(
#ifdef TI_LLVM_15
@@ -1900,9 +1899,8 @@ std::tuple<llvm::Value *, llvm::Value *> TaskCodeGenLLVM::get_range_for_bounds(
if (stmt->const_end) {
end = tlctx->get_constant(stmt->end_value);
} else {
auto end_stmt = Stmt::make<GlobalTemporaryStmt>(
stmt->end_offset,
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32));
auto end_stmt =
Stmt::make<GlobalTemporaryStmt>(stmt->end_offset, PrimitiveType::i32);
end_stmt->accept(this);
end = builder->CreateLoad(
#ifdef TI_LLVM_15
7 changes: 3 additions & 4 deletions taichi/codegen/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
@@ -260,7 +260,7 @@ class KernelCodegenImpl : public IRVisitor {
}

void visit(LocalLoadStmt *stmt) override {
auto ptr = stmt->src.var;
auto ptr = stmt->src;
emit("const {} {}({});", metal_data_type_name(stmt->element_type()),
stmt->raw_name(), ptr->raw_name());
}
@@ -1475,11 +1475,10 @@ class KernelCodegenImpl : public IRVisitor {

std::string inject_load_global_tmp(int offset,
DataType dt = PrimitiveType::i32) {
const auto vt = TypeFactory::create_vector_or_scalar_type(1, dt);
auto gtmp = Stmt::make<GlobalTemporaryStmt>(offset, vt);
auto gtmp = Stmt::make<GlobalTemporaryStmt>(offset, dt);
gtmp->accept(this);
auto gload = Stmt::make<GlobalLoadStmt>(gtmp.get());
gload->ret_type = vt;
gload->ret_type = dt;
gload->accept(this);
return gload->raw_name();
}
2 changes: 1 addition & 1 deletion taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
@@ -234,7 +234,7 @@ class TaskCodegen : public IRVisitor {
}

void visit(LocalLoadStmt *stmt) override {
auto ptr = stmt->src.var;
auto ptr = stmt->src;
spirv::Value ptr_val = ir_->query_value(ptr->raw_name());
spirv::Value val = ir_->load_variable(
ptr_val, ir_->get_primitive_type(stmt->element_type()));
5 changes: 2 additions & 3 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
@@ -262,7 +262,7 @@ bool CFGNode::store_to_load_forwarding(bool after_lower_access,
auto stmt = block->statements[i].get();
Stmt *result = nullptr;
if (auto local_load = stmt->cast<LocalLoadStmt>()) {
result = get_store_forwarding_data(local_load->src.var, i);
result = get_store_forwarding_data(local_load->src, i);
} else if (auto global_load = stmt->cast<GlobalLoadStmt>()) {
if (!after_lower_access && !autodiff_enabled) {
result = get_store_forwarding_data(global_load->src, i);
@@ -430,8 +430,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
auto atomic = stmt->cast<AtomicOpStmt>();
// Weaken the atomic operation to a load.
if (atomic->dest->is<AllocaStmt>()) {
auto local_load =
Stmt::make<LocalLoadStmt>(LocalAddress(atomic->dest, 0));
auto local_load = Stmt::make<LocalLoadStmt>(atomic->dest);
local_load->ret_type = atomic->ret_type;
// Notice that we have a load here
// (the return value of AtomicOpStmt).
8 changes: 4 additions & 4 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
@@ -242,7 +242,7 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) {
if (binary_is_logical(type)) {
auto result = ctx->push_back<AllocaStmt>(ret_type);
ctx->push_back<LocalStoreStmt>(result, lhs->stmt);
auto cond = ctx->push_back<LocalLoadStmt>(LocalAddress(result, 0));
auto cond = ctx->push_back<LocalLoadStmt>(result);
auto if_stmt = ctx->push_back<IfStmt>(cond);

FlattenContext rctx;
@@ -262,7 +262,7 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) {
}
if_stmt->set_false_statements(std::move(false_block));

auto ret = ctx->push_back<LocalLoadStmt>(LocalAddress(result, 0));
auto ret = ctx->push_back<LocalLoadStmt>(result);
ret->tb = tb;
stmt = ret;
return;
@@ -300,7 +300,7 @@ void make_ifte(Expression::FlattenContext *ctx,
false_block->set_statements(std::move(rctx.stmts));
if_stmt->set_false_statements(std::move(false_block));

ctx->push_back<LocalLoadStmt>(LocalAddress(result, 0));
ctx->push_back<LocalLoadStmt>(result);
return;
}

@@ -1153,7 +1153,7 @@ void flatten_global_load(Expr ptr, Expression::FlattenContext *ctx) {
}

void flatten_local_load(Expr ptr, Expression::FlattenContext *ctx) {
ctx->push_back<LocalLoadStmt>(LocalAddress(ptr->stmt, 0));
ctx->push_back<LocalLoadStmt>(ptr->stmt);
ptr->stmt = ctx->back_stmt();
}

2 changes: 1 addition & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
@@ -65,7 +65,7 @@ class FrontendAllocaStmt : public Stmt {

FrontendAllocaStmt(const Identifier &lhs, DataType type)
: ident(lhs), is_shared(false) {
ret_type = TypeFactory::create_vector_or_scalar_type(1, type);
ret_type = type;
}

FrontendAllocaStmt(const Identifier &lhs,
4 changes: 0 additions & 4 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
@@ -496,9 +496,5 @@ void DelayedIRModifier::mark_as_modified() {
modified_ = true;
}

LocalAddress::LocalAddress(Stmt *var, int offset) : var(var), offset(offset) {
TI_ASSERT(var->is<AllocaStmt>() || var->is<PtrOffsetStmt>());
}

} // namespace lang
} // namespace taichi
27 changes: 0 additions & 27 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
@@ -573,25 +573,6 @@ class DelayedIRModifier {
void mark_as_modified();
};

struct LocalAddress {
Stmt *var;
int offset;

LocalAddress(Stmt *var, int offset);
};

class VectorElement {
public:
Stmt *stmt;
int index;

VectorElement() : stmt(nullptr), index(0) {
}

VectorElement(Stmt *stmt, int index) : stmt(stmt), index(index) {
}
};

template <typename T>
inline void StmtFieldManager::operator()(const char *key, T &&value) {
using decay_T = typename std::decay<T>::type;
@@ -612,14 +593,6 @@ inline void StmtFieldManager::operator()(const char *key, T &&value) {
}
} else if constexpr (std::is_same<decay_T, Stmt *>::value) {
stmt_->register_operand(const_cast<Stmt *&>(value));
} else if constexpr (std::is_same<decay_T, LocalAddress>::value) {
stmt_->register_operand(const_cast<Stmt *&>(value.var));
stmt_->field_manager.fields.emplace_back(
std::make_unique<StmtFieldNumeric<int>>(value.offset));
} else if constexpr (std::is_same<decay_T, VectorElement>::value) {
stmt_->register_operand(const_cast<Stmt *&>(value.stmt));
stmt_->field_manager.fields.emplace_back(
std::make_unique<StmtFieldNumeric<int>>(value.index));
} else if constexpr (std::is_same<decay_T, SNode *>::value) {
stmt_->field_manager.fields.emplace_back(
std::make_unique<StmtFieldSNode>(value));
2 changes: 1 addition & 1 deletion taichi/ir/ir_builder.cpp
Original file line number Diff line number Diff line change
@@ -405,7 +405,7 @@ AllocaStmt *IRBuilder::create_local_var(DataType dt) {
}

LocalLoadStmt *IRBuilder::create_local_load(AllocaStmt *ptr) {
return insert(Stmt::make_typed<LocalLoadStmt>(LocalAddress(ptr, 0)));
return insert(Stmt::make_typed<LocalLoadStmt>(ptr));
}

void IRBuilder::create_local_store(AllocaStmt *ptr, Stmt *data) {
16 changes: 5 additions & 11 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
@@ -19,12 +19,7 @@ class Function;
class AllocaStmt : public Stmt {
public:
AllocaStmt(DataType type) : is_shared(false) {
ret_type = TypeFactory::create_vector_or_scalar_type(1, type);
TI_STMT_REG_FIELDS;
}

AllocaStmt(int width, DataType type) : is_shared(false) {
ret_type = TypeFactory::create_vector_or_scalar_type(width, type);
ret_type = type;
TI_STMT_REG_FIELDS;
}

@@ -169,7 +164,7 @@ class ArgLoadStmt : public Stmt {

ArgLoadStmt(int arg_id, const DataType &dt, bool is_ptr = false)
: arg_id(arg_id) {
this->ret_type = TypeFactory::create_vector_or_scalar_type(1, dt);
this->ret_type = dt;
this->is_ptr = is_ptr;
TI_STMT_REG_FIELDS;
}
@@ -593,9 +588,9 @@ class GlobalStoreStmt : public Stmt {
*/
class LocalLoadStmt : public Stmt {
public:
LocalAddress src;
Stmt *src;

explicit LocalLoadStmt(const LocalAddress &src) : src(src) {
explicit LocalLoadStmt(Stmt *src) : src(src) {
TI_STMT_REG_FIELDS;
}

@@ -1394,8 +1389,7 @@ class InternalFuncStmt : public Stmt {
args(args),
with_runtime_context(with_runtime_context) {
if (ret_type == nullptr) {
this->ret_type =
TypeFactory::create_vector_or_scalar_type(1, PrimitiveType::i32);
this->ret_type = PrimitiveType::i32;
} else {
this->ret_type = ret_type;
}
11 changes: 0 additions & 11 deletions taichi/ir/type_factory.cpp
Original file line number Diff line number Diff line change
@@ -131,17 +131,6 @@ PrimitiveType *TypeFactory::get_primitive_real_type(int bits) {
return real_type->cast<PrimitiveType>();
}

DataType TypeFactory::create_vector_or_scalar_type(int width,
DataType element,
bool element_is_pointer) {
TI_ASSERT(width == 1);
if (element_is_pointer) {
return TypeFactory::get_instance().get_pointer_type(element);
} else {
return element;
}
}

DataType TypeFactory::create_tensor_type(std::vector<int> shape,
DataType element) {
return TypeFactory::get_instance().get_tensor_type(shape, element);
4 changes: 0 additions & 4 deletions taichi/ir/type_factory.h
Original file line number Diff line number Diff line change
@@ -44,10 +44,6 @@ class TypeFactory {
Type *element_type,
int num_elements);

static DataType create_vector_or_scalar_type(int width,
DataType element,
bool element_is_pointer = false);

static DataType create_tensor_type(std::vector<int> shape, DataType element);

private:
Loading