-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(ops): Add keras.ops.numpy.rot90 operation (#20723) #20745
feat(ops): Add keras.ops.numpy.rot90 operation (#20723) #20745
Conversation
Adds a new operation to rotate tensors by 90 degrees in the specified plane: - Implements rot90 operation in keras.ops.image module - Adds support for multiple rotations (k parameter) and custom axes - Matches numpy.rot90 behavior and API for consistency - Adds comprehensive test coverage including batch images support - Handles input validation for tensor dimensions and axes - Supports symbolic tensor execution The operation follows the same interface as numpy.rot90 and tf.image.rot90: rot90(array, k=1, axes=(0, 1))
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20745 +/- ##
==========================================
- Coverage 81.95% 81.94% -0.01%
==========================================
Files 553 553
Lines 51458 51530 +72
Branches 7961 7977 +16
==========================================
+ Hits 42174 42228 +54
- Misses 7346 7361 +15
- Partials 1938 1941 +3
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
Add implementations of rot90() for multiple backend frameworks: - JAX backend implementation - NumPy backend implementation - PyTorch backend implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
rot90
is a numpy op: https://numpy.org/doc/stable/reference/generated/numpy.rot90.html
So please move the code to e.g. keras/src/backend/jax/numpy.py
(and so on) and export the op to keras.ops.rot90
as well as keras.ops.numpy.rot90
.
Move rot90 operation to numpy.py files in backend implementations since it's a numpy op (https://numpy.org/doc/stable/reference/generated/numpy.rot90.html). Now exported as both keras.ops.rot90 and keras.ops.numpy.rot90.
Resolved the 'Invalid dtype: object' error by explicitly using to avoid naming conflicts with the custom function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
Replace tf.experimental.numpy.rot90 with core TF ops for XLA compatibility. Use convert_to_tensor for input handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! LGTM
Adds a new operation to rotate tensors by 90 degrees in the specified plane:
keras.ops.rot90
fortf.image.rot90
#20723