Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ops): Add keras.ops.numpy.rot90 operation (#20723) #20745

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix dtype conflict in PyTorch backend's rot90 function
Resolved the 'Invalid dtype: object' error by explicitly using  to avoid naming conflicts with the custom  function.
  • Loading branch information
harshaljanjani committed Jan 11, 2025
commit e9dfec290b4cfca38d167bf3c6298b55aaff447a
42 changes: 14 additions & 28 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,43 +36,29 @@ def rot90(array, k=1, axes=(0, 1)):
Returns:
Rotated tensor
"""
x = convert_to_tensor(array)

if x.ndim < 2:
if not isinstance(array, (np.ndarray, torch.Tensor)):
harshaljanjani marked this conversation as resolved.
Show resolved Hide resolved
array = np.asarray(array)
if isinstance(array, np.ndarray):
array = torch.from_numpy(array)
if array.ndim < 2:
raise ValueError(
f"Input array must have at least 2 dimensions. Received: array.ndim={x.ndim}"
f"Input array must have at least 2 dimensions. Received: array.ndim={array.ndim}"
)
if len(axes) != 2 or axes[0] == axes[1]:
raise ValueError(
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
)

k = k % 4
if k == 0:
return x

axes = tuple(axis if axis >= 0 else x.ndim + axis for axis in axes)

if not all(0 <= axis < x.ndim for axis in axes):
raise ValueError(f"Invalid axes {axes} for tensor with {x.ndim} dimensions")
axes = tuple(axis if axis >= 0 else array.ndim + axis for axis in axes)
# fix: all() method conflict, explicitly use builtins.all()
if not builtins.all(0 <= axis < array.ndim for axis in axes):
raise ValueError(f"Invalid axes {axes} for tensor with {array.ndim} dimensions")

for _ in range(k):
perm = list(range(x.ndim))
for i, axis in enumerate(axes):
perm.remove(axis)
perm.append(axis)
x = x.permute(perm)

x = torch.flip(x, dims=[-1])
x = x.transpose(-2, -1)

perm = list(range(x.ndim))
for i, axis in enumerate(axes):
perm.remove(x.ndim - 2 + i)
perm.insert(axis, x.ndim - 2 + i)
x = x.permute(perm)
rotated = torch.rot90(array, k=k, dims=axes)
if isinstance(array, np.ndarray):
rotated = rotated.cpu().numpy()

return x
return rotated


def add(x1, x2):
Expand Down
Loading