diff --git a/cubed/array_api/creation_functions.py b/cubed/array_api/creation_functions.py index b1d1902fd..02b657055 100644 --- a/cubed/array_api/creation_functions.py +++ b/cubed/array_api/creation_functions.py @@ -218,18 +218,18 @@ def linspace( step=step, endpoint=endpoint, linspace_dtype=dtype, + device=device, ) -def _linspace(x, size, start, step, endpoint, linspace_dtype, block_id=None): +def _linspace(x, size, start, step, endpoint, linspace_dtype, device=None, block_id=None): bs = x.shape[0] i = block_id[0] adjusted_bs = bs - 1 if endpoint else bs - blockstart = start + (i * size * step) - blockstop = blockstart + (adjusted_bs * step) - return nxp.linspace( - blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype - ) + float_ = default_dtypes(device=device)['real floating'] + blockstart = float_(start + (i * size * step)) + blockstop = float_(blockstart + float_(adjusted_bs * step)) + return nxp.linspace(blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype) def meshgrid(*arrays, indexing="xy") -> List["Array"]: diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index 9cf825819..95cf8b315 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -113,17 +113,17 @@ def test_eye(spec, k): @pytest.mark.parametrize("endpoint", [True, False]) def test_linspace(spec, endpoint): - a = xp.linspace(6, 49, 50, endpoint=endpoint, chunks=5, spec=spec) - npa = np.linspace(6, 49, 50, endpoint=endpoint) - assert_allclose(a, npa) + a = xp.linspace(6, 49, 50, endpoint=endpoint, chunks=5, spec=spec, dtype=xp.float32) + npa = np.linspace(6, 49, 50, endpoint=endpoint, dtype=np.float32) + assert_allclose(a, npa, rtol=1e-5) - a = xp.linspace(1.4, 4.9, 13, endpoint=endpoint, chunks=5, spec=spec) - npa = np.linspace(1.4, 4.9, 13, endpoint=endpoint) - assert_allclose(a, npa) + a = xp.linspace(1.4, 4.9, 13, endpoint=endpoint, chunks=5, spec=spec, dtype=xp.float32) + npa = np.linspace(1.4, 4.9, 13, endpoint=endpoint, dtype=np.float32) + assert_allclose(a, npa, rtol=1e-5) - a = xp.linspace(0, 0, 0, endpoint=endpoint) - npa = np.linspace(0, 0, 0, endpoint=endpoint) - assert_allclose(a, npa) + a = xp.linspace(0, 0, 0, endpoint=endpoint, dtype=xp.float32) + npa = np.linspace(0, 0, 0, endpoint=endpoint, dtype=np.float32) + assert_allclose(a, npa, rtol=1e-5) def test_ones(spec, executor):