Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Enable matrix representation in Frontend/CHI IR #5551

Closed
wants to merge 38 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
ab37bfe
init: the real matrix
AD1024 Jul 19, 2022
841b405
save: refactor matrix init
AD1024 Jul 21, 2022
2412238
fix some passes
AD1024 Jul 22, 2022
49d3d2b
PrintStmt for TensorType
AD1024 Jul 22, 2022
82df596
use MatrixInitStmt
AD1024 Jul 23, 2022
85b868b
try codegen
AD1024 Jul 23, 2022
2c746c6
finish codegen
AD1024 Jul 25, 2022
f21373c
[impl] basic indexing
AD1024 Jul 27, 2022
23656ab
format code
AD1024 Jul 27, 2022
1342f7c
[impl] basic operators step 1
AD1024 Jul 27, 2022
fb39d78
[fix] skip alg simp for some cases
AD1024 Jul 27, 2022
ef33ff1
add simple ad hoc shape check placeholder
AD1024 Jul 27, 2022
a90a901
save
AD1024 Aug 2, 2022
743393a
fix cfg pass
AD1024 Aug 3, 2022
e7cce02
Merge branch 'master' into impl-matrix
AD1024 Aug 4, 2022
b7fc15d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 4, 2022
ff8b73c
get rid of reduce
AD1024 Aug 5, 2022
378e61a
fix typo bug
AD1024 Aug 5, 2022
bcca7f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2022
3132514
save
AD1024 Aug 9, 2022
25f2ed8
shape check for indexing
AD1024 Aug 9, 2022
46b3829
Merge branch 'impl-matrix' of github.com:AD1024/taichi into impl-matrix
AD1024 Aug 9, 2022
08a337c
format
AD1024 Aug 9, 2022
b493e12
remove log
AD1024 Aug 9, 2022
d3daeef
oopsss
AD1024 Aug 9, 2022
a868388
fix cfg for matrix
AD1024 Aug 9, 2022
feef9e8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 9, 2022
1e6511d
fix error for windows
AD1024 Aug 10, 2022
a32fa06
Merge branch 'impl-matrix' of github.com:AD1024/taichi into impl-matrix
AD1024 Aug 10, 2022
b94e0a5
format
AD1024 Aug 10, 2022
8490a62
wat
AD1024 Aug 10, 2022
4f1325d
format yet again
AD1024 Aug 10, 2022
acb9fd2
fix struct_for loop
AD1024 Aug 10, 2022
bdfb179
format yet yet again
AD1024 Aug 10, 2022
1b7315d
try fix loop
AD1024 Aug 10, 2022
3dddae8
add one more pick
AD1024 Aug 10, 2022
8699512
fix some cases
AD1024 Aug 11, 2022
c513682
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.field import Field
from taichi.lang.matrix import (Matrix, MatrixType, _PyScopeMatrixImpl,
from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl,
_TiScopeMatrixImpl)
from taichi.lang.snode import append
from taichi.lang.util import in_taichi_scope, is_taichi_class, to_taichi_type
Expand Down Expand Up @@ -488,6 +488,11 @@ def build_Call(ctx, node):
node.ptr = impl.ti_format(*args, **keywords)
return node.ptr

if isinstance(node.func,
ast.Attribute) and func == Matrix or func == Vector:
node.ptr = matrix.make_matrix(*args, **keywords)
return node.ptr

if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
return node.ptr

Expand Down
9 changes: 9 additions & 0 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def __init__(self, *args, tb=None, dtype=None):
self.ptr.set_tb(self.tb)
self.ptr.type_check(impl.get_runtime().prog.config)

def __getitem__(self, *indices):
if not isinstance(indices, (list, tuple)):
indices = (indices, )

indices = make_expr_group(*indices)
return Expr(
impl.get_runtime().prog.current_ast_builder().expr_indexed_matrix(
self.ptr, indices))

def __hash__(self):
return self.ptr.get_raw_address()

Expand Down
8 changes: 8 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def expr_init_local_tensor(shape, element_type, elements):
shape, element_type, elements)


@taichi_scope
def expr_init_matrix(shape, element_type, elements):
return get_runtime().prog.current_ast_builder().expr_alloca_matrix(
shape, element_type, elements)


