Skip to content

Commit 6dd26ac

Browse files
authored
[vulkan] Relax number of array args for each kernel (#5689)
1 parent bf20327 commit 6dd26ac

File tree

5 files changed

+60
-45
lines changed

5 files changed

+60
-45
lines changed

docs/lang/articles/kernels/syntax.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ my_kernel(24, 3.2) # The system prints 27.2
7272

7373
The upper limit for element numbers is backend-specific:
7474

75-
- 8 on OpenGL backend
76-
- 64 on CPU, Vulkan, CUDA, or Metal
75+
- 64 on CPU, Vulkan, CUDA, OpenGL or Metal
7776

7877
:::note
7978
- The number of elements in a scalar argument is always 1.
8079
- The number of the elements in a `ti.Matrix` or in a `ti.Vector` is the actual number of scalars inside of them.
80+
- The upper limit of array arguments (`ti.types.ndarray()`) is 32 and they must be the first 32 among all 64 arguments.
8181
:::
8282

8383
```python

python/taichi/lang/kernel_impl.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -759,16 +759,14 @@ def func__(*args):
759759
) and self.runtime.target_tape and not self.runtime.grad_replaced:
760760
self.runtime.target_tape.insert(self, args)
761761

762-
if actual_argument_slot > 8 and (
763-
impl.current_cfg().arch == _ti_core.opengl
764-
or impl.current_cfg().arch == _ti_core.cc):
762+
if actual_argument_slot > 8 and impl.current_cfg(
763+
).arch == _ti_core.cc:
765764
raise TaichiRuntimeError(
766765
f"The number of elements in kernel arguments is too big! Do not exceed 8 on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
767766
)
768767

769-
if actual_argument_slot > 64 and (
770-
(impl.current_cfg().arch != _ti_core.opengl
771-
and impl.current_cfg().arch != _ti_core.cc)):
768+
if actual_argument_slot > 64 and impl.current_cfg(
769+
).arch != _ti_core.cc:
772770
raise TaichiRuntimeError(
773771
f"The number of elements in kernel arguments is too big! Do not exceed 64 on {_ti_core.arch_name(impl.current_cfg().arch)} backend."
774772
)

taichi/inc/constants.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
constexpr int taichi_max_num_indices = 8;
66
// legacy: only used in cc and opengl backends
77
constexpr int taichi_max_num_args = 8;
8-
// used in llvm backend: only the first 16 arguments can be types.ndarray
8+
// used in llvm backend: only the first 32 arguments can be types.ndarray
99
// TODO: refine argument passing
1010
constexpr int taichi_max_num_args_total = 64;
11-
constexpr int taichi_max_num_args_extra = 16;
11+
constexpr int taichi_max_num_args_extra = 32;
1212
constexpr int taichi_max_num_snodes = 1024;
1313
constexpr int kMaxNumSnodeTreesLlvm = 512;
1414
constexpr int taichi_max_gpu_block_dim = 1024;

tests/python/test_aot.py

-34
Original file line numberDiff line numberDiff line change
@@ -264,40 +264,6 @@ def init(d: ti.i32, density1: ti.types.ndarray(),
264264
assert (density6.to_numpy() == (np.zeros(shape=(n, n)) + 6)).all()
265265

266266

267-
@test_utils.test(arch=ti.opengl)
268-
def test_opengl_exceed_max_ssbo():
269-
# 8 ndarrays + args > 8 (maximum allowed)
270-
n = 4
271-
density1 = ti.ndarray(dtype=ti.f32, shape=(n, n))
272-
density2 = ti.ndarray(dtype=ti.f32, shape=(n, n))
273-
density3 = ti.ndarray(dtype=ti.f32, shape=(n, n))
274-
density4 = ti.ndarray(dtype=ti.f32, shape=(n, n))
275-
density5 = ti.ndarray(dtype=ti.f32, shape=(n, n))
276-
density6 = ti.ndarray(dtype=ti.f32, shape=(n, n))
277-
density7 = ti.ndarray(dtype=ti.f32, shape=(n, n))
278-
density8 = ti.ndarray(dtype=ti.f32, shape=(n, n))
279-
280-
@ti.kernel
281-
def init(d: ti.i32, density1: ti.types.ndarray(),
282-
density2: ti.types.ndarray(), density3: ti.types.ndarray(),
283-
density4: ti.types.ndarray(), density5: ti.types.ndarray(),
284-
density6: ti.types.ndarray(), density7: ti.types.ndarray(),
285-
density8: ti.types.ndarray()):
286-
for i, j in density1:
287-
density1[i, j] = d + 1
288-
density2[i, j] = d + 2
289-
density3[i, j] = d + 3
290-
density4[i, j] = d + 4
291-
density5[i, j] = d + 5
292-
density6[i, j] = d + 6
293-
density7[i, j] = d + 7
294-
density8[i, j] = d + 8
295-
296-
with pytest.raises(RuntimeError):
297-
init(0, density1, density2, density3, density4, density5, density6,
298-
density7, density8)
299-
300-
301267
@test_utils.test(arch=[ti.opengl, ti.vulkan])
302268
def test_mpm99_aot():
303269
quality = 1 # Use a larger value for higher-res simulations

tests/python/test_argument.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tests import test_utils
55

66

7-
@test_utils.test(arch=[ti.opengl, ti.cc])
7+
@test_utils.test(arch=[ti.cc])
88
def test_exceed_max_eight():
99
@ti.kernel
1010
def foo1(a: ti.i32, b: ti.i32, c: ti.i32, d: ti.i32, e: ti.i32, f: ti.i32,
@@ -181,3 +181,54 @@ def duplicate():
181181
with pytest.raises(ti.TaichiSyntaxError,
182182
match="Multiple values for argument 'a'"):
183183
duplicate()
184+
185+
186+
@test_utils.test(exclude=[ti.cc])
187+
def test_args_with_many_ndarrays():
188+
189+
particle_num = 0
190+
cluster_num = 0
191+
permu_num = 0
192+
193+
particlePosition = ti.Vector.ndarray(3, ti.f32, shape=10)
194+
outClusterPosition = ti.Vector.ndarray(3, ti.f32, shape=10)
195+
outClusterOffsets = ti.ndarray(ti.i32, shape=10)
196+
outClusterSizes = ti.ndarray(ti.i32, shape=10)
197+
outClusterIndices = ti.ndarray(ti.i32, shape=10)
198+
199+
particle_pos = ti.Vector.ndarray(3, ti.f32, shape=20)
200+
particle_prev_pos = ti.Vector.ndarray(3, ti.f32, shape=20)
201+
particle_rest_pos = ti.Vector.ndarray(3, ti.f32, shape=20)
202+
particle_index = ti.ndarray(ti.i32, shape=20)
203+
204+
cluster_rest_mass_center = ti.Vector.ndarray(3, ti.f32, shape=20)
205+
cluster_begin = ti.ndarray(ti.i32, shape=20)
206+
207+
@ti.kernel
208+
def ti_import_cluster_data(
209+
center: ti.types.vector(3,
210+
ti.f32), particle_num: int, cluster_num: int,
211+
permu_num: int, particlePosition: ti.types.ndarray(field_dim=1),
212+
outClusterPosition: ti.types.ndarray(field_dim=1),
213+
outClusterOffsets: ti.types.ndarray(field_dim=1),
214+
outClusterSizes: ti.types.ndarray(field_dim=1),
215+
outClusterIndices: ti.types.ndarray(field_dim=1),
216+
particle_pos: ti.types.ndarray(field_dim=1),
217+
particle_prev_pos: ti.types.ndarray(field_dim=1),
218+
particle_rest_pos: ti.types.ndarray(field_dim=1),
219+
cluster_rest_mass_center: ti.types.ndarray(field_dim=1),
220+
cluster_begin: ti.types.ndarray(field_dim=1),
221+
particle_index: ti.types.ndarray(field_dim=1)):
222+
223+
added_permu_num = outClusterIndices.shape[0]
224+
225+
for i in range(added_permu_num):
226+
particle_index[i] = 1.0
227+
228+
center = ti.math.vec3(0, 0, 0)
229+
ti_import_cluster_data(center, particle_num, cluster_num, permu_num,
230+
particlePosition, outClusterPosition,
231+
outClusterOffsets, outClusterSizes,
232+
outClusterIndices, particle_pos, particle_prev_pos,
233+
particle_rest_pos, cluster_rest_mass_center,
234+
cluster_begin, particle_index)

0 commit comments

Comments
 (0)