Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[type] [refactor] Misc improvements to quant codegen #5129

Merged
merged 5 commits into from
Jun 13, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Replace is_custom_type() with is_quant()
strongoier committed Jun 10, 2022
commit 9a4705fb8abdff022fe88927a143328c9bb2e21b
5 changes: 1 addition & 4 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
@@ -1456,10 +1456,7 @@ def _calc_dynamic_index_stride(self):
self.dynamic_index_stride = 0
return
length = len(paths[0])
if any(
len(path) != length or ti_core.is_custom_type(path[length -
1]._dtype)
for path in paths):
if any(len(path) != length or ti_core.is_quant(path[length - 1]._dtype) for path in paths):
return
for i in range(length):
if any(path[i] != paths[0][i] for path in paths):
6 changes: 2 additions & 4 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
@@ -510,10 +510,8 @@ void AtomicOpExpression::type_check(CompileConfig *) {
};
if (!val->ret_type->is<PrimitiveType>())
error();
if (auto cit = dest->ret_type->cast<CustomIntType>()) {
ret_type = cit->get_compute_type();
} else if (auto cft = dest->ret_type->cast<CustomFloatType>()) {
ret_type = cft->get_compute_type();
if (is_quant(dest->ret_type)) {
ret_type = dest->ret_type->get_compute_type();
} else if (dest->ret_type->is<PrimitiveType>()) {
ret_type = dest->ret_type;
} else {
2 changes: 1 addition & 1 deletion taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
@@ -73,7 +73,7 @@ inline PrimitiveTypeID get_primitive_data_type() {
}
}

inline bool is_custom_type(DataType dt) {
inline bool is_quant(DataType dt) {
return dt->is<CustomIntType>() || dt->is<CustomFloatType>();
}

2 changes: 1 addition & 1 deletion taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
@@ -859,7 +859,7 @@ void export_lang(py::module &m) {
#undef PER_TYPE

m.def("data_type_size", data_type_size);
m.def("is_custom_type", is_custom_type);
m.def("is_quant", is_quant);
m.def("is_integral", is_integral);
m.def("is_signed", is_signed);
m.def("is_real", is_real);
2 changes: 1 addition & 1 deletion taichi/transforms/type_check.cpp
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ class TypeCheck : public IRVisitor {
Stmt *&val,
const std::string &stmt_name) {
auto dst_type = dst->ret_type.ptr_removed();
if (dst_type->is<CustomIntType>() || dst_type->is<CustomFloatType>()) {
if (is_quant(dst_type)) {
// We force the value type to be the compute_type of the bit pointer.
// Casting from compute_type to physical_type is handled in codegen.
dst_type = dst_type->get_compute_type();