-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add max_pool and average_pool for MLX #20814
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
keras/src/backend/mlx/nn.py
Outdated
raise ValueError( | ||
f"To extract sliding windows the window shapes and strides must " | ||
f"have the same number of spatial dimensions as the signal but " | ||
f"the signal has {len(spatial_dims)} dims and the window shape " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: lines that don't have {}
s in them don't need to start with f
. This applies to all string lines in the PR.
In this error message, please format the argument information via name={value}
, e.g.
"input_shape={x.shape}, window_shape={window_shape}, strides={window_strides}"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
keras/src/backend/mlx/nn.py
Outdated
raise ValueError( | ||
"To extract sliding windows, the lengths of window_shape and " | ||
"window_strides must be equal to the signal's spatial dimensions. " | ||
f"However, the signal has spatial_dims={len(spatial_dims)} while " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For all args, rather than printing just len it's useful to print the full arg value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
This PR implements support for max_pool and average_pool operations in the MLX backend.
To achieve this functionality the following functions were copied from mlx repository:
_non_overlapping_sliding_windows and _sliding_windows
The tests related to pooling in:
were passed