From 5c028b247cdb78a10fac06ecedb4ef95b8013314 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 30 May 2024 08:29:42 +0200 Subject: [PATCH 01/13] Fix signatures for a few functions. --- ci/Numba-array-api-xfails.txt | 13 ---- sparse/numba_backend/__init__.py | 2 + sparse/numba_backend/_common.py | 65 +++++++++++++++----- sparse/numba_backend/_coo/common.py | 10 ++- sparse/numba_backend/tests/test_namespace.py | 1 + 5 files changed, 60 insertions(+), 31 deletions(-) diff --git a/ci/Numba-array-api-xfails.txt b/ci/Numba-array-api-xfails.txt index c3de1977..7d715bbf 100644 --- a/ci/Numba-array-api-xfails.txt +++ b/ci/Numba-array-api-xfails.txt @@ -73,24 +73,11 @@ array_api_tests/test_set_functions.py::test_unique_inverse array_api_tests/test_signatures.py::test_func_signature[unique_all] array_api_tests/test_signatures.py::test_func_signature[unique_inverse] array_api_tests/test_signatures.py::test_func_signature[arange] -array_api_tests/test_signatures.py::test_func_signature[empty] -array_api_tests/test_signatures.py::test_func_signature[empty_like] -array_api_tests/test_signatures.py::test_func_signature[eye] array_api_tests/test_signatures.py::test_func_signature[from_dlpack] -array_api_tests/test_signatures.py::test_func_signature[full] -array_api_tests/test_signatures.py::test_func_signature[full_like] array_api_tests/test_signatures.py::test_func_signature[linspace] array_api_tests/test_signatures.py::test_func_signature[meshgrid] -array_api_tests/test_signatures.py::test_func_signature[ones] -array_api_tests/test_signatures.py::test_func_signature[ones_like] -array_api_tests/test_signatures.py::test_func_signature[zeros] -array_api_tests/test_signatures.py::test_func_signature[zeros_like] -array_api_tests/test_signatures.py::test_func_signature[broadcast_to] -array_api_tests/test_signatures.py::test_func_signature[squeeze] array_api_tests/test_signatures.py::test_func_signature[argsort] -array_api_tests/test_signatures.py::test_func_signature[sort] array_api_tests/test_signatures.py::test_func_signature[isdtype] -array_api_tests/test_signatures.py::test_func_signature[conj] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] array_api_tests/test_signatures.py::test_array_method_signature[__setitem__] diff --git a/sparse/numba_backend/__init__.py b/sparse/numba_backend/__init__.py index b51cbfe3..419003dd 100644 --- a/sparse/numba_backend/__init__.py +++ b/sparse/numba_backend/__init__.py @@ -88,6 +88,7 @@ broadcast_to, concat, concatenate, + conj, dot, einsum, empty, @@ -201,6 +202,7 @@ "complex64", "concat", "concatenate", + "conj", "cos", "cosh", "diagonal", diff --git a/sparse/numba_backend/_common.py b/sparse/numba_backend/_common.py index 83644677..c01ac9a0 100644 --- a/sparse/numba_backend/_common.py +++ b/sparse/numba_backend/_common.py @@ -29,6 +29,17 @@ def _is_scipy_sparse_obj(x): return False +def _check_device(func): + @wraps(func) + def wrapped(*args, **kwargs): + device = kwargs.get("device", None) + if device not in {"cpu", None}: + raise ValueError("Device must be `'cpu'` or `None`.") + return func(*args, **kwargs) + + return wrapped + + def _is_sparse(x): """ Tests if the supplied argument is a SciPy sparse object, or one from this library. @@ -1533,7 +1544,8 @@ def concatenate(arrays, axis=0, compressed_axes=None): concat = concatenate -def eye(N, M=None, k=0, dtype=float, format="coo", **kwargs): +@_check_device +def eye(N, M=None, k=0, dtype=float, format="coo", *, device=None, **kwargs): """Return a 2-D array in the specified format with ones on the diagonal and zeros elsewhere. Parameters @@ -1595,7 +1607,8 @@ def eye(N, M=None, k=0, dtype=float, format="coo", **kwargs): return COO(coords, data=data, shape=(N, M), has_duplicates=False, sorted=True).asformat(format, **kwargs) -def full(shape, fill_value, dtype=None, format="coo", order="C", **kwargs): +@_check_device +def full(shape, fill_value, dtype=None, format="coo", order="C", *, device=None, **kwargs): """Return a SparseArray of given shape and type, filled with `fill_value`. Parameters @@ -1649,7 +1662,8 @@ def full(shape, fill_value, dtype=None, format="coo", order="C", **kwargs): ).asformat(format, **kwargs) -def full_like(a, fill_value, dtype=None, shape=None, format=None, **kwargs): +@_check_device +def full_like(a, fill_value, dtype=None, shape=None, format=None, *, device=None, **kwargs): """Return a full array with the same shape and type as a given array. Parameters @@ -1692,7 +1706,7 @@ def full_like(a, fill_value, dtype=None, shape=None, format=None, **kwargs): ) -def zeros(shape, dtype=float, format="coo", **kwargs): +def zeros(shape, dtype=float, format="coo", *, device=None, **kwargs): """Return a SparseArray of given shape and type, filled with zeros. Parameters @@ -1721,10 +1735,10 @@ def zeros(shape, dtype=float, format="coo", **kwargs): array([[0, 0], [0, 0]]) """ - return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, **kwargs) + return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, device=device, **kwargs) -def zeros_like(a, dtype=None, shape=None, format=None, **kwargs): +def zeros_like(a, dtype=None, shape=None, format=None, *, device=None, **kwargs): """Return a SparseArray of zeros with the same shape and type as ``a``. Parameters @@ -1750,10 +1764,10 @@ def zeros_like(a, dtype=None, shape=None, format=None, **kwargs): array([[0, 0, 0], [0, 0, 0]]) """ - return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, **kwargs) + return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, device=device, **kwargs) -def ones(shape, dtype=float, format="coo", **kwargs): +def ones(shape, dtype=float, format="coo", *, device=None, **kwargs): """Return a SparseArray of given shape and type, filled with ones. Parameters @@ -1782,10 +1796,10 @@ def ones(shape, dtype=float, format="coo", **kwargs): array([[1, 1], [1, 1]]) """ - return full(shape, fill_value=1, dtype=np.dtype(dtype), format=format, **kwargs) + return full(shape, fill_value=1, dtype=np.dtype(dtype), format=format, device=device, **kwargs) -def ones_like(a, dtype=None, shape=None, format=None, **kwargs): +def ones_like(a, dtype=None, shape=None, format=None, *, device=None, **kwargs): """Return a SparseArray of ones with the same shape and type as ``a``. Parameters @@ -1811,18 +1825,18 @@ def ones_like(a, dtype=None, shape=None, format=None, **kwargs): array([[1, 1, 1], [1, 1, 1]]) """ - return full_like(a, fill_value=1, dtype=dtype, shape=shape, format=format, **kwargs) + return full_like(a, fill_value=1, dtype=dtype, shape=shape, format=format, device=device, **kwargs) -def empty(shape, dtype=float, format="coo", **kwargs): - return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, **kwargs) +def empty(shape, dtype=float, format="coo", *, device=None, **kwargs): + return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, device=device, **kwargs) empty.__doc__ = zeros.__doc__ -def empty_like(a, dtype=None, shape=None, format=None, **kwargs): - return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, **kwargs) +def empty_like(a, dtype=None, shape=None, format=None, *, device=None, **kwargs): + return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, device=device, **kwargs) empty_like.__doc__ = zeros_like.__doc__ @@ -2004,7 +2018,8 @@ def format_to_string(format): raise ValueError(f"invalid format: {format}") -def asarray(obj, /, *, dtype=None, format="coo", device=None, copy=False): +@_check_device +def asarray(obj, /, *, dtype=None, format="coo", copy=False, device=None): """ Convert the input to a sparse array. @@ -2066,6 +2081,7 @@ def _support_numpy(func): we want to flag it and dispatch to NumPy. """ + @wraps(func) def wrapper_func(*args, **kwargs): x = args[0] if isinstance(x, np.ndarray | np.number): @@ -2130,12 +2146,29 @@ def reshape(x, /, shape, *, copy=None): return x.reshape(shape=shape) +def conj(x, /): + return x.conj() + + def astype(x, dtype, /, *, copy=True): return x.astype(dtype, copy=copy) @_support_numpy def squeeze(x, /, axis=None): + """Remove singleton dimensions from array. + + Parameters + ---------- + x : SparseArray + axis : int or tuple[int, ...], optional + The singleton axes to remove. By default all singleton axes are removed. + + Returns + ------- + SparseArray + Array with singleton dimensions removed. + """ return x.squeeze(axis=axis) diff --git a/sparse/numba_backend/_coo/common.py b/sparse/numba_backend/_coo/common.py index 5fe8c2c0..cc6eb79d 100644 --- a/sparse/numba_backend/_coo/common.py +++ b/sparse/numba_backend/_coo/common.py @@ -1243,7 +1243,7 @@ def unique_values(x, /): return values -def sort(x, /, *, axis=-1, descending=False): +def sort(x, /, *, axis=-1, descending=False, stable=False): """ Returns a sorted copy of an input array ``x``. @@ -1258,6 +1258,9 @@ def sort(x, /, *, axis=-1, descending=False): Sort order. If ``True``, the array must be sorted in descending order (by value). If ``False``, the array must be sorted in ascending order (by value). Default: ``False``. + stable : bool + Whether the sort is stable. Provided for compatibility with the Array API, only + ``False`` (the default) is currently supported. Returns ------- @@ -1279,12 +1282,14 @@ def sort(x, /, *, axis=-1, descending=False): array([ 2, 2, 1, 0, 0, -3]) """ - from .._common import moveaxis from .core import COO x = _validate_coo_input(x) + if stable: + raise ValueError("`stable=True` isn't currently supported.") + original_ndim = x.ndim if x.ndim == 1: x = x[None, :] @@ -1362,6 +1367,7 @@ def _sort_coo( fill_value: float, sort_axis_len: int, descending: bool, + stable: bool, ) -> tuple[np.ndarray, np.ndarray]: assert coords.shape[0] == 2 group_coords = coords[0, :] diff --git a/sparse/numba_backend/tests/test_namespace.py b/sparse/numba_backend/tests/test_namespace.py index 902b1ffd..39556f99 100644 --- a/sparse/numba_backend/tests/test_namespace.py +++ b/sparse/numba_backend/tests/test_namespace.py @@ -43,6 +43,7 @@ def test_namespace(): "complex64", "concat", "concatenate", + "conj", "cos", "cosh", "diagonal", From 034b454ce2537a105e37abab2cbbc67c5e524572 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Thu, 30 May 2024 11:38:04 +0200 Subject: [PATCH 02/13] Fix scalar casting issues. --- ci/Numba-array-api-xfails.txt | 19 -------------- .../numba_backend/_compressed/compressed.py | 10 ++++++++ sparse/numba_backend/_coo/common.py | 7 +----- sparse/numba_backend/_coo/core.py | 10 ++++++++ sparse/numba_backend/_sparse_array.py | 15 +++++++++-- sparse/numba_backend/_umath.py | 25 ++++++++++++------- sparse/numba_backend/_utils.py | 22 +++++++++++----- 7 files changed, 66 insertions(+), 42 deletions(-) diff --git a/ci/Numba-array-api-xfails.txt b/ci/Numba-array-api-xfails.txt index 7d715bbf..c8747e9d 100644 --- a/ci/Numba-array-api-xfails.txt +++ b/ci/Numba-array-api-xfails.txt @@ -37,33 +37,15 @@ array_api_tests/test_has_names.py::test_has_names[creation-linspace] array_api_tests/test_has_names.py::test_has_names[creation-meshgrid] array_api_tests/test_has_names.py::test_has_names[sorting-argsort] array_api_tests/test_has_names.py::test_has_names[data_type-isdtype] -array_api_tests/test_has_names.py::test_has_names[elementwise-conj] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] array_api_tests/test_has_names.py::test_has_names[array_method-__setitem__] -array_api_tests/test_has_names.py::test_has_names[array_method-to_device] -array_api_tests/test_has_names.py::test_has_names[array_attribute-device] -array_api_tests/test_has_names.py::test_has_names[array_attribute-mT] array_api_tests/test_indexing_functions.py::test_take array_api_tests/test_linalg.py::test_vecdot array_api_tests/test_manipulation_functions.py::test_concat -array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_ceil -array_api_tests/test_operators_and_elementwise_functions.py::test_conj -array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_imag -array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_real -array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)] -array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)] array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_searching_functions.py::test_argmax array_api_tests/test_searching_functions.py::test_argmin @@ -81,7 +63,6 @@ array_api_tests/test_signatures.py::test_func_signature[isdtype] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__] array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__] array_api_tests/test_signatures.py::test_array_method_signature[__setitem__] -array_api_tests/test_signatures.py::test_array_method_signature[to_device] array_api_tests/test_sorting_functions.py::test_argsort array_api_tests/test_sorting_functions.py::test_sort array_api_tests/test_special_cases.py::test_unary[isfinite(x_i is NaN) -> False] diff --git a/sparse/numba_backend/_compressed/compressed.py b/sparse/numba_backend/_compressed/compressed.py index f92c3687..314ca9f9 100644 --- a/sparse/numba_backend/_compressed/compressed.py +++ b/sparse/numba_backend/_compressed/compressed.py @@ -324,6 +324,16 @@ def _reordered_shape(self): def T(self): return self.transpose() + @property + def mT(self): + if self.ndim < 2: + raise ValueError("Cannot compute matrix transpose if `ndim < 2`.") + + axis = list(range(self.ndim)) + axis[-1], axis[-2] = axis[-2], axis[-1] + + return self.transpose(axis) + def __str__(self): summary = ( f" tuple[np.ndarray, np.ndarray]: assert coords.shape[0] == 2 group_coords = coords[0, :] diff --git a/sparse/numba_backend/_coo/core.py b/sparse/numba_backend/_coo/core.py index ae5d039d..73cf2f9b 100644 --- a/sparse/numba_backend/_coo/core.py +++ b/sparse/numba_backend/_coo/core.py @@ -849,6 +849,16 @@ def T(self): """ return self.transpose(tuple(range(self.ndim))[::-1]) + @property + def mT(self): + if self.ndim < 2: + raise ValueError("Cannot compute matrix transpose if `ndim < 2`.") + + axis = list(range(self.ndim)) + axis[-1], axis[-2] = axis[-2], axis[-1] + + return self.transpose(axis) + def swapaxes(self, axis1, axis2): """Returns array that has axes axis1 and axis2 swapped. diff --git a/sparse/numba_backend/_sparse_array.py b/sparse/numba_backend/_sparse_array.py index 0c9b2e0b..242414c1 100644 --- a/sparse/numba_backend/_sparse_array.py +++ b/sparse/numba_backend/_sparse_array.py @@ -47,6 +47,17 @@ def __init__(self, shape, fill_value=None): dtype = None + @property + def device(self): + data = getattr(self, "data", None) + return getattr(data, "device", "cpu") + + def to_device(self, device, /, *, stream=None): + if device != "cpu": + raise ValueError("Only `device='cpu'` is supported.") + + return self + @property @abstractmethod def nnz(self): @@ -315,11 +326,11 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): return self.__array_function__(ufunc, (np.ndarray, type(self)), inputs, kwargs) if out is not None: - test_args = [np.empty(1, dtype=a.dtype) if hasattr(a, "dtype") else [a] for a in inputs] + test_args = [np.empty((1,), dtype=a.dtype) if hasattr(a, "dtype") else a for a in inputs] test_kwargs = kwargs.copy() if method == "reduce": test_kwargs["axis"] = None - test_out = tuple(np.empty(1, dtype=a.dtype) for a in out) + test_out = tuple(np.empty((1,), dtype=a.dtype) for a in out) if len(test_out) == 1: test_out = test_out[0] getattr(ufunc, method)(*test_args, out=test_out, **test_kwargs) diff --git a/sparse/numba_backend/_umath.py b/sparse/numba_backend/_umath.py index 1d77e1a9..786f062f 100644 --- a/sparse/numba_backend/_umath.py +++ b/sparse/numba_backend/_umath.py @@ -431,7 +431,7 @@ def __init__(self, func, *args, **kwargs): processed_args.append(COO.from_scipy_sparse(arg)) elif isscalar(arg) or isinstance(arg, np.ndarray): # Faster and more reliable to pass ()-shaped ndarrays as scalars. - processed_args.append(np.asarray(arg)) + processed_args.append(arg) elif isinstance(arg, SparseArray): if not isinstance(arg, COO): arg = arg.asformat(COO) @@ -513,15 +513,22 @@ def _get_fill_value(self): """ from ._coo import COO - zero_args = tuple( - np.asarray(arg.fill_value, like=arg.data) if isinstance(arg, COO) else arg for arg in self.args - ) + def get_zero_arg(x): + if isinstance(x, COO): + return np.atleast_1d(x.fill_value) + + if isinstance(x, np.generic | np.ndarray): + return np.atleast_1d(x) + + return x + + zero_args = tuple(get_zero_arg(a) for a in self.args) # Some elemwise functions require a dtype argument, some abhorr it. try: - fill_value_array = self.func(*np.broadcast_arrays(*zero_args), dtype=self.dtype, **self.kwargs) + fill_value_array = self.func(*zero_args, dtype=self.dtype, **self.kwargs) except TypeError: - fill_value_array = self.func(*np.broadcast_arrays(*zero_args), **self.kwargs) + fill_value_array = self.func(*zero_args, **self.kwargs) try: fill_value = fill_value_array[(0,) * fill_value_array.ndim] @@ -531,7 +538,7 @@ def _get_fill_value(self): ) fill_value = self.func(*zero_args, **self.kwargs)[()] - equivalent_fv = equivalent(fill_value, fill_value_array).all() + equivalent_fv = equivalent(fill_value, fill_value_array, loose=True).all() if not equivalent_fv and self.shape != self.ndarray_shape: raise ValueError( "Performing a mixed sparse-dense operation that would result in a dense array. " @@ -558,7 +565,7 @@ def _check_broadcast(self): """ from ._coo import COO - full_shape = _get_nary_broadcast_shape(*tuple(arg.shape for arg in self.args)) + full_shape = _get_nary_broadcast_shape(*tuple(np.shape(arg) for arg in self.args)) non_ndarray_shape = _get_nary_broadcast_shape(*tuple(arg.shape for arg in self.args if isinstance(arg, COO))) ndarray_shape = _get_nary_broadcast_shape(*tuple(arg.shape for arg in self.args if isinstance(arg, np.ndarray))) @@ -587,7 +594,7 @@ def _get_func_coords_data(self, mask): ndarray_args = [arg for arg, m in zip(self.args, mask, strict=True) if m is None] matched_broadcast_shape = _get_nary_broadcast_shape( - *tuple(arg.shape for arg in itertools.chain(matched_args, ndarray_args)) + *tuple(np.shape(arg) for arg in itertools.chain(matched_args, ndarray_args)) ) matched_arrays = self._match_coo(*matched_args, cache=self.cache, broadcast_shape=matched_broadcast_shape) diff --git a/sparse/numba_backend/_utils.py b/sparse/numba_backend/_utils.py index 08a39ecf..5695dd5e 100644 --- a/sparse/numba_backend/_utils.py +++ b/sparse/numba_backend/_utils.py @@ -406,7 +406,7 @@ def normalize_axis(axis, ndim): raise ValueError(f"axis {axis} not understood") -def equivalent(x, y): +def equivalent(x, y, /, loose=False): """ Checks the equivalence of two scalars or arrays with broadcasting. Assumes a consistent dtype. @@ -432,17 +432,27 @@ def equivalent(x, y): >>> equivalent(np.inf, np.inf) True >>> equivalent(np.PZERO, np.NZERO) - True + False """ x = np.asarray(x) y = np.asarray(y) # Can't contain NaNs - if any(np.issubdtype(x.dtype, t) for t in [np.integer, np.bool_, np.character]): + dt = np.result_type(x.dtype, y.dtype) + if not any(np.issubdtype(dt, t) for t in [np.floating, np.complexfloating]): return x == y - # Can contain NaNs - # FIXME: Complex floats and np.void with multiple values can't be compared properly. - return (x == y) | ((x != x) & (y != y)) # noqa: PLR0124 + if loose: + if np.issubdtype(dt, np.complexfloating): + return equivalent(x.real, y.real) & equivalent(x.imag, y.imag) + + # TODO: Rec array handling + return (x == y) | ((x != x) & (y != y)) + + if x.size == 0 or y.size == 0: + shape = np.broadcast_shapes(x.shape, y.shape) + return np.empty(shape, dtype=np.bool_) + x, y = np.broadcast_arrays(x[..., None], y[..., None]) + return (x.astype(dt).view(np.uint8) == y.astype(dt).view(np.uint8)).all(axis=-1) # copied from zarr From 975c09d0c41563c2da1be36bd5867d5915f6f9aa Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 08:42:39 +0200 Subject: [PATCH 03/13] Add some tests. --- sparse/numba_backend/_utils.py | 2 +- sparse/numba_backend/tests/test_coo.py | 59 ++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/sparse/numba_backend/_utils.py b/sparse/numba_backend/_utils.py index 5695dd5e..8d1fb5ed 100644 --- a/sparse/numba_backend/_utils.py +++ b/sparse/numba_backend/_utils.py @@ -559,7 +559,7 @@ def check_fill_value(x, /, *, accept_fv=None) -> None: if not isinstance(accept_fv, Iterable): accept_fv = [accept_fv] - if not any(equivalent(fv, x.fill_value) for fv in accept_fv): + if not any(equivalent(fv, x.fill_value, loose=True) for fv in accept_fv): raise ValueError(f"{x.fill_value=} but should be in {accept_fv}.") diff --git a/sparse/numba_backend/tests/test_coo.py b/sparse/numba_backend/tests/test_coo.py index 0c548c6c..d8336916 100644 --- a/sparse/numba_backend/tests/test_coo.py +++ b/sparse/numba_backend/tests/test_coo.py @@ -1793,14 +1793,19 @@ def test_expand_dims(axis): @pytest.mark.parametrize("fill_value", [-1, 0, 1, 3]) @pytest.mark.parametrize("axis", [0, 1, -1]) @pytest.mark.parametrize("descending", [False, True]) -def test_sort(arr, fill_value, axis, descending): +@pytest.mark.parametrize( + "stable", [False, pytest.param(True, marks=pytest.mark.xfail(reason="Numba doesn't support `stable=True`."))] +) +def test_sort(arr, fill_value, axis, descending, stable): if axis >= arr.ndim: return s_arr = sparse.COO.from_numpy(arr, fill_value) - result = sparse.sort(s_arr, axis=axis, descending=descending) - expected = -np.sort(-arr, axis=axis) if descending else np.sort(arr, axis=axis) + kind = "mergesort" if stable else "quicksort" + + result = sparse.sort(s_arr, axis=axis, descending=descending, stable=stable) + expected = -np.sort(-arr, axis=axis, kind=kind) if descending else np.sort(arr, axis=axis, kind=kind) np.testing.assert_equal(result.todense(), expected) # make sure no inplace changes happened @@ -1868,7 +1873,8 @@ def test_matrix_transpose(ndim, density): expected = np.transpose(xd, axes=transpose_axes) actual = sparse.matrix_transpose(xs) - np.testing.assert_equal(actual.todense(), expected) + assert_eq(actual, expected) + assert_eq(xs.mT, expected) @pytest.mark.parametrize( @@ -1913,3 +1919,48 @@ def np_vecdot(x1, x2, /, *, axis=-1): actual = sparse.vecdot(s1, s2, axis=axis) np.testing.assert_allclose(actual.todense(), expected) + + +@pytest.mark.parametrize( + ("func", "args", "kwargs"), + [ + (sparse.eye, (5,), {}), + (sparse.zeros, ((5,)), {}), + (sparse.ones, ((5,)), {}), + (sparse.full, ((5,), 5), {}), + (sparse.empty, ((5,)), {}), + (sparse.full_like, (5,), {}), + (sparse.ones_like, (), {}), + (sparse.zeros_like, (), {}), + (sparse.empty_like, (), {}), + (sparse.asarray, (), {}), + ], +) +def test_invalid_device(func, args, kwargs): + if func.__name__.endswith("_like") or func is sparse.asarray: + like = sparse.random((5, 5), density=0.5) + args = (like,) + args + + with pytest.raises(ValueError, match="Device must be"): + func(*args, device="invalid_device", **kwargs) + + +def test_device(): + s = sparse.random((5, 5), density=0.5) + data = getattr(s, "data", None) + device = getattr(data, "device", "cpu") + + assert s.device == device + + +def test_to_device(): + s = sparse.random((5, 5), density=0.5) + s2 = s.to_device(s.device) + + assert s is s2 + + +def test_to_invalid_device(): + s = sparse.random((5, 5), density=0.5) + with pytest.raises(ValueError, match=r"Only .* is supported."): + s.to_device("invalid_device") From cd6acaaca1f68b1ba5e21d585d91f03294944854 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 08:53:57 +0200 Subject: [PATCH 04/13] Remove spurious function. --- sparse/numba_backend/__init__.py | 2 +- sparse/numba_backend/_common.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sparse/numba_backend/__init__.py b/sparse/numba_backend/__init__.py index 419003dd..74cb3d57 100644 --- a/sparse/numba_backend/__init__.py +++ b/sparse/numba_backend/__init__.py @@ -8,6 +8,7 @@ ceil, complex64, complex128, + conj, cos, cosh, divide, @@ -88,7 +89,6 @@ broadcast_to, concat, concatenate, - conj, dot, einsum, empty, diff --git a/sparse/numba_backend/_common.py b/sparse/numba_backend/_common.py index c01ac9a0..4614364b 100644 --- a/sparse/numba_backend/_common.py +++ b/sparse/numba_backend/_common.py @@ -2146,10 +2146,6 @@ def reshape(x, /, shape, *, copy=None): return x.reshape(shape=shape) -def conj(x, /): - return x.conj() - - def astype(x, dtype, /, *, copy=True): return x.astype(dtype, copy=copy) From cf0d2be8c68116880e51a1b4ac1ca40d3a7d0ef8 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:13:35 +0200 Subject: [PATCH 05/13] Fix `real` and `imag`. --- ci/Numba-array-api-xfails.txt | 2 -- sparse/numba_backend/__init__.py | 4 ++-- sparse/numba_backend/_common.py | 8 ++++++++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ci/Numba-array-api-xfails.txt b/ci/Numba-array-api-xfails.txt index c8747e9d..abf3af21 100644 --- a/ci/Numba-array-api-xfails.txt +++ b/ci/Numba-array-api-xfails.txt @@ -44,8 +44,6 @@ array_api_tests/test_indexing_functions.py::test_take array_api_tests/test_linalg.py::test_vecdot array_api_tests/test_manipulation_functions.py::test_concat array_api_tests/test_operators_and_elementwise_functions.py::test_ceil -array_api_tests/test_operators_and_elementwise_functions.py::test_imag -array_api_tests/test_operators_and_elementwise_functions.py::test_real array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_searching_functions.py::test_argmax array_api_tests/test_searching_functions.py::test_argmin diff --git a/sparse/numba_backend/__init__.py b/sparse/numba_backend/__init__.py index 74cb3d57..4441d9b3 100644 --- a/sparse/numba_backend/__init__.py +++ b/sparse/numba_backend/__init__.py @@ -24,7 +24,6 @@ greater, greater_equal, iinfo, - imag, inf, int8, int16, @@ -48,7 +47,6 @@ not_equal, pi, positive, - real, remainder, sign, sin, @@ -97,6 +95,7 @@ eye, full, full_like, + imag, isfinite, isinf, isnan, @@ -112,6 +111,7 @@ pad, permute_dims, prod, + real, reshape, round, squeeze, diff --git a/sparse/numba_backend/_common.py b/sparse/numba_backend/_common.py index 4614364b..133ace18 100644 --- a/sparse/numba_backend/_common.py +++ b/sparse/numba_backend/_common.py @@ -2205,6 +2205,14 @@ def nonzero(x, /): return x.nonzero() +def imag(x, /): + return x.imag + + +def real(x, /): + return x.real + + def vecdot(x1, x2, /, *, axis=-1): """ Computes the (vector) dot product of two arrays. From 57bfbbf2ba537e5d1c73fdd0e4e4469ccf591bac Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 11:42:14 +0200 Subject: [PATCH 06/13] Fix a couple of small issues. --- ci/Numba-array-api-xfails.txt | 3 --- sparse/numba_backend/__init__.py | 2 +- sparse/numba_backend/_common.py | 4 ---- sparse/numba_backend/_coo/core.py | 2 ++ sparse/numba_backend/_sparse_array.py | 10 +++------- 5 files changed, 6 insertions(+), 15 deletions(-) diff --git a/ci/Numba-array-api-xfails.txt b/ci/Numba-array-api-xfails.txt index abf3af21..253cf663 100644 --- a/ci/Numba-array-api-xfails.txt +++ b/ci/Numba-array-api-xfails.txt @@ -47,7 +47,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_searching_functions.py::test_argmax array_api_tests/test_searching_functions.py::test_argmin -array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error array_api_tests/test_set_functions.py::test_unique_all array_api_tests/test_set_functions.py::test_unique_inverse array_api_tests/test_signatures.py::test_func_signature[unique_all] @@ -63,8 +62,6 @@ array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device_ array_api_tests/test_signatures.py::test_array_method_signature[__setitem__] array_api_tests/test_sorting_functions.py::test_argsort array_api_tests/test_sorting_functions.py::test_sort -array_api_tests/test_special_cases.py::test_unary[isfinite(x_i is NaN) -> False] -array_api_tests/test_special_cases.py::test_empty_arrays[prod] array_api_tests/test_special_cases.py::test_nan_propagation[max] array_api_tests/test_special_cases.py::test_nan_propagation[mean] array_api_tests/test_special_cases.py::test_nan_propagation[min] diff --git a/sparse/numba_backend/__init__.py b/sparse/numba_backend/__init__.py index 4441d9b3..4658109b 100644 --- a/sparse/numba_backend/__init__.py +++ b/sparse/numba_backend/__init__.py @@ -29,6 +29,7 @@ int16, int32, int64, + isfinite, less, less_equal, log, @@ -96,7 +97,6 @@ full, full_like, imag, - isfinite, isinf, isnan, matmul, diff --git a/sparse/numba_backend/_common.py b/sparse/numba_backend/_common.py index 133ace18..1016ea82 100644 --- a/sparse/numba_backend/_common.py +++ b/sparse/numba_backend/_common.py @@ -2197,10 +2197,6 @@ def isnan(x, /): return x.isnan() -def isfinite(x, /): - return ~isinf(x) - - def nonzero(x, /): return x.nonzero() diff --git a/sparse/numba_backend/_coo/core.py b/sparse/numba_backend/_coo/core.py index 73cf2f9b..09cde274 100644 --- a/sparse/numba_backend/_coo/core.py +++ b/sparse/numba_backend/_coo/core.py @@ -1478,6 +1478,8 @@ def nonzero(self): (array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4])) """ check_zero_fill_value(self) + if self.ndim == 0: + raise ValueError("`nonzero` is undefined for `self.ndim == 0`.") return tuple(self.coords) def asformat(self, format, **kwargs): diff --git a/sparse/numba_backend/_sparse_array.py b/sparse/numba_backend/_sparse_array.py index 242414c1..6192dae0 100644 --- a/sparse/numba_backend/_sparse_array.py +++ b/sparse/numba_backend/_sparse_array.py @@ -390,13 +390,9 @@ def reduce(self, method, axis=(0,), keepdims=False, **kwargs): """ axis = normalize_axis(axis, self.ndim) zero_reduce_result = method.reduce([self.fill_value, self.fill_value], **kwargs) - reduce_super_ufunc = None - - if not equivalent(zero_reduce_result, self.fill_value): - reduce_super_ufunc = _reduce_super_ufunc.get(method) - - if reduce_super_ufunc is None: - raise ValueError(f"Performing this reduction operation would produce a dense result: {method!s}") + reduce_super_ufunc = _reduce_super_ufunc.get(method) + if not equivalent(zero_reduce_result, self.fill_value) and reduce_super_ufunc is None: + raise ValueError(f"Performing this reduction operation would produce a dense result: {method!s}") if not isinstance(axis, tuple): axis = (axis,) From 0c73ba069d9a0774ebca53cb59244dae5d272596 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 12:17:58 +0200 Subject: [PATCH 07/13] Fix `axis=None` for `concat`. --- .github/workflows/ci.yml | 2 +- ci/Numba-array-api-xfails.txt | 2 -- sparse/numba_backend/_coo/common.py | 1 + 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 025026d7..0a88c6d6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -173,7 +173,7 @@ jobs: SPARSE_BACKEND: ${{ matrix.backend }} run: | cd ${GITHUB_WORKSPACE}/array-api-tests - pytest array_api_tests -v -c pytest.ini -n 4 --ci --max-examples=2 --derandomize --disable-deadline -o xfail_strict=True --xfails-file ${GITHUB_WORKSPACE}/ci/${{ matrix.backend }}-array-api-xfails.txt --skips-file ${GITHUB_WORKSPACE}/ci/${{ matrix.backend }}-array-api-skips.txt + pytest array_api_tests -v -c pytest.ini -n 4 --max-examples=2 --derandomize --disable-deadline -o xfail_strict=True --xfails-file ${GITHUB_WORKSPACE}/ci/${{ matrix.backend }}-array-api-xfails.txt --skips-file ${GITHUB_WORKSPACE}/ci/${{ matrix.backend }}-array-api-skips.txt on: # Trigger the workflow on push or pull request, # but only for the main branch diff --git a/ci/Numba-array-api-xfails.txt b/ci/Numba-array-api-xfails.txt index 253cf663..12b29209 100644 --- a/ci/Numba-array-api-xfails.txt +++ b/ci/Numba-array-api-xfails.txt @@ -40,9 +40,7 @@ array_api_tests/test_has_names.py::test_has_names[data_type-isdtype] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] array_api_tests/test_has_names.py::test_has_names[array_method-__setitem__] -array_api_tests/test_indexing_functions.py::test_take array_api_tests/test_linalg.py::test_vecdot -array_api_tests/test_manipulation_functions.py::test_concat array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_trunc array_api_tests/test_searching_functions.py::test_argmax diff --git a/sparse/numba_backend/_coo/common.py b/sparse/numba_backend/_coo/common.py index 2802395a..d497119f 100644 --- a/sparse/numba_backend/_coo/common.py +++ b/sparse/numba_backend/_coo/common.py @@ -160,6 +160,7 @@ def concatenate(arrays, axis=0): check_consistent_fill_value(arrays) if axis is None: + axis = 0 arrays = [x.flatten() for x in arrays] arrays = [x if isinstance(x, COO) else COO(x) for x in arrays] From 76d9a2774139dcfb923102d83b667b5ad35e10d3 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 12:29:57 +0200 Subject: [PATCH 08/13] Fix invalid shapes accepted for `vecdot`. --- ci/Numba-array-api-xfails.txt | 1 + sparse/numba_backend/_common.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/ci/Numba-array-api-xfails.txt b/ci/Numba-array-api-xfails.txt index 12b29209..f3234ea6 100644 --- a/ci/Numba-array-api-xfails.txt +++ b/ci/Numba-array-api-xfails.txt @@ -40,6 +40,7 @@ array_api_tests/test_has_names.py::test_has_names[data_type-isdtype] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__] array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__] array_api_tests/test_has_names.py::test_has_names[array_method-__setitem__] +array_api_tests/test_indexing_functions.py::test_take array_api_tests/test_linalg.py::test_vecdot array_api_tests/test_operators_and_elementwise_functions.py::test_ceil array_api_tests/test_operators_and_elementwise_functions.py::test_trunc diff --git a/sparse/numba_backend/_common.py b/sparse/numba_backend/_common.py index 1016ea82..99cbc21b 100644 --- a/sparse/numba_backend/_common.py +++ b/sparse/numba_backend/_common.py @@ -2225,6 +2225,9 @@ def vecdot(x1, x2, /, *, axis=-1): out : Union[SparseArray, numpy.ndarray] Sparse or 0-D array containing dot product. """ + if x1.shape[axis] != x2.shape[axis]: + raise ValueError("Shapes must match along `axis`.") + if np.issubdtype(x1.dtype, np.complexfloating): x1 = np.conjugate(x1) From c7a051d73a9310af8c3dc862ab5499916e9c9f36 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 12:36:52 +0200 Subject: [PATCH 09/13] Attempted fix for docs. --- sparse/numba_backend/_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sparse/numba_backend/_common.py b/sparse/numba_backend/_common.py index 99cbc21b..cae7ffa6 100644 --- a/sparse/numba_backend/_common.py +++ b/sparse/numba_backend/_common.py @@ -2157,12 +2157,13 @@ def squeeze(x, /, axis=None): Parameters ---------- x : SparseArray + Input array. axis : int or tuple[int, ...], optional The singleton axes to remove. By default all singleton axes are removed. Returns ------- - SparseArray + output : SparseArray Array with singleton dimensions removed. """ return x.squeeze(axis=axis) From 4f406bbc7d417837219795a6d8d88012cafcebb7 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:00:10 +0200 Subject: [PATCH 10/13] Fix tests for `vecdot`. --- sparse/numba_backend/tests/test_coo.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sparse/numba_backend/tests/test_coo.py b/sparse/numba_backend/tests/test_coo.py index d8336916..8124f184 100644 --- a/sparse/numba_backend/tests/test_coo.py +++ b/sparse/numba_backend/tests/test_coo.py @@ -1904,7 +1904,7 @@ def data_rvs(size): s1 = sparse.random(shape1, density=density, data_rvs=data_rvs) s2 = sparse.random(shape2, density=density, data_rvs=data_rvs) - axis = rng.integers(max(s1.ndim, s2.ndim)) + axis = rng.integers(min(s1.ndim, s2.ndim)) x1 = s1.todense() x2 = s2.todense() @@ -1915,8 +1915,13 @@ def np_vecdot(x1, x2, /, *, axis=-1): return np.sum(x1 * x2, axis=axis) - expected = np_vecdot(x1, x2, axis=axis) + if shape1[axis] != shape2[axis]: + with pytest.raises(ValueError, match="Shapes must match along"): + sparse.vecdot(s1, s2, axis=axis) + return + actual = sparse.vecdot(s1, s2, axis=axis) + expected = np_vecdot(x1, x2, axis=axis) np.testing.assert_allclose(actual.todense(), expected) From c61825507aa60a0ff6851dbb5b6236d77fb9436d Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:00:58 +0200 Subject: [PATCH 11/13] Add missing doc pages. --- docs/generated/sparse.expand_dims.rst | 6 ++++++ docs/generated/sparse.flip.rst | 6 ++++++ docs/generated/sparse.matrix_transpose.rst | 6 ++++++ docs/generated/sparse.sort.rst | 6 ++++++ docs/generated/sparse.take.rst | 6 ++++++ docs/generated/sparse.var.rst | 6 ++++++ docs/generated/sparse.vecdot.rst | 6 ++++++ 7 files changed, 42 insertions(+) create mode 100644 docs/generated/sparse.expand_dims.rst create mode 100644 docs/generated/sparse.flip.rst create mode 100644 docs/generated/sparse.matrix_transpose.rst create mode 100644 docs/generated/sparse.sort.rst create mode 100644 docs/generated/sparse.take.rst create mode 100644 docs/generated/sparse.var.rst create mode 100644 docs/generated/sparse.vecdot.rst diff --git a/docs/generated/sparse.expand_dims.rst b/docs/generated/sparse.expand_dims.rst new file mode 100644 index 00000000..7f918642 --- /dev/null +++ b/docs/generated/sparse.expand_dims.rst @@ -0,0 +1,6 @@ +expand\_dims +============ + +.. currentmodule:: sparse + +.. autofunction:: expand_dims diff --git a/docs/generated/sparse.flip.rst b/docs/generated/sparse.flip.rst new file mode 100644 index 00000000..9c83383e --- /dev/null +++ b/docs/generated/sparse.flip.rst @@ -0,0 +1,6 @@ +flip +==== + +.. currentmodule:: sparse + +.. autofunction:: flip diff --git a/docs/generated/sparse.matrix_transpose.rst b/docs/generated/sparse.matrix_transpose.rst new file mode 100644 index 00000000..755521dd --- /dev/null +++ b/docs/generated/sparse.matrix_transpose.rst @@ -0,0 +1,6 @@ +matrix\_transpose +================= + +.. currentmodule:: sparse + +.. autofunction:: matrix_transpose diff --git a/docs/generated/sparse.sort.rst b/docs/generated/sparse.sort.rst new file mode 100644 index 00000000..cf1e71a0 --- /dev/null +++ b/docs/generated/sparse.sort.rst @@ -0,0 +1,6 @@ +sort +==== + +.. currentmodule:: sparse + +.. autofunction:: sort diff --git a/docs/generated/sparse.take.rst b/docs/generated/sparse.take.rst new file mode 100644 index 00000000..42f5ac85 --- /dev/null +++ b/docs/generated/sparse.take.rst @@ -0,0 +1,6 @@ +take +==== + +.. currentmodule:: sparse + +.. autofunction:: take diff --git a/docs/generated/sparse.var.rst b/docs/generated/sparse.var.rst new file mode 100644 index 00000000..f8badb51 --- /dev/null +++ b/docs/generated/sparse.var.rst @@ -0,0 +1,6 @@ +var +=== + +.. currentmodule:: sparse + +.. autofunction:: var diff --git a/docs/generated/sparse.vecdot.rst b/docs/generated/sparse.vecdot.rst new file mode 100644 index 00000000..995d8108 --- /dev/null +++ b/docs/generated/sparse.vecdot.rst @@ -0,0 +1,6 @@ +vecdot +====== + +.. currentmodule:: sparse + +.. autofunction:: vecdot From f02fc2ffc3984427708a236a750a89700fb9d997 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:06:40 +0200 Subject: [PATCH 12/13] Fix doc pages. --- docs/generated/sparse.isfinite.rst | 6 ------ docs/generated/sparse.rst | 2 -- sparse/numba_backend/_coo/common.py | 3 +-- 3 files changed, 1 insertion(+), 10 deletions(-) delete mode 100644 docs/generated/sparse.isfinite.rst diff --git a/docs/generated/sparse.isfinite.rst b/docs/generated/sparse.isfinite.rst deleted file mode 100644 index b5f7bb96..00000000 --- a/docs/generated/sparse.isfinite.rst +++ /dev/null @@ -1,6 +0,0 @@ -isfinite -======== - -.. currentmodule:: sparse - -.. autofunction:: isfinite diff --git a/docs/generated/sparse.rst b/docs/generated/sparse.rst index 8c8e9b34..9592cc24 100644 --- a/docs/generated/sparse.rst +++ b/docs/generated/sparse.rst @@ -80,8 +80,6 @@ API full_like - isfinite - isinf isnan diff --git a/sparse/numba_backend/_coo/common.py b/sparse/numba_backend/_coo/common.py index d497119f..55e63cdb 100644 --- a/sparse/numba_backend/_coo/common.py +++ b/sparse/numba_backend/_coo/common.py @@ -1260,8 +1260,7 @@ def sort(x, /, *, axis=-1, descending=False, stable=False): If ``False``, the array must be sorted in ascending order (by value). Default: ``False``. stable : bool - Whether the sort is stable. Provided for compatibility with the Array API, only - ``False`` (the default) is currently supported. + Whether the sort is stable. Only ``False`` is supported currently. Returns ------- From 0f244eb334ba02c07b841083c6651a2af2df44f2 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi <2190658+hameerabbasi@users.noreply.github.com> Date: Mon, 3 Jun 2024 16:17:33 +0200 Subject: [PATCH 13/13] Fix small `vecdot` inconsistencies and separate tests. --- ci/Numba-array-api-xfails.txt | 1 - sparse/numba_backend/_common.py | 3 +- sparse/numba_backend/_coo/core.py | 2 +- sparse/numba_backend/_sparse_array.py | 2 +- sparse/numba_backend/tests/test_coo.py | 48 ++++++++++++++++---------- 5 files changed, 33 insertions(+), 23 deletions(-) diff --git a/ci/Numba-array-api-xfails.txt b/ci/Numba-array-api-xfails.txt index f3234ea6..b7d1473c 100644 --- a/ci/Numba-array-api-xfails.txt +++ b/ci/Numba-array-api-xfails.txt @@ -68,4 +68,3 @@ array_api_tests/test_special_cases.py::test_nan_propagation[prod] array_api_tests/test_special_cases.py::test_nan_propagation[std] array_api_tests/test_special_cases.py::test_nan_propagation[sum] array_api_tests/test_special_cases.py::test_nan_propagation[var] -array_api_tests/test_statistical_functions.py::test_mean diff --git a/sparse/numba_backend/_common.py b/sparse/numba_backend/_common.py index cae7ffa6..ac396904 100644 --- a/sparse/numba_backend/_common.py +++ b/sparse/numba_backend/_common.py @@ -2226,7 +2226,8 @@ def vecdot(x1, x2, /, *, axis=-1): out : Union[SparseArray, numpy.ndarray] Sparse or 0-D array containing dot product. """ - if x1.shape[axis] != x2.shape[axis]: + ndmin = builtins.min((x1.ndim, x2.ndim)) + if not (-ndmin <= axis < ndmin) or x1.shape[axis] != x2.shape[axis]: raise ValueError("Shapes must match along `axis`.") if np.issubdtype(x1.dtype, np.complexfloating): diff --git a/sparse/numba_backend/_coo/core.py b/sparse/numba_backend/_coo/core.py index 09cde274..575cb483 100644 --- a/sparse/numba_backend/_coo/core.py +++ b/sparse/numba_backend/_coo/core.py @@ -687,7 +687,7 @@ def __str__(self): __repr__ = __str__ def _reduce_calc(self, method, axis, keepdims=False, **kwargs): - if axis[0] is None: + if axis == (None,): axis = tuple(range(self.ndim)) axis = tuple(a if a >= 0 else a + self.ndim for a in axis) neg_axis = tuple(ax for ax in range(self.ndim) if ax not in set(axis)) diff --git a/sparse/numba_backend/_sparse_array.py b/sparse/numba_backend/_sparse_array.py index 6192dae0..763a779e 100644 --- a/sparse/numba_backend/_sparse_array.py +++ b/sparse/numba_backend/_sparse_array.py @@ -795,7 +795,7 @@ def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): if dtype is None and issubclass(self.dtype.type, np.integer | np.bool_): dtype = np.dtype("f8") - arrmean = self.sum(axis, dtype=dtype, keepdims=True) + arrmean = self.sum(axis, dtype=dtype, keepdims=True)[...] np.divide(arrmean, rcount, out=arrmean) x = self - arrmean if issubclass(self.dtype.type, np.complexfloating): diff --git a/sparse/numba_backend/tests/test_coo.py b/sparse/numba_backend/tests/test_coo.py index 8124f184..ed1ef96f 100644 --- a/sparse/numba_backend/tests/test_coo.py +++ b/sparse/numba_backend/tests/test_coo.py @@ -1878,23 +1878,20 @@ def test_matrix_transpose(ndim, density): @pytest.mark.parametrize( - "shape1, shape2", + ("shape1", "shape2", "axis"), [ - ((2, 3, 4), (3, 4)), - ((3, 4), (2, 3, 4)), - ((3, 1, 4), (3, 2, 4)), - ((1, 3, 4), (3, 4)), - ((3, 4, 1), (3, 4, 2)), - ((1, 5), (5, 1)), - ((3, 1), (3, 4)), - ((3, 1), (1, 4)), - ((1, 4), (3, 4)), - ((2, 2, 2), (1, 1, 1)), + ((2, 3, 4), (3, 4), -2), + ((3, 4), (2, 3, 4), -1), + ((3, 1, 4), (3, 2, 4), 2), + ((1, 3, 4), (3, 4), -2), + ((3, 4, 1), (3, 4, 2), 0), + ((3, 1), (3, 4), -2), + ((1, 4), (3, 4), 1), ], ) @pytest.mark.parametrize("density", [0.0, 0.1, 0.25, 1.0]) @pytest.mark.parametrize("is_complex", [False, True]) -def test_vecdot(shape1, shape2, density, rng, is_complex): +def test_vecdot(shape1, shape2, axis, density, rng, is_complex): def data_rvs(size): data = rng.random(size) if is_complex: @@ -1904,8 +1901,6 @@ def data_rvs(size): s1 = sparse.random(shape1, density=density, data_rvs=data_rvs) s2 = sparse.random(shape2, density=density, data_rvs=data_rvs) - axis = rng.integers(min(s1.ndim, s2.ndim)) - x1 = s1.todense() x2 = s2.todense() @@ -1915,17 +1910,32 @@ def np_vecdot(x1, x2, /, *, axis=-1): return np.sum(x1 * x2, axis=axis) - if shape1[axis] != shape2[axis]: - with pytest.raises(ValueError, match="Shapes must match along"): - sparse.vecdot(s1, s2, axis=axis) - return - actual = sparse.vecdot(s1, s2, axis=axis) expected = np_vecdot(x1, x2, axis=axis) np.testing.assert_allclose(actual.todense(), expected) +@pytest.mark.parametrize( + ("shape1", "shape2", "axis"), + [ + ((2, 3, 4), (3, 4), 0), + ((3, 4), (2, 3, 4), 0), + ((3, 1, 4), (3, 2, 4), -2), + ((1, 3, 4), (3, 4), -3), + ((3, 4, 1), (3, 4, 2), -1), + ((3, 1), (3, 4), 1), + ((1, 4), (3, 4), -2), + ], +) +def test_vecdot_invalid_axis(shape1, shape2, axis): + s1 = sparse.random(shape1, density=0.5) + s2 = sparse.random(shape2, density=0.5) + + with pytest.raises(ValueError, match=r"Shapes must match along"): + sparse.vecdot(s1, s2, axis=axis) + + @pytest.mark.parametrize( ("func", "args", "kwargs"), [