Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLX Backend #19571

Open
20 of 62 tasks
lkarthee opened this issue Apr 20, 2024 · 27 comments
Open
20 of 62 tasks

MLX Backend #19571

lkarthee opened this issue Apr 20, 2024 · 27 comments
Labels
stat:contributions welcome A pull request to fix this issue would be welcome.

Comments

@lkarthee
Copy link
Contributor

lkarthee commented Apr 20, 2024

Issue for tracking and coordinating mlx backend work:

mlx.math

mlx.numpy

mlx.image

mlx.nn

  • max_pool
  • avg_pool
  • conv
  • depthwise_conv
  • separable_conv
  • conv_transpose
  • ctc_loss

mlx.rnn

  • rnn
  • lstm
  • gru

mlx.linalg

mlx.core

@lkarthee
Copy link
Contributor Author

lkarthee commented Apr 20, 2024

PyTest Output
=========================================================================== test session starts ============================================================================
platform darwin -- Python 3.12.2, pytest-8.1.1, pluggy-1.4.0 -- /Users/kartheek/erlang-ws/github-ws/latest/keras/.venv/bin/python3.12
cachedir: .pytest_cache
rootdir: /Users/kartheek/erlang-ws/github-ws/latest/keras
configfile: pyproject.toml
plugins: cov-5.0.0
collected 6 items

keras/src/ops/operation_test.py::OperationTest::test_autoconfig PASSED                                                                                               [ 16%]
keras/src/ops/operation_test.py::OperationTest::test_eager_call PASSED                                                                                               [ 33%]
keras/src/ops/operation_test.py::OperationTest::test_input_conversion FAILED                                                                                         [ 50%]
keras/src/ops/operation_test.py::OperationTest::test_serialization PASSED                                                                                            [ 66%]
keras/src/ops/operation_test.py::OperationTest::test_symbolic_call PASSED                                                                                            [ 83%]
keras/src/ops/operation_test.py::OperationTest::test_valid_naming PASSED                                                                                             [100%]

================================================================================= FAILURES =================================================================================
___________________________________________________________________ OperationTest.test_input_conversion ____________________________________________________________________

self = <keras.src.ops.operation_test.OperationTest testMethod=test_input_conversion>

    def test_input_conversion(self):
        x = np.ones((2,))
        y = np.ones((2,))
        z = knp.ones((2,))  # mix
        if backend.backend() == "torch":
            z = z.cpu()
        op = OpWithMultipleInputs()
>       out = op(x, y, z)

keras/src/ops/operation_test.py:152:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
keras/src/utils/traceback_utils.py:113: in error_handler
    return fn(*args, **kwargs)
keras/src/ops/operation.py:56: in __call__
    return self.call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <Operation name=op_with_multiple_inputs>, x = array([1., 1.]), y = array([1., 1.])
z = <[ValueError('item can only be called on arrays of size 1.') raised in repr()] array object at 0x13f7450c0>

    def call(self, x, y, z=None):
        # `z` has to be put first due to the order of operations issue with
        # torch backend.
>       return 3 * z + x + 2 * y
E       ValueError: Cannot perform addition on an mlx.core.array and ndarray

keras/src/ops/operation_test.py:14: ValueError
========================================================================= short test summary info ==========================================================================
FAILED keras/src/ops/operation_test.py::OperationTest::test_input_conversion - ValueError: Cannot perform addition on an mlx.core.array and ndarray
======================================================================= 1 failed, 5 passed in 0.13s ========================================================================

How to fix this test case any idea ? add(mx_array, numpy_array) works but fails when using + operator. Should we skip this test for mlx backend ?

@fchollet
Copy link
Collaborator

How to fix this test case any idea ? add(mx_array, numpy_array) works but fails when using + operator. Should we skip this test for mlx backend ?

It's not fixable on our side, we should file an issue with the MLX repo. + will hit array.__add__ which is on their side.

@fchollet fchollet added the stat:contributions welcome A pull request to fix this issue would be welcome. label Apr 20, 2024
@Faisal-Alsrheed
Copy link
Contributor

Thank you for the list.

I am doing

keras/backend/mlx/nn.py:conv
keras/backend/mlx/nn.py:depthwise_conv
keras/backend/mlx/nn.py:separable_conv
keras/backend/mlx/nn.py:conv_transpose

@lkarthee
Copy link
Contributor Author

I am working on segment_sum, segment_max, max_pool and avg_pool. Thank you .

@yrahul3910
Copy link

I want to take a stab at arctan2 (first-time contributor, so I want to start small). I'm working with the mlx team to see if I can add in the required stuff there first, and then I'll add the implementation here.

@lkarthee
Copy link
Contributor Author

lkarthee commented May 2, 2024

Thank you @yrahul3910 , please go ahead with adding arctan2 impl.

@lkarthee
Copy link
Contributor Author

lkarthee commented May 6, 2024

