Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expand Array API coverage for Numba backend #691

Merged
merged 13 commits into from
Jun 3, 2024
37 changes: 0 additions & 37 deletions ci/Numba-array-api-xfails.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,68 +37,31 @@ array_api_tests/test_has_names.py::test_has_names[creation-linspace]
array_api_tests/test_has_names.py::test_has_names[creation-meshgrid]
array_api_tests/test_has_names.py::test_has_names[sorting-argsort]
array_api_tests/test_has_names.py::test_has_names[data_type-isdtype]
array_api_tests/test_has_names.py::test_has_names[elementwise-conj]
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
array_api_tests/test_has_names.py::test_has_names[array_method-__setitem__]
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-device]
array_api_tests/test_has_names.py::test_has_names[array_attribute-mT]
array_api_tests/test_indexing_functions.py::test_take
array_api_tests/test_linalg.py::test_vecdot
array_api_tests/test_manipulation_functions.py::test_concat
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__add__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_add[__iadd__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__lshift__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_left_shift[__ilshift__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__rshift__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_bitwise_right_shift[__irshift__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_ceil
array_api_tests/test_operators_and_elementwise_functions.py::test_conj
array_api_tests/test_operators_and_elementwise_functions.py::test_divide[__truediv__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_imag
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__mul__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_multiply[__imul__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__pow__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_pow[__ipow__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_real
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__sub__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_subtract[__isub__(x, s)]
array_api_tests/test_operators_and_elementwise_functions.py::test_trunc
array_api_tests/test_searching_functions.py::test_argmax
array_api_tests/test_searching_functions.py::test_argmin
array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error
array_api_tests/test_set_functions.py::test_unique_all
array_api_tests/test_set_functions.py::test_unique_inverse
array_api_tests/test_signatures.py::test_func_signature[unique_all]
array_api_tests/test_signatures.py::test_func_signature[unique_inverse]
array_api_tests/test_signatures.py::test_func_signature[arange]
array_api_tests/test_signatures.py::test_func_signature[empty]
array_api_tests/test_signatures.py::test_func_signature[empty_like]
array_api_tests/test_signatures.py::test_func_signature[eye]
array_api_tests/test_signatures.py::test_func_signature[from_dlpack]
array_api_tests/test_signatures.py::test_func_signature[full]
array_api_tests/test_signatures.py::test_func_signature[full_like]
array_api_tests/test_signatures.py::test_func_signature[linspace]
array_api_tests/test_signatures.py::test_func_signature[meshgrid]
array_api_tests/test_signatures.py::test_func_signature[ones]
array_api_tests/test_signatures.py::test_func_signature[ones_like]
array_api_tests/test_signatures.py::test_func_signature[zeros]
array_api_tests/test_signatures.py::test_func_signature[zeros_like]
array_api_tests/test_signatures.py::test_func_signature[broadcast_to]
array_api_tests/test_signatures.py::test_func_signature[squeeze]
array_api_tests/test_signatures.py::test_func_signature[argsort]
array_api_tests/test_signatures.py::test_func_signature[sort]
array_api_tests/test_signatures.py::test_func_signature[isdtype]
array_api_tests/test_signatures.py::test_func_signature[conj]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack__]
array_api_tests/test_signatures.py::test_array_method_signature[__dlpack_device__]
array_api_tests/test_signatures.py::test_array_method_signature[__setitem__]
array_api_tests/test_signatures.py::test_array_method_signature[to_device]
array_api_tests/test_sorting_functions.py::test_argsort
array_api_tests/test_sorting_functions.py::test_sort
array_api_tests/test_special_cases.py::test_unary[isfinite(x_i is NaN) -> False]
array_api_tests/test_special_cases.py::test_empty_arrays[prod]
array_api_tests/test_special_cases.py::test_nan_propagation[max]
array_api_tests/test_special_cases.py::test_nan_propagation[mean]
array_api_tests/test_special_cases.py::test_nan_propagation[min]
Expand Down
8 changes: 5 additions & 3 deletions sparse/numba_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ceil,
complex64,
complex128,
conj,
cos,
cosh,
divide,
Expand All @@ -23,12 +24,12 @@
greater,
greater_equal,
iinfo,
imag,
inf,
int8,
int16,
int32,
int64,
isfinite,
less,
less_equal,
log,
Expand All @@ -47,7 +48,6 @@
not_equal,
pi,
positive,
real,
remainder,
sign,
sin,
Expand Down Expand Up @@ -96,7 +96,7 @@
eye,
full,
full_like,
isfinite,
imag,
isinf,
isnan,
matmul,
Expand All @@ -111,6 +111,7 @@
pad,
permute_dims,
prod,
real,
reshape,
round,
squeeze,
Expand Down Expand Up @@ -201,6 +202,7 @@
"complex64",
"concat",
"concatenate",
"conj",
"cos",
"cosh",
"diagonal",
Expand Down
73 changes: 53 additions & 20 deletions sparse/numba_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@
return False


def _check_device(func):
@wraps(func)
def wrapped(*args, **kwargs):
device = kwargs.get("device", None)
if device not in {"cpu", None}:
raise ValueError("Device must be `'cpu'` or `None`.")
return func(*args, **kwargs)

return wrapped


def _is_sparse(x):
"""
Tests if the supplied argument is a SciPy sparse object, or one from this library.
Expand Down Expand Up @@ -1533,7 +1544,8 @@
concat = concatenate


def eye(N, M=None, k=0, dtype=float, format="coo", **kwargs):
@_check_device
def eye(N, M=None, k=0, dtype=float, format="coo", *, device=None, **kwargs):
"""Return a 2-D array in the specified format with ones on the diagonal and zeros elsewhere.

Parameters
Expand Down Expand Up @@ -1595,7 +1607,8 @@
return COO(coords, data=data, shape=(N, M), has_duplicates=False, sorted=True).asformat(format, **kwargs)


def full(shape, fill_value, dtype=None, format="coo", order="C", **kwargs):
@_check_device
def full(shape, fill_value, dtype=None, format="coo", order="C", *, device=None, **kwargs):
"""Return a SparseArray of given shape and type, filled with `fill_value`.

Parameters
Expand Down Expand Up @@ -1649,7 +1662,8 @@
).asformat(format, **kwargs)


def full_like(a, fill_value, dtype=None, shape=None, format=None, **kwargs):
@_check_device
def full_like(a, fill_value, dtype=None, shape=None, format=None, *, device=None, **kwargs):
"""Return a full array with the same shape and type as a given array.

Parameters
Expand Down Expand Up @@ -1692,7 +1706,7 @@
)


def zeros(shape, dtype=float, format="coo", **kwargs):
def zeros(shape, dtype=float, format="coo", *, device=None, **kwargs):
"""Return a SparseArray of given shape and type, filled with zeros.

Parameters
Expand Down Expand Up @@ -1721,10 +1735,10 @@
array([[0, 0],
[0, 0]])
"""
return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, **kwargs)
return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, device=device, **kwargs)


def zeros_like(a, dtype=None, shape=None, format=None, **kwargs):
def zeros_like(a, dtype=None, shape=None, format=None, *, device=None, **kwargs):
"""Return a SparseArray of zeros with the same shape and type as ``a``.

Parameters
Expand All @@ -1750,10 +1764,10 @@
array([[0, 0, 0],
[0, 0, 0]])
"""
return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, **kwargs)
return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, device=device, **kwargs)


