@@ -41,7 +41,6 @@ def check(instance, pattern):
41
41
42
42
for key_group in KEYGROUP_SET :
43
43
for index , attr in enumerate (key_group ):
44
-
45
44
def gen_property (attr , attr_idx , key_group ):
46
45
checker = make_valid_attribs_checker (key_group )
47
46
@@ -122,7 +121,7 @@ def _linearize_entry_id(self, *args):
122
121
if len (args ) == 1 and isinstance (args [0 ], (list , tuple )):
123
122
args = args [0 ]
124
123
if len (args ) == 1 :
125
- args = args + (0 , )
124
+ args = args + (0 ,)
126
125
# TODO(#1004): See if it's possible to support indexing at runtime
127
126
for i , a in enumerate (args ):
128
127
if not isinstance (a , int ):
@@ -244,19 +243,16 @@ def _subscript(self, is_global_mat, *indices):
244
243
if self .any_array_access :
245
244
return self .any_array_access .subscript (i , j )
246
245
if self .local_tensor_proxy is not None :
247
- assert self .dynamic_index_stride is not None
248
246
if len (indices ) == 1 :
249
- return impl .make_tensor_element_expr (self .local_tensor_proxy ,
250
- (i , ), (self .n , ),
251
- self .dynamic_index_stride )
252
- return impl .make_tensor_element_expr (self .local_tensor_proxy ,
253
- (i , j ), (self .n , self .m ),
254
- self .dynamic_index_stride )
247
+ return impl .make_index_expr (self .local_tensor_proxy ,
248
+ (i ,))
249
+ return impl .make_index_expr (self .local_tensor_proxy ,
250
+ (i , j ))
255
251
if impl .current_cfg (
256
252
).dynamic_index and is_global_mat and self .dynamic_index_stride :
257
- return impl .make_tensor_element_expr (self .entries [0 ].ptr , (i , j ),
258
- (self .n , self .m ),
259
- self .dynamic_index_stride )
253
+ return impl .make_stride_expr (self .entries [0 ].ptr , (i , j ),
254
+ (self .n , self .m ),
255
+ self .dynamic_index_stride )
260
256
return self ._get_entry (i , j )
261
257
262
258
def _calc_slice (self , index , dim ):
@@ -318,17 +314,15 @@ def with_dynamic_index(self, arr, dt):
318
314
local_tensor_proxy = impl .expr_init_local_tensor (
319
315
[len (arr )], dt ,
320
316
expr .make_expr_group ([expr .Expr (x ) for x in arr ]))
321
- dynamic_index_stride = 1
322
317
mat = []
323
318
for i in range (len (arr )):
324
319
mat .append (
325
320
list ([
326
- impl .make_tensor_element_expr (
321
+ impl .make_index_expr (
327
322
local_tensor_proxy ,
328
- (expr .Expr (i , dtype = primitive_types .i32 ), ),
329
- (len (arr ), ), dynamic_index_stride )
323
+ (expr .Expr (i , dtype = primitive_types .i32 ),))
330
324
]))
331
- return local_tensor_proxy , dynamic_index_stride , mat
325
+ return local_tensor_proxy , mat
332
326
333
327
def _get_entry_to_infer (self , arr ):
334
328
return arr [0 ]
@@ -348,18 +342,16 @@ def with_dynamic_index(self, arr, dt):
348
342
expr .make_expr_group (
349
343
[expr .Expr (x ) for row in arr for x in row ]))
350
344
351
- dynamic_index_stride = 1
352
345
mat = []
353
346
for i in range (len (arr )):
354
347
mat .append ([])
355
348
for j in range (len (arr [0 ])):
356
349
mat [i ].append (
357
- impl .make_tensor_element_expr (
350
+ impl .make_index_expr (
358
351
local_tensor_proxy ,
359
352
(expr .Expr (i , dtype = primitive_types .i32 ),
360
- expr .Expr (j , dtype = primitive_types .i32 )),
361
- (len (arr ), len (arr [0 ])), dynamic_index_stride ))
362
- return local_tensor_proxy , dynamic_index_stride , mat
353
+ expr .Expr (j , dtype = primitive_types .i32 ))))
354
+ return local_tensor_proxy , mat
363
355
364
356
def _get_entry_to_infer (self , arr ):
365
357
return arr [0 ][0 ]
@@ -413,7 +405,6 @@ class Matrix(TaichiOperations):
413
405
414
406
def __init__ (self , arr , dt = None , suppress_warning = False , is_ref = False ):
415
407
local_tensor_proxy = None
416
- dynamic_index_stride = None
417
408
418
409
if not isinstance (arr , (list , tuple , np .ndarray )):
419
410
raise TaichiTypeError (
@@ -440,7 +431,7 @@ def __init__(self, arr, dt=None, suppress_warning=False, is_ref=False):
440
431
)
441
432
if dt is None :
442
433
dt = initializer .infer_dt (arr )
443
- local_tensor_proxy , dynamic_index_stride , mat = initializer .with_dynamic_index (
434
+ local_tensor_proxy , mat = initializer .with_dynamic_index (
444
435
arr , dt )
445
436
446
437
self .n , self .m = len (mat ), 1
@@ -464,7 +455,7 @@ def __init__(self, arr, dt=None, suppress_warning=False, is_ref=False):
464
455
self ._impl = _PyScopeMatrixImpl (m , n , entries )
465
456
else :
466
457
self ._impl = _TiScopeMatrixImpl (m , n , entries , local_tensor_proxy ,
467
- dynamic_index_stride )
458
+ None )
468
459
469
460
def _element_wise_binary (self , foo , other ):
470
461
other = self ._broadcast_copy (other )
@@ -680,8 +671,8 @@ def E(x, y):
680
671
for i in range (n ):
681
672
for j in range (n ):
682
673
entries [j ][i ] = inv_determinant * (
683
- E (i + 1 , j + 1 ) * E (i + 2 , j + 2 ) -
684
- E (i + 2 , j + 1 ) * E (i + 1 , j + 2 ))
674
+ E (i + 1 , j + 1 ) * E (i + 2 , j + 2 ) -
675
+ E (i + 2 , j + 1 ) * E (i + 1 , j + 2 ))
685
676
return Matrix (entries )
686
677
if self .n == 4 :
687
678
n = 4
@@ -693,14 +684,14 @@ def E(x, y):
693
684
694
685
for i in range (n ):
695
686
for j in range (n ):
696
- entries [j ][i ] = inv_determinant * (- 1 )** (i + j ) * ((
697
- E (i + 1 , j + 1 ) *
698
- (E (i + 2 , j + 2 ) * E (i + 3 , j + 3 ) -
699
- E (i + 3 , j + 2 ) * E (i + 2 , j + 3 )) - E (i + 2 , j + 1 ) *
700
- (E (i + 1 , j + 2 ) * E (i + 3 , j + 3 ) -
701
- E (i + 3 , j + 2 ) * E (i + 1 , j + 3 )) + E (i + 3 , j + 1 ) *
702
- (E (i + 1 , j + 2 ) * E (i + 2 , j + 3 ) -
703
- E (i + 2 , j + 2 ) * E (i + 1 , j + 3 ))))
687
+ entries [j ][i ] = inv_determinant * (- 1 ) ** (i + j ) * ((
688
+ E (i + 1 , j + 1 ) *
689
+ (E (i + 2 , j + 2 ) * E (i + 3 , j + 3 ) -
690
+ E (i + 3 , j + 2 ) * E (i + 2 , j + 3 )) - E (i + 2 , j + 1 ) *
691
+ (E (i + 1 , j + 2 ) * E (i + 3 , j + 3 ) -
692
+ E (i + 3 , j + 2 ) * E (i + 1 , j + 3 )) + E (i + 3 , j + 1 ) *
693
+ (E (i + 1 , j + 2 ) * E (i + 2 , j + 3 ) -
694
+ E (i + 2 , j + 2 ) * E (i + 1 , j + 3 ))))
704
695
return Matrix (entries )
705
696
raise Exception (
706
697
"Inversions of matrices with sizes >= 5 are not supported" )
@@ -760,7 +751,7 @@ def determinant(a):
760
751
if a .n == 3 and a .m == 3 :
761
752
return a (0 , 0 ) * (a (1 , 1 ) * a (2 , 2 ) - a (2 , 1 ) * a (1 , 2 )) - a (
762
753
1 , 0 ) * (a (0 , 1 ) * a (2 , 2 ) - a (2 , 1 ) * a (0 , 2 )) + a (
763
- 2 , 0 ) * (a (0 , 1 ) * a (1 , 2 ) - a (1 , 1 ) * a (0 , 2 ))
754
+ 2 , 0 ) * (a (0 , 1 ) * a (1 , 2 ) - a (1 , 1 ) * a (0 , 2 ))
764
755
if a .n == 4 and a .m == 4 :
765
756
n = 4
766
757
@@ -769,14 +760,14 @@ def E(x, y):
769
760
770
761
det = impl .expr_init (0.0 )
771
762
for i in range (4 ):
772
- det = det + (- 1.0 )** i * (
773
- a (i , 0 ) *
774
- (E (i + 1 , 1 ) *
775
- (E (i + 2 , 2 ) * E (i + 3 , 3 ) - E (i + 3 , 2 ) * E (i + 2 , 3 )) -
776
- E (i + 2 , 1 ) *
777
- (E (i + 1 , 2 ) * E (i + 3 , 3 ) - E (i + 3 , 2 ) * E (i + 1 , 3 )) +
778
- E (i + 3 , 1 ) *
779
- (E (i + 1 , 2 ) * E (i + 2 , 3 ) - E (i + 2 , 2 ) * E (i + 1 , 3 ))))
763
+ det = det + (- 1.0 ) ** i * (
764
+ a (i , 0 ) *
765
+ (E (i + 1 , 1 ) *
766
+ (E (i + 2 , 2 ) * E (i + 3 , 3 ) - E (i + 3 , 2 ) * E (i + 2 , 3 )) -
767
+ E (i + 2 , 1 ) *
768
+ (E (i + 1 , 2 ) * E (i + 3 , 3 ) - E (i + 3 , 2 ) * E (i + 1 , 3 )) +
769
+ E (i + 3 , 1 ) *
770
+ (E (i + 1 , 2 ) * E (i + 2 , 3 ) - E (i + 2 , 2 ) * E (i + 1 , 3 ))))
780
771
return det
781
772
raise Exception (
782
773
"Determinants of matrices with sizes >= 5 are not supported" )
@@ -908,6 +899,7 @@ def fill(self, val):
908
899
>>> A
909
900
[-1, -1, -1, -1]
910
901
"""
902
+
911
903
def assign_renamed (x , y ):
912
904
return ops_mod .assign (x , y )
913
905
@@ -933,7 +925,7 @@ def to_numpy(self, keep_dims=False):
933
925
array([0, 1, 2, 3])
934
926
"""
935
927
as_vector = self .m == 1 and not keep_dims
936
- shape_ext = (self .n , ) if as_vector else (self .n , self .m )
928
+ shape_ext = (self .n ,) if as_vector else (self .n , self .m )
937
929
return np .array (self .to_list ()).reshape (shape_ext )
938
930
939
931
@taichi_scope
@@ -1128,9 +1120,9 @@ def field(cls,
1128
1120
1129
1121
if shape is not None :
1130
1122
if isinstance (shape , numbers .Number ):
1131
- shape = (shape , )
1123
+ shape = (shape ,)
1132
1124
if isinstance (offset , numbers .Number ):
1133
- offset = (offset , )
1125
+ offset = (offset ,)
1134
1126
1135
1127
if offset is not None :
1136
1128
assert len (shape ) == len (
@@ -1182,7 +1174,7 @@ def ndarray(cls, n, m, dtype, shape, layout=Layout.AOS):
1182
1174
>>> x = ti.Matrix.ndarray(4, 5, ti.f32, shape=(16, 8))
1183
1175
"""
1184
1176
if isinstance (shape , numbers .Number ):
1185
- shape = (shape , )
1177
+ shape = (shape ,)
1186
1178
return MatrixNdarray (n , m , dtype , shape , layout )
1187
1179
1188
1180
@classmethod
@@ -1202,7 +1194,7 @@ def _Vector_ndarray(cls, n, dtype, shape, layout=Layout.AOS):
1202
1194
>>> x = ti.Vector.ndarray(3, ti.f32, shape=(16, 8))
1203
1195
"""
1204
1196
if isinstance (shape , numbers .Number ):
1205
- shape = (shape , )
1197
+ shape = (shape ,)
1206
1198
return VectorNdarray (n , dtype , shape , layout )
1207
1199
1208
1200
@staticmethod
@@ -1392,6 +1384,7 @@ class _IntermediateMatrix(Matrix):
1392
1384
m (int): Number of columns of the matrix.
1393
1385
entries (List[Expr]): All entries of the matrix.
1394
1386
"""
1387
+
1395
1388
def __init__ (self , n , m , entries ):
1396
1389
assert isinstance (entries , list )
1397
1390
assert n * m == len (entries ), "Number of entries doesn't match n * m"
@@ -1411,6 +1404,7 @@ class _MatrixFieldElement(_IntermediateMatrix):
1411
1404
field (MatrixField): The matrix field.
1412
1405
indices (taichi_core.ExprGroup): Indices of the element.
1413
1406
"""
1407
+
1414
1408
def __init__ (self , field , indices ):
1415
1409
super ().__init__ (field .n , field .m , [
1416
1410
expr .Expr (ti_core .subscript (e .ptr , indices ))
@@ -1427,6 +1421,7 @@ class MatrixField(Field):
1427
1421
n (Int): Number of rows.
1428
1422
m (Int): Number of columns.
1429
1423
"""
1424
+
1430
1425
def __init__ (self , _vars , n , m ):
1431
1426
assert len (_vars ) == n * m
1432
1427
super ().__init__ (_vars )
@@ -1472,7 +1467,7 @@ def _calc_dynamic_index_stride(self):
1472
1467
i + 1 ]._offset_bytes_in_parent_cell for path in paths ):
1473
1468
return
1474
1469
stride = paths [1 ][depth_below_lca ]._offset_bytes_in_parent_cell - \
1475
- paths [0 ][depth_below_lca ]._offset_bytes_in_parent_cell
1470
+ paths [0 ][depth_below_lca ]._offset_bytes_in_parent_cell
1476
1471
for i in range (2 , num_members ):
1477
1472
if stride != paths [i ][depth_below_lca ]._offset_bytes_in_parent_cell \
1478
1473
- paths [i - 1 ][depth_below_lca ]._offset_bytes_in_parent_cell :
@@ -1493,7 +1488,7 @@ def fill(self, val):
1493
1488
elif isinstance (val ,
1494
1489
(list , tuple )) and isinstance (val [0 ], numbers .Number ):
1495
1490
assert self .m == 1
1496
- val = tuple ([(v , ) for v in val ])
1491
+ val = tuple ([(v ,) for v in val ])
1497
1492
elif isinstance (val , Matrix ):
1498
1493
val_tuple = []
1499
1494
for i in range (val .n ):
@@ -1525,7 +1520,7 @@ def to_numpy(self, keep_dims=False, dtype=None):
1525
1520
if dtype is None :
1526
1521
dtype = to_numpy_type (self .dtype )
1527
1522
as_vector = self .m == 1 and not keep_dims
1528
- shape_ext = (self .n , ) if as_vector else (self .n , self .m )
1523
+ shape_ext = (self .n ,) if as_vector else (self .n , self .m )
1529
1524
arr = np .zeros (self .shape + shape_ext , dtype = dtype )
1530
1525
from taichi ._kernels import matrix_to_ext_arr # pylint: disable=C0415
1531
1526
matrix_to_ext_arr (self , arr , as_vector )
@@ -1545,7 +1540,7 @@ def to_torch(self, device=None, keep_dims=False):
1545
1540
"""
1546
1541
import torch # pylint: disable=C0415
1547
1542
as_vector = self .m == 1 and not keep_dims
1548
- shape_ext = (self .n , ) if as_vector else (self .n , self .m )
1543
+ shape_ext = (self .n ,) if as_vector else (self .n , self .m )
1549
1544
# pylint: disable=E1101
1550
1545
arr = torch .empty (self .shape + shape_ext ,
1551
1546
dtype = to_pytorch_type (self .dtype ),
@@ -1568,7 +1563,7 @@ def to_paddle(self, place=None, keep_dims=False):
1568
1563
"""
1569
1564
import paddle # pylint: disable=C0415
1570
1565
as_vector = self .m == 1 and not keep_dims
1571
- shape_ext = (self .n , ) if as_vector else (self .n , self .m )
1566
+ shape_ext = (self .n ,) if as_vector else (self .n , self .m )
1572
1567
# pylint: disable=E1101
1573
1568
# paddle.empty() doesn't support argument `place``
1574
1569
arr = paddle .to_tensor (paddle .empty (self .shape + shape_ext ,
@@ -1688,6 +1683,7 @@ class MatrixNdarray(Ndarray):
1688
1683
1689
1684
>>> arr = ti.MatrixNdarray(2, 2, ti.f32, shape=(3, 3), layout=Layout.SOA)
1690
1685
"""
1686
+
1691
1687
def __init__ (self , n , m , dtype , shape , layout ):
1692
1688
self .n = n
1693
1689
self .m = m
@@ -1725,7 +1721,7 @@ def __setitem__(self, key, value):
1725
1721
@python_scope
1726
1722
def __getitem__ (self , key ):
1727
1723
key = () if key is None else (
1728
- key , ) if isinstance (key , numbers .Number ) else tuple (key )
1724
+ key ,) if isinstance (key , numbers .Number ) else tuple (key )
1729
1725
return Matrix (
1730
1726
[[NdarrayHostAccess (self , key , (i , j )) for j in range (self .m )]
1731
1727
for i in range (self .n )])
@@ -1786,13 +1782,14 @@ class VectorNdarray(Ndarray):
1786
1782
1787
1783
>>> a = ti.VectorNdarray(3, ti.f32, (3, 3), layout=Layout.SOA)
1788
1784
"""
1785
+
1789
1786
def __init__ (self , n , dtype , shape , layout ):
1790
1787
self .n = n
1791
1788
super ().__init__ ()
1792
1789
self .dtype = cook_dtype (dtype )
1793
1790
self .layout = layout
1794
1791
self .shape = tuple (shape )
1795
- self .element_type = TensorType ((n , ), self .dtype )
1792
+ self .element_type = TensorType ((n ,), self .dtype )
1796
1793
# TODO: pass in element_type, shape, layout directly
1797
1794
self .arr = impl .get_runtime ().prog .create_ndarray (
1798
1795
self .element_type .dtype , shape , self .element_type .shape , layout )
@@ -1819,9 +1816,9 @@ def __setitem__(self, key, value):
1819
1816
@python_scope
1820
1817
def __getitem__ (self , key ):
1821
1818
key = () if key is None else (
1822
- key , ) if isinstance (key , numbers .Number ) else tuple (key )
1819
+ key ,) if isinstance (key , numbers .Number ) else tuple (key )
1823
1820
return Vector (
1824
- [NdarrayHostAccess (self , key , (i , )) for i in range (self .n )])
1821
+ [NdarrayHostAccess (self , key , (i ,)) for i in range (self .n )])
1825
1822
1826
1823
@python_scope
1827
1824
def to_numpy (self ):
0 commit comments