Skip to content

Commit 613a48d

Browse files
committed
Fixed minor issue
1 parent 6739192 commit 613a48d

File tree

7 files changed

+68
-24
lines changed

7 files changed

+68
-24
lines changed

c_api/src/taichi_core_impl.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -558,8 +558,12 @@ void ti_launch_compute_graph(TiRuntime runtime,
558558
}
559559
}
560560

561-
ndarrays.emplace_back(
562-
taichi::lang::Ndarray(devalloc, *prim_ty, shape, elem_shape));
561+
taichi::lang::DataType dtype = *prim_ty;
562+
if (elem_shape.size() > 0) {
563+
dtype = taichi::lang::TypeFactory::get_instance().get_tensor_type(
564+
elem_shape, dtype);
565+
}
566+
ndarrays.emplace_back(taichi::lang::Ndarray(devalloc, dtype, shape));
563567
arg_map.emplace(std::make_pair(
564568
arg.name, taichi::lang::aot::IValue::create(ndarrays.back())));
565569
break;

python/taichi/lang/matrix.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1900,10 +1900,10 @@ def __init__(self, n, m, dtype, shape, layout):
19001900

19011901
self.layout = layout
19021902
self.shape = tuple(shape)
1903-
self.element_type = TensorType((self.n, self.m), self.dtype)
1903+
self.element_type = TensorType((self.n, self.m), dtype)
19041904
# TODO: we should pass in element_type, shape, layout instead.
19051905
self.arr = impl.get_runtime().prog.create_ndarray(
1906-
self.element_type, shape, layout)
1906+
cook_dtype(self.element_type.ptr), shape, layout)
19071907

19081908
@property
19091909
def element_shape(self):
@@ -1915,7 +1915,7 @@ def element_shape(self):
19151915
>>> arr.element_shape
19161916
(2, 2)
19171917
"""
1918-
return tuple(self.arr.element_shape)
1918+
return tuple(self.arr.element_shape())
19191919

19201920
@python_scope
19211921
def __setitem__(self, key, value):
@@ -1999,12 +1999,12 @@ def __init__(self, n, dtype, shape, layout):
19991999
super().__init__()
20002000
# TODO(zhanlue): remove self.dtype and migrate its usages to element_type
20012001
self.dtype = cook_dtype(dtype)
2002+
20022003
self.layout = layout
20032004
self.shape = tuple(shape)
2004-
self.element_type = TensorType((n, ), self.dtype)
2005-
# TODO: pass in element_type, shape, layout directly
2005+
self.element_type = TensorType((n, ), dtype)
20062006
self.arr = impl.get_runtime().prog.create_ndarray(
2007-
self.element_type, shape, layout)
2007+
cook_dtype(self.element_type.ptr), shape, layout)
20082008

20092009
@property
20102010
def element_shape(self):
@@ -2016,7 +2016,7 @@ def element_shape(self):
20162016
>>> a.element_shape
20172017
(3,)
20182018
"""
2019-
return tuple(self.arr.element_shape)
2019+
return tuple(self.arr.element_shape())
20202020

20212021
@python_scope
20222022
def __setitem__(self, key, value):

python/taichi/types/compound_types.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
1+
from taichi._lib.utils import ti_python_core as _ti_python_core
2+
13
import taichi
24

5+
_type_factory = _ti_python_core.get_type_factory_instance()
6+
37

48
class CompoundType:
59
pass
610

711

812
class TensorType(CompoundType):
913
def __init__(self, shape, dtype):
10-
self.dtype = dtype
11-
self.shape = shape
14+
if isinstance(dtype, _ti_python_core.DataType):
15+
dtype = dtype.get_ptr()
16+
self.ptr = _type_factory.get_tensor_type(shape, dtype)
17+
self.shape = self.get_shape()
18+
self.dtype = self.get_element_type()
19+
20+
def get_shape(self):
21+
return self.ptr.get_element_shape()
22+
23+
def get_element_type(self):
24+
return self.ptr.get_element_type()
1225

1326

1427
# TODO: maybe move MatrixType, StructType here to avoid the circular import?

taichi/aot/graph_data.cpp

