Skip to content

Commit

Permalink
gh-439: a framework for testing array API compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Saransh-cpp committed Nov 26, 2024
1 parent cb89b88 commit 487989f
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 1 deletion.
15 changes: 14 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,23 @@ jobs:
env:
FORCE_COLOR: 1

- name: Run tests and generate coverage report
- name: Run NumPy tests and generate coverage report
run: nox -s coverage-${{ matrix.python-version }} --verbose
env:
FORCE_COLOR: 1
GLASS_ARRAY_BACKEND: numpy

- name: Run array API strict tests
run: nox -s doctests-${{ matrix.python-version }} --verbose
env:
FORCE_COLOR: 1
GLASS_ARRAY_BACKEND: array_api_strict

- name: Run JAX tests
run: nox -s doctests-${{ matrix.python-version }} --verbose
env:
FORCE_COLOR: 1
GLASS_ARRAY_BACKEND: jax

- name: Coveralls requires XML report
run: coverage xml
Expand Down
53 changes: 53 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,42 @@ following way -
python -m pytest --cov --doctest-plus
```

### Array API tests

One can specify a particular array backend for testing by setting the
`GLASS_ARRAY_BACKEND` environment variable. The default array backend is NumPy.
_GLASS_ can be tested with every supported array library available in the
environment by setting `GLASS_ARRAY_BACKEND` to `all`. The testing framework
only installs NumPy automatically; hence, remaining array libraries should
either be installed manually or developers should use `Nox`.

```bash
# run tests using numpy
python -m pytest
GLASS_ARRAY_BACKEND=numpy python -m pytest
# run tests using array_api_strict (should be installed manually)
GLASS_ARRAY_BACKEND=array_api_strict python -m pytest
# run tests using jax (should be installed manually)
GLASS_ARRAY_BACKEND=jax python -m pytest
# run tests using every supported array library available in the environment
GLASS_ARRAY_BACKEND=all python -m pytest
```

Moreover, one can mark a test to be compatible with the array API standard by
decorating it with `@array_api_compatible`. This will `parameterize` the test to
run on every array library specified through `GLASS_ARRAY_BACKEND` -

```py
import types
from tests.conftest import array_api_compatible


@array_api_compatible
def test_something(xp: types.ModuleType):
# use `xp.` to access the array library functionality
...
```

## Documenting

_GLASS_'s documentation is mainly written in the form of
Expand Down Expand Up @@ -173,6 +209,23 @@ syntax -
nox -s tests-3.11
```

One can specify a particular array backend for testing by setting the
`GLASS_ARRAY_BACKEND` environment variable. The default array backend is NumPy.
_GLASS_ can be tested with every supported array library by setting
`GLASS_ARRAY_BACKEND` to `all`.

```bash
# run tests using numpy
nox -s tests-3.11
GLASS_ARRAY_BACKEND=numpy nox -s tests-3.11
# run tests using array_api_strict
GLASS_ARRAY_BACKEND=array_api_strict nox -s tests-3.11
# run tests using jax
GLASS_ARRAY_BACKEND=jax nox -s tests-3.11
# run tests using every supported array library
GLASS_ARRAY_BACKEND=all nox -s tests-3.11
```

The following command can be used to deploy the docs on `localhost` -

```bash
Expand Down
10 changes: 10 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import os
from pathlib import Path

import nox
Expand Down Expand Up @@ -29,6 +30,15 @@ def lint(session: nox.Session) -> None:
def tests(session: nox.Session) -> None:
"""Run the unit tests."""
session.install("-c", ".github/test-constraints.txt", "-e", ".[test]")

array_backend = os.environ.get("GLASS_ARRAY_BACKEND")
if array_backend == "array_api_strict":
session.install("array_api_strict>=2")
elif array_backend == "jax":
session.install("jax>=0.4.32")
elif array_backend == "all":
session.install("array_api_strict>=2", "jax>=0.4.32")

session.run(
"pytest",
*session.posargs,
Expand Down
89 changes: 89 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,100 @@
import contextlib
import importlib.metadata
import os
import types

import numpy as np
import numpy.typing as npt
import packaging.version
import pytest

from cosmology import Cosmology

from glass import RadialWindow

# environment variable to specify array backends for testing
# can be:
# a particular array library (numpy, jax, array_api_strict, ...)
# all (try finding every supported array library available in the environment)
GLASS_ARRAY_BACKEND: str | bool = os.environ.get("GLASS_ARRAY_BACKEND", False)


def _check_version(lib: str, array_api_compliant_version: str) -> None:
"""
Check if installed library's version is compliant with the array API standard.
Parameters
----------
lib
name of the library.
array_api_compliant_version
version of the library compliant with the array API standard.
Raises
------
ImportError
If the installed version is not compliant with the array API standard.
"""
lib_version = packaging.version.Version(importlib.metadata.version(lib))
if lib_version < packaging.version.Version(array_api_compliant_version):
msg = f"{lib} must be >= {array_api_compliant_version}; found {lib_version}"
raise ImportError(msg)


def _import_and_add_numpy(xp_available_backends: dict[str, types.ModuleType]) -> None:
"""Add numpy to the backends dictionary."""
_check_version("numpy", "2.1.0")
xp_available_backends.update({"numpy": np})


def _import_and_add_array_api_strict(
xp_available_backends: dict[str, types.ModuleType],
) -> None:
"""Add array_api_strict to the backends dictionary."""
import array_api_strict

_check_version("array_api_strict", "2.0.0")
xp_available_backends.update({"array_api_strict": array_api_strict})
array_api_strict.set_array_api_strict_flags(api_version="2023.12")


def _import_and_add_jax(xp_available_backends: dict[str, types.ModuleType]) -> None:
"""Add jax to the backends dictionary."""
import jax

_check_version("jax", "0.4.32")
xp_available_backends.update({"jax.numpy": jax.numpy})
# enable 64 bit numbers
jax.config.update("jax_enable_x64", val=True)


# a dictionary with all array backends to test
xp_available_backends: dict[str, types.ModuleType] = {}

# if no backend passed, use numpy by default
if not GLASS_ARRAY_BACKEND or GLASS_ARRAY_BACKEND == "numpy":
_import_and_add_numpy(xp_available_backends)
elif GLASS_ARRAY_BACKEND == "array_api_strict":
_import_and_add_array_api_strict(xp_available_backends)
elif GLASS_ARRAY_BACKEND == "jax":
_import_and_add_jax(xp_available_backends)
# if all, try importing every backend
elif GLASS_ARRAY_BACKEND == "all":
with contextlib.suppress(ImportError):
_import_and_add_numpy(xp_available_backends)

with contextlib.suppress(ImportError):
_import_and_add_array_api_strict(xp_available_backends)

with contextlib.suppress(ImportError):
_import_and_add_jax(xp_available_backends)
else:
msg = f"unsupported array backend: {GLASS_ARRAY_BACKEND}"
raise ValueError(msg)

# use this as a decorator for tests involving array API compatible functions
array_api_compatible = pytest.mark.parametrize("xp", xp_available_backends.values())


@pytest.fixture(scope="session")
def cosmo() -> Cosmology:
Expand Down

0 comments on commit 487989f

Please sign in to comment.