def ones(shape, dtype=float, format="coo", **kwargs):
def ones(shape, dtype=float, format="coo", *, device=None, **kwargs):
"""Return a SparseArray of given shape and type, filled with ones.

Parameters
Expand Down Expand Up @@ -1782,10 +1796,10 @@
array([[1, 1],
[1, 1]])
"""
return full(shape, fill_value=1, dtype=np.dtype(dtype), format=format, **kwargs)
return full(shape, fill_value=1, dtype=np.dtype(dtype), format=format, device=device, **kwargs)


def ones_like(a, dtype=None, shape=None, format=None, **kwargs):
def ones_like(a, dtype=None, shape=None, format=None, *, device=None, **kwargs):
"""Return a SparseArray of ones with the same shape and type as ``a``.

Parameters
Expand All @@ -1811,18 +1825,18 @@
array([[1, 1, 1],
[1, 1, 1]])
"""
return full_like(a, fill_value=1, dtype=dtype, shape=shape, format=format, **kwargs)
return full_like(a, fill_value=1, dtype=dtype, shape=shape, format=format, device=device, **kwargs)


def empty(shape, dtype=float, format="coo", **kwargs):
return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, **kwargs)
def empty(shape, dtype=float, format="coo", *, device=None, **kwargs):
return full(shape, fill_value=0, dtype=np.dtype(dtype), format=format, device=device, **kwargs)


empty.__doc__ = zeros.__doc__


def empty_like(a, dtype=None, shape=None, format=None, **kwargs):
return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, **kwargs)
def empty_like(a, dtype=None, shape=None, format=None, *, device=None, **kwargs):
return full_like(a, fill_value=0, dtype=dtype, shape=shape, format=format, device=device, **kwargs)


empty_like.__doc__ = zeros_like.__doc__
Expand Down Expand Up @@ -2004,7 +2018,8 @@
raise ValueError(f"invalid format: {format}")


def asarray(obj, /, *, dtype=None, format="coo", device=None, copy=False):
@_check_device
def asarray(obj, /, *, dtype=None, format="coo", copy=False, device=None):
"""
Convert the input to a sparse array.

Expand Down Expand Up @@ -2066,6 +2081,7 @@
we want to flag it and dispatch to NumPy.
"""

@wraps(func)
def wrapper_func(*args, **kwargs):
x = args[0]
if isinstance(x, np.ndarray | np.number):
Expand Down Expand Up @@ -2136,6 +2152,19 @@

