Skip to content

Commit e10a11b

Browse files
authored
[lang] Texture support 0/n: IR changes (#5134)
1 parent a01c373 commit e10a11b

10 files changed

+179
-0
lines changed

taichi/analysis/gen_offline_cache_key.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,18 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
8282
emit(expr->arg_id);
8383
}
8484

85+
void visit(TexturePtrExpression *expr) override {
86+
emit(ExprOpCode::TexturePtrExpression);
87+
emit(expr->arg_id);
88+
}
89+
90+
void visit(TextureOpExpression *expr) override {
91+
emit(ExprOpCode::TextureOpExpression);
92+
emit(expr->op);
93+
emit(expr->texture_ptr);
94+
emit(expr->args.exprs);
95+
}
96+
8597
void visit(RandExpression *expr) override {
8698
emit(ExprOpCode::RandExpression);
8799
emit(expr->dt);
@@ -611,6 +623,7 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
611623
DEFINE_EMIT_ENUM(SNodeAccessFlag);
612624
DEFINE_EMIT_ENUM(MeshRelationAccessType);
613625
DEFINE_EMIT_ENUM(ExternalFuncType);
626+
DEFINE_EMIT_ENUM(TextureOpType);
614627
DEFINE_EMIT_ENUM(mesh::MeshElementType);
615628
DEFINE_EMIT_ENUM(mesh::MeshRelationType);
616629
DEFINE_EMIT_ENUM(mesh::ConvType);

taichi/inc/expressions.inc.h

+2
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ PER_EXPRESSION(MeshPatchIndexExpression)
2020
PER_EXPRESSION(MeshRelationAccessExpression)
2121
PER_EXPRESSION(MeshIndexConversionExpression)
2222
PER_EXPRESSION(ReferenceExpression)
23+
PER_EXPRESSION(TextureOpExpression)
24+
PER_EXPRESSION(TexturePtrExpression)

taichi/inc/statements.inc.h

+3
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,8 @@ PER_STATEMENT(BlockLocalPtrStmt)
7979
// Special
8080
PER_STATEMENT(InternalFuncStmt)
8181

82+
PER_STATEMENT(TexturePtrStmt)
83+
PER_STATEMENT(TextureOpStmt)
84+
8285
// Quantization
8386
PER_STATEMENT(BitStructStoreStmt)

taichi/ir/expression_printer.h

+10
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
4141
fmt::format("arg[{}] (dt={})", expr->arg_id, data_type_name(expr->dt)));
4242
}
4343

44+
void visit(TexturePtrExpression *expr) override {
45+
emit(fmt::format("(Texture *)(arg[{}])", expr->arg_id));
46+
}
47+
48+
void visit(TextureOpExpression *expr) override {
49+
emit(fmt::format("texture_{}(", texture_op_type_name(expr->op)));
50+
visit(expr->args);
51+
emit(")");
52+
}
53+
4454
void visit(RandExpression *expr) override {
4555
emit(fmt::format("rand<{}>()", data_type_name(expr->dt)));
4656
}

taichi/ir/frontend_ir.cpp

+55
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ void ArgLoadExpression::flatten(FlattenContext *ctx) {
126126
stmt = ctx->back_stmt();
127127
}
128128

