Skip to content

Commit be0cc2f

Browse files
[test] Remove tests with real_matrix=True and real_matrix_scalarize=True (#6873)
Issue: #5819 ### Brief Summary As these two options are enabled by default (#6801), we no longer need separate tests for them. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 040c243 commit be0cc2f

30 files changed

+180
-1194
lines changed

tests/python/test_ad_gdar_diffmpm.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from tests import test_utils
55

66

7-
def _test_gdar_mpm():
7+
@test_utils.test(require=ti.extension.assertion, debug=True, exclude=[ti.cc])
8+
def test_gdar_mpm():
89
real = ti.f32
910

1011
dim = 2
@@ -182,17 +183,3 @@ def substep(s):
182183
learning_rate = 10
183184
init_v[None][0] -= learning_rate * grad[0]
184185
init_v[None][1] -= learning_rate * grad[1]
185-
186-
187-
@test_utils.test(require=ti.extension.assertion, debug=True, exclude=[ti.cc])
188-
def test_gdar_mpm():
189-
_test_gdar_mpm()
190-
191-
192-
@test_utils.test(require=ti.extension.assertion,
193-
debug=True,
194-
exclude=[ti.cc],
195-
real_matrix=True,
196-
real_matrix_scalarize=True)
197-
def test_gdar_mpm_real_matrix_scalarize():
198-
_test_gdar_mpm()

tests/python/test_ast_refactor.py

+12-63
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,8 @@ def func():
432432
assert x[i, j, k] == 0
433433

434434

435-
def _test_grouped_ndrange_for():
435+
@test_utils.test(print_preprocessed_ir=True)
436+
def test_grouped_ndrange_for():
436437
x = ti.field(ti.i32, shape=(6, 6, 6))
437438
y = ti.field(ti.i32, shape=(6, 6, 6))
438439

@@ -457,18 +458,6 @@ def func():
457458
assert x[i, j, k] == y[i, j, k]
458459

459460

460-
@test_utils.test(print_preprocessed_ir=True)
461-
def test_grouped_ndrange_for():
462-
_test_grouped_ndrange_for()
463-
464-
465-
@test_utils.test(print_preprocessed_ir=True,
466-
real_matrix=True,
467-
real_matrix_scalarize=True)
468-
def test_grouped_ndrange_for_matrix_scalarize():
469-
_test_grouped_ndrange_for()
470-
471-
472461
@test_utils.test(print_preprocessed_ir=True)
473462
def test_static_for_break():
474463
n = 10
@@ -493,7 +482,8 @@ def foo(a: ti.template()):
493482
assert a[i] == 0
494483

495484

496-
def _test_static_grouped_for_break():
485+
@test_utils.test(print_preprocessed_ir=True)
486+
def test_static_grouped_for_break():
497487
n = 4
498488

499489
@ti.kernel
@@ -517,18 +507,6 @@ def foo(a: ti.template()):
517507
assert a[i, j] == 0
518508

519509

520-
@test_utils.test(print_preprocessed_ir=True)
521-
def test_static_grouped_for_break():
522-
_test_static_grouped_for_break()
523-
524-
525-
@test_utils.test(print_preprocessed_ir=True,
526-
real_matrix=True,
527-
real_matrix_scalarize=True)
528-
def test_static_grouped_for_break_matrix_scalarize():
529-
_test_static_grouped_for_break()
530-
531-
532510
@test_utils.test(print_preprocessed_ir=True)
533511
def test_static_for_continue():
534512
n = 10
@@ -551,7 +529,8 @@ def foo(a: ti.template()):
551529
assert a[i] == 3
552530

553531

554-
def _test_static_grouped_for_continue():
532+
@test_utils.test(print_preprocessed_ir=True)
533+
def test_static_grouped_for_continue():
555534
n = 4
556535

557536
@ti.kernel
@@ -573,18 +552,6 @@ def foo(a: ti.template()):
573552
assert a[i, j] == 3
574553

575554

576-
@test_utils.test(print_preprocessed_ir=True)
577-
def test_static_grouped_for_continue():
578-
_test_static_grouped_for_continue()
579-
580-
581-
@test_utils.test(print_preprocessed_ir=True,
582-
real_matrix=True,
583-
real_matrix_scalarize=True)
584-
def test_static_grouped_for_continue_matrix_scalarize():
585-
_test_static_grouped_for_continue()
586-
587-
588555
@test_utils.test(print_preprocessed_ir=True)
589556
def test_for_break():
590557
n = 4
@@ -885,8 +852,8 @@ def foo(x: ti.template()) -> ti.i32:
885852
foo(2)
886853

887854

888-
@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
889-
def test_single_listcomp_matrix_scalarize():
855+
@test_utils.test()
856+
def test_single_listcomp():
890857
@ti.func
891858
def identity(dt, n: ti.template()):
892859
return ti.Matrix([[ti.cast(int(i == j), dt) for j in range(n)]
@@ -904,7 +871,8 @@ def foo(n: ti.template()) -> ti.i32:
904871
assert foo(5) == 1
905872

906873

907-
def _test_listcomp():
874+
@test_utils.test()
875+
def test_listcomp():
908876
@ti.func
909877
def identity(dt, n: ti.template()):
910878
return ti.Matrix([[ti.cast(int(i == j), dt) for j in range(n)]
@@ -923,16 +891,6 @@ def foo(n: ti.template()) -> ti.i32:
923891
assert foo(5) == 1 + 4 + 9 + 16
924892

925893

926-
@test_utils.test()
927-
def test_listcomp():
928-
_test_listcomp()
929-
930-
931-
@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
932-
def test_listcomp_matrix_scalarize():
933-
_test_listcomp()
934-
935-
936894
@test_utils.test()
937895
def test_dictcomp():
938896
@ti.kernel
@@ -1061,7 +1019,8 @@ def foo() -> ti.i32:
10611019
assert foo() == 123
10621020

10631021

1064-
def _test_grouped_static_for_cast():
1022+
@test_utils.test()
1023+
def test_grouped_static_for_cast():
10651024
@ti.kernel
10661025
def foo() -> ti.f32:
10671026
ret = 0.
@@ -1071,13 +1030,3 @@ def foo() -> ti.f32:
10711030
return ret
10721031

10731032
assert foo() == test_utils.approx(10)
1074-
1075-
1076-
@test_utils.test()
1077-
def test_grouped_static_for_cast():
1078-
_test_grouped_static_for_cast()
1079-
1080-
1081-
@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
1082-
def test_grouped_static_for_cast_matrix_scalarize():
1083-
_test_grouped_static_for_cast()

tests/python/test_bls.py

+16-104
Original file line numberDiff line numberDiff line change
@@ -62,142 +62,54 @@ def _test_bls_stencil(*args, **kwargs):
6262
bls_test_template(*args, **kwargs)
6363

6464

65-
def _test_gather_1d_trivial():
65+
@test_utils.test(require=ti.extension.bls)
66+
def test_gather_1d_trivial():
6667
# y[i] = x[i]
6768
_test_bls_stencil(1, 128, bs=32, stencil=((0, ), ))
6869

6970

70-
def _test_gather_1d():
71+
@test_utils.test(require=ti.extension.bls)
72+
def test_gather_1d():
7173
# y[i] = x[i - 1] + x[i]
7274
_test_bls_stencil(1, 128, bs=32, stencil=((-1, ), (0, )))
7375

7476

75-
def _test_gather_2d():
77+
@test_utils.test(require=ti.extension.bls)
78+
def test_gather_2d():
7679
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
7780
_test_bls_stencil(2, 128, bs=16, stencil=stencil)
7881

7982

80-
def _test_gather_2d_nonsquare():
83+
@test_utils.test(require=ti.extension.bls)
84+
def test_gather_2d_nonsquare():
8185
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
8286
_test_bls_stencil(2, 128, bs=(4, 16), stencil=stencil)
8387

8488

85-
def _test_gather_3d():
89+
@test_utils.test(require=ti.extension.bls)
90+
def test_gather_3d():
8691
stencil = [(-1, -1, -1), (2, 0, 1)]
8792
_test_bls_stencil(3, 64, bs=(4, 8, 16), stencil=stencil)
8893

8994

90-
def _test_scatter_1d_trivial():
95+
@test_utils.test(require=ti.extension.bls)
96+
def test_scatter_1d_trivial():
9197
# y[i] = x[i]
9298
_test_bls_stencil(1, 128, bs=32, stencil=((0, ), ), scatter=True)
9399

94100

95-
def _test_scatter_1d():
101+
@test_utils.test(require=ti.extension.bls)
102+
def test_scatter_1d():
96103
_test_bls_stencil(1, 128, bs=32, stencil=(
97104
(1, ),
98105
(0, ),
99106
), scatter=True)
100107

101108

102-
def _test_scatter_2d():
103-
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
104-
_test_bls_stencil(2, 128, bs=16, stencil=stencil, scatter=True)
105-
106-
107-
@test_utils.test(require=ti.extension.bls)
108-
def test_gather_1d_trivial():
109-
_test_gather_1d_trivial()
110-
111-
112-
@test_utils.test(require=ti.extension.bls)
113-
def test_gather_1d():
114-
_test_gather_1d()
115-
116-
117-
@test_utils.test(require=ti.extension.bls)
118-
def test_gather_2d():
119-
_test_gather_2d()
120-
121-
122-
@test_utils.test(require=ti.extension.bls)
123-
def test_gather_2d_nonsquare():
124-
_test_gather_2d_nonsquare()
125-
126-
127-
@test_utils.test(require=ti.extension.bls)
128-
def test_gather_3d():
129-
_test_gather_3d()
130-
131-
132-
@test_utils.test(require=ti.extension.bls)
133-
def test_scatter_1d_trivial():
134-
_test_scatter_1d_trivial()
135-
136-
137-
@test_utils.test(require=ti.extension.bls)
138-
def test_scatter_1d():
139-
_test_scatter_1d()
140-
141-
142109
@test_utils.test(require=ti.extension.bls)
143110
def test_scatter_2d():
144-
_test_scatter_2d()
145-
146-
147-
@test_utils.test(require=ti.extension.bls,
148-
real_matrix=True,
149-
real_matrix_scalarize=True)
150-
def test_gather_1d_trivial_matrix_scalarize():
151-
_test_gather_1d_trivial()
152-
153-
154-
@test_utils.test(require=ti.extension.bls,
155-
real_matrix=True,
156-
real_matrix_scalarize=True)
157-
def test_gather_1d_matrix_scalarize():
158-
_test_gather_1d()
159-
160-
161-
@test_utils.test(require=ti.extension.bls,
162-
real_matrix=True,
163-
real_matrix_scalarize=True)
164-
def test_gather_2d_matrix_scalarize():
165-
_test_gather_2d()
166-
167-
168-
@test_utils.test(require=ti.extension.bls,
169-
real_matrix=True,
170-
real_matrix_scalarize=True)
171-
def test_gather_2d_nonsquare_matrix_scalarize():
172-
_test_gather_2d_nonsquare()
173-
174-
175-
@test_utils.test(require=ti.extension.bls,
176-
real_matrix=True,
177-
real_matrix_scalarize=True)
178-
def test_gather_3d_matrix_scalarize():
179-
_test_gather_3d()
180-
181-
182-
@test_utils.test(require=ti.extension.bls,
183-
real_matrix=True,
184-
real_matrix_scalarize=True)
185-
def test_scatter_1d_trivial_matrix_scalarize():
186-
_test_scatter_1d_trivial()
187-
188-
189-
@test_utils.test(require=ti.extension.bls,
190-
real_matrix=True,
191-
real_matrix_scalarize=True)
192-
def test_scatter_1d_matrix_scalarize():
193-
_test_scatter_1d()
194-
195-
196-
@test_utils.test(require=ti.extension.bls,
197-
real_matrix=True,
198-
real_matrix_scalarize=True)
199-
def test_scatter_2d_matrix_scalarize():
200-
_test_scatter_2d()
111+
stencil = [(0, 0), (0, -1), (0, 1), (1, 0)]
112+
_test_bls_stencil(2, 128, bs=16, stencil=stencil, scatter=True)
201113

202114

203115
@test_utils.test(require=ti.extension.bls)

0 commit comments

Comments
 (0)