Skip to content

Commit

Permalink
Merge pull request #483 from neutrinoceros/numpy2_new_api
Browse files Browse the repository at this point in the history
(NEP 18) Implement and test array functions new in numpy 2.0
  • Loading branch information
jzuhone authored Jan 9, 2024
2 parents f384ff2 + 883ec5e commit c1582f7
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 17 deletions.
6 changes: 6 additions & 0 deletions unyt/_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,12 @@ def asfarray(a, dtype=np.double):
ret_units = a.units
return np.asfarray._implementation(np.asarray(a), dtype=dtype) * ret_units

elif NUMPY_VERSION >= Version("2.0.0dev0"):
# functions that were added in numpy 2.0.0
@implements(np.linalg.outer)
def linalg_outer(x1, x2, /):
return product_helper(x1, x2, out=None, func=np.linalg.outer)


# functions with pending deprecations
if hasattr(np, "trapz"):
Expand Down
5 changes: 4 additions & 1 deletion unyt/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def assert_array_equal_units(x, y, **kwargs):
from numpy.testing import assert_array_equal

assert_array_equal(x, y, **kwargs)
assert getattr(x, "units", NULL_UNIT) == getattr(y, "units", NULL_UNIT)
if not (xu := getattr(x, "units", NULL_UNIT)) == (
yu := getattr(y, "units", NULL_UNIT)
):
raise AssertionError(f"Arguments' units do not match (got {xu} and {yu})")


def _process_warning(op, message, warning_class, args=(), kwargs=None):
Expand Down
211 changes: 195 additions & 16 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pytest
from numpy.testing import assert_allclose
from packaging.version import Version

from unyt import A, K, cm, degC, delta_degC, g, km, rad, s
Expand Down Expand Up @@ -153,9 +154,25 @@


if NUMPY_VERSION >= Version("2.0.0dev0"):
# the followin all work out of the box (tested)
NOOP_FUNCTIONS |= {
np.linalg.diagonal, # works out of the box (tested)
np.linalg.trace, # works out of the box (tested)
np.linalg.cross,
np.linalg.diagonal,
np.linalg.matmul,
np.linalg.matrix_norm,
np.linalg.matrix_transpose,
np.linalg.svdvals,
np.linalg.tensordot,
np.linalg.trace,
np.linalg.vecdot,
np.linalg.vector_norm,
np.astype,
np.matrix_transpose,
np.unique_all,
np.unique_counts,
np.unique_inverse,
np.unique_values,
np.vecdot,
}

# Functions for which behaviour is intentionally left to default
Expand Down Expand Up @@ -498,6 +515,110 @@ def test_linalg_trace():
assert b.units == a.units


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="linalg.outer is new in numpy 2.0"
)
def test_linalg_outer():
a = np.arange(10) * cm
assert_array_equal_units(np.linalg.outer(a, a), np.outer(a, a))


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="linalg.cross is new in numpy 2.0"
)
def test_linalg_cross():
a = np.arange(3) * cm
assert_array_equal_units(np.linalg.cross(a, a), np.cross(a, a))


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="linalg.matmul is new in numpy 2.0"
)
def test_linalg_matmul():
a = np.eye(3) * cm
assert_array_equal_units(np.linalg.matmul(a, a), np.matmul(a, a))


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"),
reason="linalg.matrix_norm is new in numpy 2.0",
)
def test_linalg_matrix_norm():
a = np.eye(3) * cm
assert_array_equal_units(np.linalg.matrix_norm(a), np.linalg.norm(a))


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="matrix_transpose is new in numpy 2.0"
)
@pytest.mark.parametrize("namespace", [None, "linalg"])
def test_matrix_transpose(namespace):
if namespace is None:
func = np.matrix_transpose
else:
func = getattr(np, namespace).matrix_transpose
a = np.arange(0, 9).reshape(3, 3)
assert_array_equal_units(func(a), np.transpose(a))


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="vecdot is new in numpy 2.0"
)
@pytest.mark.parametrize("namespace", [None, "linalg"])
def test_vecdot(namespace):
if namespace is None:
func = np.vecdot
else:
func = getattr(np, namespace).vecdot
a = np.arange(0, 9)
assert_array_equal_units(func(a, a), np.vdot(a, a))


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"),
reason="linalg.vector_norm is new in numpy 2.0",
)
def test_linalg_vector_norm():
a = np.arange(0, 9)
assert_array_equal_units(np.linalg.vector_norm(a), np.linalg.norm(a))