+17-4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void CompiledGraph::run(
2525
const aot::IValue &ival = found->second;
2626
if (ival.tag == aot::ArgKind::kNdarray) {
2727
Ndarray *arr = reinterpret_cast<Ndarray *>(ival.val);
28+
2829
TI_ERROR_IF(arr->get_element_shape() != symbolic_arg.element_shape,
2930
"Mismatched shape information for argument {}",
3031
symbolic_arg.name);
@@ -33,12 +34,24 @@ void CompiledGraph::run(
3334
"field_dim={} but got an ndarray with field_dim={}",
3435
symbolic_arg.name, symbolic_arg.field_dim,
3536
arr->shape.size());
36-
TI_ERROR_IF(arr->dtype != symbolic_arg.dtype(),
37+
38+
DataType symbolic_arg_primitive_dtype = symbolic_arg.dtype();
39+
if (symbolic_arg.dtype()->is<TensorType>()) {
40+
symbolic_arg_primitive_dtype =
41+
symbolic_arg.dtype()->cast<TensorType>()->get_element_type();
42+
}
43+
44+
DataType arr_primitive_dtype = arr->dtype;
45+
if (arr->dtype->is<TensorType>()) {
46+
arr_primitive_dtype =
47+
arr->dtype->cast<TensorType>()->get_element_type();
48+
}
49+
50+
TI_ERROR_IF(arr_primitive_dtype != symbolic_arg_primitive_dtype,
3751
"Dispatch node is compiled for argument {} with "
3852
"dtype={} but got an ndarray with dtype={}",
39-
symbolic_arg.name, symbolic_arg.dtype().to_string(),
40-
arr->dtype.to_string());
41-
53+
symbolic_arg.name, symbolic_arg_primitive_dtype.to_string(),
54+
arr_primitive_dtype.to_string());
4255
ctx.set_arg_ndarray(i, arr->get_device_allocation_ptr_as_int(),
4356
arr->shape);
4457
} else if (ival.tag == aot::ArgKind::kScalar) {

taichi/ir/type.h

+11-3
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ class TI_DLL_EXPORT Type {
4242

4343
bool is_primitive(PrimitiveTypeID type) const;
4444

45+
virtual std::vector<int> get_shape() const {
46+
return {};
47+
}
48+
49+
virtual Type *get_element_type() const {
50+
return nullptr;
51+
}
52+
4553
virtual Type *get_compute_type() {
4654
TI_NOT_IMPLEMENTED;
4755
}
@@ -160,7 +168,7 @@ class TensorType : public Type {
160168
: shape_(std::move(shape)), element_(element) {
161169
}
162170

163-
Type *get_element_type() const {
171+
Type *get_element_type() const override {
164172
return element_;
165173
}
166174

@@ -171,7 +179,7 @@ class TensorType : public Type {
171179
return num_elements;
172180
}
173181

174-
std::vector<int> get_shape() const {
182+
std::vector<int> get_shape() const override {
175183
return shape_;
176184
}
177185

@@ -339,7 +347,7 @@ class QuantArrayType : public Type {
339347
return physical_type_;
340348
}
341349

342-
Type *get_element_type() const {
350+
Type *get_element_type() const override {
343351
return element_type_;
344352
}
345353

taichi/program/ndarray.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,11 @@ Ndarray::Ndarray(DeviceAllocation &devalloc,
9191
const DataType type,
9292
const std::vector<int> &shape,
9393
const std::vector<int> &element_shape,
94-
ExternalArrayLayout layout) {
95-
TI_ASSERT(type->is<PrimitiveType>());
96-
auto tensor_type = TypeFactory::create_tensor_type(element_shape, type);
97-
Ndarray(devalloc, tensor_type, shape, layout);
94+
ExternalArrayLayout layout)
95+
: Ndarray(devalloc,
96+
TypeFactory::create_tensor_type(element_shape, type),
97+
shape,
98+
layout) {
9899
}
99100

100101
Ndarray::~Ndarray() {

taichi/python/export_lang.cpp

+7-2
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,10 @@ void export_lang(py::module &m) {
10901090

10911091
// Type system
10921092

1093-
py::class_<Type>(m, "Type").def("to_string", &Type::to_string);
1093+
py::class_<Type>(m, "Type")
1094+
.def("to_string", &Type::to_string)
1095+
.def("get_element_shape", &Type::get_shape)
1096+
.def("get_element_type", &Type::get_element_type);
10941097

10951098
// Note that it is important to specify py::return_value_policy::reference for
10961099
// the factory methods, otherwise pybind11 will delete the Types owned by
@@ -1104,7 +1107,9 @@ void export_lang(py::module &m) {
11041107
py::return_value_policy::reference)
11051108
.def("get_quant_float_type", &TypeFactory::get_quant_float_type,
11061109
py::arg("digits_type"), py::arg("exponent_type"),
1107-
py::arg("compute_type"), py::return_value_policy::reference);
1110+
py::arg("compute_type"), py::return_value_policy::reference)
1111+
.def("get_tensor_type", &TypeFactory::get_tensor_type, py::arg("shape"),
1112+
py::arg("element_type"), py::return_value_policy::reference);
11081113

11091114
m.def("get_type_factory_instance", TypeFactory::get_instance,
11101115
py::return_value_policy::reference);

0 commit comments

Comments
 (0)