mx.matmul and mx.tensordot works only for bfloat16, float16, float32.

FAILED keras/src/ops/numpy_test.py::NumpyDtypeTest::test_tensordot_('int16', 'bool') - ValueError: [matmul] Only real floating point types are supported but int16 and bool were provided which results in int16, which is not a real floating point type.

@fchollet How do we handle this - we can cast integers arguments to float32 if both are integers and result will be float32. If we go this route, we have to modify test cases in numpy_test.py for mlx. Do you have any suggestions.

@awni
Copy link

awni commented Jul 25, 2024

Just want to let you all know some updates to MLX as of 0.16.1 that may be useful here:

  • mx.einsum
  • mx.nan_to_num
  • mx.conjugate

Are there any high priority items we can fix or add to help move this along?

@lkarthee
Copy link
Contributor Author

Thank you @awni , we need some help in moving this forward. I will make a list and get back to you in a day or two.

@acsweet
Copy link

acsweet commented Jan 15, 2025

I'd like to pick up on this issue (first time contributor) starting with fft if that's okay

@acsweet
Copy link

acsweet commented Jan 17, 2025

I'm going to start with the "easy" stuff already implemented in mlx, and I'll start in mlx.math with

  • fft2
  • rfft
  • irfft
  • qr (I'll have to see how to handle the mode argument from Keras

@awni
Copy link

awni commented Jan 17, 2025

Sounds great! Let us know how we can help on the MLX side.

@acsweet
Copy link

acsweet commented Jan 18, 2025

@awni Thank you! I'll keep you updated as I progress.

Right now, would it be possible to get stft and istft implemented on the mlx side? It looks like it was started here ml-explore/mlx#1004
I saw this implementation too (without an inverse) https://github.com/nuniz/mlx_stft

@fchollet
Copy link
Collaborator

fchollet commented Jan 18, 2025 via email

@acsweet
Copy link

acsweet commented Jan 18, 2025

I'm going to hold off on math.qr for now, mlx currently only supports square matrices (and no option for the complete or reduced factorization).

I have a PR for fft2, rfft, and irfft (and a fix to fft), if that looks good I'll start looking at the rnn namespace.

It looked like the backend implementations for rnn.gru and rnn.lstm were only implemented for tensorflow for cudnn specific speedups with tf. So I think it's safe to follow similarly to jax and torch?

@fchollet
Copy link
Collaborator

fchollet commented Jan 18, 2025 via email

@acsweet acsweet mentioned this issue Jan 20, 2025
@acsweet
Copy link

acsweet commented Jan 20, 2025

I'm going to start working through mlx.nn now.

I hope that's okay, but I'm going to start with conv, and if @lkarthee or @Faisal-Alsrheed would like to jump back in, please do! Otherwise I'll keep working through the other functions.

@acsweet
Copy link

acsweet commented Jan 20, 2025

@awni Would it be possible to get support for non-square matrices implemented in mlx.linalg.qr? I didn't see an open issue for it, I can open a feature enhancement too.

@awni
Copy link

awni commented Jan 20, 2025

Yes please open an issue about it, it should be straightforward to get it working

@acsweet
Copy link

acsweet commented Jan 22, 2025

If the conv implementation looks good, I think I'll get started on the other convolutional functions

  • depthwise_conv
  • separable_conv
  • conv_transpose

@acsweet
Copy link

acsweet commented Jan 24, 2025

I have a pull request for the remaining convolutional functions, if those look good I'll continue!

Fadi asked to work on max_pool and avg_pool, so I'm going to work on the remaining nn functions that are failing tests.

@acsweet
Copy link

acsweet commented Jan 24, 2025

@awni Would it be pretty straightforward to implement singular value norms in linalg::matrix_norm?

I can open an issue for it too!

@acsweet
Copy link

acsweet commented Jan 26, 2025

It's in the PR, but for reference:

I'm going to continue with numpy and linalg implementations focusing on passing tests in keras/src/layers

@fbadine
Copy link
Contributor

fbadine commented Jan 26, 2025

pooling functionality added in #20814

@acsweet
Copy link

acsweet commented Jan 27, 2025

Once these latest two PRs are merged, I'd like to try merging the master branch into mlx (if that's okay)

@acsweet
Copy link

acsweet commented Jan 30, 2025

Merged the Keras master branch into mlx and patched a few files for pytest to work

Going to add new functions and check tests starting with nn.py
Will start adjusting related tests that should be skipped for mlx, e.g. float64, flash_attention, etc.

@acsweet
Copy link

acsweet commented Feb 3, 2025

I'm currently working through getting the layer tests to pass (keras/src/layers) including updates to ops functions as needed, and skipping unsupported tests.

Updates to ops are mostly in math, nn, image, numpy, and core.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contributions welcome A pull request to fix this issue would be welcome.
Projects
None yet
Development

No branches or pull requests

8 participants