-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MLX Backend #19571
Comments
PyTest Output=========================================================================== test session starts ============================================================================
platform darwin -- Python 3.12.2, pytest-8.1.1, pluggy-1.4.0 -- /Users/kartheek/erlang-ws/github-ws/latest/keras/.venv/bin/python3.12
cachedir: .pytest_cache
rootdir: /Users/kartheek/erlang-ws/github-ws/latest/keras
configfile: pyproject.toml
plugins: cov-5.0.0
collected 6 items
keras/src/ops/operation_test.py::OperationTest::test_autoconfig PASSED [ 16%]
keras/src/ops/operation_test.py::OperationTest::test_eager_call PASSED [ 33%]
keras/src/ops/operation_test.py::OperationTest::test_input_conversion FAILED [ 50%]
keras/src/ops/operation_test.py::OperationTest::test_serialization PASSED [ 66%]
keras/src/ops/operation_test.py::OperationTest::test_symbolic_call PASSED [ 83%]
keras/src/ops/operation_test.py::OperationTest::test_valid_naming PASSED [100%]
================================================================================= FAILURES =================================================================================
___________________________________________________________________ OperationTest.test_input_conversion ____________________________________________________________________
self = <keras.src.ops.operation_test.OperationTest testMethod=test_input_conversion>
def test_input_conversion(self):
x = np.ones((2,))
y = np.ones((2,))
z = knp.ones((2,)) # mix
if backend.backend() == "torch":
z = z.cpu()
op = OpWithMultipleInputs()
> out = op(x, y, z)
keras/src/ops/operation_test.py:152:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
keras/src/utils/traceback_utils.py:113: in error_handler
return fn(*args, **kwargs)
keras/src/ops/operation.py:56: in __call__
return self.call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <Operation name=op_with_multiple_inputs>, x = array([1., 1.]), y = array([1., 1.])
z = <[ValueError('item can only be called on arrays of size 1.') raised in repr()] array object at 0x13f7450c0>
def call(self, x, y, z=None):
# `z` has to be put first due to the order of operations issue with
# torch backend.
> return 3 * z + x + 2 * y
E ValueError: Cannot perform addition on an mlx.core.array and ndarray
keras/src/ops/operation_test.py:14: ValueError
========================================================================= short test summary info ==========================================================================
FAILED keras/src/ops/operation_test.py::OperationTest::test_input_conversion - ValueError: Cannot perform addition on an mlx.core.array and ndarray
======================================================================= 1 failed, 5 passed in 0.13s ======================================================================== How to fix this test case any idea ? add(mx_array, numpy_array) works but fails when using + operator. Should we skip this test for mlx backend ? |
It's not fixable on our side, we should file an issue with the MLX repo. |
Thank you for the list. I am doing keras/backend/mlx/nn.py:conv |
I am working on segment_sum, segment_max, max_pool and avg_pool. Thank you . |
I want to take a stab at |
Thank you @yrahul3910 , please go ahead with adding |
FAILED keras/src/ops/numpy_test.py::NumpyDtypeTest::test_tensordot_('int16', 'bool') - ValueError: [matmul] Only real floating point types are supported but int16 and bool were provided which results in int16, which is not a real floating point type. @fchollet How do we handle this - we can cast integers arguments to float32 if both are integers and result will be float32. If we go this route, we have to modify test cases in |
Just want to let you all know some updates to MLX as of 0.16.1 that may be useful here:
Are there any high priority items we can fix or add to help move this along? |
Thank you @awni , we need some help in moving this forward. I will make a list and get back to you in a day or two. |
I'd like to pick up on this issue (first time contributor) starting with |
I'm going to start with the "easy" stuff already implemented in mlx, and I'll start in
|
Sounds great! Let us know how we can help on the MLX side. |
@awni Thank you! I'll keep you updated as I progress. Right now, would it be possible to get |
Please note, the nn and rnn namespaces are the most important for getting
mlx to work with typical workflows.
…On Fri, Jan 17, 2025, 4:23 PM acsweet ***@***.***> wrote:
@awni <https://github.com/awni> Thank you! I'll keep you updated as I
progress.
Right now, would it be possible to get stft and istft implemented on the
mlx side? It looks like it was started here ml-explore/mlx#1004
<ml-explore/mlx#1004>
I saw this implementation too (without an inverse)
https://github.com/nuniz/mlx_stft
—
Reply to this email directly, view it on GitHub
<#19571 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFNM37QZKZSTXJGJWT7ADD2LGNJLAVCNFSM6AAAAABVICBTNGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKOJZGQYDQNRSGI>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
I'm going to hold off on I have a PR for It looked like the backend implementations for |
Right, unless mlx actually exposes some cudnn bindings for these
…On Fri, Jan 17, 2025, 11:15 PM acsweet ***@***.***> wrote:
I'm going to hold off on math.qr for now, mlx currently only supports
square matrices (and no option for the complete or reduced factorization).
I have a PR for fft2, rfft, and irfft (and a fix to fft), if that looks
good I'll start looking at the rnn namespace.
It looked like the backend implementations for rnn.gru and rnn.lstm were
only implemented for tensorflow for cudnn specific speedups with tf. So I
think it's safe to follow similarly to jax and torch?
—
Reply to this email directly, view it on GitHub
<#19571 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAFNM367LBQFSKY3WNIPHQ32LH5S5AVCNFSM6AAAAABVICBTNGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKOJZGU4DQOBTGY>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
I'm going to start working through I hope that's okay, but I'm going to start with |
@awni Would it be possible to get support for non-square matrices implemented in |
Yes please open an issue about it, it should be straightforward to get it working |
If the
|
I have a pull request for the remaining convolutional functions, if those look good I'll continue! Fadi asked to work on |
@awni Would it be pretty straightforward to implement singular value norms in I can open an issue for it too! |
It's in the PR, but for reference:
I'm going to continue with |
pooling functionality added in #20814 |
Once these latest two PRs are merged, I'd like to try merging the master branch into mlx (if that's okay) |
Merged the Keras Going to add new functions and check tests starting with |
I'm currently working through getting the layer tests to pass ( Updates to |
Issue for tracking and coordinating mlx backend work:
mlx.math
fft
fft2
rfft
irfft
stft
istft
logsumexp
mlx - add missing convert_to_tensor #19578qr
segment_sum
mlx - implement segment_sum and segment_max #19652segment_max
mlx - implement segment_sum and segment_max #19652erfinv
feat(math): support erfinv on mlx #19628mlx.numpy
einsum
bincount
nonzero
cross
vdot
nan_to_num
copy
roll
median
Implementedmedian(...)
function. #19568 Implement missing functions in mlx backend #19574meshgrid
Implement missing functions in mlx backend #19574conjugate
arctan2
Added arctan2 operation #19759quantile
imag
real
select
argpartition
mlx - add argpartition to numpy #19680slogdet
select
vectorize
correlate
diag
mlx - fix diag and diagonal in numpy #19714diagonal
mlx - fix diag and diagonal in numpy #19714mlx.image
rgb_to_grayscale
mlx - add rgb_to_grayscale #19609resize
- mlx - image.resize addcrop_to_aspect_ratio
argument #19699mlx.nn
max_pool
avg_pool
conv
depthwise_conv
separable_conv
conv_transpose
ctc_loss
mlx.rnn
rnn
lstm
gru
mlx.linalg
cholesky
det
eig
eigh
inv
lu_factor
norm
mlx - add linalg.norm #19698qr
solve
solve_triangular
svd
mlx.core
np.ndarray
of bfloat16 using ml_dtypes is being interpreted as complex64 ml-explore/mlx#1075The text was updated successfully, but these errors were encountered: