Skip to content

TST: Redesign test_all #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
20 changes: 19 additions & 1 deletion array_api_compat/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Internal helpers
"""

import importlib
from collections.abc import Callable
from functools import wraps
from inspect import signature
Expand Down Expand Up @@ -52,8 +53,25 @@ def wrapped_f(*args: object, **kwargs: object) -> object:
return inner


__all__ = ["get_xp"]
def clone_module(mod_name: str, globals_: dict[str, object]) -> list[str]:
"""Import everything from module, updating globals().
Returns __all__.
"""
mod = importlib.import_module(mod_name)
# Neither of these two methods is sufficient by itself,
# depending on various idiosyncrasies of the libraries we're wrapping.
objs = {}
exec(f"from {mod.__name__} import *", objs)

for n in dir(mod):
if not n.startswith("_") and hasattr(mod, n):
objs[n] = getattr(mod, n)

globals_.update(objs)
return list(objs)


__all__ = ["get_xp", "clone_module"]

def __dir__() -> list[str]:
return __all__
2 changes: 0 additions & 2 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,8 +720,6 @@ def iinfo(type_: DType | Array, /, xp: Namespace) -> Any:
"finfo",
"iinfo",
]
_all_ignore = ["inspect", "array_namespace", "NamedTuple"]


def __dir__() -> list[str]:
return __all__
2 changes: 0 additions & 2 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,5 @@ def is_lazy_array(x: object) -> bool:
"to_device",
]

_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings']

def __dir__() -> list[str]:
return __all__
2 changes: 0 additions & 2 deletions array_api_compat/common/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,6 @@ def trace(
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
'trace']

_all_ignore = ['math', 'normalize_axis_tuple', 'get_xp', 'np', 'isdtype']


def __dir__() -> list[str]:
return __all__
13 changes: 12 additions & 1 deletion array_api_compat/cupy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
from typing import Final
from cupy import * # noqa: F403

# from cupy import * doesn't overwrite these builtin names
from cupy import abs, max, min, round # noqa: F401

# These imports may overwrite names from the import * above.
from ._aliases import * # noqa: F403
from ._info import __array_namespace_info__ # noqa: F401

# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')

__array_api_version__ = '2024.12'
__array_api_version__: Final = '2024.12'

__all__ = sorted(
{name for name in globals() if not name.startswith("__")}
- {"Final", "_aliases", "_info", "_typing"}
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
)

def __dir__() -> list[str]:
return __all__
6 changes: 3 additions & 3 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ..common import _aliases, _helpers
from ..common._typing import NestedSequence, SupportsBufferProtocol
from .._internal import get_xp
from ._info import __array_namespace_info__
from ._typing import Array, Device, DType

bool = cp.bool_
Expand Down Expand Up @@ -141,10 +140,11 @@ def count_nonzero(
else:
unstack = get_xp(cp)(_aliases.unstack)

__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
__all__ = _aliases.__all__ + ['asarray', 'astype',
'acos', 'acosh', 'asin', 'asinh', 'atan',
'atan2', 'atanh', 'bitwise_left_shift',
'bitwise_invert', 'bitwise_right_shift',
'bool', 'concat', 'count_nonzero', 'pow', 'sign']

_all_ignore = ['cp', 'get_xp']
def __dir__() -> list[str]:
return __all__
1 change: 0 additions & 1 deletion array_api_compat/cupy/_typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

__all__ = ["Array", "DType", "Device"]
_all_ignore = ["cp"]

from typing import TYPE_CHECKING

Expand Down
7 changes: 3 additions & 4 deletions array_api_compat/cupy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

__all__ = fft_all + _fft.__all__

del get_xp
del cp
del fft_all
del _fft
def __dir__() -> list[str]:
return __all__

6 changes: 2 additions & 4 deletions array_api_compat/cupy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,5 @@

__all__ = linalg_all + _linalg.__all__

del get_xp
del cp
del linalg_all
del _linalg
def __dir__() -> list[str]:
return __all__
16 changes: 15 additions & 1 deletion array_api_compat/dask/array/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
from typing import Final

from dask.array import * # noqa: F403
from ..._internal import clone_module

__all__ = clone_module("dask.array", globals())

# These imports may overwrite names from the import * above.
from . import _aliases
from ._aliases import * # noqa: F403
from ._info import __array_namespace_info__ # noqa: F401

__array_api_version__: Final = "2024.12"
del Final

# See the comment in the numpy __init__.py
__import__(__package__ + '.linalg')
__import__(__package__ + '.fft')

__all__ = sorted(
set(__all__)
| set(_aliases.__all__)
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
)

def __dir__() -> list[str]:
return __all__
4 changes: 0 additions & 4 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
NestedSequence,
SupportsBufferProtocol,
)
from ._info import __array_namespace_info__

isdtype = get_xp(np)(_aliases.isdtype)
unstack = get_xp(da)(_aliases.unstack)
Expand Down Expand Up @@ -355,7 +354,6 @@ def count_nonzero(


__all__ = [
"__array_namespace_info__",
"count_nonzero",
"bool",
"int8", "int16", "int32", "int64",
Expand All @@ -369,8 +367,6 @@ def count_nonzero(
"bitwise_left_shift", "bitwise_right_shift", "bitwise_invert",
] # fmt: skip
__all__ += _aliases.__all__
_all_ignore = ["array_namespace", "get_xp", "da", "np"]


def __dir__() -> list[str]:
return __all__
19 changes: 7 additions & 12 deletions array_api_compat/dask/array/fft.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
from dask.array.fft import * # noqa: F403
# dask.array.fft doesn't have __all__. If it is added, replace this with
#
# from dask.array.fft import __all__ as linalg_all
_n = {}
exec('from dask.array.fft import *', _n)
for k in ("__builtins__", "Sequence", "annotations", "warnings"):
_n.pop(k, None)
fft_all = list(_n)
del _n, k
from ..._internal import clone_module

__all__ = clone_module("dask.array.fft", globals())

from ...common import _fft
from ..._internal import get_xp
Expand All @@ -17,5 +10,7 @@
fftfreq = get_xp(da)(_fft.fftfreq)
rfftfreq = get_xp(da)(_fft.rfftfreq)

__all__ = fft_all + ["fftfreq", "rfftfreq"]
_all_ignore = ["da", "fft_all", "get_xp", "warnings"]
__all__ += ["fftfreq", "rfftfreq"]

def __dir__() -> list[str]:
return __all__
30 changes: 11 additions & 19 deletions array_api_compat/dask/array/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,13 @@
from dask.array import matmul, outer, tensordot

# Exports
from dask.array.linalg import * # noqa: F403

from ..._internal import get_xp
from ..._internal import clone_module, get_xp
from ...common import _linalg
from ...common._typing import Array as _Array
from ._aliases import matrix_transpose, vecdot

# dask.array.linalg doesn't have __all__. If it is added, replace this with
#
# from dask.array.linalg import __all__ as linalg_all
_n = {}
exec('from dask.array.linalg import *', _n)
for k in ('__builtins__', 'annotations', 'operator', 'warnings', 'Array'):
_n.pop(k, None)
linalg_all = list(_n)
del _n, k
__all__ = clone_module("dask.array.linalg", globals())

from ._aliases import matrix_transpose, vecdot

EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
Expand Down Expand Up @@ -63,10 +54,11 @@ def svdvals(x: _Array) -> _Array:
vector_norm = get_xp(da)(_linalg.vector_norm)
diagonal = get_xp(da)(_linalg.diagonal)

__all__ = linalg_all + ["trace", "outer", "matmul", "tensordot",
"matrix_transpose", "vecdot", "EighResult",
"QRResult", "SlogdetResult", "SVDResult", "qr",
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
"vector_norm", "diagonal"]
__all__ += ["trace", "outer", "matmul", "tensordot",
"matrix_transpose", "vecdot", "EighResult",
"QRResult", "SlogdetResult", "SVDResult", "qr",
"cholesky", "matrix_rank", "matrix_norm", "svdvals",
"vector_norm", "diagonal"]

_all_ignore = ['get_xp', 'da', 'linalg_all', 'warnings']
def __dir__() -> list[str]:
return __all__
29 changes: 16 additions & 13 deletions array_api_compat/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# ruff: noqa: PLC0414
from typing import Final

from numpy import * # noqa: F403 # pyright: ignore[reportWildcardImportFromLibrary]
from .._internal import clone_module

# from numpy import * doesn't overwrite these builtin names
from numpy import abs as abs
from numpy import max as max
from numpy import min as min
from numpy import round as round
# This needs to be loaded explicitly before cloning
import numpy.typing # noqa: F401

__all__ = clone_module("numpy", globals())

# These imports may overwrite names from the import * above.
from . import _aliases
from ._aliases import * # noqa: F403
from ._info import __array_namespace_info__ # noqa: F401

# Don't know why, but we have to do an absolute import to import linalg. If we
# instead do
Expand All @@ -23,13 +24,15 @@

__import__(__package__ + ".fft")

from ..common._helpers import * # noqa: F403
from .linalg import matrix_transpose, vecdot # noqa: F401

try:
# Used in asarray(). Not present in older versions.
from numpy import _CopyMode # noqa: F401
except ImportError:
pass

__array_api_version__: Final = "2024.12"

__all__ = sorted(
set(__all__)
| set(_aliases.__all__)
| {"__array_api_version__", "__array_namespace_info__", "linalg", "fft"}
)

def __dir__() -> list[str]:
return __all__
6 changes: 1 addition & 5 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .._internal import get_xp
from ..common import _aliases, _helpers
from ..common._typing import NestedSequence, SupportsBufferProtocol
from ._info import __array_namespace_info__
from ._typing import Array, Device, DType

if TYPE_CHECKING:
Expand Down Expand Up @@ -157,8 +156,7 @@ def count_nonzero(
else:
unstack = get_xp(np)(_aliases.unstack)

__all__ = [
"__array_namespace_info__",
__all__ = _aliases.__all__ + [
"asarray",
"astype",
"acos",
Expand All @@ -176,8 +174,6 @@ def count_nonzero(
"count_nonzero",
"pow",
]
__all__ += _aliases.__all__
_all_ignore = ["np", "get_xp"]


def __dir__() -> list[str]:
Expand Down
1 change: 0 additions & 1 deletion array_api_compat/numpy/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
Array: TypeAlias = np.ndarray

__all__ = ["Array", "DType", "Device"]
_all_ignore = ["np"]


def __dir__() -> list[str]:
Expand Down
15 changes: 5 additions & 10 deletions array_api_compat/numpy/fft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
from numpy.fft import __all__ as fft_all
from numpy.fft import fft2, ifft2, irfft2, rfft2

from .._internal import clone_module

__all__ = clone_module("numpy.fft", globals())

from .._internal import get_xp
from ..common import _fft
Expand All @@ -21,15 +23,8 @@
ifftshift = get_xp(np)(_fft.ifftshift)


__all__ = ["rfft2", "irfft2", "fft2", "ifft2"]
__all__ += _fft.__all__

__all__ = sorted(set(__all__) | set(_fft.__all__))

def __dir__() -> list[str]:
return __all__


del get_xp
del np
del fft_all
del _fft
Loading
Loading