From 1816faed9a6d5eb18d5610007d276562d113f0b1 Mon Sep 17 00:00:00 2001
From: Ailing Zhang <ailing@taichi.graphics>
Date: Mon, 11 Jul 2022 21:10:40 +0800
Subject: [PATCH] [refactor] Default dtype of ndarray type should be None
 instead of f32

---
 python/taichi/aot/utils.py          |  2 +-
 python/taichi/types/ndarray_type.py | 16 ++++++----------
 tests/python/test_graph.py          |  2 +-
 3 files changed, 8 insertions(+), 12 deletions(-)

diff --git a/python/taichi/aot/utils.py b/python/taichi/aot/utils.py
index 05bc0a7c1e074..57dc17b65b222 100644
--- a/python/taichi/aot/utils.py
+++ b/python/taichi/aot/utils.py
@@ -54,7 +54,7 @@ def produce_injected_args(kernel, symbolic_args=None):
                 raise TaichiCompilationError(
                     f'{field_dim} from Arg {arg.name} doesn\'t match kernel\'s annotated field_dim={anno.field_dim}'
                 )
-            if dtype != anno.dtype:
+            if anno.dtype is not None and dtype != anno.dtype:
                 raise TaichiCompilationError(
                     f' Arg {arg.name}\'s dtype {dtype.to_string()} doesn\'t match kernel\'s annotated dtype={anno.dtype.to_string()}'
                 )
diff --git a/python/taichi/types/ndarray_type.py b/python/taichi/types/ndarray_type.py
index 7c097b2e52f1a..95d51ff531a99 100644
--- a/python/taichi/types/ndarray_type.py
+++ b/python/taichi/types/ndarray_type.py
@@ -1,6 +1,3 @@
-from taichi.types.primitive_types import f32
-
-
 class NdarrayTypeMetadata:
     def __init__(self, element_type, shape=None, layout=None):
         self.element_type = element_type
@@ -20,13 +17,12 @@ class NdarrayType:
         field_dim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for external arrays for now.
         layout (Union[Layout, NoneType], optional): None if not specified (will be treated as Layout.AOS for external arrays), Layout.AOS or Layout.SOA.
     """
-    def __init__(
-            self,
-            dtype=f32,  # TODO: default should be None
-            element_dim=None,
-            element_shape=None,
-            field_dim=None,
-            layout=None):
+    def __init__(self,
+                 dtype=None,
+                 element_dim=None,
+                 element_shape=None,
+                 field_dim=None,
+                 layout=None):
         if element_dim is not None and (element_dim < 0 or element_dim > 2):
             raise ValueError(
                 "Only scalars, vectors, and matrices are allowed as elements of ti.types.ndarray()"
diff --git a/tests/python/test_graph.py b/tests/python/test_graph.py
index 7e27516e5455a..d8fa102e1c32b 100644
--- a/tests/python/test_graph.py
+++ b/tests/python/test_graph.py
@@ -142,7 +142,7 @@ def test_arg_mismatched_ndarray_dtype():
     n = 4
 
     @ti.kernel
-    def test(pos: ti.types.ndarray(field_dim=1)):
+    def test(pos: ti.types.ndarray(dtype=ti.f32, field_dim=1)):
         for i in range(n):
             pos[i] = 2.5