Skip to content

Commit 170a7e5

Browse files
committed
[Lang] [ir] Refactor indexing expressions in AST & enforce integer indices
1 parent e10a11b commit 170a7e5

12 files changed

+229
-209
lines changed

python/taichi/lang/impl.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,16 @@ def subscript(value, *_indices, skip_reordered=False):
197197

198198

199199
@taichi_scope
200-
def make_tensor_element_expr(_var, _indices, shape, stride):
200+
def make_stride_expr(_var, _indices, shape, stride):
201201
return Expr(
202-
_ti_core.make_tensor_element_expr(_var, make_expr_group(*_indices),
203-
shape, stride))
202+
_ti_core.make_stride_expr(_var, make_expr_group(*_indices),
203+
shape, stride))
204+
205+
206+
@taichi_scope
207+
def make_index_expr(_var, _indices):
208+
return Expr(
209+
_ti_core.make_index_expr(_var, make_expr_group(*_indices)))
204210

205211

206212
class SrcInfoGuard:
@@ -433,6 +439,7 @@ def deactivate_all_snodes():
433439

434440
class _Root:
435441
"""Wrapper around the default root FieldsBuilder instance."""
442+
436443
@staticmethod
437444
def parent(n=1):
438445
"""Same as :func:`taichi.SNode.parent`"""
@@ -539,10 +546,10 @@ def field(dtype, shape=None, name="", offset=None, needs_grad=False):
539546
"""
540547

541548
if isinstance(shape, numbers.Number):
542-
shape = (shape, )
549+
shape = (shape,)
543550

544551
if isinstance(offset, numbers.Number):
545-
offset = (offset, )
552+
offset = (offset,)
546553

547554
if shape is not None and offset is not None:
548555
assert len(shape) == len(
@@ -583,7 +590,7 @@ def ndarray(dtype, shape, layout=Layout.NULL):
583590
>>> z = ti.ndarray(matrix_ty, shape=(4, 5), layout=ti.Layout.SOA) # ndarray of shape (4, 5), each element is a matrix of (3, 4) ti.float scalars.
584591
"""
585592
if isinstance(shape, numbers.Number):
586-
shape = (shape, )
593+
shape = (shape,)
587594
if dtype in types:
588595
assert layout == Layout.NULL
589596
return ScalarNdarray(dtype, shape)

python/taichi/lang/matrix.py

+55-58
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def check(instance, pattern):
4141

4242
for key_group in KEYGROUP_SET:
4343
for index, attr in enumerate(key_group):
44-
4544
def gen_property(attr, attr_idx, key_group):
4645
checker = make_valid_attribs_checker(key_group)
4746

@@ -122,7 +121,7 @@ def _linearize_entry_id(self, *args):
122121
if len(args) == 1 and isinstance(args[0], (list, tuple)):
123122
args = args[0]
124123
if len(args) == 1:
125-
args = args + (0, )
124+
args = args + (0,)
126125
# TODO(#1004): See if it's possible to support indexing at runtime
127126
for i, a in enumerate(args):
128127
if not isinstance(a, int):
@@ -244,19 +243,16 @@ def _subscript(self, is_global_mat, *indices):
244243
if self.any_array_access:
245244
return self.any_array_access.subscript(i, j)
246245
if self.local_tensor_proxy is not None:
247-
assert self.dynamic_index_stride is not None
248246
if len(indices) == 1:
249-
return impl.make_tensor_element_expr(self.local_tensor_proxy,
250-
(i, ), (self.n, ),
251-
self.dynamic_index_stride)
252-
return impl.make_tensor_element_expr(self.local_tensor_proxy,
253-
(i, j), (self.n, self.m),
254-
self.dynamic_index_stride)
247+
return impl.make_index_expr(self.local_tensor_proxy,
248+
(i,))
249+
return impl.make_index_expr(self.local_tensor_proxy,
250+
(i, j))
255251
if impl.current_cfg(
256252
).dynamic_index and is_global_mat and self.dynamic_index_stride:
257-
return impl.make_tensor_element_expr(self.entries[0].ptr, (i, j),
258-
(self.n, self.m),
259-
self.dynamic_index_stride)
253+
return impl.make_stride_expr(self.entries[0].ptr, (i, j),
254+
(self.n, self.m),
255+
self.dynamic_index_stride)
260256
return self._get_entry(i, j)
261257

262258
def _calc_slice(self, index, dim):
@@ -318,17 +314,15 @@ def with_dynamic_index(self, arr, dt):
318314
local_tensor_proxy = impl.expr_init_local_tensor(
319315
[len(arr)], dt,
320316
expr.make_expr_group([expr.Expr(x) for x in arr]))
321-
dynamic_index_stride = 1
322317
mat = []
323318
for i in range(len(arr)):
324319
mat.append(
325320
list([
326-
impl.make_tensor_element_expr(
321+
impl.make_index_expr(
327322
local_tensor_proxy,
328-
(expr.Expr(i, dtype=primitive_types.i32), ),
329-
(len(arr), ), dynamic_index_stride)
323+
(expr.Expr(i, dtype=primitive_types.i32),))
330324
]))
331-
return local_tensor_proxy, dynamic_index_stride, mat
325+
return local_tensor_proxy, mat
332326

333327
def _get_entry_to_infer(self, arr):
334328
return arr[0]
@@ -348,18 +342,16 @@ def with_dynamic_index(self, arr, dt):
348342
expr.make_expr_group(
349343
[expr.Expr(x) for row in arr for x in row]))
350344

351-
dynamic_index_stride = 1
352345
mat = []
353346
for i in range(len(arr)):
354347
mat.append([])
355348
for j in range(len(arr[0])):
356349
mat[i].append(
357-
impl.make_tensor_element_expr(
350+
impl.make_index_expr(
358351
local_tensor_proxy,
359352
(expr.Expr(i, dtype=primitive_types.i32),
360-
expr.Expr(j, dtype=primitive_types.i32)),
361-
(len(arr), len(arr[0])), dynamic_index_stride))
362-
return local_tensor_proxy, dynamic_index_stride, mat
353+
expr.Expr(j, dtype=primitive_types.i32))))
354+
return local_tensor_proxy, mat
363355

364356
def _get_entry_to_infer(self, arr):
365357
return arr[0][0]
@@ -413,7 +405,6 @@ class Matrix(TaichiOperations):
413405

414406
def __init__(self, arr, dt=None, suppress_warning=False, is_ref=False):
415407
local_tensor_proxy = None
416-
dynamic_index_stride = None
417408

418409
if not isinstance(arr, (list, tuple, np.ndarray)):
419410
raise TaichiTypeError(
@@ -440,7 +431,7 @@ def __init__(self, arr, dt=None, suppress_warning=False, is_ref=False):
440431
)
441432
if dt is None:
442433
dt = initializer.infer_dt(arr)
443-
local_tensor_proxy, dynamic_index_stride, mat = initializer.with_dynamic_index(
434+
local_tensor_proxy, mat = initializer.with_dynamic_index(
444435
arr, dt)
445436

446437
self.n, self.m = len(mat), 1
@@ -464,7 +455,7 @@ def __init__(self, arr, dt=None, suppress_warning=False, is_ref=False):
464455
self._impl = _PyScopeMatrixImpl(m, n, entries)
465456
else:
466457
self._impl = _TiScopeMatrixImpl(m, n, entries, local_tensor_proxy,
467-
dynamic_index_stride)
458+
None)
468459

