Skip to content

Commit

Permalink
I think I got the linspace tests passing. Need to check if array equa…
Browse files Browse the repository at this point in the history
…lity tolerances are OK.
  • Loading branch information
alxmrs committed Jul 23, 2024
1 parent 10b22b9 commit 2590e22
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
12 changes: 6 additions & 6 deletions cubed/array_api/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,18 @@ def linspace(
step=step,
endpoint=endpoint,
linspace_dtype=dtype,
device=device,
)


def _linspace(x, size, start, step, endpoint, linspace_dtype, block_id=None):
def _linspace(x, size, start, step, endpoint, linspace_dtype, device=None, block_id=None):
bs = x.shape[0]
i = block_id[0]
adjusted_bs = bs - 1 if endpoint else bs
blockstart = start + (i * size * step)
blockstop = blockstart + (adjusted_bs * step)
return nxp.linspace(
blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype
)
float_ = default_dtypes(device=device)['real floating']
blockstart = float_(start + (i * size * step))
blockstop = float_(blockstart + float_(adjusted_bs * step))
return nxp.linspace(blockstart, blockstop, bs, endpoint=endpoint, dtype=linspace_dtype)


def meshgrid(*arrays, indexing="xy") -> List["Array"]:
Expand Down
18 changes: 9 additions & 9 deletions cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,17 @@ def test_eye(spec, k):

@pytest.mark.parametrize("endpoint", [True, False])
def test_linspace(spec, endpoint):
a = xp.linspace(6, 49, 50, endpoint=endpoint, chunks=5, spec=spec)
npa = np.linspace(6, 49, 50, endpoint=endpoint)
assert_allclose(a, npa)
a = xp.linspace(6, 49, 50, endpoint=endpoint, chunks=5, spec=spec, dtype=xp.float32)
npa = np.linspace(6, 49, 50, endpoint=endpoint, dtype=np.float32)
assert_allclose(a, npa, rtol=1e-5)

a = xp.linspace(1.4, 4.9, 13, endpoint=endpoint, chunks=5, spec=spec)
npa = np.linspace(1.4, 4.9, 13, endpoint=endpoint)
assert_allclose(a, npa)
a = xp.linspace(1.4, 4.9, 13, endpoint=endpoint, chunks=5, spec=spec, dtype=xp.float32)
npa = np.linspace(1.4, 4.9, 13, endpoint=endpoint, dtype=np.float32)
assert_allclose(a, npa, rtol=1e-5)

a = xp.linspace(0, 0, 0, endpoint=endpoint)
npa = np.linspace(0, 0, 0, endpoint=endpoint)
assert_allclose(a, npa)
a = xp.linspace(0, 0, 0, endpoint=endpoint, dtype=xp.float32)
npa = np.linspace(0, 0, 0, endpoint=endpoint, dtype=np.float32)
assert_allclose(a, npa, rtol=1e-5)


def test_ones(spec, executor):
Expand Down

0 comments on commit 2590e22

Please sign in to comment.