Skip to content

Commit 09228b2

Browse files
strongoierpre-commit-ci[bot]
authored andcommitted
[Lang] Remove the real_matrix switch (taichi-dev#6885)
Issue: taichi-dev#5819 ### Brief Summary We no longer need the switch after taichi-dev#6801. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a1c73cb commit 09228b2

21 files changed

+77
-191
lines changed

python/taichi/_funcs.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,7 @@ def _svd3d(A, dt, iters=None):
162162
Decomposed 3x3 matrices `U`, 'S' and `V`.
163163
"""
164164
assert A.n == 3 and A.m == 3
165-
if impl.current_cfg().real_matrix:
166-
inputs = get_runtime().prog.current_ast_builder().expand_expr([A.ptr])
167-
else:
168-
inputs = tuple([e.ptr for e in A.entries])
165+
inputs = get_runtime().prog.current_ast_builder().expand_expr([A.ptr])
169166
assert dt in [f32, f64]
170167
if iters is None:
171168
if dt == f32:

python/taichi/lang/ast/ast_transformer.py

+42-73
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
TaichiTypeError)
1919
from taichi.lang.expr import Expr, make_expr_group
2020
from taichi.lang.field import Field
21-
from taichi.lang.impl import current_cfg
2221
from taichi.lang.matrix import Matrix, MatrixType, Vector, is_vector
2322
from taichi.lang.snode import append, deactivate, length
2423
from taichi.lang.struct import Struct, StructType
@@ -300,8 +299,7 @@ def process_generators(ctx, node, now_comp, func, result):
300299
with ctx.static_scope_guard():
301300
_iter = build_stmt(ctx, node.generators[now_comp].iter)
302301

303-
if impl.current_cfg().real_matrix and isinstance(
304-
_iter, impl.Expr) and _iter.ptr.is_tensor():
302+
if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor():
305303
shape = _iter.ptr.get_shape()
306304
flattened = [
307305
Expr(x) for x in ctx.ast_builder.expand_expr([_iter.ptr])
@@ -505,8 +503,7 @@ def build_Call(ctx, node):
505503
for arg in node.args:
506504
if isinstance(arg, ast.Starred):
507505
arg_list = arg.ptr
508-
if impl.current_cfg().real_matrix and isinstance(
509-
arg_list, Expr):
506+
if isinstance(arg_list, Expr) and arg_list.is_tensor():
510507
# Expand Expr with Matrix-type return into list of Exprs
511508
arg_list = [
512509
Expr(x)
@@ -529,8 +526,7 @@ def build_Call(ctx, node):
529526
node.ptr = impl.ti_format(*args, **keywords)
530527
return node.ptr
531528

532-
if ((id(func) == id(Matrix)
533-
or id(func) == id(Vector))) and impl.current_cfg().real_matrix:
529+
if id(func) == id(Matrix) or id(func) == id(Vector):
534530
node.ptr = matrix.make_matrix(*args, **keywords)
535531
return node.ptr
536532

@@ -654,57 +650,40 @@ def transform_as_kernel():
654650
if isinstance(ctx.func.arguments[i].annotation,
655651
(MatrixType)):
656652

657-
if current_cfg().real_matrix:
658-
# with real_matrix=True, "data" is expected to be an Expr here
659-
# Therefore we simply call "impl.expr_init_func(data)" to perform:
660-
#
661-
# TensorType* t = alloca()
662-
# assign(t, data)
663-
#
664-
# We created local variable "t" - a copy of the passed-in argument "data"
665-
if not isinstance(
666-
data,
667-
expr.Expr) or not data.ptr.is_tensor():
668-
raise TaichiSyntaxError(
669-
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix, but got {type(data)}."
670-
)
671-
672-
element_shape = data.ptr.get_ret_type().shape()
673-
if len(element_shape
674-
) != ctx.func.arguments[i].annotation.ndim:
675-
raise TaichiSyntaxError(
676-
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with ndim {ctx.func.arguments[i].annotation.ndim}, but got {len(element_shape)}."
677-
)
678-
679-
assert ctx.func.arguments[i].annotation.ndim > 0
680-
if element_shape[0] != ctx.func.arguments[
681-
i].annotation.n:
682-
raise TaichiSyntaxError(
683-
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with n {ctx.func.arguments[i].annotation.n}, but got {element_shape[0]}."
684-
)
685-
686-
if ctx.func.arguments[
687-
i].annotation.ndim == 2 and element_shape[
688-
1] != ctx.func.arguments[
689-
i].annotation.m:
690-
raise TaichiSyntaxError(
691-
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with m {ctx.func.arguments[i].annotation.m}, but got {element_shape[0]}."
692-
)
693-
else:
694-
if not isinstance(data, Matrix):
695-
raise TaichiSyntaxError(
696-
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix, but got {type(data)}."
697-
)
698-
699-
if data.m != ctx.func.arguments[i].annotation.m:
700-
raise TaichiSyntaxError(
701-
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with m {ctx.func.arguments[i].annotation.m}, but got {data.m}."
702-
)
703-
704-
if data.n != ctx.func.arguments[i].annotation.n:
705-
raise TaichiSyntaxError(
706-
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with n {ctx.func.arguments[i].annotation.n}, but got {data.n}."
707-
)
653+
# "data" is expected to be an Expr here,
654+
# so we simply call "impl.expr_init_func(data)" to perform:
655+
#
656+
# TensorType* t = alloca()
657+
# assign(t, data)
658+
#
659+
# We created local variable "t" - a copy of the passed-in argument "data"
660+
if not isinstance(
661+
data, expr.Expr) or not data.ptr.is_tensor():
662+
raise TaichiSyntaxError(
663+
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix, but got {type(data)}."
664+
)
665+
666+
element_shape = data.ptr.get_ret_type().shape()
667+
if len(element_shape
668+
) != ctx.func.arguments[i].annotation.ndim:
669+
raise TaichiSyntaxError(
670+
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with ndim {ctx.func.arguments[i].annotation.ndim}, but got {len(element_shape)}."
671+
)
672+
673+
assert ctx.func.arguments[i].annotation.ndim > 0
674+
if element_shape[0] != ctx.func.arguments[
675+
i].annotation.n:
676+
raise TaichiSyntaxError(
677+
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with n {ctx.func.arguments[i].annotation.n}, but got {element_shape[0]}."
678+
)
679+
680+
if ctx.func.arguments[
681+
i].annotation.ndim == 2 and element_shape[
682+
1] != ctx.func.arguments[i].annotation.m:
683+
raise TaichiSyntaxError(
684+
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with m {ctx.func.arguments[i].annotation.m}, but got {element_shape[0]}."
685+
)
686+
708687
ctx.create_variable(arg.arg, impl.expr_init_func(data))
709688
continue
710689

@@ -1189,12 +1168,8 @@ def build_grouped_ndrange_for(ctx, node):
11891168
f"Group for should have 1 loop target, found {len(targets)}"
11901169
)
11911170
target = targets[0]
1192-
if current_cfg().real_matrix:
1193-
mat = matrix.make_matrix([0] * len(ndrange_var.dimensions),
1194-
dt=primitive_types.i32)
1195-
else:
1196-
mat = matrix.Vector([0] * len(ndrange_var.dimensions),
1197-
dt=primitive_types.i32)
1171+
mat = matrix.make_matrix([0] * len(ndrange_var.dimensions),
1172+
dt=primitive_types.i32)
11981173
target_var = impl.expr_init(mat)
11991174

12001175
ctx.create_variable(target, target_var)
@@ -1236,15 +1211,9 @@ def build_struct_for(ctx, node, is_grouped):
12361211
expr_group = expr.make_expr_group(loop_indices)
12371212
impl.begin_frontend_struct_for(ctx.ast_builder, expr_group,
12381213
loop_var)
1239-
if impl.current_cfg().real_matrix:
1240-
ctx.create_variable(
1241-
target,
1242-
matrix.make_matrix(loop_indices,
1243-
dt=primitive_types.i32))
1244-
else:
1245-
ctx.create_variable(
1246-
target,
1247-
matrix.Vector(loop_indices, dt=primitive_types.i32))
1214+
ctx.create_variable(
1215+
target,
1216+
matrix.make_matrix(loop_indices, dt=primitive_types.i32))
12481217
build_stmts(ctx, node.body)
12491218
ctx.ast_builder.end_frontend_struct_for()
12501219
else:

python/taichi/lang/ast/ast_transformer_utils.py

-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from taichi.lang.exception import (TaichiCompilationError, TaichiNameError,
1111
TaichiSyntaxError,
1212
handle_exception_from_cpp)
13-
from taichi.lang.matrix import Matrix
1413

1514

1615
class Builder:
@@ -246,9 +245,6 @@ def get_var_by_name(self, name):
246245
if name in s:
247246
return s[name]
248247
if name in self.global_vars:
249-
if isinstance(self.global_vars[name],
250-
Matrix) and impl.current_cfg().real_matrix:
251-
return impl.expr_init(self.global_vars[name])
252248
return self.global_vars[name]
253249
try:
254250
return getattr(builtins, name)

python/taichi/lang/expr.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def __init__(self, *args, tb=None, dtype=None):
2525
'Cannot initialize scalar expression from '
2626
f'taichi class: {type(args[0])}')
2727
elif isinstance(args[0], (list, tuple)):
28-
assert impl.current_cfg().real_matrix
2928
self.ptr = make_matrix(args[0]).ptr
3029
else:
3130
# assume to be constant
@@ -173,8 +172,7 @@ def _get_flattened_ptrs(val):
173172
for item in val._members:
174173
ptrs.extend(_get_flattened_ptrs(item))
175174
return ptrs
176-
if impl.current_cfg().real_matrix and isinstance(
177-
val, Expr) and val.ptr.is_tensor():
175+
if isinstance(val, Expr) and val.ptr.is_tensor():
178176
return impl.get_runtime().prog.current_ast_builder().expand_expr(
179177
[val.ptr])
180178
return [Expr(val).ptr]

python/taichi/lang/impl.py

+13-36
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from taichi.lang._ndarray import ScalarNdarray
99
from taichi.lang._ndrange import GroupedNDRange, _Ndrange
1010
from taichi.lang._texture import RWTextureAccessor
11-
from taichi.lang.any_array import AnyArray, AnyArrayAccess
11+
from taichi.lang.any_array import AnyArray
1212
from taichi.lang.enums import SNodeGradType
1313
from taichi.lang.exception import (TaichiCompilationError, TaichiIndexError,
1414
TaichiRuntimeError, TaichiSyntaxError,
@@ -17,8 +17,7 @@
1717
from taichi.lang.field import Field, ScalarField
1818
from taichi.lang.kernel_arguments import SparseMatrixProxy
1919
from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType,
20-
Vector, VectorNdarray, _IntermediateMatrix,
21-
_MatrixFieldElement, make_matrix)
20+
VectorNdarray, make_matrix)
2221
from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance,
2322
MeshRelationAccessProxy,
2423
MeshReorderedMatrixFieldProxy,
@@ -58,18 +57,11 @@ def expr_init(rhs):
5857
if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
5958
return Matrix(*rhs.to_list(), ndim=rhs.ndim)
6059
if isinstance(rhs, Matrix):
61-
if current_cfg().real_matrix:
62-
if rhs.ndim == 1:
63-
entries = [rhs(i) for i in range(rhs.n)]
64-
else:
65-
entries = [[rhs(i, j) for j in range(rhs.m)]
66-
for i in range(rhs.n)]
67-
return make_matrix(entries)
68-
if (isinstance(rhs, Vector)
69-
or getattr(rhs, "ndim", None) == 1) and rhs.m == 1:
70-
# _IntermediateMatrix may reach here
71-
return Vector(rhs.to_list(), ndim=rhs.ndim)
72-
return Matrix(rhs.to_list(), ndim=rhs.ndim)
60+
if rhs.ndim == 1:
61+
entries = [rhs(i) for i in range(rhs.n)]
62+
else:
63+
entries = [[rhs(i, j) for j in range(rhs.m)] for i in range(rhs.n)]
64+
return make_matrix(entries)
7365
if isinstance(rhs, SharedArray):
7466
return rhs
7567
if isinstance(rhs, Struct):
@@ -230,11 +222,9 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
230222
f'Field with dim {field_dim} accessed with indices of dim {index_dim}'
231223
)
232224
if isinstance(value, MatrixField):
233-
if current_cfg().real_matrix:
234-
return Expr(
235-
_ti_core.subscript(value.ptr, indices_expr_group,
236-
get_runtime().get_current_src_info()))
237-
return _MatrixFieldElement(value, indices_expr_group)
225+
return Expr(
226+
_ti_core.subscript(value.ptr, indices_expr_group,
227+
get_runtime().get_current_src_info()))
238228
if isinstance(value, StructField):
239229
entries = {
240230
k: subscript(ast_builder, v, *indices)
@@ -252,25 +242,12 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
252242
raise IndexError(
253243
f'Field with dim {dim - element_dim} accessed with indices of dim {index_dim}'
254244
)
255-
if element_dim == 0 or current_cfg().real_matrix:
256-
return Expr(
257-
_ti_core.subscript(value.ptr, indices_expr_group,
258-
get_runtime().get_current_src_info()))
259-
n = value.element_shape()[0]
260-
m = 1 if element_dim == 1 else value.element_shape()[1]
261-
any_array_access = AnyArrayAccess(value, indices)
262-
ret = _IntermediateMatrix(n,
263-
m, [
264-
any_array_access.subscript(i, j)
265-
for i in range(n) for j in range(m)
266-
],
267-
ndim=element_dim)
268-
ret.any_array_access = any_array_access
269-
return ret
245+
return Expr(
246+
_ti_core.subscript(value.ptr, indices_expr_group,
247+
get_runtime().get_current_src_info()))
270248
if isinstance(value, Expr):
271249
# Index into TensorType
272250
# value: IndexExpression with ret_type = TensorType
273-
assert current_cfg().real_matrix
274251
assert value.is_tensor()
275252

276253
if has_slice:

python/taichi/lang/kernel_impl.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ def func_call_rvalue(self, key, args):
255255
elif isinstance(anno, primitive_types.RefType):
256256
non_template_args.append(
257257
_ti_core.make_reference(args[i].ptr))
258-
elif impl.current_cfg().real_matrix and isinstance(
259-
args[i], impl.Expr) and args[i].ptr.is_tensor():
258+
elif isinstance(args[i],
259+
impl.Expr) and args[i].ptr.is_tensor():
260260
non_template_args.extend([
261261
Expr(x) for x in impl.get_runtime().prog.
262262
current_ast_builder().expand_expr([args[i].ptr])

python/taichi/lang/matrix.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -666,9 +666,6 @@ def _subscript(self, *indices):
666666
is_global_mat = isinstance(self, _MatrixFieldElement)
667667
return self._impl._subscript(is_global_mat, *indices)
668668

669-
def _make_matrix(self):
670-
return make_matrix(self._impl.entries)
671-
672669
def to_list(self):
673670
"""Return this matrix as a 1D `list`.
674671
@@ -1797,7 +1794,7 @@ def cast(self, mat):
17971794
if isinstance(mat, impl.Expr) and mat.ptr.is_tensor():
17981795
return ops_mod.cast(mat, self.dtype)
17991796

1800-
if isinstance(mat, Matrix) and impl.current_cfg().real_matrix:
1797+
if isinstance(mat, Matrix):
18011798
arr = [[mat(i, j) for j in range(self.m)] for i in range(self.n)]
18021799
return ops_mod.cast(make_matrix(arr), self.dtype)
18031800

@@ -1898,7 +1895,7 @@ def cast(self, vec):
18981895
if isinstance(vec, impl.Expr) and vec.ptr.is_tensor():
18991896
return ops_mod.cast(vec, self.dtype)
19001897

1901-
if isinstance(vec, Matrix) and impl.current_cfg().real_matrix:
1898+
if isinstance(vec, Matrix):
19021899
arr = vec.entries
19031900
return ops_mod.cast(make_matrix(arr), self.dtype)
19041901

python/taichi/lang/ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def uniform_matrix_inputs(*args):
2121
results = []
2222
for arg in args:
2323
if has_real_matrix and is_matrix_class(arg):
24-
results.append(arg._make_matrix())
24+
results.append(impl.expr_init(arg))
2525
else:
2626
results.append(arg)
2727

taichi/analysis/offline_cache_util.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
6565
serializer(config->experimental_auto_mesh_local);
6666
serializer(config->auto_mesh_local_default_occupacy);
6767
serializer(config->dynamic_index);
68-
serializer(config->real_matrix);
6968
serializer(config->real_matrix_scalarize);
7069
serializer.finalize();
7170

taichi/codegen/codegen_utils.h

+1-5
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
namespace taichi::lang {
55

66
inline bool codegen_vector_type(CompileConfig *config) {
7-
if (config->real_matrix && !config->real_matrix_scalarize) {
8-
return true;
9-
}
10-
11-
return false;
7+
return !config->real_matrix_scalarize;
128
}
139

1410
} // namespace taichi::lang

taichi/ir/frontend_ir.cpp

+1-7
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,7 @@ void UnaryOpExpression::type_check(CompileConfig *config) {
152152
auto operand_primitive_type = operand->ret_type.get_element_type();
153153
auto ret_primitive_type = ret_type;
154154

155-
if (config->real_matrix) {
156-
TI_ASSERT(operand_primitive_type->is<PrimitiveType>());
157-
158-
} else if (!operand->ret_type->is<PrimitiveType>()) {
155+
if (!operand_primitive_type->is<PrimitiveType>()) {
159156
throw TaichiTypeError(fmt::format(
160157
"unsupported operand type(s) for '{}': '{}'", unary_op_type_name(type),
161158
operand_primitive_type->to_string()));
@@ -539,9 +536,6 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) {
539536
// The scalarization should happen after
540537
// irpass::lower_access()
541538
auto prim_dt = dt;
542-
if (!get_compile_config()->real_matrix) {
543-
prim_dt = dt.get_element_type();
544-
}
545539
auto ptr = Stmt::make<ArgLoadStmt>(arg_id, prim_dt, /*is_ptr=*/true);
546540

547541
int external_dims = dim - std::abs(element_dim);

taichi/ir/frontend_ir.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ class MatrixFieldExpression : public Expression {
553553

554554
/**
555555
* Creating a local matrix;
556-
* lowered from ti.Matrix with real_matrix=True
556+
* lowered from ti.Matrix
557557
*/
558558
class MatrixExpression : public Expression {
559559
public:

0 commit comments

Comments
 (0)