@taichi_scope
def expr_init_shared_array(shape, element_type):
return get_runtime().prog.current_ast_builder().expr_alloca_shared_array(
Expand All @@ -48,6 +54,8 @@ def expr_init(rhs):
if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
return type(rhs)(*rhs.to_list())
if isinstance(rhs, Matrix):
if current_cfg().real_matrix:
return rhs
return Matrix(rhs.to_list())
if isinstance(rhs, SharedArray):
return rhs
Expand Down
16 changes: 16 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,22 @@ def prop_setter(instance, value):
return cls


def make_matrix(arr, dt=None, suppress_warning=False, is_ref=False, **kwargs):
if not impl.current_cfg().real_matrix or in_python_scope():
return Matrix(arr, dt, suppress_warning, is_ref, **kwargs)
cast = (lambda x: ops_mod.cast(x, dt)) if dt else (
lambda x: x if isinstance(x, expr.Expr) else expr.Expr(x))
if len(arr) == 0:
return impl.expr_init(impl.expr_init_matrix([0], dt, []))
if not isinstance(arr[0], Iterable):
return impl.expr_init(
impl.expr_init_matrix([len(arr)], dt,
[cast(elt).ptr for elt in arr]))
return impl.expr_init(
impl.expr_init_matrix([len(arr), len(arr[0])], dt,
[cast(elt).ptr for row in arr for elt in row]))


class _MatrixBaseImpl:
def __init__(self, m, n, entries):
self.m = m
Expand Down
6 changes: 5 additions & 1 deletion taichi/analysis/data_source_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ std::vector<Stmt *> get_load_pointers(Stmt *load_stmt) {
return external_func->arg_stmts;
} else if (auto ref = load_stmt->cast<ReferenceStmt>()) {
return {ref->var};
} else if (auto matrix_init = load_stmt->cast<MatrixInitStmt>()) {
return matrix_init->values;
} else if (auto ptr_offset = load_stmt->cast<PtrOffsetStmt>()) {
return {ptr_offset->origin};
} else {
return std::vector<Stmt *>();
}
Expand All @@ -59,7 +63,7 @@ Stmt *get_store_data(Stmt *store_stmt) {

std::vector<Stmt *> get_store_destination(Stmt *store_stmt) {
// If store_stmt provides some data sources, return the pointers of the data.
if (store_stmt->is<AllocaStmt>() && !store_stmt->ret_type->is<TensorType>()) {
if (store_stmt->is<AllocaStmt>()) {
// The statement itself provides a data source (const [0]).
return std::vector<Stmt *>(1, store_stmt);
} else if (auto local_store = store_stmt->cast<LocalStoreStmt>()) {
Expand Down
8 changes: 8 additions & 0 deletions taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,14 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
emit(expr->dual);
}

void visit(MatrixExpression *expr) override {
emit(ExprOpCode::MatrixExpression);
emit(expr->dt);
for (auto elt : expr->elements) {
emit(elt);
}
}

void visit(IndexExpression *expr) override {
emit(ExprOpCode::IndexExpression);
emit(expr->var);
Expand Down
1 change: 1 addition & 0 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
serializer(config->demote_no_access_mesh_fors);
serializer(config->experimental_auto_mesh_local);
serializer(config->auto_mesh_local_default_occupacy);
serializer(config->real_matrix);
serializer.finalize();

return serializer.data;
Expand Down
18 changes: 18 additions & 0 deletions taichi/analysis/same_statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,24 @@ class IRNodeComparator : public IRVisitor {
basic_check(stmt);
}

void visit(MatrixInitStmt *stmt) override {
basic_check(stmt);
if (!same)
return;
auto o = other_node_->as<MatrixInitStmt>();
if (stmt->values.size() != o->values.size()) {
same = false;
return;
}
for (int i = 0; i < stmt->values.size(); ++i) {
other_node_ = o->values[i];
stmt->values[i]->accept(this);
other_node_ = o;
if (!same)
return;
}
}

void visit(IfStmt *stmt) override {
basic_check(stmt);
if (!same)
Expand Down
Loading