Skip to content

Commit c9091bf

Browse files
ailzhangAiling Zhang
authored and
Ailing Zhang
committed
[bug] Support indexing via np.integer for field (taichi-dev#5712)
1 parent af873d1 commit c9091bf

File tree

3 files changed

+25
-2
lines changed

3 files changed

+25
-2
lines changed

python/taichi/lang/field.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,9 @@ def __getitem__(self, key):
352352
# Check for potential slicing behaviour
353353
# for instance: x[0, :]
354354
padded_key = self._pad_key(key)
355+
import numpy as np # pylint: disable=C0415
355356
for key in padded_key:
356-
if not isinstance(key, int):
357+
if not isinstance(key, (int, np.integer)):
357358
raise TypeError(
358359
f"Detected illegal element of type: {type(key)}. "
359360
f"Please be aware that slicing a ti.field is not supported so far."

python/taichi/lang/matrix.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _linearize_entry_id(self, *args):
125125
args = args + (0, )
126126
# TODO(#1004): See if it's possible to support indexing at runtime
127127
for i, a in enumerate(args):
128-
if not isinstance(a, int):
128+
if not isinstance(a, (int, np.integer)):
129129
raise TaichiSyntaxError(
130130
f'The {i}-th index of a Matrix/Vector must be a compile-time constant '
131131
f'integer, got {type(a)}.\n'

tests/python/test_field.py

+22
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
To test our new `ti.field` API is functional (#1500)
33
'''
44

5+
import numpy as np
56
import pytest
67
from taichi.lang import impl
78
from taichi.lang.misc import get_host_arch_list
@@ -282,6 +283,27 @@ def test_invalid_slicing():
282283
val[0, :]
283284

284285

286+
@test_utils.test()
287+
def test_indexing_with_np_int():
288+
val = ti.field(ti.i32, shape=(2))
289+
idx = np.int32(0)
290+
val[idx]
291+
292+
293+
@test_utils.test()
294+
def test_indexing_vec_field_with_np_int():
295+
val = ti.Vector.field(2, ti.i32, shape=(2))
296+
idx = np.int32(0)
297+
val[idx][idx]
298+
299+
300+
@test_utils.test()
301+
def test_indexing_mat_field_with_np_int():
302+
val = ti.Matrix.field(2, 2, ti.i32, shape=(2))
303+
idx = np.int32(0)
304+
val[idx][idx, idx]
305+
306+
285307
@test_utils.test(exclude=[ti.cc], debug=True)
286308
def test_field_fill():
287309
x = ti.field(int, shape=(3, 3))

0 commit comments

Comments
 (0)