Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max_pool and average_pool for MLX #20814

Merged
merged 3 commits into from
Jan 28, 2025
Merged

Add max_pool and average_pool for MLX #20814

merged 3 commits into from
Jan 28, 2025

Conversation

fbadine
Copy link
Contributor

@fbadine fbadine commented Jan 26, 2025

This PR implements support for max_pool and average_pool operations in the MLX backend.
To achieve this functionality the following functions were copied from mlx repository:
_non_overlapping_sliding_windows and _sliding_windows

The tests related to pooling in:

  • eras/src/ops/nn_test.py
  • keras/src/layers/pooling

were passed

Copy link

google-cla bot commented Jan 26, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@fbadine fbadine mentioned this pull request Jan 26, 2025
62 tasks
raise ValueError(
f"To extract sliding windows the window shapes and strides must "
f"have the same number of spatial dimensions as the signal but "
f"the signal has {len(spatial_dims)} dims and the window shape "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: lines that don't have {}s in them don't need to start with f. This applies to all string lines in the PR.

In this error message, please format the argument information via name={value}, e.g.

"input_shape={x.shape}, window_shape={window_shape}, strides={window_strides}"

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

raise ValueError(
"To extract sliding windows, the lengths of window_shape and "
"window_strides must be equal to the signal's spatial dimensions. "
f"However, the signal has spatial_dims={len(spatial_dims)} while "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For all args, rather than printing just len it's useful to print the full arg value

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jan 28, 2025
@fchollet fchollet merged commit f0e9882 into keras-team:mlx Jan 28, 2025
1 of 7 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Jan 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants