Skip to content

Commit 5704111

Browse files
author
ailzhang
committed
[refactor] Specialized Ndarray Type is (element_type, shape, layout)
ghstack-source-id: 977cd453359b8ccc09deccacc62a915abcd42734 Pull Request resolved: #5065
1 parent e5b6639 commit 5704111

File tree

9 files changed

+70
-49
lines changed

9 files changed

+70
-49
lines changed

python/taichi/aot/utils.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,33 @@ def produce_injected_args(kernel, symbolic_args=None):
3535
raise TaichiCompilationError(
3636
f'Expected Ndaray type, got {anno}')
3737
if symbolic_args is not None:
38-
anno.element_shape = tuple(symbolic_args[i].element_shape)
39-
anno.element_dim = len(anno.element_shape)
40-
anno.dtype = symbolic_args[i].dtype()
38+
element_shape = tuple(symbolic_args[i].element_shape)
39+
element_dim = len(element_shape)
40+
dtype = symbolic_args[i].dtype()
41+
else:
42+
element_shape = anno.element_shape
43+
element_dim = anno.element_dim
44+
dtype = anno.dtype
4145

42-
if anno.element_shape is None or anno.field_dim is None:
46+
if element_shape is None or anno.field_dim is None:
4347
raise TaichiCompilationError(
4448
'Please either specify both `element_shape` and `field_dim` '
4549
'in the param annotation, or provide an example '
4650
f'ndarray for param={arg.name}')
47-
if anno.element_dim == 0:
51+
if element_dim is None or element_dim == 0:
4852
injected_args.append(
49-
ScalarNdarray(anno.dtype, (2, ) * anno.field_dim))
50-
elif anno.element_dim == 1:
53+
ScalarNdarray(dtype, (2, ) * anno.field_dim))
54+
elif element_dim == 1:
5155
injected_args.append(
52-
VectorNdarray(anno.element_shape[0],
53-
dtype=anno.dtype,
56+
VectorNdarray(element_shape[0],
57+
dtype=dtype,
5458
shape=(2, ) * anno.field_dim,
5559
layout=Layout.AOS))
56-
elif anno.element_dim == 2:
60+
elif element_dim == 2:
5761
injected_args.append(
58-
MatrixNdarray(anno.element_shape[0],
59-
anno.element_shape[1],
60-
dtype=anno.dtype,
62+
MatrixNdarray(element_shape[0],
63+
element_shape[1],
64+
dtype=dtype,
6165
shape=(2, ) * anno.field_dim,
6266
layout=Layout.AOS))
6367
else:

python/taichi/lang/_ndarray.py

+9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from taichi.lang.enums import Layout
55
from taichi.lang.util import cook_dtype, python_scope, to_numpy_type
66
from taichi.types import primitive_types
7+
from taichi.types.ndarray_type import SpecializeNdarrayType
78

89

910
class Ndarray:
@@ -15,10 +16,17 @@ class Ndarray:
1516
"""
1617
def __init__(self, dtype, arr_shape):
1718
self.host_accessor = None
19+
self.layout = None
20+
self.shape = None
21+
self.element_type = None
1822
self.dtype = cook_dtype(dtype)
1923
self.arr = impl.get_runtime().prog.create_ndarray(
2024
cook_dtype(dtype), arr_shape)
2125

26+
def get_type(self):
27+
return SpecializeNdarrayType(self.element_type, self.shape,
28+
self.layout)
29+
2230
@property
2331
def element_shape(self):
2432
"""Gets ndarray element shape.
@@ -209,6 +217,7 @@ class ScalarNdarray(Ndarray):
209217
def __init__(self, dtype, arr_shape):
210218
super().__init__(dtype, arr_shape)
211219
self.shape = tuple(self.arr.shape)
220+
self.element_type = dtype
212221

213222
@property
214223
def element_shape(self):

python/taichi/lang/kernel_impl.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -337,21 +337,13 @@ def extract_arg(arg, anno):
337337
return arg
338338
if isinstance(anno, ndarray_type.NdarrayType):
339339
if isinstance(arg, taichi.lang._ndarray.ScalarNdarray):
340-
anno._check_element_dim(arg, 0)
341-
anno._check_element_shape(())
342-
anno._check_field_dim(len(arg.shape))
340+
anno.match(arg.get_type())
343341
return arg.dtype, len(arg.shape), (), Layout.AOS
344342
if isinstance(arg, taichi.lang.matrix.VectorNdarray):
345-
anno._check_element_dim(arg, 1)
346-
anno._check_element_shape((arg.n, ))
347-
anno._check_field_dim(len(arg.shape))
348-
anno._check_layout(arg)
343+
anno.match(arg.get_type())
349344
return arg.dtype, len(arg.shape) + 1, (arg.n, ), arg.layout
350345
if isinstance(arg, taichi.lang.matrix.MatrixNdarray):
351-
anno._check_element_dim(arg, 2)
352-
anno._check_element_shape((arg.n, arg.m))
353-
anno._check_field_dim(len(arg.shape))
354-
anno._check_layout(arg)
346+
anno.match(arg.get_type())
355347
return arg.dtype, len(arg.shape) + 2, (arg.n,
356348
arg.m), arg.layout
357349
# external arrays

python/taichi/lang/matrix.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
taichi_scope, to_numpy_type, to_paddle_type,
1818
to_pytorch_type, warning)
1919
from taichi.types import primitive_types
20-
from taichi.types.compound_types import CompoundType
20+
from taichi.types.compound_types import CompoundType, TensorType
2121

2222

2323
def _gen_swizzles(cls):
@@ -1688,12 +1688,14 @@ class MatrixNdarray(Ndarray):
16881688
>>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3), layout=Layout.SOA)
16891689
"""
16901690
def __init__(self, n, m, dtype, shape, layout):
1691-
self.layout = layout
1692-
self.shape = shape
16931691
self.n = n
16941692
self.m = m
1693+
# TODO: we should pass in element_type, shape, layout instead.
16951694
arr_shape = (n, m) + shape if layout == Layout.SOA else shape + (n, m)
16961695
super().__init__(dtype, arr_shape)
1696+
self.layout = layout
1697+
self.shape = shape
1698+
self.element_type = TensorType((self.n, self.m), dtype)
16971699

16981700
@property
16991701
def element_shape(self):
@@ -1783,11 +1785,13 @@ class VectorNdarray(Ndarray):
17831785
>>> a = ti.VectorNdarray(3, ti.f32, (3, 3), layout=Layout.SOA)
17841786
"""
17851787
def __init__(self, n, dtype, shape, layout):
1786-
self.layout = layout
1787-
self.shape = shape
17881788
self.n = n
1789+
# TODO: pass in element_type, shape, layout directly
17891790
arr_shape = (n, ) + shape if layout == Layout.SOA else shape + (n, )
17901791
super().__init__(dtype, arr_shape)
1792+
self.layout = layout
1793+
self.shape = shape
1794+
self.element_type = TensorType((n, ), dtype)
17911795

