Skip to content

Commit 78a0139

Browse files
[Lang] [ir] Add short-circuit if-then-else operator (#5022)
* [Lang] [ir] Add a short-circuit if-then-else operator and use it to implement IfExp * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8966069 commit 78a0139

9 files changed

+78
-25
lines changed

python/taichi/lang/ast/ast_transformer.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -1108,17 +1108,7 @@ def build_IfExp(ctx, node):
11081108
node.ptr = build_stmt(ctx, node.orelse)
11091109
return node.ptr
11101110

1111-
val = impl.expr_init(None)
1112-
1113-
impl.begin_frontend_if(ctx.ast_builder, node.test.ptr)
1114-
ctx.ast_builder.begin_frontend_if_true()
1115-
val._assign(node.body.ptr)
1116-
ctx.ast_builder.pop_scope()
1117-
ctx.ast_builder.begin_frontend_if_false()
1118-
val._assign(node.orelse.ptr)
1119-
ctx.ast_builder.pop_scope()
1120-
1121-
node.ptr = val
1111+
node.ptr = ti_ops.ifte(node.test.ptr, node.body.ptr, node.orelse.ptr)
11221112
return node.ptr
11231113

11241114
@staticmethod

python/taichi/lang/ops.py

+23
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,29 @@ def py_select(cond, x1, x2):
11431143
return _ternary_operation(_ti_core.expr_select, py_select, cond, x1, x2)
11441144

11451145

1146+
@ternary
1147+
def ifte(cond, x1, x2):
1148+
"""Evaluate and return `x1` if `cond` is true; otherwise evaluate and return `x2`. This operator guarantees
1149+
short-circuit semantics: exactly one of `x1` or `x2` will be evaluated.
1150+
1151+
Args:
1152+
cond (:mod:`~taichi.types.primitive_types`): \
1153+
The condition.
1154+
x1, x2 (:mod:`~taichi.types.primitive_types`): \
1155+
The outputs.
1156+
1157+
Returns:
1158+
`x1` if `cond` is true and `x2` otherwise.
1159+
"""
1160+
# TODO: systematically resolve `-1 = True` problem by introducing u1:
1161+
cond = logical_not(logical_not(cond))
1162+
1163+
def py_ifte(cond, x1, x2):
1164+
return x1 if cond else x2
1165+
1166+
return _ternary_operation(_ti_core.expr_ifte, py_ifte, cond, x1, x2)
1167+
1168+
11461169
@writeback_binary
11471170
def atomic_add(x, y):
11481171
"""Atomically compute `x + y`, store the result in `x`,

taichi/ir/expression_ops.h

+1
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ DEFINE_EXPRESSION_FUNC_BINARY(floordiv)
124124
DEFINE_EXPRESSION_FUNC_BINARY(bit_shr)
125125

126126
DEFINE_EXPRESSION_FUNC_TERNARY(select)
127+
DEFINE_EXPRESSION_FUNC_TERNARY(ifte)
127128

128129
} // namespace lang
129130
} // namespace taichi

taichi/ir/frontend_ir.cpp

+44-7
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,37 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) {
263263
stmt = ctx->back_stmt();
264264
}
265265

266+
void make_ifte(Expression::FlattenContext *ctx,
267+
DataType ret_type,
268+
Expr cond,
269+
Expr true_val,
270+
Expr false_val) {
271+
auto result = ctx->push_back<AllocaStmt>(ret_type);
272+
flatten_rvalue(cond, ctx);
273+
auto if_stmt = ctx->push_back<IfStmt>(cond->stmt);
274+
275+
Expression::FlattenContext lctx;
276+
lctx.current_block = ctx->current_block;
277+
flatten_rvalue(true_val, &lctx);
278+
lctx.push_back<LocalStoreStmt>(result, true_val->stmt);
279+
280+
Expression::FlattenContext rctx;
281+
rctx.current_block = ctx->current_block;
282+
flatten_rvalue(false_val, &rctx);
283+
rctx.push_back<LocalStoreStmt>(result, false_val->stmt);
284+
285+
auto true_block = std::make_unique<Block>();
286+
true_block->set_statements(std::move(lctx.stmts));
287+
if_stmt->set_true_statements(std::move(true_block));
288+
289+
auto false_block = std::make_unique<Block>();
290+
false_block->set_statements(std::move(rctx.stmts));
291+
if_stmt->set_false_statements(std::move(false_block));
292+
293+
ctx->push_back<LocalLoadStmt>(LocalAddress(result, 0));
294+
return;
295+
}
296+
266297
void TernaryOpExpression::type_check(CompileConfig *) {
267298
TI_ASSERT_TYPE_CHECKED(op1);
268299
TI_ASSERT_TYPE_CHECKED(op2);
@@ -276,21 +307,27 @@ void TernaryOpExpression::type_check(CompileConfig *) {
276307
ternary_type_name(type), op1->ret_type->to_string(),
277308
op2->ret_type->to_string(), op3->ret_type->to_string()));
278309
};
279-
if (!is_integral(op1_type) || !op2_type->is<PrimitiveType>() ||
280-
!op3_type->is<PrimitiveType>())
310+
if (op1_type != PrimitiveType::i32)
311+
error();
312+
if (!op2_type->is<PrimitiveType>() || !op3_type->is<PrimitiveType>())
281313
error();
282314
ret_type = promoted_type(op2_type, op3_type);
283315
}
284316

285317
void TernaryOpExpression::flatten(FlattenContext *ctx) {
286318
// if (stmt)
287319
// return;
288-
flatten_rvalue(op1, ctx);
289-
flatten_rvalue(op2, ctx);
290-
flatten_rvalue(op3, ctx);
291-
ctx->push_back(
292-
std::make_unique<TernaryOpStmt>(type, op1->stmt, op2->stmt, op3->stmt));
320+
if (type == TernaryOpType::select) {
321+
flatten_rvalue(op1, ctx);
322+
flatten_rvalue(op2, ctx);
323+
flatten_rvalue(op3, ctx);
324+
ctx->push_back(
325+
std::make_unique<TernaryOpStmt>(type, op1->stmt, op2->stmt, op3->stmt));
326+
} else if (type == TernaryOpType::ifte) {
327+
make_ifte(ctx, ret_type, op1, op2, op3);
328+
}
293329
stmt = ctx->back_stmt();
330+
stmt->tb = tb;
294331
}
295332

296333
void InternalFuncCallExpression::type_check(CompileConfig *) {

taichi/ir/stmt_op_types.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ std::string ternary_type_name(TernaryOpType type) {
7777
return #i;
7878

7979
REGISTER_TYPE(select);
80+
REGISTER_TYPE(ifte);
8081

8182
#undef REGISTER_TYPE
8283
default:

taichi/ir/stmt_op_types.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ inline bool is_bit_op(BinaryOpType type) {
6262

6363
std::string binary_op_type_symbol(BinaryOpType type);
6464

65-
enum class TernaryOpType : int { select, undefined };
65+
enum class TernaryOpType : int { select, ifte, undefined };
6666

6767
std::string ternary_type_name(TernaryOpType type);
6868

taichi/python/export_lang.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,7 @@ void export_lang(py::module &m) {
763763
DEFINE_EXPRESSION_OP(log)
764764

765765
DEFINE_EXPRESSION_OP(select)
766+
DEFINE_EXPRESSION_OP(ifte)
766767

767768
DEFINE_EXPRESSION_OP(cmp_le)
768769
DEFINE_EXPRESSION_OP(cmp_lt)

tests/cpp/ir/frontend_type_inference_test.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,16 @@ TEST(FrontendTypeInference, UnaryOp) {
6666
}
6767

6868
TEST(FrontendTypeInference, TernaryOp) {
69-
auto const_i16 = value<int16>(-(1 << 10));
70-
const_i16->type_check(nullptr);
71-
EXPECT_EQ(const_i16->ret_type, PrimitiveType::i16);
72-
auto cast_i8 = cast(const_i16, PrimitiveType::i8);
69+
auto const_i32 = value<int32>(-(1 << 10));
70+
const_i32->type_check(nullptr);
71+
EXPECT_EQ(const_i32->ret_type, PrimitiveType::i32);
72+
auto cast_i8 = cast(const_i32, PrimitiveType::i8);
7373
cast_i8->type_check(nullptr);
7474
EXPECT_EQ(cast_i8->ret_type, PrimitiveType::i8);
7575
auto const_f32 = value<float32>(5.0);
7676
const_f32->type_check(nullptr);
7777
EXPECT_EQ(const_f32->ret_type, PrimitiveType::f32);
78-
auto ternary_f32 = expr_select(const_i16, cast_i8, const_f32);
78+
auto ternary_f32 = expr_select(const_i32, cast_i8, const_f32);
7979
ternary_f32->type_check(nullptr);
8080
EXPECT_EQ(ternary_f32->ret_type, PrimitiveType::f32);
8181
}

tests/python/test_type_check.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def select():
4141
d = b if a else c
4242

4343
with pytest.raises(TypeError,
44-
match="`if` conditions must be of type int32"):
44+
match="unsupported operand type\\(s\\) for 'ifte'"):
4545
select()
4646

4747

0 commit comments

Comments
 (0)