diff --git a/python/taichi/_kernels.py b/python/taichi/_kernels.py index 8043cc5e0de9a..dc7d88b43dc9c 100644 --- a/python/taichi/_kernels.py +++ b/python/taichi/_kernels.py @@ -71,10 +71,11 @@ def vector_to_fast_image(img: template(), out: ndarray_type.ndarray()): r, g, b = 0, 0, 0 color = img[i, img.shape[1] - 1 - j] if static(img.dtype in [f16, f32, f64]): - r, g, b = min(255, max(0, int(color * 255))) + r, g, b = min(255, max(0, int(color * 255)))[:3] else: static_assert(img.dtype == u8) - r, g, b = color + r, g, b = color[:3] + idx = j * img.shape[0] + i # We use i32 for |out| since OpenGL and Metal doesn't support u8 types if static(get_os_name() != 'osx'): diff --git a/python/taichi/math/mathimpl.py b/python/taichi/math/mathimpl.py index 085e7d169ff2b..9561ffd74bd7b 100644 --- a/python/taichi/math/mathimpl.py +++ b/python/taichi/math/mathimpl.py @@ -10,44 +10,44 @@ import taichi as ti -vec2 = ti.types.vector(2, float) # pylint: disable=E1101 -"""2D float vector type. -""" +_get_uint_ip = lambda: ti.u32 if impl.get_runtime( +).default_ip == ti.i32 else ti.u64 -vec3 = ti.types.vector(3, float) # pylint: disable=E1101 -"""3D float vector type. -""" -vec4 = ti.types.vector(4, float) # pylint: disable=E1101 -"""4D float vector type. -""" +def vec2(*args): + """2D floating vector type. + """ + return ti.types.vector(2, float)(*args) # pylint: disable=E1101 -ivec2 = ti.types.vector(2, int) # pylint: disable=E1101 -"""2D int vector type. -""" -ivec3 = ti.types.vector(3, int) # pylint: disable=E1101 -"""3D int vector type. -""" +def vec3(*args): + """3D floating vector type. + """ + return ti.types.vector(3, float)(*args) # pylint: disable=E1101 -ivec4 = ti.types.vector(4, int) # pylint: disable=E1101 -"""4D int vector type. -""" -mat2 = ti.types.matrix(2, 2, float) # pylint: disable=E1101 -"""2x2 float matrix type. -""" +def vec4(*args): + """4D floating vector type. + """ + return ti.types.vector(4, float)(*args) # pylint: disable=E1101 -mat3 = ti.types.matrix(3, 3, float) # pylint: disable=E1101 -"""3x3 float matrix type. -""" -mat4 = ti.types.matrix(4, 4, float) # pylint: disable=E1101 -"""4x4 float matrix type. -""" +def ivec2(*args): + """2D signed int vector type. + """ + return ti.types.vector(2, int)(*args) # pylint: disable=E1101 -_get_uint_ip = lambda: ti.u32 if impl.get_runtime( -).default_ip == ti.i32 else ti.u64 + +def ivec3(*args): + """3D signed int vector type. + """ + return ti.types.vector(3, int)(*args) # pylint: disable=E1101 + + +def ivec4(*args): + """4D signed int vector type. + """ + return ti.types.vector(4, int)(*args) # pylint: disable=E1101 def uvec2(*args): @@ -68,6 +68,24 @@ def uvec4(*args): return ti.types.vector(4, _get_uint_ip())(*args) # pylint: disable=E1101 +def mat2(*args): + """2x2 floating matrix type. + """ + return ti.types.matrix(2, 2, float)(*args) # pylint: disable=E1101 + + +def mat3(*args): + """3x3 floating matrix type. + """ + return ti.types.matrix(3, 3, float)(*args) # pylint: disable=E1101 + + +def mat4(*args): + """4x4 floating matrix type. + """ + return ti.types.matrix(4, 4, float)(*args) # pylint: disable=E1101 + + @ti.func def mix(x, y, a): """Performs a linear interpolation between `x` and `y` using @@ -611,12 +629,46 @@ def length(x): return x.norm() +@ti.func +def determinant(m): + """Alias for :func:`taichi.Matrix.determinant`. + """ + return m.determinant() + + +@ti.func +def inverse(mat): # pylint: disable=R1710 + """Calculate the inverse of a matrix. + + This function is equivalent to the `inverse` function in GLSL. + + Args: + mat (:class:`taichi.Matrix`): The matrix of which to take the inverse. + + Returns: + Inverse of the input matrix. + + Example:: + + >>> @ti.kernel + >>> def test(): + >>> m = mat3([(1, 1, 0), (0, 1, 1), (0, 0, 1)]) + >>> print(inverse(m)) + >>> + >>> test() + [[1.000000, -1.000000, 1.000000], + [0.000000, 1.000000, -1.000000], + [0.000000, 0.000000, 1.000000]] + """ + return mat.inverse() + + __all__ = [ "acos", "asin", "atan2", "ceil", "clamp", "cos", "cross", "degrees", - "distance", "dot", "e", "exp", "eye", "floor", "fract", "ivec2", "ivec3", - "ivec4", "length", "log", "log2", "mat2", "mat3", "mat4", "max", "min", - "mix", "mod", "normalize", "pi", "pow", "radians", "reflect", "refract", - "rot2", "rot3", "rotate2d", "rotate3d", "round", "sign", "sin", - "smoothstep", "sqrt", "step", "tan", "tanh", "uvec2", "uvec3", "uvec4", - "vec2", "vec3", "vec4" + "determinant", "distance", "dot", "e", "exp", "eye", "floor", "fract", + "inverse", "ivec2", "ivec3", "ivec4", "length", "log", "log2", "mat2", + "mat3", "mat4", "max", "min", "mix", "mod", "normalize", "pi", "pow", + "radians", "reflect", "refract", "rot2", "rot3", "rotate2d", "rotate3d", + "round", "sign", "sin", "smoothstep", "sqrt", "step", "tan", "tanh", + "uvec2", "uvec3", "uvec4", "vec2", "vec3", "vec4" ] diff --git a/tests/python/test_api.py b/tests/python/test_api.py index af554afcbd266..a83b600d9fc89 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -98,13 +98,13 @@ def _get_expected_matrix_apis(): ] user_api[ti.math] = [ 'acos', 'asin', 'atan2', 'cconj', 'cdiv', 'ceil', 'cexp', 'cinv', 'clamp', - 'clog', 'cmul', 'cos', 'cpow', 'cross', 'csqrt', 'degrees', 'distance', - 'dot', 'e', 'exp', 'eye', 'floor', 'fract', 'ivec2', 'ivec3', 'ivec4', - 'length', 'log', 'log2', 'mat2', 'mat3', 'mat4', 'max', 'min', 'mix', - 'mod', 'normalize', 'pi', 'pow', 'radians', 'reflect', 'refract', 'rot2', - 'rot3', 'rotate2d', 'rotate3d', 'round', 'sign', 'sin', 'smoothstep', - 'sqrt', 'step', 'tan', 'tanh', 'uvec2', 'uvec3', 'uvec4', 'vec2', 'vec3', - 'vec4' + 'clog', 'cmul', 'cos', 'cpow', 'cross', 'csqrt', 'degrees', 'determinant', + 'distance', 'dot', 'e', 'exp', 'eye', 'floor', 'fract', 'inverse', 'ivec2', + 'ivec3', 'ivec4', 'length', 'log', 'log2', 'mat2', 'mat3', 'mat4', 'max', + 'min', 'mix', 'mod', 'normalize', 'pi', 'pow', 'radians', 'reflect', + 'refract', 'rot2', 'rot3', 'rotate2d', 'rotate3d', 'round', 'sign', 'sin', + 'smoothstep', 'sqrt', 'step', 'tan', 'tanh', 'uvec2', 'uvec3', 'uvec4', + 'vec2', 'vec3', 'vec4' ] user_api[ti.Matrix] = _get_expected_matrix_apis() user_api[ti.MatrixField] = [