Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlx - rnn #20786

Merged
merged 2 commits into from
Jan 20, 2025
Merged

mlx - rnn #20786

merged 2 commits into from
Jan 20, 2025

Conversation

acsweet
Copy link

@acsweet acsweet commented Jan 20, 2025

This addresses #19571

All tests in keras/src/layers/rnn/ (excluding those with convolutions) are passing!

A few notes:

  • Modified mlx.numpy.flip as slice did not work with mlx arrays
  • I tried [::-1] for simple flipping, but it failed to propagate gradients with the bidirectional rnn layers
  • I used jax's qr factorization in mlx.linalg.qr as mlx's implementation only works on square matrices for now (is numpy preferred in this case?)

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome -- thank you for the contribution!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Jan 20, 2025
@fchollet fchollet merged commit 0ad4c78 into keras-team:mlx Jan 20, 2025
4 of 11 checks passed
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Jan 20, 2025
@fchollet
Copy link
Collaborator

@acsweet do you have an email address?

@acsweet
Copy link
Author

acsweet commented Jan 21, 2025

@fchollet yes, it's andrewcsweet@gmail.com

@acsweet acsweet deleted the mlx-rnn branch February 2, 2025 11:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants