Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Texture support 0/n: IR changes #5134

Merged
merged 1 commit into from
Jun 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,18 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr->arg_id);
}

void visit(TexturePtrExpression *expr) override {
emit(ExprOpCode::TexturePtrExpression);
emit(expr->arg_id);
}

void visit(TextureOpExpression *expr) override {
emit(ExprOpCode::TextureOpExpression);
emit(expr->op);
emit(expr->texture_ptr);
emit(expr->args.exprs);
}

void visit(RandExpression *expr) override {
emit(ExprOpCode::RandExpression);
emit(expr->dt);
Expand Down Expand Up @@ -611,6 +623,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
DEFINE_EMIT_ENUM(SNodeAccessFlag);
DEFINE_EMIT_ENUM(MeshRelationAccessType);
DEFINE_EMIT_ENUM(ExternalFuncType);
DEFINE_EMIT_ENUM(TextureOpType);
DEFINE_EMIT_ENUM(mesh::MeshElementType);
DEFINE_EMIT_ENUM(mesh::MeshRelationType);
DEFINE_EMIT_ENUM(mesh::ConvType);
Expand Down
2 changes: 2 additions & 0 deletions taichi/inc/expressions.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ PER_EXPRESSION(MeshPatchIndexExpression)
PER_EXPRESSION(MeshRelationAccessExpression)
PER_EXPRESSION(MeshIndexConversionExpression)
PER_EXPRESSION(ReferenceExpression)
PER_EXPRESSION(TextureOpExpression)
PER_EXPRESSION(TexturePtrExpression)
3 changes: 3 additions & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,8 @@ PER_STATEMENT(BlockLocalPtrStmt)
// Special
PER_STATEMENT(InternalFuncStmt)

PER_STATEMENT(TexturePtrStmt)
PER_STATEMENT(TextureOpStmt)

// Quantization
PER_STATEMENT(BitStructStoreStmt)
10 changes: 10 additions & 0 deletions taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
fmt::format("arg[{}] (dt={})", expr->arg_id, data_type_name(expr->dt)));
}

void visit(TexturePtrExpression *expr) override {
emit(fmt::format("(Texture *)(arg[{}])", expr->arg_id));
}

void visit(TextureOpExpression *expr) override {
emit(fmt::format("texture_{}(", texture_op_type_name(expr->op)));
visit(expr->args);
emit(")");
}

void visit(RandExpression *expr) override {
emit(fmt::format("rand<{}>()", data_type_name(expr->dt)));
}
Expand Down
55 changes: 55 additions & 0 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,15 @@ void ArgLoadExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void TexturePtrExpression::type_check(CompileConfig *config) {
}

void TexturePtrExpression::flatten(FlattenContext *ctx) {
ctx->push_back<ArgLoadStmt>(arg_id, PrimitiveType::f32, true);
ctx->push_back<TexturePtrStmt>(ctx->back_stmt());
stmt = ctx->back_stmt();
}

void RandExpression::type_check(CompileConfig *) {
TI_ASSERT_INFO(dt->is<PrimitiveType>() && dt != PrimitiveType::unknown,
"Invalid dt [{}] for RandExpression", dt->to_string());
Expand Down Expand Up @@ -589,6 +598,52 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) {
stmt = ctx->back_stmt();
}

void TextureOpExpression::type_check(CompileConfig *config) {
if (op == TextureOpType::sample_lod) {
// UV, Lod
TI_ASSERT_INFO(args.size() == 3,
"Invalid number of args for sample_lod Texture op");
TI_ASSERT_TYPE_CHECKED(args[0]);
TI_ASSERT_TYPE_CHECKED(args[1]);
TI_ASSERT_TYPE_CHECKED(args[2]);
if (args[0].get_ret_type() != PrimitiveType::f32 ||
args[1].get_ret_type() != PrimitiveType::f32 ||
args[2].get_ret_type() != PrimitiveType::f32) {
throw TaichiTypeError(
fmt::format("All arguments to sample_lod Texture op must be FP32"));
}
} else if (op == TextureOpType::fetch_texel) {
// index, int LOD
TI_ASSERT_INFO(args.size() == 3,
"Invalid number of args for fetch_texel Texture op");
TI_ASSERT_TYPE_CHECKED(args[0]);
TI_ASSERT_TYPE_CHECKED(args[1]);
TI_ASSERT_TYPE_CHECKED(args[2]);
if (args[0].get_ret_type() != PrimitiveType::i32 ||
args[1].get_ret_type() != PrimitiveType::i32 ||
args[2].get_ret_type() != PrimitiveType::i32) {
throw TaichiTypeError(
fmt::format("All arguments to fetch_texel Texture op must be i32"));
}
} else {
TI_ERROR("Invalid TextureOpType");
}
ret_type =
TypeFactory::get_instance().get_pointer_type(PrimitiveType::f32,
/*is_bit_pointer=*/false);
}

