From d6a53f91583c45a86f425a5f208651b5fa57781f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 21 Apr 2025 11:18:24 +0100 Subject: [PATCH 01/12] Overhaul test_all --- array_api_compat/common/_aliases.py | 2 - array_api_compat/common/_helpers.py | 2 - array_api_compat/common/_linalg.py | 2 - array_api_compat/cupy/_aliases.py | 2 - array_api_compat/cupy/_typing.py | 1 - array_api_compat/dask/array/_aliases.py | 2 - array_api_compat/dask/array/fft.py | 4 +- array_api_compat/dask/array/linalg.py | 3 +- array_api_compat/numpy/_aliases.py | 4 +- array_api_compat/numpy/_typing.py | 1 - array_api_compat/numpy/fft.py | 11 +- array_api_compat/numpy/linalg.py | 7 +- array_api_compat/torch/_aliases.py | 2 - array_api_compat/torch/fft.py | 2 - array_api_compat/torch/linalg.py | 4 - tests/test_all.py | 308 +++++++++++++++++++----- 16 files changed, 262 insertions(+), 95 deletions(-) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8ea9162a..44ef6834 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -720,8 +720,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any: "finfo", "iinfo", ] -_all_ignore = ["inspect", "array_namespace", "NamedTuple"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index db3e4cd7..fdeeaf7b 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -1042,7 +1042,5 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ["sys", "math", "inspect", "warnings"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/common/_linalg.py b/array_api_compat/common/_linalg.py index 7ad87a1b..6fea96f0 100644 --- a/array_api_compat/common/_linalg.py +++ b/array_api_compat/common/_linalg.py @@ -225,8 +225,6 @@ def trace( 'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal', 'trace'] -_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype'] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index fd1460ae..01668939 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -160,5 +160,3 @@ def count_nonzero( 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign'] - -_all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/cupy/_typing.py b/array_api_compat/cupy/_typing.py index d8e49ca7..e5c202dc 100644 --- a/array_api_compat/cupy/_typing.py +++ b/array_api_compat/cupy/_typing.py @@ -1,7 +1,6 @@ from __future__ import annotations __all__ = ["Array", "DType", "Device"] -_all_ignore = ["cp"] from typing import TYPE_CHECKING diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index 9687a9cd..f5221ddd 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -369,8 +369,6 @@ def count_nonzero( "bitwise_left_shift", "bitwise_right_shift", "bitwise_invert", ] # fmt: skip __all__ += _aliases.__all__ -_all_ignore = ["array_namespace", "get_xp", "da", "np"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 3f40dffe..6b0cced9 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -18,4 +18,6 @@ rfftfreq = get_xp(da)(_fft.rfftfreq) __all__ = fft_all + ["fftfreq", "rfftfreq"] -_all_ignore = ["da", "fft_all", "get_xp", "warnings"] + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 0825386e..d15d6882 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -69,4 +69,5 @@ def svdvals(x: _Array) -> _Array: "cholesky", "matrix_rank", "matrix_norm", "svdvals", "vector_norm", "diagonal"] -_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings'] +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d8792611..d74e465e 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -157,7 +157,7 @@ def count_nonzero( else: unstack = get_xp(np)(_aliases.unstack) -__all__ = [ +__all__ = _aliases.__all__ + [ "__array_namespace_info__", "asarray", "astype", @@ -176,8 +176,6 @@ def count_nonzero( "count_nonzero", "pow", ] -__all__ += _aliases.__all__ -_all_ignore = ["np", "get_xp"] def __dir__() -> list[str]: diff --git a/array_api_compat/numpy/_typing.py b/array_api_compat/numpy/_typing.py index e771c788..b5fa188c 100644 --- a/array_api_compat/numpy/_typing.py +++ b/array_api_compat/numpy/_typing.py @@ -23,7 +23,6 @@ Array: TypeAlias = np.ndarray __all__ = ["Array", "DType", "Device"] -_all_ignore = ["np"] def __dir__() -> list[str]: diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 06875f00..a3470fb5 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,5 +1,4 @@ import numpy as np -from numpy.fft import __all__ as fft_all from numpy.fft import fft2, ifft2, irfft2, rfft2 from .._internal import get_xp @@ -21,15 +20,7 @@ ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = ["rfft2", "irfft2", "fft2", "ifft2"] -__all__ += _fft.__all__ - +__all__ = _fft.__all__ + ["rfft2", "irfft2", "fft2", "ifft2"] def __dir__() -> list[str]: return __all__ - - -del get_xp -del np -del fft_all -del _fft diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 2d3e731d..e99f0e39 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -120,7 +120,7 @@ def solve(x1: Array, x2: Array, /) -> Array: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = [ +__all__ = _linalg.__all__ + [ "LinAlgError", "cond", "det", @@ -132,12 +132,11 @@ def solve(x1: Array, x2: Array, /) -> Array: "matrix_power", "multi_dot", "norm", + "solve", "tensorinv", "tensorsolve", + "vector_norm", ] -__all__ += _linalg.__all__ -__all__ += ["solve", "vector_norm"] - def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 027a0261..bb1d3f14 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -841,5 +841,3 @@ def sign(x: Array, /) -> Array: 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype', 'take', 'take_along_axis', 'sign', 'finfo', 'iinfo', 'repeat'] - -_all_ignore = ['torch', 'get_xp'] diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 50e6a0d0..b4ba7358 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -81,5 +81,3 @@ def ifftshift( "fftshift", "ifftshift", ] - -_all_ignore = ['torch'] diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 70d72405..7d454511 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -113,9 +113,5 @@ def vector_norm( __all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] -_all_ignore = ['torch_linalg', 'sum'] - -del linalg_all - def __dir__() -> list[str]: return __all__ diff --git a/tests/test_all.py b/tests/test_all.py index 271cd189..ec262099 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -1,63 +1,259 @@ -""" -Test that files that define __all__ aren't missing any exports. +"""Test exported names""" -You can add names that shouldn't be exported to _all_ignore, like +import builtins -_all_ignore = ['sys'] +import numpy as np +import pytest -This is preferable to del-ing the names as this will break any name that is -used inside of a function. Note that names starting with an underscore are automatically ignored. -""" +from ._helpers import wrapped_libraries +NAMES = { + "": [ + # Inspection + "__array_api_version__", + "__array_namespace_info__", + # Constants + "e", + "inf", + "nan", + "newaxis", + "pi", + # Creation functions + "arange", + "asarray", + "empty", + "empty_like", + "eye", + "from_dlpack", + "full", + "full_like", + "linspace", + "meshgrid", + "ones", + "ones_like", + "tril", + "triu", + "zeros", + "zeros_like", + # Data Type Functions + "astype", + "can_cast", + "finfo", + "iinfo", + "isdtype", + "result_type", + # Data Types + "bool", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float32", + "float64", + "complex64", + "complex128", + # Elementwise Functions + "abs", + "acos", + "acosh", + "add", + "asin", + "asinh", + "atan", + "atan2", + "atanh", + "bitwise_and", + "bitwise_left_shift", + "bitwise_invert", + "bitwise_or", + "bitwise_right_shift", + "bitwise_xor", + "ceil", + "clip", + "conj", + "copysign", + "cos", + "cosh", + "divide", + "equal", + "exp", + "expm1", + "floor", + "floor_divide", + "greater", + "greater_equal", + "hypot", + "imag", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", + "log1p", + "log2", + "log10", + "logaddexp", + "logical_and", + "logical_not", + "logical_or", + "logical_xor", + "maximum", + "minimum", + "multiply", + "negative", + "nextafter", + "not_equal", + "positive", + "pow", + "real", + "reciprocal", + "remainder", + "round", + "sign", + "signbit", + "sin", + "sinh", + "square", + "sqrt", + "subtract", + "tan", + "tanh", + "trunc", + # Indexing Functions + "take", + "take_along_axis", + # Linear Algebra Functions + "matmul", + "matrix_transpose", + "tensordot", + "vecdot", + # Manipulation Functions + "broadcast_arrays", + "broadcast_to", + "concat", + "expand_dims", + "flip", + "moveaxis", + "permute_dims", + "repeat", + "reshape", + "roll", + "squeeze", + "stack", + "tile", + "unstack", + # Searching Functions + "argmax", + "argmin", + "count_nonzero", + "nonzero", + "searchsorted", + "where", + # Set functions + "unique_all", + "unique_counts", + "unique_inverse", + "unique_values", + # Sorting Functions + "argsort", + "sort", + # Statistical Functions + "cumulative_prod", + "cumulative_sum", + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", + # Utility Functions + "all", + "any", + "diff", + ], + "fft": [ + "fft", + "ifft", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", + ], + "linalg": [ + "cholesky", + "cross", + "det", + "diagonal", + "eigh", + "eigvalsh", + "inv", + "matmul", + "matrix_norm", + "matrix_power", + "matrix_rank", + "matrix_transpose", + "outer", + "pinv", + "qr", + "slogdet", + "solve", + "svd", + "svdvals", + "tensordot", + "trace", + "vecdot", + "vector_norm", + ], +} -import sys +XFAILS = { + ("numpy", ""): ["from_dlpack"] if np.__version__ < "1.23" else [], + ("dask.array", ""): ["from_dlpack", "take_along_axis"], + ("dask.array", "linalg"): [ + "cross", + "det", + "eigh", + "eigvalsh", + "matrix_power", + "pinv", + "slogdet", + ], +} -from ._helpers import import_, wrapped_libraries -import pytest -import typing - -TYPING_NAMES = frozenset(( - "Array", - "Device", - "DType", - "Namespace", - "NestedSequence", - "SupportsBufferProtocol", -)) - -@pytest.mark.parametrize("library", ["common"] + wrapped_libraries) -def test_all(library): - if library == "common": - import array_api_compat.common # noqa: F401 - else: - import_(library, wrapper=True) - - # NB: iterate over a copy to avoid a "dictionary size changed" error - for mod_name in sys.modules.copy(): - if not mod_name.startswith('array_api_compat.' + library): - continue - - module = sys.modules[mod_name] - - # TODO: We should define __all__ in the __init__.py files and test it - # there too. - if not hasattr(module, '__all__'): - continue - - dir_names = [n for n in dir(module) if not n.startswith('_')] - if '__array_namespace_info__' in dir(module): - dir_names.append('__array_namespace_info__') - ignore_all_names = set(getattr(module, '_all_ignore', ())) - ignore_all_names |= set(dir(typing)) - ignore_all_names |= {"annotations"} - if not module.__name__.endswith("._typing"): - ignore_all_names |= TYPING_NAMES - dir_names = set(dir_names) - set(ignore_all_names) - all_names = module.__all__ - - if set(dir_names) != set(all_names): - extra_dir = set(dir_names) - set(all_names) - extra_all = set(all_names) - set(dir_names) - assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}" - assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}" +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_dir(library, module): + """Test that dir() isn't missing any exports.""" + xp = pytest.importorskip(f"array_api_compat.{library}") + mod = getattr(xp, module) if module else xp + missing = set(NAMES[module]) - set(dir(mod)) + xfail = set(XFAILS.get((library, module), [])) + xpass = xfail - missing + fails = missing - xfail + assert not xpass, "Names in XFAILS are defined: %s" % xpass + assert not fails, "Missing exports: %s" % fails + + +@pytest.mark.parametrize( + "name", [name for name in NAMES[""] if hasattr(builtins, name)] +) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_builtins_collision(library, name): + """Test that xp.bool is not accidentally builtins.bool, etc.""" + xp = pytest.importorskip(f"array_api_compat.{library}") + assert getattr(xp, name) is not getattr(builtins, name) From 2f01c20677a06126dc1ab5d049545cfced4dc58f Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 21 Apr 2025 12:23:32 +0100 Subject: [PATCH 02/12] cp --- array_api_compat/cupy/_aliases.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index 01668939..b06bb9ae 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -160,3 +160,6 @@ def count_nonzero( 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', 'bool', 'concat', 'count_nonzero', 'pow', 'sign'] + +def __dir__() -> list[str]: + return __all__ From 978053a022d5012f06b7975c0b8902d35b7bcd48 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 21 Apr 2025 12:24:26 +0100 Subject: [PATCH 03/12] WIP test vs extra names --- tests/test_all.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/test_all.py b/tests/test_all.py index ec262099..886661f5 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -249,6 +249,59 @@ def test_dir(library, module): assert not fails, "Missing exports: %s" % fails +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_doesnt_hide_names(library, module): + """The base namespace can have more names than the ones explicitly exported + by array-api-compat. Test that we're not suppressing them. + """ + bare_xp = pytest.importorskip(library) + compat_xp = pytest.importorskip(f"array_api_compat.{library}") + bare_mod = getattr(bare_xp, module) if module else bare_xp + compat_mod = getattr(compat_xp, module) if module else compat_xp + aapi_names = set(NAMES[module]) + extra_names = { + name + for name in dir(bare_mod) + if not name.startswith("_") and name not in aapi_names + } + missing = extra_names - set(dir(compat_mod)) + + # These are spurious to begin with in the bare libraries + missing -= {"annotations", "importlib", "warnings", "operator", "sys", "Sequence"} + if module != "": + missing -= {"Array", "test"} + + assert not missing, "Non-Array API names have been hidden: %s" % missing + + +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_compat_spurious_names(library, module): + """Test that array-api-compat isn't adding non-Array API names + to the namespace. + """ + bare_xp = pytest.importorskip(library) + compat_xp = pytest.importorskip(f"array_api_compat.{library}") + bare_mod = getattr(bare_xp, module) if module else bare_xp + compat_mod = getattr(compat_xp, module) if module else compat_xp + aapi_names = set(NAMES[module]) + compat_spurious_names = ( + set(dir(compat_mod)) + - set(dir(bare_mod)) + - aapi_names + - {"__all__"} + ) + # Quietly ignore *Result dataclasses + compat_spurious_names = { + name for name in compat_spurious_names if not name.endswith("Result") + } + + assert not compat_spurious_names, ( + "array-api-compat is adding non-Array API names: %s" % compat_spurious_names + ) + + @pytest.mark.parametrize( "name", [name for name in NAMES[""] if hasattr(builtins, name)] ) From 3a7a1f7c2be4ba6883769d5ea4f8dbc08f0da630 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Mon, 21 Apr 2025 12:39:57 +0100 Subject: [PATCH 04/12] Extra test --- tests/test_all.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/test_all.py b/tests/test_all.py index ec262099..721007e0 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -249,6 +249,31 @@ def test_dir(library, module): assert not fails, "Missing exports: %s" % fails +@pytest.mark.parametrize("module", list(NAMES)) +@pytest.mark.parametrize("library", wrapped_libraries) +def test_all(library, module): + """Test that __all__ isn't missing any exports.""" + modname = f"array_api_compat.{library}" + pytest.importorskip(modname) + if module: + modname += f".{module}" + + objs = {} + exec(f"from {modname} import *", objs) + + missing = set(NAMES[module]) - objs.keys() + xfail = set(XFAILS.get((library, module), [])) + + # FIXME are these canonically meant to be in __all__? + if module == "": + xfail |= {"__array_namespace_info__", "__array_api_version__"} + + xpass = xfail - missing + fails = missing - xfail + assert not xpass, "Names in XFAILS are defined: %s" % xpass + assert not fails, "Missing exports: %s" % fails + + @pytest.mark.parametrize( "name", [name for name in NAMES[""] if hasattr(builtins, name)] ) From c08918b6e11681cc24acb578d1ab834be3f4d179 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 15:34:18 +0100 Subject: [PATCH 05/12] lint --- array_api_compat/numpy/fft.py | 2 +- tests/test_all.py | 5 +---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index a3470fb5..ee1f0d5d 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -20,7 +20,7 @@ ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = _fft.__all__ + ["rfft2", "irfft2", "fft2", "ifft2"] +__all__ = _fft.__all__ + ["fft2", "ifft2", "irfft2", "rfft2"] def __dir__() -> list[str]: return __all__ diff --git a/tests/test_all.py b/tests/test_all.py index 0eb9752c..2b5f272e 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -312,10 +312,7 @@ def test_compat_spurious_names(library, module): compat_mod = getattr(compat_xp, module) if module else compat_xp aapi_names = set(NAMES[module]) compat_spurious_names = ( - set(dir(compat_mod)) - - set(dir(bare_mod)) - - aapi_names - - {"__all__"} + set(dir(compat_mod)) - set(dir(bare_mod)) - aapi_names - {"__all__"} ) # Quietly ignore *Result dataclasses compat_spurious_names = { From 20740afe3577a315e9937b753195993937c101c1 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 16:49:41 +0100 Subject: [PATCH 06/12] Rework test and imports --- array_api_compat/cupy/__init__.py | 13 +++- array_api_compat/cupy/_aliases.py | 3 +- array_api_compat/dask/array/__init__.py | 15 +++++ array_api_compat/dask/array/_aliases.py | 2 - array_api_compat/dask/array/fft.py | 9 +-- array_api_compat/dask/array/linalg.py | 4 +- array_api_compat/numpy/__init__.py | 19 ++++-- array_api_compat/numpy/_aliases.py | 2 - array_api_compat/numpy/fft.py | 6 +- array_api_compat/numpy/linalg.py | 22 ++----- array_api_compat/torch/__init__.py | 17 ++++- array_api_compat/torch/_aliases.py | 3 +- tests/test_all.py | 88 +++++++++---------------- 13 files changed, 100 insertions(+), 103 deletions(-) diff --git a/array_api_compat/cupy/__init__.py b/array_api_compat/cupy/__init__.py index 9a30f95d..af003c5a 100644 --- a/array_api_compat/cupy/__init__.py +++ b/array_api_compat/cupy/__init__.py @@ -1,3 +1,4 @@ +from typing import Final from cupy import * # noqa: F403 # from cupy import * doesn't overwrite these builtin names @@ -5,9 +6,19 @@ # These imports may overwrite names from the import * above. from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__ = '2024.12' +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + {name for name in globals() if not name.startswith("__")} + - {"Final", "_aliases", "_info", "_typing"} + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index b06bb9ae..b556e98b 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -7,7 +7,6 @@ from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol from .._internal import get_xp -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType bool = cp.bool_ @@ -155,7 +154,7 @@ def count_nonzero( else: unstack = get_xp(cp)(_aliases.unstack) -__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype', +__all__ = _aliases.__all__ + ['asarray', 'astype', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', 'bitwise_right_shift', diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index 1e47b960..dc0e1404 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,12 +1,27 @@ from typing import Final +import dask.array as da from dask.array import * # noqa: F403 # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 __array_api_version__: Final = "2024.12" # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') + +def _make_all(base): + return sorted( + set(base) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} + ) + +__all__ = _make_all(da.__all__) + +def __dir__() -> list[str]: + return _make_all(dir(da)) diff --git a/array_api_compat/dask/array/_aliases.py b/array_api_compat/dask/array/_aliases.py index f5221ddd..ce15da3b 100644 --- a/array_api_compat/dask/array/_aliases.py +++ b/array_api_compat/dask/array/_aliases.py @@ -41,7 +41,6 @@ NestedSequence, SupportsBufferProtocol, ) -from ._info import __array_namespace_info__ isdtype = get_xp(np)(_aliases.isdtype) unstack = get_xp(da)(_aliases.unstack) @@ -355,7 +354,6 @@ def count_nonzero( __all__ = [ - "__array_namespace_info__", "count_nonzero", "bool", "int8", "int16", "int32", "int64", diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 6b0cced9..35951863 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -1,13 +1,10 @@ from dask.array.fft import * # noqa: F403 # dask.array.fft doesn't have __all__. If it is added, replace this with -# -# from dask.array.fft import __all__ as linalg_all -_n = {} +# from dask.array.fft import __all__ as fft_all +_n: dict[str, object] = {} exec('from dask.array.fft import *', _n) -for k in ("__builtins__", "Sequence", "annotations", "warnings"): - _n.pop(k, None) fft_all = list(_n) -del _n, k +del _n from ...common import _fft from ..._internal import get_xp diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index d15d6882..6b8185c7 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -20,10 +20,8 @@ # from dask.array.linalg import __all__ as linalg_all _n = {} exec('from dask.array.linalg import *', _n) -for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'): - _n.pop(k, None) linalg_all = list(_n) -del _n, k +del _n EighResult = _linalg.EighResult QRResult = _linalg.QRResult diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index f7b558ba..87624764 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,6 +1,7 @@ # ruff: noqa: PLC0414 from typing import Final +import numpy as np from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] # from numpy import * doesn't overwrite these builtin names @@ -10,7 +11,9 @@ from numpy import round as round # These imports may overwrite names from the import * above. +from . import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # Don't know why, but we have to do an absolute import to import linalg. If we # instead do @@ -23,13 +26,15 @@ __import__(__package__ + ".fft") -from ..common._helpers import * # noqa: F403 from .linalg import matrix_transpose, vecdot # noqa: F401 -try: - # Used in asarray(). Not present in older versions. - from numpy import _CopyMode # noqa: F401 -except ImportError: - pass - __array_api_version__: Final = "2024.12" + +__all__ = sorted( + set(np.__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index d74e465e..292bd510 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -9,7 +9,6 @@ from .._internal import get_xp from ..common import _aliases, _helpers from ..common._typing import NestedSequence, SupportsBufferProtocol -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType if TYPE_CHECKING: @@ -158,7 +157,6 @@ def count_nonzero( unstack = get_xp(np)(_aliases.unstack) __all__ = _aliases.__all__ + [ - "__array_namespace_info__", "asarray", "astype", "acos", diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index ee1f0d5d..3a80808e 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,5 +1,5 @@ import numpy as np -from numpy.fft import fft2, ifft2, irfft2, rfft2 +from numpy.fft import * # noqa: F403 from .._internal import get_xp from ..common import _fft @@ -20,7 +20,7 @@ ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = _fft.__all__ + ["fft2", "ifft2", "irfft2", "rfft2"] +__all__ = sorted(set(np.fft.__all__) | set(_fft.__all__)) def __dir__() -> list[str]: - return __all__ + return sorted(set(dir(np.fft)) | set(_fft.__all__)) diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index e99f0e39..a37463e7 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -7,22 +7,7 @@ import numpy as np -# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__` -from numpy.linalg import ( - LinAlgError, - cond, - det, - eig, - eigvals, - eigvalsh, - inv, - lstsq, - matrix_power, - multi_dot, - norm, - tensorinv, - tensorsolve, -) +from numpy.linalg import * # noqa: F403 from .._internal import get_xp from ..common import _linalg @@ -120,7 +105,7 @@ def solve(x1: Array, x2: Array, /) -> Array: vector_norm = get_xp(np)(_linalg.vector_norm) -__all__ = _linalg.__all__ + [ +_all = [ "LinAlgError", "cond", "det", @@ -137,6 +122,7 @@ def solve(x1: Array, x2: Array, /) -> Array: "tensorsolve", "vector_norm", ] +__all__ = sorted(set(np.linalg.__all__) | set(_linalg.__all__) | set(_all)) def __dir__() -> list[str]: - return __all__ + return sorted(set(dir(np.linalg)) | set(_linalg.__all__) | set(_all)) diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 69fd19ce..10bf7a74 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,6 +1,9 @@ +from typing import Final + from torch import * # noqa: F403 # Several names are not included in the above import * +_torch_all = set() import torch for n in dir(torch): if (n.startswith('_') @@ -10,13 +13,25 @@ or 'backward' in n): continue exec(f"{n} = torch.{n}") + _torch_all.add(n) del n # These imports may overwrite names from the import * above. +import _aliases from ._aliases import * # noqa: F403 +from ._info import __array_namespace_info__ # noqa: F401 # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__array_api_version__ = '2024.12' +__array_api_version__: Final = '2024.12' + +__all__ = sorted( + set(_torch_all) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index bb1d3f14..95a64d82 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -9,7 +9,6 @@ from .._internal import get_xp from ..common import _aliases from ..common._typing import NestedSequence, SupportsBufferProtocol -from ._info import __array_namespace_info__ from ._typing import Array, Device, DType _int_dtypes = { @@ -824,7 +823,7 @@ def sign(x: Array, /) -> Array: return out -__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast', +__all__ = ['asarray', 'result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero', diff --git a/tests/test_all.py b/tests/test_all.py index 2b5f272e..f645ad7b 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -235,48 +235,37 @@ } -@pytest.mark.parametrize("module", list(NAMES)) -@pytest.mark.parametrize("library", wrapped_libraries) -def test_dir(library, module): - """Test that dir() isn't missing any exports.""" - xp = pytest.importorskip(f"array_api_compat.{library}") - mod = getattr(xp, module) if module else xp - missing = set(NAMES[module]) - set(dir(mod)) - xfail = set(XFAILS.get((library, module), [])) - xpass = xfail - missing - fails = missing - xfail - assert not xpass, "Names in XFAILS are defined: %s" % xpass - assert not fails, "Missing exports: %s" % fails +def all_names(mod): + """Return all names imported by `from mod import *`. + This is typically `__all__` but, if not defined, Python + implements automated fallbacks. + """ + objs = {} + exec(f"from {mod.__name__} import *", objs) + return list(objs) +@pytest.mark.parametrize("func", [all_names, dir]) @pytest.mark.parametrize("module", list(NAMES)) @pytest.mark.parametrize("library", wrapped_libraries) -def test_all(library, module): - """Test that __all__ isn't missing any exports.""" - modname = f"array_api_compat.{library}" - pytest.importorskip(modname) - if module: - modname += f".{module}" - - objs = {} - exec(f"from {modname} import *", objs) - - missing = set(NAMES[module]) - objs.keys() +def test_array_api_names(library, module, func): + """Test that __all__ and dir() aren't missing any exports + dictated by the Standard. + """ + xp = pytest.importorskip(f"array_api_compat.{library}") + mod = getattr(xp, module) if module else xp + missing = set(NAMES[module]) - set(func(mod)) xfail = set(XFAILS.get((library, module), [])) - - # FIXME are these canonically meant to be in __all__? - if module == "": - xfail |= {"__array_namespace_info__", "__array_api_version__"} - xpass = xfail - missing fails = missing - xfail - assert not xpass, "Names in XFAILS are defined: %s" % xpass - assert not fails, "Missing exports: %s" % fails + assert not xpass, f"Names in XFAILS are defined: {xpass}" + assert not fails, f"Missing exports: {fails}" +@pytest.mark.parametrize("func", [all_names, dir]) @pytest.mark.parametrize("module", list(NAMES)) @pytest.mark.parametrize("library", wrapped_libraries) -def test_compat_doesnt_hide_names(library, module): +def test_compat_doesnt_hide_names(library, module, func): """The base namespace can have more names than the ones explicitly exported by array-api-compat. Test that we're not suppressing them. """ @@ -284,43 +273,30 @@ def test_compat_doesnt_hide_names(library, module): compat_xp = pytest.importorskip(f"array_api_compat.{library}") bare_mod = getattr(bare_xp, module) if module else bare_xp compat_mod = getattr(compat_xp, module) if module else compat_xp - aapi_names = set(NAMES[module]) - extra_names = { - name - for name in dir(bare_mod) - if not name.startswith("_") and name not in aapi_names - } - missing = extra_names - set(dir(compat_mod)) - - # These are spurious to begin with in the bare libraries - missing -= {"annotations", "importlib", "warnings", "operator", "sys", "Sequence"} - if module != "": - missing -= {"Array", "test"} - assert not missing, "Non-Array API names have been hidden: %s" % missing + missing = set(func(bare_mod)) - set(func(compat_mod)) + missing = {name for name in missing if not name.startswith("_")} + assert not missing, f"Non-Array API names have been hidden: {missing}" +@pytest.mark.parametrize("func", [all_names, dir]) @pytest.mark.parametrize("module", list(NAMES)) @pytest.mark.parametrize("library", wrapped_libraries) -def test_compat_spurious_names(library, module): - """Test that array-api-compat isn't adding non-Array API names - to the namespace. +def test_compat_doesnt_add_names(library, module, func): + """Test that array-api-compat isn't adding names to the namespace + besides those defined by the Array API Standard. """ bare_xp = pytest.importorskip(library) compat_xp = pytest.importorskip(f"array_api_compat.{library}") bare_mod = getattr(bare_xp, module) if module else bare_xp compat_mod = getattr(compat_xp, module) if module else compat_xp + aapi_names = set(NAMES[module]) - compat_spurious_names = ( - set(dir(compat_mod)) - set(dir(bare_mod)) - aapi_names - {"__all__"} - ) + spurious = set(func(compat_mod)) - set(func(bare_mod)) - aapi_names - {"__all__"} # Quietly ignore *Result dataclasses - compat_spurious_names = { - name for name in compat_spurious_names if not name.endswith("Result") - } - - assert not compat_spurious_names, ( - "array-api-compat is adding non-Array API names: %s" % compat_spurious_names + spurious = {name for name in spurious if not name.endswith("Result")} + assert not spurious, ( + f"array-api-compat is adding non-Array API names: {spurious}" ) From 9345e5f338c8efb717e3f41c2d2f8fb6cd47e52b Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 16:56:48 +0100 Subject: [PATCH 07/12] test --- tests/test_all.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/test_all.py b/tests/test_all.py index f645ad7b..7b7068f7 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -245,6 +245,13 @@ def all_names(mod): return list(objs) +def get_mod(library, module, *, compat): + if compat: + library = f"array_api_compat.{library}" + xp = pytest.importorskip(library) + return getattr(xp, module) if module else xp + + @pytest.mark.parametrize("func", [all_names, dir]) @pytest.mark.parametrize("module", list(NAMES)) @pytest.mark.parametrize("library", wrapped_libraries) @@ -252,8 +259,7 @@ def test_array_api_names(library, module, func): """Test that __all__ and dir() aren't missing any exports dictated by the Standard. """ - xp = pytest.importorskip(f"array_api_compat.{library}") - mod = getattr(xp, module) if module else xp + mod = get_mod(library, module, compat=True) missing = set(NAMES[module]) - set(func(mod)) xfail = set(XFAILS.get((library, module), [])) xpass = xfail - missing @@ -269,10 +275,8 @@ def test_compat_doesnt_hide_names(library, module, func): """The base namespace can have more names than the ones explicitly exported by array-api-compat. Test that we're not suppressing them. """ - bare_xp = pytest.importorskip(library) - compat_xp = pytest.importorskip(f"array_api_compat.{library}") - bare_mod = getattr(bare_xp, module) if module else bare_xp - compat_mod = getattr(compat_xp, module) if module else compat_xp + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) missing = set(func(bare_mod)) - set(func(compat_mod)) missing = {name for name in missing if not name.startswith("_")} @@ -286,10 +290,8 @@ def test_compat_doesnt_add_names(library, module, func): """Test that array-api-compat isn't adding names to the namespace besides those defined by the Array API Standard. """ - bare_xp = pytest.importorskip(library) - compat_xp = pytest.importorskip(f"array_api_compat.{library}") - bare_mod = getattr(bare_xp, module) if module else bare_xp - compat_mod = getattr(compat_xp, module) if module else compat_xp + bare_mod = get_mod(library, module, compat=False) + compat_mod = get_mod(library, module, compat=True) aapi_names = set(NAMES[module]) spurious = set(func(compat_mod)) - set(func(bare_mod)) - aapi_names - {"__all__"} From c7024c4378c16a835f730da01faae8801b376080 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 17:14:39 +0100 Subject: [PATCH 08/12] fixes --- array_api_compat/torch/__init__.py | 17 +++++++++++------ array_api_compat/torch/fft.py | 8 ++++++-- tests/test_all.py | 5 +++++ 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 10bf7a74..4d416c1a 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -3,21 +3,25 @@ from torch import * # noqa: F403 # Several names are not included in the above import * -_torch_all = set() +_torch_dir = set() import torch for n in dir(torch): if (n.startswith('_') or n.endswith('_') - or 'cuda' in n - or 'cpu' in n or 'backward' in n): continue exec(f"{n} = torch.{n}") - _torch_all.add(n) + _torch_dir.add(n) del n +# torch.__all__ is wildly incorrect +_n: dict[str, object] = {} +exec('from torch import *', _n) +_torch_all = set(_n) +del _n + # These imports may overwrite names from the import * above. -import _aliases +from . import _aliases from ._aliases import * # noqa: F403 from ._info import __array_namespace_info__ # noqa: F401 @@ -31,7 +35,8 @@ set(_torch_all) | set(_aliases.__all__) | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} + | {"from_dlpack"} ) def __dir__() -> list[str]: - return __all__ + return sorted(set(__all__) | set(_torch_dir)) diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index b4ba7358..a8e41ec0 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -73,11 +73,15 @@ def ifftshift( return torch.fft.ifftshift(x, dim=axes, **kwargs) -__all__ = torch.fft.__all__ + [ +_all = { "fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift", -] +} +__all__ = sorted(set(torch.fft.__all__) |_all) + +def __dir__() -> list[str]: + return sorted(set(dir(torch.fft)) | _all) diff --git a/tests/test_all.py b/tests/test_all.py index 7b7068f7..12be6258 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -12,6 +12,9 @@ # Inspection "__array_api_version__", "__array_namespace_info__", + # Submodules + "fft", + "linalg", # Constants "e", "inf", @@ -240,6 +243,8 @@ def all_names(mod): This is typically `__all__` but, if not defined, Python implements automated fallbacks. """ + # Note: this method also makes the test trip if a name is + # in __all__ but doesn't actually appear in the module. objs = {} exec(f"from {mod.__name__} import *", objs) return list(objs) From 9088c6888271cc6dc30fdeaa013bc26f21037d95 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 17:49:01 +0100 Subject: [PATCH 09/12] v3 --- array_api_compat/cupy/fft.py | 7 +++--- array_api_compat/cupy/linalg.py | 6 ++--- array_api_compat/dask/array/__init__.py | 22 ++++++++--------- array_api_compat/numpy/fft.py | 8 ++++-- array_api_compat/numpy/linalg.py | 7 ++++-- array_api_compat/torch/fft.py | 19 +++++++------- tests/test_all.py | 33 +++++++++++-------------- 7 files changed, 49 insertions(+), 53 deletions(-) diff --git a/array_api_compat/cupy/fft.py b/array_api_compat/cupy/fft.py index 307e0f72..06af566b 100644 --- a/array_api_compat/cupy/fft.py +++ b/array_api_compat/cupy/fft.py @@ -30,7 +30,6 @@ __all__ = fft_all + _fft.__all__ -del get_xp -del cp -del fft_all -del _fft +def __dir__() -> list[str]: + return __all__ + diff --git a/array_api_compat/cupy/linalg.py b/array_api_compat/cupy/linalg.py index 7fcdd498..cd94be84 100644 --- a/array_api_compat/cupy/linalg.py +++ b/array_api_compat/cupy/linalg.py @@ -43,7 +43,5 @@ __all__ = linalg_all + _linalg.__all__ -del get_xp -del cp -del linalg_all -del _linalg +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index dc0e1404..a022f236 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,27 +1,25 @@ from typing import Final -import dask.array as da from dask.array import * # noqa: F403 +# The above is missing a wealth of stuff +import dask.array as da +__all__ = [n for n in dir(da) if not n.startswith("_")] +globals().update({n: getattr(da, n) for n in __all__}) +del da + # These imports may overwrite names from the import * above. from . import _aliases from ._aliases import * # noqa: F403 from ._info import __array_namespace_info__ # noqa: F401 __array_api_version__: Final = "2024.12" +del Final # See the comment in the numpy __init__.py __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -def _make_all(base): - return sorted( - set(base) - | set(_aliases.__all__) - | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} - ) - -__all__ = _make_all(da.__all__) - -def __dir__() -> list[str]: - return _make_all(dir(da)) +__all__ += _aliases.__all__ +__all__ += ["__array_api_version__", "__array_namespace_info__", "linalg", "fft"] +__all__ = sorted(set(__all__)) diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 3a80808e..339f66fc 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,6 +1,9 @@ import numpy as np from numpy.fft import * # noqa: F403 +__all__ = [n for n in dir(np.fft) if not n.startswith("_")] +globals().update({n: getattr(np.fft, n) for n in __all__}) + from .._internal import get_xp from ..common import _fft @@ -20,7 +23,8 @@ ifftshift = get_xp(np)(_fft.ifftshift) -__all__ = sorted(set(np.fft.__all__) | set(_fft.__all__)) +__all__ = sorted(set(__all__) | set(_fft.__all__)) def __dir__() -> list[str]: - return sorted(set(dir(np.fft)) | set(_fft.__all__)) + return __all__ + diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index a37463e7..e70e8364 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -16,6 +16,9 @@ from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 from ._typing import Array +__all__ = [n for n in dir(np.linalg) if not n.startswith("_")] +globals().update({n: getattr(np.linalg, n) for n in __all__}) + cross = get_xp(np)(_linalg.cross) outer = get_xp(np)(_linalg.outer) EighResult = _linalg.EighResult @@ -122,7 +125,7 @@ def solve(x1: Array, x2: Array, /) -> Array: "tensorsolve", "vector_norm", ] -__all__ = sorted(set(np.linalg.__all__) | set(_linalg.__all__) | set(_all)) +__all__ = sorted(set(__all__) | set(_linalg.__all__) | set(_all)) def __dir__() -> list[str]: - return sorted(set(dir(np.linalg)) | set(_linalg.__all__) | set(_all)) + return __all__ diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index a8e41ec0..6c52e038 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -8,6 +8,10 @@ from ._typing import Array +# The above is missing a wealth of stuff +__all__ = [n for n in dir(torch.fft) if not n.startswith("_")] +globals().update({n: getattr(torch.fft, n) for n in __all__}) + # Several torch fft functions do not map axes to dim def fftn( @@ -73,15 +77,10 @@ def ifftshift( return torch.fft.ifftshift(x, dim=axes, **kwargs) -_all = { - "fftn", - "ifftn", - "rfftn", - "irfftn", - "fftshift", - "ifftshift", -} -__all__ = sorted(set(torch.fft.__all__) |_all) +__all__ = sorted( + set(__all__) + | {"fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"} +) def __dir__() -> list[str]: - return sorted(set(dir(torch.fft)) | _all) + return __all__ diff --git a/tests/test_all.py b/tests/test_all.py index 12be6258..4c1d0e3d 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -21,7 +21,7 @@ "nan", "newaxis", "pi", - # Creation functions + # Creation Functions "arange", "asarray", "empty", @@ -157,7 +157,7 @@ "nonzero", "searchsorted", "where", - # Set functions + # Set Functions "unique_all", "unique_counts", "unique_inverse", @@ -239,15 +239,13 @@ def all_names(mod): - """Return all names imported by `from mod import *`. - This is typically `__all__` but, if not defined, Python - implements automated fallbacks. - """ - # Note: this method also makes the test trip if a name is - # in __all__ but doesn't actually appear in the module. + """Return all names available in a module.""" objs = {} exec(f"from {mod.__name__} import *", objs) - return list(objs) + for n in dir(mod): + if not n.startswith("_") and hasattr(mod, n): + objs[n] = getattr(mod, n) + return set(objs) def get_mod(library, module, *, compat): @@ -257,15 +255,14 @@ def get_mod(library, module, *, compat): return getattr(xp, module) if module else xp -@pytest.mark.parametrize("func", [all_names, dir]) @pytest.mark.parametrize("module", list(NAMES)) @pytest.mark.parametrize("library", wrapped_libraries) -def test_array_api_names(library, module, func): - """Test that __all__ and dir() aren't missing any exports +def test_array_api_names(library, module): + """Test that __all__ isn't missing any exports dictated by the Standard. """ mod = get_mod(library, module, compat=True) - missing = set(NAMES[module]) - set(func(mod)) + missing = set(NAMES[module]) - all_names(mod) xfail = set(XFAILS.get((library, module), [])) xpass = xfail - missing fails = missing - xfail @@ -273,25 +270,23 @@ def test_array_api_names(library, module, func): assert not fails, f"Missing exports: {fails}" -@pytest.mark.parametrize("func", [all_names, dir]) @pytest.mark.parametrize("module", list(NAMES)) @pytest.mark.parametrize("library", wrapped_libraries) -def test_compat_doesnt_hide_names(library, module, func): +def test_compat_doesnt_hide_names(library, module): """The base namespace can have more names than the ones explicitly exported by array-api-compat. Test that we're not suppressing them. """ bare_mod = get_mod(library, module, compat=False) compat_mod = get_mod(library, module, compat=True) - missing = set(func(bare_mod)) - set(func(compat_mod)) + missing = all_names(bare_mod) - all_names(compat_mod) missing = {name for name in missing if not name.startswith("_")} assert not missing, f"Non-Array API names have been hidden: {missing}" -@pytest.mark.parametrize("func", [all_names, dir]) @pytest.mark.parametrize("module", list(NAMES)) @pytest.mark.parametrize("library", wrapped_libraries) -def test_compat_doesnt_add_names(library, module, func): +def test_compat_doesnt_add_names(library, module): """Test that array-api-compat isn't adding names to the namespace besides those defined by the Array API Standard. """ @@ -299,7 +294,7 @@ def test_compat_doesnt_add_names(library, module, func): compat_mod = get_mod(library, module, compat=True) aapi_names = set(NAMES[module]) - spurious = set(func(compat_mod)) - set(func(bare_mod)) - aapi_names - {"__all__"} + spurious = all_names(compat_mod) - all_names(bare_mod) - aapi_names # Quietly ignore *Result dataclasses spurious = {name for name in spurious if not name.endswith("Result")} assert not spurious, ( From 86dc4a05d22bd2d584740746116378d20faa1639 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 18:10:24 +0100 Subject: [PATCH 10/12] v4 --- array_api_compat/_internal.py | 15 ++++++++++++++- array_api_compat/dask/array/__init__.py | 19 ++++++++++--------- array_api_compat/dask/array/fft.py | 12 ++++-------- array_api_compat/dask/array/linalg.py | 25 +++++++++---------------- array_api_compat/numpy/__init__.py | 11 +++-------- array_api_compat/numpy/fft.py | 6 +++--- array_api_compat/numpy/linalg.py | 11 +++++------ array_api_compat/torch/__init__.py | 25 ++++--------------------- array_api_compat/torch/fft.py | 11 +++-------- array_api_compat/torch/linalg.py | 11 ++++------- 10 files changed, 59 insertions(+), 87 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index cd8d939f..a9520fdf 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -2,6 +2,7 @@ Internal helpers """ +import importlib from collections.abc import Callable from functools import wraps from inspect import signature @@ -52,8 +53,20 @@ def wrapped_f(*args: object, **kwargs: object) -> object: return inner -__all__ = ["get_xp"] +def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]: + """Import everything from module, updating globals(). + Returns __all__. + """ + mod = importlib.import_module(mod_name) + all_ = [] + for n in dir(mod): + if not n.startswith("_") and hasattr(mod, n): + all_.append(n) + globals_[n] = getattr(mod, n) + return all_ + +__all__ = ["get_xp", "clone_module"] def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/dask/array/__init__.py b/array_api_compat/dask/array/__init__.py index a022f236..fb1e0b94 100644 --- a/array_api_compat/dask/array/__init__.py +++ b/array_api_compat/dask/array/__init__.py @@ -1,12 +1,8 @@ from typing import Final -from dask.array import * # noqa: F403 +from ..._internal import clone_module -# The above is missing a wealth of stuff -import dask.array as da -__all__ = [n for n in dir(da) if not n.startswith("_")] -globals().update({n: getattr(da, n) for n in __all__}) -del da +__all__ = clone_module("dask.array", globals()) # These imports may overwrite names from the import * above. from . import _aliases @@ -20,6 +16,11 @@ __import__(__package__ + '.linalg') __import__(__package__ + '.fft') -__all__ += _aliases.__all__ -__all__ += ["__array_api_version__", "__array_namespace_info__", "linalg", "fft"] -__all__ = sorted(set(__all__)) +__all__ = sorted( + set(__all__) + | set(_aliases.__all__) + | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} +) + +def __dir__() -> list[str]: + return __all__ diff --git a/array_api_compat/dask/array/fft.py b/array_api_compat/dask/array/fft.py index 35951863..44b68e73 100644 --- a/array_api_compat/dask/array/fft.py +++ b/array_api_compat/dask/array/fft.py @@ -1,10 +1,6 @@ -from dask.array.fft import * # noqa: F403 -# dask.array.fft doesn't have __all__. If it is added, replace this with -# from dask.array.fft import __all__ as fft_all -_n: dict[str, object] = {} -exec('from dask.array.fft import *', _n) -fft_all = list(_n) -del _n +from ..._internal import clone_module + +__all__ = clone_module("dask.array.fft", globals()) from ...common import _fft from ..._internal import get_xp @@ -14,7 +10,7 @@ fftfreq = get_xp(da)(_fft.fftfreq) rfftfreq = get_xp(da)(_fft.rfftfreq) -__all__ = fft_all + ["fftfreq", "rfftfreq"] +__all__ += ["fftfreq", "rfftfreq"] def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/dask/array/linalg.py b/array_api_compat/dask/array/linalg.py index 6b8185c7..7c80620c 100644 --- a/array_api_compat/dask/array/linalg.py +++ b/array_api_compat/dask/array/linalg.py @@ -8,20 +8,13 @@ from dask.array import matmul, outer, tensordot # Exports -from dask.array.linalg import * # noqa: F403 - -from ..._internal import get_xp +from ..._internal import clone_module, get_xp from ...common import _linalg from ...common._typing import Array as _Array -from ._aliases import matrix_transpose, vecdot -# dask.array.linalg doesn't have __all__. If it is added, replace this with -# -# from dask.array.linalg import __all__ as linalg_all -_n = {} -exec('from dask.array.linalg import *', _n) -linalg_all = list(_n) -del _n +__all__ = clone_module("dask.array.linalg", globals()) + +from ._aliases import matrix_transpose, vecdot EighResult = _linalg.EighResult QRResult = _linalg.QRResult @@ -61,11 +54,11 @@ def svdvals(x: _Array) -> _Array: vector_norm = get_xp(da)(_linalg.vector_norm) diagonal = get_xp(da)(_linalg.diagonal) -__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot", - "matrix_transpose", "vecdot", "EighResult", - "QRResult", "SlogdetResult", "SVDResult", "qr", - "cholesky", "matrix_rank", "matrix_norm", "svdvals", - "vector_norm", "diagonal"] +__all__ += ["trace", "outer", "matmul", "tensordot", + "matrix_transpose", "vecdot", "EighResult", + "QRResult", "SlogdetResult", "SVDResult", "qr", + "cholesky", "matrix_rank", "matrix_norm", "svdvals", + "vector_norm", "diagonal"] def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 87624764..9812bebb 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -1,14 +1,9 @@ # ruff: noqa: PLC0414 from typing import Final -import numpy as np -from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary] +from .._internal import clone_module -# from numpy import * doesn't overwrite these builtin names -from numpy import abs as abs -from numpy import max as max -from numpy import min as min -from numpy import round as round +__all__ = clone_module("numpy", globals()) # These imports may overwrite names from the import * above. from . import _aliases @@ -31,7 +26,7 @@ __array_api_version__: Final = "2024.12" __all__ = sorted( - set(np.__all__) + set(__all__) | set(_aliases.__all__) | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} ) diff --git a/array_api_compat/numpy/fft.py b/array_api_compat/numpy/fft.py index 339f66fc..a492feb8 100644 --- a/array_api_compat/numpy/fft.py +++ b/array_api_compat/numpy/fft.py @@ -1,8 +1,8 @@ import numpy as np -from numpy.fft import * # noqa: F403 -__all__ = [n for n in dir(np.fft) if not n.startswith("_")] -globals().update({n: getattr(np.fft, n) for n in __all__}) +from .._internal import clone_module + +__all__ = clone_module("numpy.fft", globals()) from .._internal import get_xp from ..common import _fft diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index e70e8364..328f2ea7 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -7,18 +7,17 @@ import numpy as np -from numpy.linalg import * # noqa: F403 - -from .._internal import get_xp +from .._internal import clone_module, get_xp from ..common import _linalg +from .._internal import clone_module + +__all__ = clone_module("numpy.linalg", globals()) + # These functions are in both the main and linalg namespaces from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401 from ._typing import Array -__all__ = [n for n in dir(np.linalg) if not n.startswith("_")] -globals().update({n: getattr(np.linalg, n) for n in __all__}) - cross = get_xp(np)(_linalg.cross) outer = get_xp(np)(_linalg.outer) EighResult = _linalg.EighResult diff --git a/array_api_compat/torch/__init__.py b/array_api_compat/torch/__init__.py index 4d416c1a..6cbb6ec2 100644 --- a/array_api_compat/torch/__init__.py +++ b/array_api_compat/torch/__init__.py @@ -1,24 +1,8 @@ from typing import Final -from torch import * # noqa: F403 +from .._internal import clone_module -# Several names are not included in the above import * -_torch_dir = set() -import torch -for n in dir(torch): - if (n.startswith('_') - or n.endswith('_') - or 'backward' in n): - continue - exec(f"{n} = torch.{n}") - _torch_dir.add(n) -del n - -# torch.__all__ is wildly incorrect -_n: dict[str, object] = {} -exec('from torch import *', _n) -_torch_all = set(_n) -del _n +__all__ = clone_module("torch", globals()) # These imports may overwrite names from the import * above. from . import _aliases @@ -32,11 +16,10 @@ __array_api_version__: Final = '2024.12' __all__ = sorted( - set(_torch_all) + set(__all__) | set(_aliases.__all__) | {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"} - | {"from_dlpack"} ) def __dir__() -> list[str]: - return sorted(set(__all__) | set(_torch_dir)) + return __all__ diff --git a/array_api_compat/torch/fft.py b/array_api_compat/torch/fft.py index 6c52e038..242c92db 100644 --- a/array_api_compat/torch/fft.py +++ b/array_api_compat/torch/fft.py @@ -4,13 +4,11 @@ import torch import torch.fft -from torch.fft import * # noqa: F403 from ._typing import Array +from .._internal import clone_module -# The above is missing a wealth of stuff -__all__ = [n for n in dir(torch.fft) if not n.startswith("_")] -globals().update({n: getattr(torch.fft, n) for n in __all__}) +__all__ = clone_module("torch.fft", globals()) # Several torch fft functions do not map axes to dim @@ -77,10 +75,7 @@ def ifftshift( return torch.fft.ifftshift(x, dim=axes, **kwargs) -__all__ = sorted( - set(__all__) - | {"fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"} -) +__all__ += ["fftn", "ifftn", "rfftn", "irfftn", "fftshift", "ifftshift"] def __dir__() -> list[str]: return __all__ diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index 7d454511..e08e1324 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -3,12 +3,9 @@ import torch from typing import Optional, Union, Tuple -from torch.linalg import * # noqa: F403 +from .._internal import clone_module -# torch.linalg doesn't define __all__ -# from torch.linalg import __all__ as linalg_all -from torch import linalg as torch_linalg -linalg_all = [i for i in dir(torch_linalg) if not i.startswith('_')] +__all__ = clone_module("torch.linalg", globals()) # outer is implemented in torch but aren't in the linalg namespace from torch import outer @@ -110,8 +107,8 @@ def vector_norm( return out return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs) -__all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot', - 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] +__all__ += ['outer', 'matmul', 'matrix_transpose', 'tensordot', + 'cross', 'vecdot', 'solve', 'trace', 'vector_norm'] def __dir__() -> list[str]: return __all__ From 59e03eeefb60abdc83799f88688baeb0be488f7c Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 18:12:12 +0100 Subject: [PATCH 11/12] nits --- array_api_compat/numpy/linalg.py | 2 -- array_api_compat/torch/linalg.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/array_api_compat/numpy/linalg.py b/array_api_compat/numpy/linalg.py index 328f2ea7..ca540880 100644 --- a/array_api_compat/numpy/linalg.py +++ b/array_api_compat/numpy/linalg.py @@ -10,8 +10,6 @@ from .._internal import clone_module, get_xp from ..common import _linalg -from .._internal import clone_module - __all__ = clone_module("numpy.linalg", globals()) # These functions are in both the main and linalg namespaces diff --git a/array_api_compat/torch/linalg.py b/array_api_compat/torch/linalg.py index e08e1324..df94c351 100644 --- a/array_api_compat/torch/linalg.py +++ b/array_api_compat/torch/linalg.py @@ -1,6 +1,7 @@ from __future__ import annotations import torch +import torch.linalg from typing import Optional, Union, Tuple from .._internal import clone_module @@ -27,7 +28,7 @@ def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if not (x1.shape[axis] == x2.shape[axis] == 3): raise ValueError(f"cross product axis must have size 3, got {x1.shape[axis]} and {x2.shape[axis]}") x1, x2 = torch.broadcast_tensors(x1, x2) - return torch_linalg.cross(x1, x2, dim=axis) + return torch.linalg.cross(x1, x2, dim=axis) def vecdot(x1: Array, x2: Array, /, *, axis: int = -1, **kwargs) -> Array: from ._aliases import isdtype From 1fad2936a1890ccd5446419ad5354f7c4171fa32 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 23 Apr 2025 18:27:53 +0100 Subject: [PATCH 12/12] v5 --- array_api_compat/_internal.py | 13 +++++++++---- array_api_compat/numpy/__init__.py | 3 +++ tests/test_all.py | 7 +++---- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/array_api_compat/_internal.py b/array_api_compat/_internal.py index a9520fdf..ebe9bc5d 100644 --- a/array_api_compat/_internal.py +++ b/array_api_compat/_internal.py @@ -58,12 +58,17 @@ def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]: Returns __all__. """ mod = importlib.import_module(mod_name) - all_ = [] + # Neither of these two methods is sufficient by itself, + # depending on various idiosyncrasies of the libraries we're wrapping. + objs = {} + exec(f"from {mod.__name__} import *", objs) + for n in dir(mod): if not n.startswith("_") and hasattr(mod, n): - all_.append(n) - globals_[n] = getattr(mod, n) - return all_ + objs[n] = getattr(mod, n) + + globals_.update(objs) + return list(objs) __all__ = ["get_xp", "clone_module"] diff --git a/array_api_compat/numpy/__init__.py b/array_api_compat/numpy/__init__.py index 9812bebb..c08aa14f 100644 --- a/array_api_compat/numpy/__init__.py +++ b/array_api_compat/numpy/__init__.py @@ -3,6 +3,9 @@ from .._internal import clone_module +# This needs to be loaded explicitly before cloning +import numpy.typing # noqa: F401 + __all__ = clone_module("numpy", globals()) # These imports may overwrite names from the import * above. diff --git a/tests/test_all.py b/tests/test_all.py index 4c1d0e3d..c36aef67 100644 --- a/tests/test_all.py +++ b/tests/test_all.py @@ -5,6 +5,8 @@ import numpy as np import pytest +from array_api_compat._internal import clone_module + from ._helpers import wrapped_libraries NAMES = { @@ -241,10 +243,7 @@ def all_names(mod): """Return all names available in a module.""" objs = {} - exec(f"from {mod.__name__} import *", objs) - for n in dir(mod): - if not n.startswith("_") and hasattr(mod, n): - objs[n] = getattr(mod, n) + clone_module(mod.__name__, objs) return set(objs)