Skip to content

Commit c1fd7b2

Browse files
authored
[lang] Remove redundant codegen of integer pow (#6048)
Related issue = #5915 As `pow()` with integer exponent is already demoted (#6044), there's no need for backends to handle integer pows.
1 parent 959641c commit c1fd7b2

File tree

6 files changed

+5
-71
lines changed

6 files changed

+5
-71
lines changed

taichi/codegen/cuda/codegen_cuda.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -663,14 +663,12 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
663663
TI_NOT_IMPLEMENTED
664664
}
665665
} else {
666+
// Note that ret_type here cannot be integral because pow with an
667+
// integral exponent has been demoted in the demote_operations pass
666668
if (ret_type->is_primitive(PrimitiveTypeID::f32)) {
667669
llvm_val[stmt] = create_call("__nv_powf", {lhs, rhs});
668670
} else if (ret_type->is_primitive(PrimitiveTypeID::f64)) {
669671
llvm_val[stmt] = create_call("__nv_pow", {lhs, rhs});
670-
} else if (ret_type->is_primitive(PrimitiveTypeID::i32)) {
671-
llvm_val[stmt] = create_call("pow_i32", {lhs, rhs});
672-
} else if (ret_type->is_primitive(PrimitiveTypeID::i64)) {
673-
llvm_val[stmt] = create_call("pow_i64", {lhs, rhs});
674672
} else {
675673
TI_P(data_type_name(ret_type));
676674
TI_NOT_IMPLEMENTED

taichi/codegen/llvm/codegen_llvm.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -658,14 +658,12 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
658658
}
659659
} else if (op == BinaryOpType::pow) {
660660
if (arch_is_cpu(current_arch())) {
661+
// Note that ret_type here cannot be integral because pow with an
662+
// integral exponent has been demoted in the demote_operations pass
661663
if (ret_type->is_primitive(PrimitiveTypeID::f32)) {
662664
llvm_val[stmt] = create_call("pow_f32", {lhs, rhs});
663665
} else if (ret_type->is_primitive(PrimitiveTypeID::f64)) {
664666
llvm_val[stmt] = create_call("pow_f64", {lhs, rhs});
665-
} else if (ret_type->is_primitive(PrimitiveTypeID::i32)) {
666-
llvm_val[stmt] = create_call("pow_i32", {lhs, rhs});
667-
} else if (ret_type->is_primitive(PrimitiveTypeID::i64)) {
668-
llvm_val[stmt] = create_call("pow_i64", {lhs, rhs});
669667
} else {
670668
TI_P(data_type_name(ret_type));
671669
TI_NOT_IMPLEMENTED

taichi/codegen/metal/codegen_metal.cpp

-6
Original file line numberDiff line numberDiff line change
@@ -559,12 +559,6 @@ class KernelCodegenImpl : public IRVisitor {
559559
}
560560
return;
561561
}
562-
if (op_type == BinaryOpType::pow && is_integral(bin->ret_type)) {
563-
// TODO(k-ye): Make sure the type is not i64?
564-
emit("const {} {} = pow_i32({}, {});", dt_name, bin_name, lhs_name,
565-
rhs_name);
566-
return;
567-
}
568562
const auto binop = metal_binary_op_type_symbol(op_type);
569563
if (is_metal_binary_op_infix(op_type)) {
570564
if (is_comparison(op_type)) {

taichi/codegen/spirv/spirv_codegen.cpp

+1-29
Original file line numberDiff line numberDiff line change
@@ -846,35 +846,6 @@ class TaskCodegen : public IRVisitor {
846846
BINARY_OP_TO_SPIRV_LOGICAL(cmp_ne, ne)
847847
#undef BINARY_OP_TO_SPIRV_LOGICAL
848848

849-
#define INT_OR_FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, \
850-
instruction_id, max_bits) \
851-
else if (op_type == BinaryOpType::op) { \
852-
const uint32_t instruction = instruction_id; \
853-
if (is_real(bin->element_type()) || is_integral(bin->element_type())) { \
854-
if (data_type_bits(bin->element_type()) > max_bits) { \
855-
TI_ERROR( \
856-
"[glsl450] the operand type of instruction {}({}) must <= {}bits", \
857-
#instruction, instruction_id, max_bits); \
858-
} \
859-
if (is_integral(bin->element_type())) { \
860-
bin_value = ir_->cast( \
861-
dst_type, \
862-
ir_->add(ir_->call_glsl450(ir_->f32_type(), instruction, \
863-
ir_->cast(ir_->f32_type(), lhs_value), \
864-
ir_->cast(ir_->f32_type(), rhs_value)), \
865-
ir_->float_immediate_number(ir_->f32_type(), 0.5f))); \
866-
} else { \
867-
bin_value = \
868-
ir_->call_glsl450(dst_type, instruction, lhs_value, rhs_value); \
869-
} \
870-
} else { \
871-
TI_NOT_IMPLEMENTED \
872-
} \
873-
}
874-
875-
INT_OR_FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(pow, Pow, 26, 32)
876-
#undef INT_OR_FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC
877-
878849
#define FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, instruction_id, \
879850
max_bits) \
880851
else if (op_type == BinaryOpType::op) { \
@@ -893,6 +864,7 @@ class TaskCodegen : public IRVisitor {
893864
}
894865

895866
FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(atan2, Atan2, 25, 32)
867+
FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(pow, Pow, 26, 32)
896868
#undef FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC
897869

898870
#define BINARY_OP_TO_SPIRV_FUNC(op, S_inst, S_inst_id, U_inst, U_inst_id, \

taichi/runtime/llvm/runtime_module/runtime.cpp

-16
Original file line numberDiff line numberDiff line change
@@ -205,22 +205,6 @@ DEFINE_UNARY_REAL_FUNC(asin)
205205
DEFINE_UNARY_REAL_FUNC(cos)
206206
DEFINE_UNARY_REAL_FUNC(sin)
207207

208-
#define DEFINE_FAST_POW(T) \
209-
T pow_##T(T x, T n) { \
210-
T ans = 1; \
211-
T tmp = x; \
212-
while (n > 0) { \
213-
if (n & 1) \
214-
ans *= tmp; \
215-
tmp *= tmp; \
216-
n >>= 1; \
217-
} \
218-
return ans; \
219-
}
220-
221-
DEFINE_FAST_POW(i32)
222-
DEFINE_FAST_POW(i64)
223-
224208
i32 abs_i32(i32 a) {
225209
return a >= 0 ? a : -a;
226210
}

taichi/runtime/metal/shaders/helpers.metal.h

-12
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,6 @@ STR(
3838
: intm);
3939
}
4040

41-
int32_t pow_i32(int32_t x, int32_t n) {
42-
int32_t tmp = x;
43-
int32_t ans = 1;
44-
while (n > (int32_t)(0)) {
45-
if (n & 1)
46-
ans *= tmp;
47-
tmp *= tmp;
48-
n >>= 1;
49-
}
50-
return ans;
51-
}
52-
5341
float fatomic_fetch_add(device float *dest, const float operand) {
5442
// A huge hack! Metal does not support atomic floating point numbers
5543
// natively.

0 commit comments

Comments
 (0)