469460
def _element_wise_binary(self, foo, other):
470461
other = self._broadcast_copy(other)
@@ -680,8 +671,8 @@ def E(x, y):
680671
for i in range(n):
681672
for j in range(n):
682673
entries[j][i] = inv_determinant * (
683-
E(i + 1, j + 1) * E(i + 2, j + 2) -
684-
E(i + 2, j + 1) * E(i + 1, j + 2))
674+
E(i + 1, j + 1) * E(i + 2, j + 2) -
675+
E(i + 2, j + 1) * E(i + 1, j + 2))
685676
return Matrix(entries)
686677
if self.n == 4:
687678
n = 4
@@ -693,14 +684,14 @@ def E(x, y):
693684

694685
for i in range(n):
695686
for j in range(n):
696-
entries[j][i] = inv_determinant * (-1)**(i + j) * ((
697-
E(i + 1, j + 1) *
698-
(E(i + 2, j + 2) * E(i + 3, j + 3) -
699-
E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) *
700-
(E(i + 1, j + 2) * E(i + 3, j + 3) -
701-
E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) *
702-
(E(i + 1, j + 2) * E(i + 2, j + 3) -
703-
E(i + 2, j + 2) * E(i + 1, j + 3))))
687+
entries[j][i] = inv_determinant * (-1) ** (i + j) * ((
688+
E(i + 1, j + 1) *
689+
(E(i + 2, j + 2) * E(i + 3, j + 3) -
690+
E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) *
691+
(E(i + 1, j + 2) * E(i + 3, j + 3) -
692+
E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) *
693+
(E(i + 1, j + 2) * E(i + 2, j + 3) -
694+
E(i + 2, j + 2) * E(i + 1, j + 3))))
704695
return Matrix(entries)
705696
raise Exception(
706697
"Inversions of matrices with sizes >= 5 are not supported")
@@ -760,7 +751,7 @@ def determinant(a):
760751
if a.n == 3 and a.m == 3:
761752
return a(0, 0) * (a(1, 1) * a(2, 2) - a(2, 1) * a(1, 2)) - a(
762753
1, 0) * (a(0, 1) * a(2, 2) - a(2, 1) * a(0, 2)) + a(
763-
2, 0) * (a(0, 1) * a(1, 2) - a(1, 1) * a(0, 2))
754+
2, 0) * (a(0, 1) * a(1, 2) - a(1, 1) * a(0, 2))
764755
if a.n == 4 and a.m == 4:
765756
n = 4
766757

@@ -769,14 +760,14 @@ def E(x, y):
769760

770761
det = impl.expr_init(0.0)
771762
for i in range(4):
772-
det = det + (-1.0)**i * (
773-
a(i, 0) *
774-
(E(i + 1, 1) *
775-
(E(i + 2, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 2, 3)) -
776-
E(i + 2, 1) *
777-
(E(i + 1, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 1, 3)) +
778-
E(i + 3, 1) *
779-
(E(i + 1, 2) * E(i + 2, 3) - E(i + 2, 2) * E(i + 1, 3))))
763+
det = det + (-1.0) ** i * (
764+
a(i, 0) *
765+
(E(i + 1, 1) *
766+
(E(i + 2, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 2, 3)) -
767+
E(i + 2, 1) *
768+
(E(i + 1, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 1, 3)) +
769+
E(i + 3, 1) *
770+
(E(i + 1, 2) * E(i + 2, 3) - E(i + 2, 2) * E(i + 1, 3))))
780771
return det
781772
raise Exception(
782773
"Determinants of matrices with sizes >= 5 are not supported")
@@ -908,6 +899,7 @@ def fill(self, val):
908899
>>> A
909900
[-1, -1, -1, -1]
910901
"""
902+
911903
def assign_renamed(x, y):
912904
return ops_mod.assign(x, y)
913905

@@ -933,7 +925,7 @@ def to_numpy(self, keep_dims=False):
933925
array([0, 1, 2, 3])
934926
"""
935927
as_vector = self.m == 1 and not keep_dims
936-
shape_ext = (self.n, ) if as_vector else (self.n, self.m)
928+
shape_ext = (self.n,) if as_vector else (self.n, self.m)
937929
return np.array(self.to_list()).reshape(shape_ext)
938930

