Skip to content

Commit

Permalink
TST: add tests for new wrappable numpy API
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Dec 29, 2023
1 parent 73cd72f commit 883ec5e
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,20 @@
np.linalg.cross,
np.linalg.diagonal,
np.linalg.matmul,
np.linalg.matrix_norm,
np.linalg.matrix_transpose,
np.linalg.svdvals,
np.linalg.tensordot,
np.linalg.trace,
np.linalg.vecdot,
np.linalg.vector_norm,
np.astype,
np.matrix_transpose,
np.unique_all,
np.unique_counts,
np.unique_inverse,
np.unique_values,
np.vecdot,
}

# Functions for which behaviour is intentionally left to default
Expand Down Expand Up @@ -524,6 +530,50 @@ def test_linalg_matmul():
assert_array_equal_units(np.linalg.matmul(a, a), np.matmul(a, a))


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"),
reason="linalg.matrix_norm is new in numpy 2.0",
)
def test_linalg_matrix_norm():
a = np.eye(3) * cm
assert_array_equal_units(np.linalg.matrix_norm(a), np.linalg.norm(a))


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="matrix_transpose is new in numpy 2.0"
)
@pytest.mark.parametrize("namespace", [None, "linalg"])
def test_matrix_transpose(namespace):
if namespace is None:
func = np.matrix_transpose
else:
func = getattr(np, namespace).matrix_transpose
a = np.arange(0, 9).reshape(3, 3)
assert_array_equal_units(func(a), np.transpose(a))


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="vecdot is new in numpy 2.0"
)
@pytest.mark.parametrize("namespace", [None, "linalg"])
def test_vecdot(namespace):
if namespace is None:
func = np.vecdot
else:
func = getattr(np, namespace).vecdot
a = np.arange(0, 9)
assert_array_equal_units(func(a, a), np.vdot(a, a))


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"),
reason="linalg.vector_norm is new in numpy 2.0",
)
def test_linalg_vector_norm():
a = np.arange(0, 9)
assert_array_equal_units(np.linalg.vector_norm(a), np.linalg.norm(a))


def test_linalg_svd():
rng = np.random.default_rng()
a = (rng.standard_normal(size=(9, 6)) + 1j * rng.standard_normal(size=(9, 6))) * cm
Expand Down

0 comments on commit 883ec5e

Please sign in to comment.