def test_linalg_svd():
rng = np.random.default_rng()
a = (rng.standard_normal(size=(9, 6)) + 1j * rng.standard_normal(size=(9, 6))) * cm
u, s, vh = np.linalg.svd(a)
assert type(u) is np.ndarray
assert type(vh) is np.ndarray
assert type(s) is unyt_array
assert s.units == cm

s = np.linalg.svd(a, compute_uv=False)
assert type(s) is unyt_array
assert s.units == cm


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="linalg.svdvals is new in numpy 2.0"
)
def test_linalg_svdvals():
q = np.arange(9).reshape(3, 3) * cm

_, ref, _ = np.linalg.svd(q)
res = np.linalg.svdvals(q)
assert type(res) is unyt_array
assert_allclose(res, ref, rtol=5e-16)


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="linalg.tensordot is new in numpy 2.0"
)
def test_linalg_tensordot():
q = np.arange(9).reshape(3, 3) * cm
ref = np.tensordot(q, q)
res = np.linalg.tensordot(q, q)
assert_array_equal_units(res, ref)


def test_histogram():
rng = np.random.default_rng()
arr = rng.normal(size=1000) * cm
Expand Down Expand Up @@ -1009,6 +1130,16 @@ def test_copyto_edge_cases():
assert type(y) is np.ndarray


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="astype is new in numpy 2.0"
)
def test_astype():
x = np.array([1, 2, 3], dtype="int64") * cm
res = np.astype(x, "int32")
assert type(res) is unyt_array
assert res.units == cm


def test_meshgrid():
x = [1, 2, 3] * cm
y = [1, 2, 3] * s
Expand Down Expand Up @@ -1270,20 +1401,6 @@ def test_eigvals(func):
assert w.units == cm


def test_linalg_svd():
rng = np.random.default_rng()
a = (rng.standard_normal(size=(9, 6)) + 1j * rng.standard_normal(size=(9, 6))) * cm
u, s, vh = np.linalg.svd(a)
assert type(u) is np.ndarray
assert type(vh) is np.ndarray
assert type(s) is unyt_array
assert s.units == cm

s = np.linalg.svd(a, compute_uv=False)
assert type(s) is unyt_array
assert s.units == cm


def test_savetxt(tmp_path):
a = [1, 2, 3] * cm
with pytest.raises(
Expand Down Expand Up @@ -1421,6 +1538,68 @@ def test_unique():
assert res.units == cm


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="unique_all is new in numpy 2.0"
)
def test_unique_all():
q = np.arange(9).reshape(3, 3) * cm
values, indices, inverse_indices, counts = np.unique(
q,
return_index=True,
return_inverse=True,
return_counts=True,
equal_nan=False,
)
res = np.unique_all(q)
assert len(res) == 4
assert_array_equal_units(res.values, values)
assert_array_equal_units(res.indices, indices)
assert_array_equal_units(res.inverse_indices, inverse_indices)
assert_array_equal_units(res.counts, counts)


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="unique_counts is new in numpy 2.0"
)
def test_unique_counts():
q = np.arange(9).reshape(3, 3) * cm
values, counts = np.unique(
q,
return_counts=True,
equal_nan=False,
)
res = np.unique_counts(q)
assert len(res) == 2
assert_array_equal_units(res.values, values)
assert_array_equal_units(res.counts, counts)


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="unique_inverse is new in numpy 2.0"
)
def test_unique_inverse():
q = np.arange(9).reshape(3, 3) * cm
values, inverse_indices = np.unique(
q,
return_inverse=True,
equal_nan=False,
)
res = np.unique_inverse(q)
assert len(res) == 2
assert_array_equal_units(res.values, values)
assert_array_equal_units(res.inverse_indices, inverse_indices)


@pytest.mark.skipif(
NUMPY_VERSION < Version("2.0.0dev0"), reason="unique_values is new in numpy 2.0"
)
def test_unique_values():
q = np.arange(9).reshape(3, 3) * cm
values = np.unique(q, equal_nan=False)
res = np.unique_values(q)
assert_array_equal_units(res, values)


def test_take():
a = [1, 2, 3] * cm
res = np.take(a, [0, 1])
Expand Down

0 comments on commit c1582f7

Please sign in to comment.