Skip to content

Commit

Permalink
dpctl.tensor.tile returns a scalar for 0D (scalar) input and empty…
Browse files Browse the repository at this point in the history
… `repetitions` (#1628)

Previously, this case would return a 1D array of size 1, which did not match Numpy or the array API spec's expected behavior
  • Loading branch information
ndgrigorian authored Apr 1, 2024
1 parent 65bb9ef commit 03c7615
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 20 deletions.
12 changes: 0 additions & 12 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,18 +964,6 @@ def tile(x, repetitions):
f"Expected tuple or integer type, got {type(repetitions)}."
)

# case of scalar
if x.size == 1:
if not repetitions:
# handle empty tuple
repetitions = (1,)
return dpt.full(
repetitions,
x,
dtype=x.dtype,
usm_type=x.usm_type,
sycl_queue=x.sycl_queue,
)
rep_dims = len(repetitions)
x_dims = x.ndim
if rep_dims < x_dims:
Expand Down
20 changes: 12 additions & 8 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,20 +1378,24 @@ def test_tile_size_1():

reps = 5
# test for 0d array
x = dpt.asarray(2, dtype="i4")
res = dpt.tile(x, reps)
x1 = dpt.asarray(2, dtype="i4")
res = dpt.tile(x1, reps)
assert dpt.all(res == dpt.full(reps, 2, dtype="i4"))

# test for 1d array with single element
x = dpt.asarray([2], dtype="i4")
res = dpt.tile(x, reps)
x2 = dpt.asarray([2], dtype="i4")
res = dpt.tile(x2, reps)
assert dpt.all(res == dpt.full(reps, 2, dtype="i4"))

# test empty reps returns copy of input
reps = ()
res = dpt.tile(x, reps)
assert x.shape == res.shape
assert x == res
# test for gh-1627 behavior
res = dpt.tile(x1, reps)
assert x1.shape == res.shape
assert x1 == res

res = dpt.tile(x2, reps)
assert x2.shape == res.shape
assert x2 == res


def test_tile_prepends_axes():
Expand Down

0 comments on commit 03c7615

Please sign in to comment.