diff --git a/taichi/analysis/gen_offline_cache_key.cpp b/taichi/analysis/gen_offline_cache_key.cpp index f3ecaae6171e7..f9723e99f0044 100644 --- a/taichi/analysis/gen_offline_cache_key.cpp +++ b/taichi/analysis/gen_offline_cache_key.cpp @@ -82,6 +82,18 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { emit(expr->arg_id); } + void visit(TexturePtrExpression *expr) override { + emit(ExprOpCode::TexturePtrExpression); + emit(expr->arg_id); + } + + void visit(TextureOpExpression *expr) override { + emit(ExprOpCode::TextureOpExpression); + emit(expr->op); + emit(expr->texture_ptr); + emit(expr->args.exprs); + } + void visit(RandExpression *expr) override { emit(ExprOpCode::RandExpression); emit(expr->dt); @@ -611,6 +623,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor { DEFINE_EMIT_ENUM(SNodeAccessFlag); DEFINE_EMIT_ENUM(MeshRelationAccessType); DEFINE_EMIT_ENUM(ExternalFuncType); + DEFINE_EMIT_ENUM(TextureOpType); DEFINE_EMIT_ENUM(mesh::MeshElementType); DEFINE_EMIT_ENUM(mesh::MeshRelationType); DEFINE_EMIT_ENUM(mesh::ConvType); diff --git a/taichi/inc/expressions.inc.h b/taichi/inc/expressions.inc.h index 4ec43c58357f9..d4ef5dbc6fa47 100644 --- a/taichi/inc/expressions.inc.h +++ b/taichi/inc/expressions.inc.h @@ -20,3 +20,5 @@ PER_EXPRESSION(MeshPatchIndexExpression) PER_EXPRESSION(MeshRelationAccessExpression) PER_EXPRESSION(MeshIndexConversionExpression) PER_EXPRESSION(ReferenceExpression) +PER_EXPRESSION(TextureOpExpression) +PER_EXPRESSION(TexturePtrExpression) diff --git a/taichi/inc/statements.inc.h b/taichi/inc/statements.inc.h index c40c89290afd8..fe12a8941f7f5 100644 --- a/taichi/inc/statements.inc.h +++ b/taichi/inc/statements.inc.h @@ -79,5 +79,8 @@ PER_STATEMENT(BlockLocalPtrStmt) // Special PER_STATEMENT(InternalFuncStmt) +PER_STATEMENT(TexturePtrStmt) +PER_STATEMENT(TextureOpStmt) + // Quantization PER_STATEMENT(BitStructStoreStmt) diff --git a/taichi/ir/expression_printer.h b/taichi/ir/expression_printer.h index 29f391d17334a..133b7e4eb27ba 100644 --- a/taichi/ir/expression_printer.h +++ b/taichi/ir/expression_printer.h @@ -41,6 +41,16 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter { fmt::format("arg[{}] (dt={})", expr->arg_id, data_type_name(expr->dt))); } + void visit(TexturePtrExpression *expr) override { + emit(fmt::format("(Texture *)(arg[{}])", expr->arg_id)); + } + + void visit(TextureOpExpression *expr) override { + emit(fmt::format("texture_{}(", texture_op_type_name(expr->op))); + visit(expr->args); + emit(")"); + } + void visit(RandExpression *expr) override { emit(fmt::format("rand<{}>()", data_type_name(expr->dt))); } diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 9ded4ffae1146..93fca8a78e6ef 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -126,6 +126,15 @@ void ArgLoadExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void TexturePtrExpression::type_check(CompileConfig *config) { +} + +void TexturePtrExpression::flatten(FlattenContext *ctx) { + ctx->push_back(arg_id, PrimitiveType::f32, true); + ctx->push_back(ctx->back_stmt()); + stmt = ctx->back_stmt(); +} + void RandExpression::type_check(CompileConfig *) { TI_ASSERT_INFO(dt->is() && dt != PrimitiveType::unknown, "Invalid dt [{}] for RandExpression", dt->to_string()); @@ -589,6 +598,52 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) { stmt = ctx->back_stmt(); } +void TextureOpExpression::type_check(CompileConfig *config) { + if (op == TextureOpType::sample_lod) { + // UV, Lod + TI_ASSERT_INFO(args.size() == 3, + "Invalid number of args for sample_lod Texture op"); + TI_ASSERT_TYPE_CHECKED(args[0]); + TI_ASSERT_TYPE_CHECKED(args[1]); + TI_ASSERT_TYPE_CHECKED(args[2]); + if (args[0].get_ret_type() != PrimitiveType::f32 || + args[1].get_ret_type() != PrimitiveType::f32 || + args[2].get_ret_type() != PrimitiveType::f32) { + throw TaichiTypeError( + fmt::format("All arguments to sample_lod Texture op must be FP32")); + } + } else if (op == TextureOpType::fetch_texel) { + // index, int LOD + TI_ASSERT_INFO(args.size() == 3, + "Invalid number of args for fetch_texel Texture op"); + TI_ASSERT_TYPE_CHECKED(args[0]); + TI_ASSERT_TYPE_CHECKED(args[1]); + TI_ASSERT_TYPE_CHECKED(args[2]); + if (args[0].get_ret_type() != PrimitiveType::i32 || + args[1].get_ret_type() != PrimitiveType::i32 || + args[2].get_ret_type() != PrimitiveType::i32) { + throw TaichiTypeError( + fmt::format("All arguments to fetch_texel Texture op must be i32")); + } + } else { + TI_ERROR("Invalid TextureOpType"); + } + ret_type = + TypeFactory::get_instance().get_pointer_type(PrimitiveType::f32, + /*is_bit_pointer=*/false); +} + +void TextureOpExpression::flatten(FlattenContext *ctx) { + flatten_rvalue(texture_ptr, ctx); + std::vector arg_stmts; + for (Expr &arg : args.exprs) { + flatten_rvalue(arg, ctx); + arg_stmts.push_back(arg->stmt); + } + ctx->push_back(op, texture_ptr->stmt, arg_stmts); + stmt = ctx->back_stmt(); +} + void ConstExpression::type_check(CompileConfig *) { TI_ASSERT_INFO( val.dt->is() && val.dt != PrimitiveType::unknown, diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 6343de337e831..2c2eba5bd12f2 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -293,6 +293,22 @@ class ArgLoadExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; +class Texture; + +class TexturePtrExpression : public Expression { + public: + int arg_id; + + TexturePtrExpression(int arg_id) : arg_id(arg_id) { + } + + void type_check(CompileConfig *config) override; + + void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION +}; + class RandExpression : public Expression { public: DataType dt; @@ -612,6 +628,25 @@ class SNodeOpExpression : public Expression { TI_DEFINE_ACCEPT_FOR_EXPRESSION }; +class TextureOpExpression : public Expression { + public: + TextureOpType op; + Expr texture_ptr; + ExprGroup args; + + explicit TextureOpExpression(TextureOpType op, + Expr texture_ptr, + const ExprGroup &args) + : op(op), texture_ptr(texture_ptr), args(args) { + } + + void type_check(CompileConfig *config) override; + + void flatten(FlattenContext *ctx) override; + + TI_DEFINE_ACCEPT_FOR_EXPRESSION +}; + class ConstExpression : public Expression { public: TypedConstant val; diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index d9099ccdcf785..25ceed9a47a9e 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1446,6 +1446,37 @@ class InternalFuncStmt : public Stmt { TI_DEFINE_ACCEPT_AND_CLONE }; +class Texture; + +class TexturePtrStmt : public Stmt { + public: + Stmt *arg_load_stmt{nullptr}; + + explicit TexturePtrStmt(Stmt *stmt) : arg_load_stmt(stmt) { + TI_STMT_REG_FIELDS; + } + + TI_STMT_DEF_FIELDS(arg_load_stmt); + TI_DEFINE_ACCEPT_AND_CLONE +}; + +class TextureOpStmt : public Stmt { + public: + TextureOpType op; + Stmt *texture_ptr; + std::vector args; + + explicit TextureOpStmt(TextureOpType op, + Stmt *texture_ptr, + const std::vector &args) + : op(op), texture_ptr(texture_ptr), args(args) { + TI_STMT_REG_FIELDS; + } + + TI_STMT_DEF_FIELDS(op, texture_ptr, args); + TI_DEFINE_ACCEPT_AND_CLONE +}; + /** * A local AD-stack. */ diff --git a/taichi/ir/stmt_op_types.cpp b/taichi/ir/stmt_op_types.cpp index a5f492f869c0a..0ab447c25f68a 100644 --- a/taichi/ir/stmt_op_types.cpp +++ b/taichi/ir/stmt_op_types.cpp @@ -146,5 +146,21 @@ std::string snode_op_type_name(SNodeOpType type) { } } +std::string texture_op_type_name(TextureOpType type) { + switch (type) { +#define REGISTER_TYPE(i) \ + case TextureOpType::i: \ + return #i; + + REGISTER_TYPE(sample_lod); + REGISTER_TYPE(fetch_texel); + REGISTER_TYPE(undefined); + +#undef REGISTER_TYPE + default: + TI_NOT_IMPLEMENTED + } +} + } // namespace lang } // namespace taichi diff --git a/taichi/ir/stmt_op_types.h b/taichi/ir/stmt_op_types.h index a71d5512cd6ed..12609d994fbc4 100644 --- a/taichi/ir/stmt_op_types.h +++ b/taichi/ir/stmt_op_types.h @@ -84,5 +84,9 @@ enum class SNodeOpType : int { std::string snode_op_type_name(SNodeOpType type); +enum class TextureOpType : int { sample_lod, fetch_texel, undefined }; + +std::string texture_op_type_name(TextureOpType type); + } // namespace lang } // namespace taichi diff --git a/taichi/transforms/ir_printer.cpp b/taichi/transforms/ir_printer.cpp index 7cb30527128d8..8a005ce45b082 100644 --- a/taichi/transforms/ir_printer.cpp +++ b/taichi/transforms/ir_printer.cpp @@ -424,6 +424,16 @@ class IRPrinter : public IRVisitor { print("{}{} = arg[{}]", stmt->type_hint(), stmt->name(), stmt->arg_id); } + void visit(TexturePtrStmt *stmt) override { + print("<*Texture> {} = {}", stmt->name(), stmt->arg_load_stmt->name()); + } + + void visit(TextureOpStmt *stmt) override { + print(" {} = texture_{}({}, {}, {})", stmt->name(), + texture_op_type_name(stmt->op), stmt->args[0]->name(), + stmt->args[1]->name(), stmt->args[2]->name()); + } + void visit(FrontendReturnStmt *stmt) override { print("{}{} : return [{}]", stmt->type_hint(), stmt->name(), expr_group_to_string(stmt->values));