Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ops): Add keras.ops.numpy.rot90 operation (#20723) #20745

Merged
merged 5 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@
from keras.src.backend.jax.core import convert_to_tensor


def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the specified plane."""
if array.ndim < 2:
raise ValueError(
f"Input array must have at least 2 dimensions. Received: array.ndim={array.ndim}"
)
if len(axes) != 2 or axes[0] == axes[1]:
raise ValueError(
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
)
return jnp.rot90(array, k=k, axes=axes)


@sparse.elementwise_binary_union(linear=True, use_sparsify=True)
def add(x1, x2):
x1 = convert_to_tensor(x1)
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
from keras.src.backend.numpy.core import convert_to_tensor


def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the specified plane."""
if array.ndim < 2:
raise ValueError(
f"Input array must have at least 2 dimensions. Received: array.ndim={array.ndim}"
)
if len(axes) != 2 or axes[0] == axes[1]:
raise ValueError(
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
)
return np.rot90(array, k=k, axes=axes)


def add(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
Expand Down
43 changes: 43 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,49 @@
from keras.src.backend.tensorflow.core import shape as shape_op


def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the specified plane."""
array = convert_to_tensor(array)

if array.shape.rank < 2:
raise ValueError(
f"Input array must have at least 2 dimensions. Received: array.ndim={array.shape.rank}"
)

if len(axes) != 2 or axes[0] == axes[1]:
raise ValueError(
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
)

k = k % 4
if k == 0:
return array

axes = tuple(axis if axis >= 0 else array.shape.rank + axis for axis in axes)

perm = [i for i in range(array.shape.rank) if i not in axes]
perm.extend(axes)
array = tf.transpose(array, perm)

shape = tf.shape(array)
non_rot_shape = shape[:-2]
rot_shape = shape[-2:]

array = tf.reshape(array, tf.concat([[-1], rot_shape], axis=0))

for _ in range(k):
array = tf.transpose(array, [0, 2, 1])
array = tf.reverse(array, axis=[1])
array = tf.reshape(array, tf.concat([non_rot_shape, rot_shape], axis=0))

inv_perm = [0] * len(perm)
for i, p in enumerate(perm):
inv_perm[p] = i
array = tf.transpose(array, inv_perm)

return array


@sparse.elementwise_binary_union(tf.sparse.add)
def add(x1, x2):
if not isinstance(x1, (int, float)):
Expand Down
34 changes: 34 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,40 @@
)


def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the specified plane using PyTorch.

Args:
array: Input tensor
k: Number of 90-degree rotations (default=1)
axes: Tuple of two axes that define the plane of rotation (default=(0,1))

Returns:
Rotated tensor
"""
array = convert_to_tensor(array)

if array.ndim < 2:
raise ValueError(
f"Input array must have at least 2 dimensions. Received: array.ndim={array.ndim}"
)
if len(axes) != 2 or axes[0] == axes[1]:
raise ValueError(
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
)

axes = tuple(axis if axis >= 0 else array.ndim + axis for axis in axes)

if not builtins.all(0 <= axis < array.ndim for axis in axes):
raise ValueError(f"Invalid axes {axes} for tensor with {array.ndim} dimensions")

rotated = torch.rot90(array, k=k, dims=axes)
if isinstance(array, np.ndarray):
rotated = rotated.cpu().numpy()

return rotated


def add(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
63 changes: 63 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,69 @@
from keras.src.ops.operation_utils import reduce_shape


class Rot90(Operation):
def __init__(self, k=1, axes=(0, 1)):
super().__init__()
self.k = k
self.axes = axes

def call(self, array):
return backend.numpy.rot90(array, k=self.k, axes=self.axes)

def compute_output_spec(self, array):
array_shape = list(array.shape)
if len(array_shape) < 2:
raise ValueError(
"Input array must have at least 2 dimensions. "
f"Received: array.shape={array_shape}"
)
if len(self.axes) != 2 or self.axes[0] == self.axes[1]:
raise ValueError(
f"Invalid axes: {self.axes}. Axes must be a tuple of two different dimensions."
)
axis1, axis2 = self.axes
array_shape[axis1], array_shape[axis2] = array_shape[axis2], array_shape[axis1]
return KerasTensor(shape=array_shape, dtype=array.dtype)


@keras_export(["keras.ops.rot90", "keras.ops.numpy.rot90"])
def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the plane specified by axes.

This function rotates an array counterclockwise by 90 degrees `k` times
in the plane specified by `axes`. Supports arrays of two or more dimensions.

Args:
array: Input array to rotate.
k: Number of times the array is rotated by 90 degrees.
axes: A tuple of two integers specifying the plane for rotation.

Returns:
Rotated array.

Examples:

>>> import numpy as np
>>> from keras import ops
>>> m = np.array([[1, 2], [3, 4]])
>>> rotated = ops.rot90(m)
>>> rotated
array([[2, 4],
[1, 3]])

>>> m = np.arange(8).reshape((2, 2, 2))
>>> rotated = ops.rot90(m, k=1, axes=(1, 2))
>>> rotated
array([[[1, 3],
[0, 2]],
[[5, 7],
[4, 6]]])
"""
if any_symbolic_tensors((array,)):
return Rot90(k=k, axes=axes).symbolic_call(array)
return backend.numpy.rot90(array, k=k, axes=axes)


def shape_equal(shape1, shape2, axis=None, allow_none=True):
"""Check if two shapes are equal.

Expand Down
67 changes: 67 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,73 @@
from keras.src.testing.test_utils import named_product


class NumPyTestRot90(testing.TestCase):
def test_basic(self):
array = np.array([[1, 2], [3, 4]])
rotated = knp.rot90(array)
expected = np.array([[2, 4], [1, 3]])
assert np.array_equal(rotated, expected), f"Failed basic 2D test: {rotated}"

def test_multiple_k(self):
array = np.array([[1, 2], [3, 4]])

# k=2 (180 degrees rotation)
rotated = knp.rot90(array, k=2)
expected = np.array([[4, 3], [2, 1]])
assert np.array_equal(rotated, expected), f"Failed k=2 test: {rotated}"

# k=3 (270 degrees rotation)
rotated = knp.rot90(array, k=3)
expected = np.array([[3, 1], [4, 2]])
assert np.array_equal(rotated, expected), f"Failed k=3 test: {rotated}"

# k=4 (full rotation)
rotated = knp.rot90(array, k=4)
expected = array
assert np.array_equal(rotated, expected), f"Failed k=4 test: {rotated}"

def test_axes(self):
array = np.arange(8).reshape((2, 2, 2))
rotated = knp.rot90(array, k=1, axes=(1, 2))
expected = np.array([[[1, 3], [0, 2]], [[5, 7], [4, 6]]])
assert np.array_equal(rotated, expected), f"Failed custom axes test: {rotated}"

def test_single_image(self):
array = np.random.random((4, 4, 3))
rotated = knp.rot90(array, k=1, axes=(0, 1))
expected = np.rot90(array, k=1, axes=(0, 1))
assert np.allclose(rotated, expected), "Failed single image test"

def test_batch_images(self):
array = np.random.random((2, 4, 4, 3))
rotated = knp.rot90(array, k=1, axes=(1, 2))
expected = np.rot90(array, k=1, axes=(1, 2))
assert np.allclose(rotated, expected), "Failed batch images test"

def test_invalid_axes(self):
array = np.array([[1, 2], [3, 4]])
try:
knp.rot90(array, axes=(0, 0))
except ValueError as e:
assert (
"Invalid axes: (0, 0). Axes must be a tuple of two different dimensions."
in str(e)
), f"Failed invalid axes test: {e}"
else:
raise AssertionError("Failed to raise error for invalid axes")

def test_invalid_rank(self):
array = np.array([1, 2, 3]) # 1D array
try:
knp.rot90(array)
except ValueError as e:
assert (
"Input array must have at least 2 dimensions." in str(e)
), f"Failed invalid rank test: {e}"
else:
raise AssertionError("Failed to raise error for invalid input rank")


class NumpyTwoInputOpsDynamicShapeTest(testing.TestCase):
def test_add(self):
x = KerasTensor((None, 3))
Expand Down
Loading