From 9b75b8667de2968a97b2f1a1eba78411b53bab9e Mon Sep 17 00:00:00 2001 From: acsweet <52804044+acsweet@users.noreply.github.com> Date: Sat, 18 Jan 2025 10:20:18 -0800 Subject: [PATCH] fix to fft, implement fft2, rfft, and irfft for mlx (#20781) --- keras/src/backend/mlx/math.py | 48 +++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/keras/src/backend/mlx/math.py b/keras/src/backend/mlx/math.py index 506a198083b..118dc3347c8 100644 --- a/keras/src/backend/mlx/math.py +++ b/keras/src/backend/mlx/math.py @@ -54,24 +54,56 @@ def extract_sequences(x, sequence_length, sequence_stride): return x.reshape(*batch_shape, frames, sequence_length) +def _get_complex_tensor_from_tuple(x): + if not isinstance(x, (tuple, list)) or len(x) != 2: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + f"Received: x={x}" + ) + real, imag = x + real = convert_to_tensor(real) + imag = convert_to_tensor(imag) + # Check shapes. + if real.shape != imag.shape: + raise ValueError( + "Input `x` should be a tuple of two tensors - real and imaginary." + "Both the real and imaginary parts should have the same shape. " + f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}" + ) + # Ensure dtype is float. + if not mx.issubdtype(real.dtype, mx.floating) or not mx.issubdtype( + imag.dtype, mx.floating + ): + raise ValueError( + "At least one tensor in input `x` is not of type float." + f"Received: x={x}." + ) + complex_input = mx.add(real, 1j * imag) + return complex_input + + def fft(x): - x = convert_to_tensor(x) - return mx.fft(x) + x = _get_complex_tensor_from_tuple(x) + complex_output = mx.fft.fft(x) + return mx.real(complex_output), mx.imag(complex_output) def fft2(x): - # TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft - raise NotImplementedError("fft not yet implemented in mlx") + x = _get_complex_tensor_from_tuple(x) + complex_output = mx.fft.fft2(x) + return mx.real(complex_output), mx.imag(complex_output) def rfft(x, fft_length=None): - # TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft - raise NotImplementedError("fft not yet implemented in mlx") + x = convert_to_tensor(x) + complex_output = mx.fft.rfft(x, n=fft_length) + return mx.real(complex_output), mx.imag(complex_output) def irfft(x, fft_length=None): - # TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft - raise NotImplementedError("fft not yet implemented in mlx") + x = _get_complex_tensor_from_tuple(x) + real_output = mx.fft.irfft(x, n=fft_length) + return real_output def stft(