@_support_numpy
def squeeze(x, /, axis=None):
"""Remove singleton dimensions from array.

Parameters
----------
x : SparseArray
axis : int or tuple[int, ...], optional
The singleton axes to remove. By default all singleton axes are removed.

Returns
-------
SparseArray
Array with singleton dimensions removed.
"""
return x.squeeze(axis=axis)


Expand Down Expand Up @@ -2168,14 +2197,18 @@
return x.isnan()


def isfinite(x, /):
return ~isinf(x)


def nonzero(x, /):
return x.nonzero()


def imag(x, /):
return x.imag

Check warning on line 2205 in sparse/numba_backend/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_common.py#L2205

Added line #L2205 was not covered by tests


def real(x, /):
return x.real

Check warning on line 2209 in sparse/numba_backend/_common.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_common.py#L2209

Added line #L2209 was not covered by tests


def vecdot(x1, x2, /, *, axis=-1):
"""
Computes the (vector) dot product of two arrays.
Expand Down
10 changes: 10 additions & 0 deletions sparse/numba_backend/_compressed/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,16 @@
def T(self):
return self.transpose()

@property
def mT(self):
if self.ndim < 2:
raise ValueError("Cannot compute matrix transpose if `ndim < 2`.")

Check warning on line 330 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L329-L330

Added lines #L329 - L330 were not covered by tests

axis = list(range(self.ndim))
axis[-1], axis[-2] = axis[-2], axis[-1]

Check warning on line 333 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L332-L333

Added lines #L332 - L333 were not covered by tests

return self.transpose(axis)

Check warning on line 335 in sparse/numba_backend/_compressed/compressed.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_compressed/compressed.py#L335

Added line #L335 was not covered by tests

def __str__(self):
summary = (
f"<GCXS: shape={self.shape}, dtype={self.dtype}, nnz={self.nnz}, fill_value={self.fill_value}, "
Expand Down
15 changes: 8 additions & 7 deletions sparse/numba_backend/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ def unique_values(x, /):
return values


def sort(x, /, *, axis=-1, descending=False):
def sort(x, /, *, axis=-1, descending=False, stable=False):
"""
Returns a sorted copy of an input array ``x``.

Expand All @@ -1258,6 +1258,9 @@ def sort(x, /, *, axis=-1, descending=False):
Sort order. If ``True``, the array must be sorted in descending order (by value).
If ``False``, the array must be sorted in ascending order (by value).
Default: ``False``.
stable : bool
Whether the sort is stable. Provided for compatibility with the Array API, only
``False`` (the default) is currently supported.

Returns
-------
Expand All @@ -1279,12 +1282,14 @@ def sort(x, /, *, axis=-1, descending=False):
array([ 2, 2, 1, 0, 0, -3])

"""

from .._common import moveaxis
from .core import COO

x = _validate_coo_input(x)

if stable:
raise ValueError("`stable=True` isn't currently supported.")

original_ndim = x.ndim
if x.ndim == 1:
x = x[None, :]
Expand Down Expand Up @@ -1357,11 +1362,7 @@ def _validate_coo_input(x: Any):

@numba.jit(nopython=True, nogil=True)
def _sort_coo(
coords: np.ndarray,
data: np.ndarray,
fill_value: float,
sort_axis_len: int,
descending: bool,
coords: np.ndarray, data: np.ndarray, fill_value: float, sort_axis_len: int, descending: bool
) -> tuple[np.ndarray, np.ndarray]:
assert coords.shape[0] == 2
group_coords = coords[0, :]
Expand Down
12 changes: 12 additions & 0 deletions sparse/numba_backend/_coo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,16 @@
"""
return self.transpose(tuple(range(self.ndim))[::-1])

@property
def mT(self):
if self.ndim < 2:
raise ValueError("Cannot compute matrix transpose if `ndim < 2`.")

Check warning on line 855 in sparse/numba_backend/_coo/core.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_coo/core.py#L855

Added line #L855 was not covered by tests

axis = list(range(self.ndim))
axis[-1], axis[-2] = axis[-2], axis[-1]

return self.transpose(axis)

def swapaxes(self, axis1, axis2):
"""Returns array that has axes axis1 and axis2 swapped.

Expand Down Expand Up @@ -1468,6 +1478,8 @@
(array([0, 1, 2, 3, 4]), array([0, 1, 2, 3, 4]))
"""
check_zero_fill_value(self)
if self.ndim == 0:
raise ValueError("`nonzero` is undefined for `self.ndim == 0`.")

Check warning on line 1482 in sparse/numba_backend/_coo/core.py

View check run for this annotation

Codecov / codecov/patch

sparse/numba_backend/_coo/core.py#L1482

Added line #L1482 was not covered by tests
return tuple(self.coords)

def asformat(self, format, **kwargs):
Expand Down
Loading
Loading