From e2e232dd3f5ba61332b95f4d320307c52695bf28 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 15 Sep 2022 19:45:50 +0800 Subject: [PATCH 1/3] [Lang] Disallow invalid matrix field definition --- taichi/ir/frontend_ir.h | 9 ++ tests/python/test_matrix_different_type.py | 98 +++++----------------- 2 files changed, 29 insertions(+), 78 deletions(-) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index cadf3663b78fb..760f91e20e347 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -531,6 +531,15 @@ class MatrixFieldExpression : public Expression { MatrixFieldExpression(const std::vector &fields, const std::vector &element_shape) : fields(fields), element_shape(element_shape) { + for (auto &field : fields) { + TI_ASSERT(field.is()); + } + auto compute_type = fields[0].cast()->dt->get_compute_type(); + for (auto &field : fields) { + if (field.cast()->dt->get_compute_type() != compute_type) { + throw TaichiRuntimeError("Member fields of a matrix field must have the same compute type"); + } + } } void type_check(CompileConfig *config) override { diff --git a/tests/python/test_matrix_different_type.py b/tests/python/test_matrix_different_type.py index 397e72c40f3c1..b51b344c0b683 100644 --- a/tests/python/test_matrix_different_type.py +++ b/tests/python/test_matrix_different_type.py @@ -1,80 +1,13 @@ -from pytest import approx - +import pytest import taichi as ti from tests import test_utils -# TODO: test more matrix operations -@test_utils.test() -def test_vector(): - type_list = [ti.f32, ti.i32] - - a = ti.Vector.field(len(type_list), dtype=type_list, shape=()) - b = ti.Vector.field(len(type_list), dtype=type_list, shape=()) - c = ti.Vector.field(len(type_list), dtype=type_list, shape=()) - - @ti.kernel - def init(): - a[None] = [1.0, 3] - b[None] = [2.0, 4] - c[None] = a[None] + b[None] - - def verify(): - assert isinstance(a[None][0], float) - assert isinstance(a[None][1], int) - assert isinstance(b[None][0], float) - assert isinstance(b[None][1], int) - assert c[None][0] == 3.0 - assert c[None][1] == 7 - - init() - verify() - - -# TODO: Support different element types of Matrix on opengl -@test_utils.test(require=ti.extension.data64, exclude=ti.opengl) -def test_matrix(): - type_list = [[ti.f32, ti.i32], [ti.i64, ti.f32]] - a = ti.Matrix.field(len(type_list), - len(type_list[0]), - dtype=type_list, - shape=()) - b = ti.Matrix.field(len(type_list), - len(type_list[0]), - dtype=type_list, - shape=()) - c = ti.Matrix.field(len(type_list), - len(type_list[0]), - dtype=type_list, - shape=()) - - @ti.kernel - def init(): - a[None] = [[1.0, 3], [1, 3.0]] - b[None] = [[2.0, 4], [-2, -3.0]] - c[None] = a[None] + b[None] - - def verify(): - assert isinstance(a[None][0, 0], float) - assert isinstance(a[None][0, 1], int) - assert isinstance(b[None][0, 0], float) - assert isinstance(b[None][0, 1], int) - assert c[None][0, 0] == 3.0 - assert c[None][0, 1] == 7 - assert c[None][1, 0] == -1 - assert c[None][1, 1] == 0.0 - - init() - verify() - - @test_utils.test(require=ti.extension.quant_basic) -def test_quant_type(): - qit1 = ti.types.quant.int(bits=10, signed=True) - qfxt1 = ti.types.quant.fixed(bits=10, signed=True, scale=0.1) - qit2 = ti.types.quant.int(bits=22, signed=False) - qfxt2 = ti.types.quant.fixed(bits=22, signed=False, scale=0.1) - type_list = [[qit1, qfxt2], [qfxt1, qit2]] +def test_valid(): + qflt = ti.types.quant.float(exp=8, frac=5, signed=True) + qfxt = ti.types.quant.fixed(bits=10, signed=True, scale=0.1) + type_list = [[qflt, qfxt], [qflt, qfxt]] a = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list) b = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list) c = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list) @@ -99,15 +32,24 @@ def test_quant_type(): @ti.kernel def init(): - a[0] = [[1, 3.], [2., 1]] - b[0] = [[2, 4.], [-2., 1]] + a[0] = [[1.0, 3.0], [2.0, 1.0]] + b[0] = [[2.0, 4.0], [-2.0, 1.0]] c[0] = a[0] + b[0] def verify(): - assert c[0][0, 0] == approx(3, 1e-3) - assert c[0][0, 1] == approx(7.0, 1e-3) - assert c[0][1, 0] == approx(0, 1e-3) - assert c[0][1, 1] == approx(2, 1e-3) + assert c[0][0, 0] == pytest.approx(3.0) + assert c[0][0, 1] == pytest.approx(7.0) + assert c[0][1, 0] == pytest.approx(0.0) + assert c[0][1, 1] == pytest.approx(2.0) init() verify() + + +@test_utils.test(require=ti.extension.quant_basic) +def test_invalid(): + qit = ti.types.quant.int(bits=10, signed=True) + qfxt = ti.types.quant.fixed(bits=10, signed=True, scale=0.1) + type_list = [qit, qfxt] + with pytest.raises(RuntimeError, match='Member fields of a matrix field must have the same compute type'): + a = ti.Vector.field(len(type_list), dtype=type_list) From 8564793c51ffadacb24452e826950e91b83856d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 15 Sep 2022 11:55:32 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- taichi/ir/frontend_ir.h | 9 ++++++--- tests/python/test_matrix_different_type.py | 6 +++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 760f91e20e347..07d3cad88be37 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -534,10 +534,13 @@ class MatrixFieldExpression : public Expression { for (auto &field : fields) { TI_ASSERT(field.is()); } - auto compute_type = fields[0].cast()->dt->get_compute_type(); + auto compute_type = + fields[0].cast()->dt->get_compute_type(); for (auto &field : fields) { - if (field.cast()->dt->get_compute_type() != compute_type) { - throw TaichiRuntimeError("Member fields of a matrix field must have the same compute type"); + if (field.cast()->dt->get_compute_type() != + compute_type) { + throw TaichiRuntimeError( + "Member fields of a matrix field must have the same compute type"); } } } diff --git a/tests/python/test_matrix_different_type.py b/tests/python/test_matrix_different_type.py index b51b344c0b683..9a33ca4e7d369 100644 --- a/tests/python/test_matrix_different_type.py +++ b/tests/python/test_matrix_different_type.py @@ -1,4 +1,5 @@ import pytest + import taichi as ti from tests import test_utils @@ -51,5 +52,8 @@ def test_invalid(): qit = ti.types.quant.int(bits=10, signed=True) qfxt = ti.types.quant.fixed(bits=10, signed=True, scale=0.1) type_list = [qit, qfxt] - with pytest.raises(RuntimeError, match='Member fields of a matrix field must have the same compute type'): + with pytest.raises( + RuntimeError, + match= + 'Member fields of a matrix field must have the same compute type'): a = ti.Vector.field(len(type_list), dtype=type_list) From fc747ed32efbe4e36163922c984d371522a6aac6 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Fri, 16 Sep 2022 13:47:43 +0800 Subject: [PATCH 3/3] Add assertion --- taichi/ir/frontend_ir.h | 1 + 1 file changed, 1 insertion(+) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index 07d3cad88be37..cd022734835fb 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -534,6 +534,7 @@ class MatrixFieldExpression : public Expression { for (auto &field : fields) { TI_ASSERT(field.is()); } + TI_ASSERT(!fields.empty()); auto compute_type = fields[0].cast()->dt->get_compute_type(); for (auto &field : fields) {