diff --git a/.github/workflows/jax-tests.yml b/.github/workflows/jax-tests.yml index 468fe593..809826c0 100644 --- a/.github/workflows/jax-tests.yml +++ b/.github/workflows/jax-tests.yml @@ -39,8 +39,7 @@ jobs: - name: Install run: | python -m pip install --upgrade pip - python -m pip install -e '.[test]' 'jax[cpu]' - python -m pip uninstall -y lithops # tests don't run on Lithops + python -m pip install -e '.[test-jax]' - name: Install for Mac if: matrix.os == 'macos-14' @@ -51,7 +50,8 @@ jobs: run: | # exclude tests that rely on structured types since JAX doesn't support these # exclude tests that rely on randomness because JAX is picky about this. - pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not random" + # TODO(#494): Turn back on tests that do visualization when the "FileNotFound" error is fixed. These are "visualization", "plan_scaling", and "optimization". + pytest -k "not argmax and not argmin and not mean and not apply_reduction and not broadcast_trick and not groupby and not random and not visualization and not plan_scaling and not optimization" env: CUBED_BACKEND_ARRAY_API_MODULE: jax.numpy JAX_ENABLE_X64: matrix.os != 'macos-14' diff --git a/cubed/tests/test_core.py b/cubed/tests/test_core.py index 4ce48349..11a391ec 100644 --- a/cubed/tests/test_core.py +++ b/cubed/tests/test_core.py @@ -383,7 +383,7 @@ def test_reduction_not_enough_memory(tmp_path): def test_partial_reduce(spec): - a = xp.asarray(np.arange(242).reshape((11, 22)), chunks=(3, 4), spec=spec) + a = xp.asarray(np.arange(242, dtype=np.int32).reshape((11, 22)), chunks=(3, 4), spec=spec, dtype=xp.int32) b = partial_reduce(a, np.sum, split_every={0: 2}) c = partial_reduce(b, np.sum, split_every={0: 2}) assert_array_equal( @@ -468,13 +468,13 @@ def test_compute_multiple_different_specs(tmp_path): def test_visualize(tmp_path): - a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=xp.float64, chunks=(2, 2)) + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=xp.float32, chunks=(2, 2)) b = cubed.random.random((3, 3), chunks=(2, 2)) c = xp.add(a, b) d = c.rechunk((3, 1)) e = c * 3 - f = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2)) + f = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), dtype=xp.float32) g = f * 4 assert not (tmp_path / "e.dot").exists() diff --git a/pyproject.toml b/pyproject.toml index 9c28bc33..689d3688 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,6 +110,15 @@ test-modal = [ "pytest-mock", ] +test-jax = [ + "cubed[diagnostics]", # modal tests separate due to conflicting package reqs + "dill", + "numpy_groupies", + "pytest", + "pytest-cov", + "pytest-mock", + "jax", +] [project.urls] homepage = "https://github.com/cubed-dev/cubed" documentation = "https://tomwhite.github.io/cubed"