void TextureOpExpression::flatten(FlattenContext *ctx) {
flatten_rvalue(texture_ptr, ctx);
std::vector<Stmt *> arg_stmts;
for (Expr &arg : args.exprs) {
flatten_rvalue(arg, ctx);
arg_stmts.push_back(arg->stmt);
}
ctx->push_back<TextureOpStmt>(op, texture_ptr->stmt, arg_stmts);
stmt = ctx->back_stmt();
}

void ConstExpression::type_check(CompileConfig *) {
TI_ASSERT_INFO(
val.dt->is<PrimitiveType>() && val.dt != PrimitiveType::unknown,
Expand Down
35 changes: 35 additions & 0 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,22 @@ class ArgLoadExpression : public Expression {
TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class Texture;

class TexturePtrExpression : public Expression {
public:
int arg_id;

TexturePtrExpression(int arg_id) : arg_id(arg_id) {
}

void type_check(CompileConfig *config) override;

void flatten(FlattenContext *ctx) override;

TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class RandExpression : public Expression {
public:
DataType dt;
Expand Down Expand Up @@ -612,6 +628,25 @@ class SNodeOpExpression : public Expression {
TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class TextureOpExpression : public Expression {
public:
TextureOpType op;
Expr texture_ptr;
ExprGroup args;

explicit TextureOpExpression(TextureOpType op,
Expr texture_ptr,
const ExprGroup &args)
: op(op), texture_ptr(texture_ptr), args(args) {
}

void type_check(CompileConfig *config) override;

void flatten(FlattenContext *ctx) override;

TI_DEFINE_ACCEPT_FOR_EXPRESSION
};

class ConstExpression : public Expression {
public:
TypedConstant val;
Expand Down
31 changes: 31 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,37 @@ class InternalFuncStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

class Texture;

class TexturePtrStmt : public Stmt {
public:
Stmt *arg_load_stmt{nullptr};

explicit TexturePtrStmt(Stmt *stmt) : arg_load_stmt(stmt) {
TI_STMT_REG_FIELDS;
}

TI_STMT_DEF_FIELDS(arg_load_stmt);
TI_DEFINE_ACCEPT_AND_CLONE
};

class TextureOpStmt : public Stmt {
public:
TextureOpType op;
Stmt *texture_ptr;
std::vector<Stmt *> args;

explicit TextureOpStmt(TextureOpType op,
Stmt *texture_ptr,
const std::vector<Stmt *> &args)
: op(op), texture_ptr(texture_ptr), args(args) {
TI_STMT_REG_FIELDS;
}

TI_STMT_DEF_FIELDS(op, texture_ptr, args);
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* A local AD-stack.
*/
Expand Down
16 changes: 16 additions & 0 deletions taichi/ir/stmt_op_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,5 +146,21 @@ std::string snode_op_type_name(SNodeOpType type) {
}
}

std::string texture_op_type_name(TextureOpType type) {
switch (type) {
#define REGISTER_TYPE(i) \
case TextureOpType::i: \
return #i;

REGISTER_TYPE(sample_lod);
REGISTER_TYPE(fetch_texel);
REGISTER_TYPE(undefined);

#undef REGISTER_TYPE
default:
TI_NOT_IMPLEMENTED
}
}

} // namespace lang
} // namespace taichi
4 changes: 4 additions & 0 deletions taichi/ir/stmt_op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,9 @@ enum class SNodeOpType : int {

std::string snode_op_type_name(SNodeOpType type);

enum class TextureOpType : int { sample_lod, fetch_texel, undefined };

std::string texture_op_type_name(TextureOpType type);

} // namespace lang
} // namespace taichi
10 changes: 10 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,16 @@ class IRPrinter : public IRVisitor {
print("{}{} = arg[{}]", stmt->type_hint(), stmt->name(), stmt->arg_id);
}

void visit(TexturePtrStmt *stmt) override {
print("<*Texture> {} = {}", stmt->name(), stmt->arg_load_stmt->name());
}

void visit(TextureOpStmt *stmt) override {
print("<struct> {} = texture_{}({}, {}, {})", stmt->name(),
texture_op_type_name(stmt->op), stmt->args[0]->name(),
stmt->args[1]->name(), stmt->args[2]->name());
}

void visit(FrontendReturnStmt *stmt) override {
print("{}{} : return [{}]", stmt->type_hint(), stmt->name(),
expr_group_to_string(stmt->values));
Expand Down