129+
void TexturePtrExpression::type_check(CompileConfig *config) {
130+
}
131+
132+
void TexturePtrExpression::flatten(FlattenContext *ctx) {
133+
ctx->push_back<ArgLoadStmt>(arg_id, PrimitiveType::f32, true);
134+
ctx->push_back<TexturePtrStmt>(ctx->back_stmt());
135+
stmt = ctx->back_stmt();
136+
}
137+
129138
void RandExpression::type_check(CompileConfig *) {
130139
TI_ASSERT_INFO(dt->is<PrimitiveType>() && dt != PrimitiveType::unknown,
131140
"Invalid dt [{}] for RandExpression", dt->to_string());
@@ -589,6 +598,52 @@ void SNodeOpExpression::flatten(FlattenContext *ctx) {
589598
stmt = ctx->back_stmt();
590599
}
591600

601+
void TextureOpExpression::type_check(CompileConfig *config) {
602+
if (op == TextureOpType::sample_lod) {
603+
// UV, Lod
604+
TI_ASSERT_INFO(args.size() == 3,
605+
"Invalid number of args for sample_lod Texture op");
606+
TI_ASSERT_TYPE_CHECKED(args[0]);
607+
TI_ASSERT_TYPE_CHECKED(args[1]);
608+
TI_ASSERT_TYPE_CHECKED(args[2]);
609+
if (args[0].get_ret_type() != PrimitiveType::f32 ||
610+
args[1].get_ret_type() != PrimitiveType::f32 ||
611+
args[2].get_ret_type() != PrimitiveType::f32) {
612+
throw TaichiTypeError(
613+
fmt::format("All arguments to sample_lod Texture op must be FP32"));
614+
}
615+
} else if (op == TextureOpType::fetch_texel) {
616+
// index, int LOD
617+
TI_ASSERT_INFO(args.size() == 3,
618+
"Invalid number of args for fetch_texel Texture op");
619+
TI_ASSERT_TYPE_CHECKED(args[0]);
620+
TI_ASSERT_TYPE_CHECKED(args[1]);
621+
TI_ASSERT_TYPE_CHECKED(args[2]);
622+
if (args[0].get_ret_type() != PrimitiveType::i32 ||
623+
args[1].get_ret_type() != PrimitiveType::i32 ||
624+
args[2].get_ret_type() != PrimitiveType::i32) {
625+
throw TaichiTypeError(
626+
fmt::format("All arguments to fetch_texel Texture op must be i32"));
627+
}
628+
} else {
629+
TI_ERROR("Invalid TextureOpType");
630+
}
631+
ret_type =
632+
TypeFactory::get_instance().get_pointer_type(PrimitiveType::f32,
633+
/*is_bit_pointer=*/false);
634+
}
635+
636+
void TextureOpExpression::flatten(FlattenContext *ctx) {
637+
flatten_rvalue(texture_ptr, ctx);
638+
std::vector<Stmt *> arg_stmts;
639+
for (Expr &arg : args.exprs) {
640+
flatten_rvalue(arg, ctx);
641+
arg_stmts.push_back(arg->stmt);
642+
}
643+
ctx->push_back<TextureOpStmt>(op, texture_ptr->stmt, arg_stmts);
644+
stmt = ctx->back_stmt();
645+
}
646+
592647
void ConstExpression::type_check(CompileConfig *) {
593648
TI_ASSERT_INFO(
594649
val.dt->is<PrimitiveType>() && val.dt != PrimitiveType::unknown,

taichi/ir/frontend_ir.h

+35
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,22 @@ class ArgLoadExpression : public Expression {
293293
TI_DEFINE_ACCEPT_FOR_EXPRESSION
294294
};
295295

296+
class Texture;
297+
298+
class TexturePtrExpression : public Expression {
299+
public:
300+
int arg_id;
301+
302+
TexturePtrExpression(int arg_id) : arg_id(arg_id) {
303+
}
304+
305+
void type_check(CompileConfig *config) override;
306+
307+
void flatten(FlattenContext *ctx) override;
308+
309+
TI_DEFINE_ACCEPT_FOR_EXPRESSION
310+
};
311+
296312
class RandExpression : public Expression {
297313
public:
298314
DataType dt;
@@ -612,6 +628,25 @@ class SNodeOpExpression : public Expression {
612628
TI_DEFINE_ACCEPT_FOR_EXPRESSION
613629
};
614630

631+
class TextureOpExpression : public Expression {
632+
public:
633+
TextureOpType op;
634+
Expr texture_ptr;
635+
ExprGroup args;
636+
637+
explicit TextureOpExpression(TextureOpType op,
638+
Expr texture_ptr,
639+
const ExprGroup &args)
640+
: op(op), texture_ptr(texture_ptr), args(args) {
641+
}
642+
643+
void type_check(CompileConfig *config) override;
644+
645+
void flatten(FlattenContext *ctx) override;
646+
647+
TI_DEFINE_ACCEPT_FOR_EXPRESSION
648+
};
649+
615650
class ConstExpression : public Expression {
616651
public:
617652
TypedConstant val;

taichi/ir/statements.h

+31
Original file line numberDiff line numberDiff line change
@@ -1446,6 +1446,37 @@ class InternalFuncStmt : public Stmt {
14461446
TI_DEFINE_ACCEPT_AND_CLONE
14471447
};
14481448

1449+
class Texture;
1450+
1451+
class TexturePtrStmt : public Stmt {
1452+
public:
1453+
Stmt *arg_load_stmt{nullptr};
1454+
1455+
explicit TexturePtrStmt(Stmt *stmt) : arg_load_stmt(stmt) {
1456+
TI_STMT_REG_FIELDS;
1457+
}
1458+
1459+
TI_STMT_DEF_FIELDS(arg_load_stmt);
1460+
TI_DEFINE_ACCEPT_AND_CLONE
1461+
};
1462+
1463+
class TextureOpStmt : public Stmt {
1464+
public:
1465+
TextureOpType op;
1466+
Stmt *texture_ptr;
1467+
std::vector<Stmt *> args;
1468+
1469+
explicit TextureOpStmt(TextureOpType op,
1470+
Stmt *texture_ptr,
1471+
const std::vector<Stmt *> &args)
1472+
: op(op), texture_ptr(texture_ptr), args(args) {
1473+
TI_STMT_REG_FIELDS;
1474+
}
1475+
1476+
TI_STMT_DEF_FIELDS(op, texture_ptr, args);
1477+
TI_DEFINE_ACCEPT_AND_CLONE
1478+
};
1479+
14491480
/**
14501481
* A local AD-stack.
14511482
*/

taichi/ir/stmt_op_types.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -146,5 +146,21 @@ std::string snode_op_type_name(SNodeOpType type) {
146146
}
147147
}
148148

149+
std::string texture_op_type_name(TextureOpType type) {
150+
switch (type) {
151+
#define REGISTER_TYPE(i) \
152+
case TextureOpType::i: \
153+
return #i;
154+
155+
REGISTER_TYPE(sample_lod);
156+
REGISTER_TYPE(fetch_texel);
157+
REGISTER_TYPE(undefined);
158+
159+
#undef REGISTER_TYPE
160+
default:
161+
TI_NOT_IMPLEMENTED
162+
}
163+
}
164+
149165
} // namespace lang
150166
} // namespace taichi

taichi/ir/stmt_op_types.h

+4
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,9 @@ enum class SNodeOpType : int {
8484

8585
std::string snode_op_type_name(SNodeOpType type);
8686

87+
enum class TextureOpType : int { sample_lod, fetch_texel, undefined };
88+
89+
std::string texture_op_type_name(TextureOpType type);
90+
8791
} // namespace lang
8892
} // namespace taichi

taichi/transforms/ir_printer.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,16 @@ class IRPrinter : public IRVisitor {
424424
print("{}{} = arg[{}]", stmt->type_hint(), stmt->name(), stmt->arg_id);
425425
}
426426

427+
void visit(TexturePtrStmt *stmt) override {
428+
print("<*Texture> {} = {}", stmt->name(), stmt->arg_load_stmt->name());
429+
}
430+
431+
void visit(TextureOpStmt *stmt) override {
432+
print("<struct> {} = texture_{}({}, {}, {})", stmt->name(),
433+
texture_op_type_name(stmt->op), stmt->args[0]->name(),
434+
stmt->args[1]->name(), stmt->args[2]->name());
435+
}
436+
427437
void visit(FrontendReturnStmt *stmt) override {
428438
print("{}{} : return [{}]", stmt->type_hint(), stmt->name(),
429439
expr_group_to_string(stmt->values));

0 commit comments

Comments
 (0)