Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DNM MAINT: Upgrade to array-api-compat 1.11 #120

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 42 additions & 44 deletions pixi.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Typing :: Typed",
]
dynamic = ["version"]
dependencies = ["array-api-compat>=1.10.0,<2"]
# dependencies = ["array-api-compat>=1.11.0,<2"] # DNM

[project.urls]
Homepage = "https://github.com/data-apis/array-api-extra"
Expand All @@ -48,10 +48,11 @@ platforms = ["linux-64", "osx-arm64", "win-64"]

[tool.pixi.dependencies]
python = ">=3.10,<3.14"
array-api-compat = ">=1.10.0,<2"
# array-api-compat = ">=1.11.0,<2" # DNM

[tool.pixi.pypi-dependencies]
array-api-extra = { path = ".", editable = true }
array-api-compat = { git = "https://github.com/data-apis/array-api-compat" } # DNM

[tool.pixi.feature.lint.dependencies]
typing-extensions = "*"
Expand Down
2 changes: 1 addition & 1 deletion src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
_, counts = xp.unique_counts(x)
n = _compat.size(counts)
# FIXME https://github.com/data-apis/array-api-compat/pull/231
if n is None or math.isnan(n): # e.g. Dask, ndonnx
if n is None: # e.g. Dask, ndonnx
return xp.astype(counts, xp.bool).sum()
return xp.asarray(n, device=_compat.device(x))

Expand Down
3 changes: 3 additions & 0 deletions src/array_api_extra/_lib/_utils/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
is_dask_namespace,
is_jax_array,
is_jax_namespace,
is_lazy_array,
is_numpy_array,
is_numpy_namespace,
is_pydata_sparse_array,
Expand All @@ -35,6 +36,7 @@
is_dask_namespace,
is_jax_array,
is_jax_namespace,
is_lazy_array,
is_numpy_array,
is_numpy_namespace,
is_pydata_sparse_array,
Expand All @@ -56,6 +58,7 @@
"is_dask_namespace",
"is_jax_array",
"is_jax_namespace",
"is_lazy_array",
"is_numpy_array",
"is_numpy_namespace",
"is_pydata_sparse_array",
Expand Down
1 change: 1 addition & 0 deletions src/array_api_extra/_lib/_utils/_compat.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,6 @@ def is_jax_array(x: object, /) -> bool: ...
def is_numpy_array(x: object, /) -> bool: ...
def is_pydata_sparse_array(x: object, /) -> bool: ...
def is_torch_array(x: object, /) -> bool: ...
def is_lazy_array(x: object, /) -> bool: ...
def is_writeable_array(x: object, /) -> bool: ...
def size(x: Array, /) -> int | None: ...
5 changes: 1 addition & 4 deletions tests/test_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from array_api_extra import at
from array_api_extra._lib import Backend
from array_api_extra._lib._at import _AtOp
from array_api_extra._lib._testing import xfail, xp_assert_equal
from array_api_extra._lib._testing import xp_assert_equal
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
from array_api_extra._lib._utils._typing import Array, Index
from array_api_extra.testing import lazy_xp_function
Expand Down Expand Up @@ -181,7 +181,6 @@ def test_alternate_index_syntax():
def test_incompatible_dtype(
xp: ModuleType,
library: Backend,
request: pytest.FixtureRequest,
op: _AtOp,
copy: bool | None,
bool_mask: bool,
Expand Down Expand Up @@ -215,8 +214,6 @@ def test_incompatible_dtype(
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.DASK:
if op in (_AtOp.MIN, _AtOp.MAX) and bool_mask:
xfail(request, reason="need array-api-compat 1.11")
z = at_op(x, idx, op, 1.1, copy=copy)

elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
Expand Down
6 changes: 2 additions & 4 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,14 @@

lazy_xp_function(atleast_nd, static_argnames=("ndim", "xp"))
lazy_xp_function(cov, static_argnames="xp")
# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
lazy_xp_function(create_diagonal, jax_jit=False, static_argnames=("offset", "xp"))
lazy_xp_function(create_diagonal, static_argnames=("offset", "xp"))
lazy_xp_function(expand_dims, static_argnames=("axis", "xp"))
lazy_xp_function(kron, static_argnames="xp")
lazy_xp_function(nunique, static_argnames="xp")
lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp"))
# FIXME calls in1d which calls xp.unique_values without size
lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp"))
# FIXME .device attribute https://github.com/data-apis/array-api-compat/pull/238
lazy_xp_function(sinc, jax_jit=False, static_argnames="xp")
lazy_xp_function(sinc, static_argnames="xp")


@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no expand_dims")
Expand Down
2 changes: 2 additions & 0 deletions vendor_tests/test_vendor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def test_vendor_compat():
is_dask_namespace,
is_jax_array,
is_jax_namespace,
is_lazy_array,
is_numpy_array,
is_numpy_namespace,
is_pydata_sparse_array,
Expand All @@ -35,6 +36,7 @@ def test_vendor_compat():
assert not is_dask_namespace(xp)
assert not is_jax_array(x)
assert not is_jax_namespace(xp)
assert not is_lazy_array(x)
assert not is_numpy_array(x)
assert not is_numpy_namespace(xp)
assert not is_pydata_sparse_array(x)
Expand Down