17921796
@property
17931797
def element_shape(self):

python/taichi/types/compound_types.py

+6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ class CompoundType:
55
pass
66

77

8+
class TensorType(CompoundType):
9+
def __init__(self, shape, dtype):
10+
self.dtype = taichi.lang.util.cook_dtype(dtype)
11+
self.shape = shape
12+
13+
814
# TODO: maybe move MatrixType, StructType here to avoid the circular import?
915
def matrix(n, m, dtype):
1016
"""Creates a matrix type with given shape and data type.

python/taichi/types/ndarray_type.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from taichi.types.primitive_types import f32
22

33

4+
class SpecializeNdarrayType:
5+
def __init__(self, element_type, shape=None, layout=None):
6+
self.element_type = element_type
7+
self.shape = shape
8+
self.layout = layout
9+
10+
411
class NdarrayType:
512
"""Type annotation for arbitrary arrays, including external arrays (numpy ndarrays and torch tensors) and Taichi ndarrays.
613
@@ -32,31 +39,31 @@ def __init__(self,
3239
self.element_shape = element_shape
3340
self.element_dim = len(
3441
element_shape) if element_shape is not None else element_dim
42+
3543
self.field_dim = field_dim
3644
self.layout = layout
3745

38-
def _check_element_dim(self, arg, arg_dim):
39-
if self.element_dim is not None and self.element_dim != arg_dim:
46+
def match(self, ndarray_type: SpecializeNdarrayType):
47+
if self.element_dim is not None and self.element_dim != len(
48+
ndarray_type.element_type.shape):
4049
raise ValueError(
41-
f"Invalid argument into ti.types.ndarray() - required element_dim={self.element_dim}, but {arg} is provided"
50+
f"Invalid argument into ti.types.ndarray() - required element_dim={self.element_dim}, but {len(ndarray_type.element_type.shape)} is provided"
4251
)
4352

44-
def _check_layout(self, arg):
45-
if self.layout is not None and self.layout != arg.layout:
53+
if self.element_shape is not None and self.element_shape != ndarray_type.element_type.shape:
4654
raise ValueError(
47-
f"Invalid argument into ti.types.ndarray() - required layout={self.layout}, but {arg} is provided"
55+
f"Invalid argument into ti.types.ndarray() - required element_shape={self.element_shape}, but {ndarray_type.element_type.shape} is provided"
4856
)
4957

50-
def _check_element_shape(self, shapes):
51-
if self.element_shape is not None and shapes != self.element_shape:
58+
if self.layout is not None and self.layout != ndarray_type.layout:
5259
raise ValueError(
53-
f"Invalid argument into ti.types.ndarray() - required element_shape={self.element_shape}, but {shapes} is provided"
60+
f"Invalid argument into ti.types.ndarray() - required layout={self.layout}, but {ndarray_type.layout} is provided"
5461
)
5562

56-
def _check_field_dim(self, field_dim):
57-
if self.field_dim is not None and field_dim != self.field_dim:
63+
if self.field_dim is not None and self.field_dim != len(
64+
ndarray_type.shape):
5865
raise ValueError(
59-
f"Invalid argument into ti.types.ndarray() - required field_dim={self.field_dim}, but {field_dim} is provided"
66+
f"Invalid argument into ti.types.ndarray() - required field_dim={self.field_dim}, but {ndarray_type.element_type} is provided"
6067
)
6168

6269

tests/python/test_aot.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,14 @@ def test_aot_bind_id():
8484
density1 = ti.ndarray(dtype=ti.f32, shape=(8, 8))
8585

8686
@ti.kernel
87-
def init(x: ti.f32, density1: ti.types.ndarray(field_dim=2,
88-
element_shape=())):
87+
def init(x: ti.f32, density1: ti.types.ndarray(field_dim=2)):
8988
for i, j in density1:
9089
density[i, j] = x
9190
density1[i, j] = x + 1
9291

9392
with tempfile.TemporaryDirectory() as tmpdir:
9493
m = ti.aot.Module(ti.lang.impl.current_cfg().arch)
95-
m.add_kernel(init)
94+
m.add_kernel(init, {'density1': density1})
9695
m.save(tmpdir, '')
9796
with open(os.path.join(tmpdir, 'metadata.json')) as json_file:
9897
res = json.load(json_file)

tests/python/test_api.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ def _get_expected_matrix_apis():
113113
'to_torch'
114114
]
115115
user_api[ti.MatrixNdarray] = [
116-
'copy_from', 'element_shape', 'fill', 'from_numpy', 'to_numpy'
116+
'copy_from', 'element_shape', 'fill', 'from_numpy', 'get_type', 'to_numpy'
117117
]
118-
user_api[ti.Ndarray] = ['copy_from', 'element_shape', 'fill']
118+
user_api[ti.Ndarray] = ['copy_from', 'element_shape', 'fill', 'get_type']
119119
user_api[ti.SNode] = [
120120
'bit_array', 'bit_struct', 'bitmasked', 'deactivate_all', 'dense',
121121
'dynamic', 'lazy_grad', 'parent', 'place', 'pointer', 'shape'
@@ -125,7 +125,7 @@ def _get_expected_matrix_apis():
125125
'parent', 'shape', 'snode', 'to_numpy', 'to_paddle', 'to_torch'
126126
]
127127
user_api[ti.ScalarNdarray] = [
128-
'copy_from', 'element_shape', 'fill', 'from_numpy', 'to_numpy'
128+
'copy_from', 'element_shape', 'fill', 'from_numpy', 'get_type', 'to_numpy'
129129
]
130130
user_api[ti.Struct] = ['field', 'fill', 'items', 'keys', 'to_dict']
131131
user_api[ti.StructField] = [
@@ -134,7 +134,7 @@ def _get_expected_matrix_apis():
134134
'to_paddle', 'to_torch'
135135
]
136136
user_api[ti.VectorNdarray] = [
137-
'copy_from', 'element_shape', 'fill', 'from_numpy', 'to_numpy'
137+
'copy_from', 'element_shape', 'fill', 'from_numpy', 'get_type', 'to_numpy'
138138
]
139139

140140

tests/python/test_graph.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def test_ndarray_int():
99
n = 4
1010

1111
@ti.kernel
12-
def test(pos: ti.types.ndarray(field_dim=1, element_shape=())):
12+
def test(pos: ti.types.ndarray(field_dim=1)):
1313
for i in range(n):
1414
pos[i] = 1
1515

0 commit comments

Comments
 (0)