Skip to content

Commit

Permalink
fix to fft, implement fft2, rfft, and irfft for mlx (#20781)
Browse files Browse the repository at this point in the history
  • Loading branch information
acsweet authored Jan 18, 2025
1 parent 603affa commit 9b75b86
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions keras/src/backend/mlx/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,24 +54,56 @@ def extract_sequences(x, sequence_length, sequence_stride):
return x.reshape(*batch_shape, frames, sequence_length)


def _get_complex_tensor_from_tuple(x):
if not isinstance(x, (tuple, list)) or len(x) != 2:
raise ValueError(
"Input `x` should be a tuple of two tensors - real and imaginary."
f"Received: x={x}"
)
real, imag = x
real = convert_to_tensor(real)
imag = convert_to_tensor(imag)
# Check shapes.
if real.shape != imag.shape:
raise ValueError(
"Input `x` should be a tuple of two tensors - real and imaginary."
"Both the real and imaginary parts should have the same shape. "
f"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}"
)
# Ensure dtype is float.
if not mx.issubdtype(real.dtype, mx.floating) or not mx.issubdtype(
imag.dtype, mx.floating
):
raise ValueError(
"At least one tensor in input `x` is not of type float."
f"Received: x={x}."
)
complex_input = mx.add(real, 1j * imag)
return complex_input


def fft(x):
x = convert_to_tensor(x)
return mx.fft(x)
x = _get_complex_tensor_from_tuple(x)
complex_output = mx.fft.fft(x)
return mx.real(complex_output), mx.imag(complex_output)


def fft2(x):
# TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
raise NotImplementedError("fft not yet implemented in mlx")
x = _get_complex_tensor_from_tuple(x)
complex_output = mx.fft.fft2(x)
return mx.real(complex_output), mx.imag(complex_output)


def rfft(x, fft_length=None):
# TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
raise NotImplementedError("fft not yet implemented in mlx")
x = convert_to_tensor(x)
complex_output = mx.fft.rfft(x, n=fft_length)
return mx.real(complex_output), mx.imag(complex_output)


def irfft(x, fft_length=None):
# TODO: https://ml-explore.github.io/mlx/build/html/python/fft.html#fft
raise NotImplementedError("fft not yet implemented in mlx")
x = _get_complex_tensor_from_tuple(x)
real_output = mx.fft.irfft(x, n=fft_length)
return real_output


def stft(
Expand Down

0 comments on commit 9b75b86

Please sign in to comment.