Skip to content

Commit 823f7d1

Browse files
jim19930609pre-commit-ci[bot]strongoier
authored andcommitted
[bug] MatrixType bug fix: Fix error with static-grouped-ndrange (taichi-dev#6839)
Issue: taichi-dev#5819 ### Brief Summary Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yi Xu <[email protected]>
1 parent 6f4dda1 commit 823f7d1

File tree

5 files changed

+65
-19
lines changed

5 files changed

+65
-19
lines changed

python/taichi/lang/_ndrange.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import collections.abc
22

33
import numpy as np
4-
from taichi.lang import impl, ops
4+
from taichi.lang import ops
55
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
66
from taichi.lang.expr import Expr
7-
from taichi.lang.matrix import _IntermediateMatrix, make_matrix
8-
from taichi.types import primitive_types
7+
from taichi.lang.matrix import _IntermediateMatrix
98
from taichi.types.utils import is_integral
109

1110

@@ -145,10 +144,7 @@ def __init__(self, r):
145144

146145
def __iter__(self):
147146
for ind in self.r:
148-
if impl.current_cfg().real_matrix:
149-
yield make_matrix(list(ind), dt=primitive_types.i32)
150-
else:
151-
yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1)
147+
yield _IntermediateMatrix(len(ind), 1, list(ind), ndim=1)
152148

153149

154150
__all__ = ['ndrange']

python/taichi/lang/impl.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1071,9 +1071,10 @@ def static(x, *xs):
10711071
return [static(x)] + [static(x) for x in xs]
10721072

10731073
if isinstance(x,
1074-
(bool, int, float, range, list, tuple, enumerate, _Ndrange,
1075-
GroupedNDRange, zip, filter, map)) or x is None:
1074+
(bool, int, float, range, list, tuple, enumerate,
1075+
GroupedNDRange, _Ndrange, zip, filter, map)) or x is None:
10761076
return x
1077+
10771078
if isinstance(x, AnyArray):
10781079
return x
10791080
if isinstance(x, Field):

tests/python/test_ast_refactor.py

+37-6
Original file line numberDiff line numberDiff line change
@@ -493,8 +493,7 @@ def foo(a: ti.template()):
493493
assert a[i] == 0
494494

495495

496-
@test_utils.test(print_preprocessed_ir=True)
497-
def test_static_grouped_for_break():
496+
def _test_static_grouped_for_break():
498497
n = 4
499498

500499
@ti.kernel
@@ -518,6 +517,18 @@ def foo(a: ti.template()):
518517
assert a[i, j] == 0
519518

520519

520+
@test_utils.test(print_preprocessed_ir=True)
521+
def test_static_grouped_for_break():
522+
_test_static_grouped_for_break()
523+
524+
525+
@test_utils.test(print_preprocessed_ir=True,
526+
real_matrix=True,
527+
real_matrix_scalarize=True)
528+
def test_static_grouped_for_break_matrix_scalarize():
529+
_test_static_grouped_for_break()
530+
531+
521532
@test_utils.test(print_preprocessed_ir=True)
522533
def test_static_for_continue():
523534
n = 10
@@ -540,8 +551,7 @@ def foo(a: ti.template()):
540551
assert a[i] == 3
541552

542553

543-
@test_utils.test(print_preprocessed_ir=True)
544-
def test_static_grouped_for_continue():
554+
def _test_static_grouped_for_continue():
545555
n = 4
546556

547557
@ti.kernel
@@ -563,6 +573,18 @@ def foo(a: ti.template()):
563573
assert a[i, j] == 3
564574

565575

576+
@test_utils.test(print_preprocessed_ir=True)
577+
def test_static_grouped_for_continue():
578+
_test_static_grouped_for_continue()
579+
580+
581+
@test_utils.test(print_preprocessed_ir=True,
582+
real_matrix=True,
583+
real_matrix_scalarize=True)
584+
def test_static_grouped_for_continue_matrix_scalarize():
585+
_test_static_grouped_for_continue()
586+
587+
566588
@test_utils.test(print_preprocessed_ir=True)
567589
def test_for_break():
568590
n = 4
@@ -1039,8 +1061,7 @@ def foo() -> ti.i32:
10391061
assert foo() == 123
10401062

10411063

1042-
@test_utils.test()
1043-
def test_grouped_static_for_cast():
1064+
def _test_grouped_static_for_cast():
10441065
@ti.kernel
10451066
def foo() -> ti.f32:
10461067
ret = 0.
@@ -1050,3 +1071,13 @@ def foo() -> ti.f32:
10501071
return ret
10511072

10521073
assert foo() == test_utils.approx(10)
1074+
1075+
1076+
@test_utils.test()
1077+
def test_grouped_static_for_cast():
1078+
_test_grouped_static_for_cast()
1079+
1080+
1081+
@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
1082+
def test_grouped_static_for_cast_matrix_scalarize():
1083+
_test_grouped_static_for_cast()

tests/python/test_grouped.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def test():
185185
assert val[None] == 42
186186

187187

188-
@test_utils.test()
189-
def test_static_grouped_func():
188+
def _test_static_grouped_func():
190189

191190
K = 3
192191
dim = 2
@@ -207,3 +206,13 @@ def p2g():
207206
for j in range(K):
208207
for k in range(K):
209208
assert v[i, j][k] == i + j * 3 + k * 10
209+
210+
211+
@test_utils.test()
212+
def test_static_grouped_func():
213+
_test_static_grouped_func()
214+
215+
216+
@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
217+
def test_static_grouped_func_matrix_scalarize():
218+
_test_static_grouped_func()

tests/python/test_ndrange.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,7 @@ def func():
108108
assert x[i, j, k] == 0
109109

110110

111-
@test_utils.test()
112-
def test_static_grouped_static():
111+
def _test_static_grouped_static():
113112
x = ti.Matrix.field(2, 3, dtype=ti.f32, shape=(16, 4))
114113

115114
@ti.kernel
@@ -126,6 +125,16 @@ def func():
126125
assert x[i, j][k, l] == k + l * 10 + i + j * 4
127126

128127

128+
@test_utils.test()
129+
def test_static_grouped_static():
130+
_test_static_grouped_static()
131+
132+
133+
@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
134+
def test_static_grouped_static_matrix_scalarize():
135+
_test_static_grouped_static()
136+
137+
129138
@test_utils.test()
130139
def test_field_init_eye():
131140
# https://github.com/taichi-dev/taichi/issues/1824

0 commit comments

Comments
 (0)