939931
@taichi_scope
@@ -1128,9 +1120,9 @@ def field(cls,
11281120

11291121
if shape is not None:
11301122
if isinstance(shape, numbers.Number):
1131-
shape = (shape, )
1123+
shape = (shape,)
11321124
if isinstance(offset, numbers.Number):
1133-
offset = (offset, )
1125+
offset = (offset,)
11341126

11351127
if offset is not None:
11361128
assert len(shape) == len(
@@ -1182,7 +1174,7 @@ def ndarray(cls, n, m, dtype, shape, layout=Layout.AOS):
11821174
>>> x = ti.Matrix.ndarray(4, 5, ti.f32, shape=(16, 8))
11831175
"""
11841176
if isinstance(shape, numbers.Number):
1185-
shape = (shape, )
1177+
shape = (shape,)
11861178
return MatrixNdarray(n, m, dtype, shape, layout)
11871179

11881180
@classmethod
@@ -1202,7 +1194,7 @@ def _Vector_ndarray(cls, n, dtype, shape, layout=Layout.AOS):
12021194
>>> x = ti.Vector.ndarray(3, ti.f32, shape=(16, 8))
12031195
"""
12041196
if isinstance(shape, numbers.Number):
1205-
shape = (shape, )
1197+
shape = (shape,)
12061198
return VectorNdarray(n, dtype, shape, layout)
12071199

12081200
@staticmethod
@@ -1392,6 +1384,7 @@ class _IntermediateMatrix(Matrix):
13921384
m (int): Number of columns of the matrix.
13931385
entries (List[Expr]): All entries of the matrix.
13941386
"""
1387+
13951388
def __init__(self, n, m, entries):
13961389
assert isinstance(entries, list)
13971390
assert n * m == len(entries), "Number of entries doesn't match n * m"
@@ -1411,6 +1404,7 @@ class _MatrixFieldElement(_IntermediateMatrix):
14111404
field (MatrixField): The matrix field.
14121405
indices (taichi_core.ExprGroup): Indices of the element.
14131406
"""
1407+
14141408
def __init__(self, field, indices):
14151409
super().__init__(field.n, field.m, [
14161410
expr.Expr(ti_core.subscript(e.ptr, indices))
@@ -1427,6 +1421,7 @@ class MatrixField(Field):
14271421
n (Int): Number of rows.
14281422
m (Int): Number of columns.
14291423
"""
1424+
14301425
def __init__(self, _vars, n, m):
14311426
assert len(_vars) == n * m
14321427
super().__init__(_vars)
@@ -1472,7 +1467,7 @@ def _calc_dynamic_index_stride(self):
14721467
i + 1]._offset_bytes_in_parent_cell for path in paths):
14731468
return
14741469
stride = paths[1][depth_below_lca]._offset_bytes_in_parent_cell - \
1475-
paths[0][depth_below_lca]._offset_bytes_in_parent_cell
1470+
paths[0][depth_below_lca]._offset_bytes_in_parent_cell
14761471
for i in range(2, num_members):
14771472
if stride != paths[i][depth_below_lca]._offset_bytes_in_parent_cell \
14781473
- paths[i - 1][depth_below_lca]._offset_bytes_in_parent_cell:
@@ -1493,7 +1488,7 @@ def fill(self, val):
14931488
elif isinstance(val,
14941489
(list, tuple)) and isinstance(val[0], numbers.Number):
14951490
assert self.m == 1
1496-
val = tuple([(v, ) for v in val])
1491+
val = tuple([(v,) for v in val])
14971492
elif isinstance(val, Matrix):
14981493
val_tuple = []
14991494
for i in range(val.n):
@@ -1525,7 +1520,7 @@ def to_numpy(self, keep_dims=False, dtype=None):
15251520
if dtype is None:
15261521
dtype = to_numpy_type(self.dtype)
15271522
as_vector = self.m == 1 and not keep_dims
1528-
shape_ext = (self.n, ) if as_vector else (self.n, self.m)
1523+
shape_ext = (self.n,) if as_vector else (self.n, self.m)
15291524
arr = np.zeros(self.shape + shape_ext, dtype=dtype)
15301525
from taichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
15311526
matrix_to_ext_arr(self, arr, as_vector)
@@ -1545,7 +1540,7 @@ def to_torch(self, device=None, keep_dims=False):
15451540
"""
15461541
import torch # pylint: disable=C0415
15471542
as_vector = self.m == 1 and not keep_dims
1548-
shape_ext = (self.n, ) if as_vector else (self.n, self.m)
1543+
shape_ext = (self.n,) if as_vector else (self.n, self.m)
15491544
# pylint: disable=E1101
15501545
arr = torch.empty(self.shape + shape_ext,
15511546
dtype=to_pytorch_type(self.dtype),
@@ -1568,7 +1563,7 @@ def to_paddle(self, place=None, keep_dims=False):
15681563
"""
15691564
import paddle # pylint: disable=C0415
15701565
as_vector = self.m == 1 and not keep_dims
1571-
shape_ext = (self.n, ) if as_vector else (self.n, self.m)
1566+
shape_ext = (self.n,) if as_vector else (self.n, self.m)
15721567
# pylint: disable=E1101
15731568
# paddle.empty() doesn't support argument `place``
15741569
arr = paddle.to_tensor(paddle.empty(self.shape + shape_ext,
@@ -1688,6 +1683,7 @@ class MatrixNdarray(Ndarray):
16881683
16891684
>>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3), layout=Layout.SOA)
16901685
"""
1686+
16911687
def __init__(self, n, m, dtype, shape, layout):
16921688
self.n = n
16931689
self.m = m
@@ -1725,7 +1721,7 @@ def __setitem__(self, key, value):
17251721
@python_scope
17261722
def __getitem__(self, key):
17271723
key = () if key is None else (
1728-
key, ) if isinstance(key, numbers.Number) else tuple(key)
1724+
key,) if isinstance(key, numbers.Number) else tuple(key)
17291725
return Matrix(
17301726
[[NdarrayHostAccess(self, key, (i, j)) for j in range(self.m)]
17311727
for i in range(self.n)])
@@ -1786,13 +1782,14 @@ class VectorNdarray(Ndarray):
17861782
17871783
>>> a = ti.VectorNdarray(3, ti.f32, (3, 3), layout=Layout.SOA)
17881784
"""
1785+
17891786
def __init__(self, n, dtype, shape, layout):
17901787
self.n = n
17911788
super().__init__()
17921789
self.dtype = cook_dtype(dtype)
17931790
self.layout = layout
17941791
self.shape = tuple(shape)
1795-
self.element_type = TensorType((n, ), self.dtype)
1792+
self.element_type = TensorType((n,), self.dtype)
17961793
# TODO: pass in element_type, shape, layout directly
17971794
self.arr = impl.get_runtime().prog.create_ndarray(
17981795
self.element_type.dtype, shape, self.element_type.shape, layout)
@@ -1819,9 +1816,9 @@ def __setitem__(self, key, value):
18191816
@python_scope
18201817
def __getitem__(self, key):
18211818
key = () if key is None else (
1822-
key, ) if isinstance(key, numbers.Number) else tuple(key)
1819+
key,) if isinstance(key, numbers.Number) else tuple(key)
18231820
return Vector(
1824-
[NdarrayHostAccess(self, key, (i, )) for i in range(self.n)])
1821+
[NdarrayHostAccess(self, key, (i,)) for i in range(self.n)])
18251822

18261823
@python_scope
18271824
def to_numpy(self):

0 commit comments

Comments
 (0)