Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max_pool and average_pool for MLX #20814

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 161 additions & 2 deletions keras/src/backend/mlx/nn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import builtins
import math
import operator
from itertools import accumulate

import mlx.core as mx
import mlx.nn as nn
Expand Down Expand Up @@ -122,16 +125,172 @@ def log_softmax(x, axis=-1):
return x - mx.logsumexp(x, axis=axis, keepdims=True)


def _calculate_padding(input_shape, pool_size, strides):
ndim = len(input_shape)

padding = ()
for d in range(ndim):
pad = max(0, (pool_size[d] - 1) - ((input_shape[d] - 1) % strides[d]))
padding = padding + (pad,)

return [(p // 2, (p + 1) // 2) for p in padding]


def _non_overlapping_sliding_windows(x, shape, window_shape):
# Compute the intermediate shape
new_shape = [shape[0]]
for s, w in zip(shape[1:], window_shape):
new_shape.append(s // w)
new_shape.append(w)
new_shape.append(shape[-1])

last_axis = len(new_shape) - 1
axis_order = [
0,
*range(1, last_axis, 2),
*range(2, last_axis, 2),
last_axis,
]

x = x.reshape(new_shape)
x = x.transpose(axis_order)
return x


def _sliding_windows(x, window_shape, window_strides):
if x.ndim < 3:
raise ValueError(
"To extract sliding windows at least 1 spatial dimension "
f"(3 total) is needed but the input only has {x.ndim} dimension(s)."
)

spatial_dims = x.shape[1:-1]
if not (len(spatial_dims) == len(window_shape) == len(window_strides)):
raise ValueError(
"To extract sliding windows, the lengths of window_shape and "
"window_strides must be equal to the signal's spatial dimensions. "
f"However, the signal has spatial_dims={spatial_dims} while "
f"window_shape={window_shape} and window_strides={window_strides}."
)

shape = x.shape
if all(
window == stride and size % window == 0
for size, window, stride in zip(
spatial_dims, window_shape, window_strides
)
):
return _non_overlapping_sliding_windows(x, shape, window_shape)

strides = list(
reversed(list(accumulate(reversed(shape + (1,)), operator.mul)))
)[1:]

# Compute the output shape
final_shape = [shape[0]]
final_shape += [
(size - window) // stride + 1
for size, window, stride in zip(
spatial_dims, window_shape, window_strides
)
]
final_shape += window_shape
final_shape += [shape[-1]]

# Compute the output strides
final_strides = strides[:1]
final_strides += [
og_stride * stride
for og_stride, stride in zip(strides[1:-1], window_strides)
]
final_strides += strides[1:-1]
final_strides += strides[-1:] # should always be [1]

return mx.as_strided(x, final_shape, final_strides)


def _pool(
inputs, pool_size, strides, padding, padding_value, data_format, pooling_fn
):
if padding not in ("same", "valid"):
raise ValueError(
f"Invalid padding '{padding}', must be 'same' or 'valid'."
)

if data_format == "channels_first":
# mlx expects channels_last
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1)

if padding == "same":
pads = _calculate_padding(inputs.shape[1:-1], pool_size, strides)

if any(p[1] > 0 for p in pads):
inputs = mx.pad(
inputs,
[(0, 0)] + pads + [(0, 0)],
constant_values=padding_value,
)

inputs = _sliding_windows(inputs, pool_size, strides)

axes = tuple(range(-len(pool_size) - 1, -1, 1))
result = pooling_fn(inputs, axes)

if data_format == "channels_first":
result = result.transpose(0, -1, *range(1, result.ndim - 1))
return result


def max_pool(
inputs, pool_size, strides=None, padding="valid", data_format=None
):
raise NotImplementedError("MLX backend doesn't support max pooling yet")
inputs = convert_to_tensor(inputs)
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size")
strides = pool_size if strides is None else strides
strides = standardize_tuple(strides, num_spatial_dims, "strides")

return _pool(
inputs, pool_size, strides, padding, -mx.inf, data_format, mx.max
)


def average_pool(
inputs, pool_size, strides=None, padding="valid", data_format=None
):
raise NotImplementedError("MLX backend doesn't support average pooling yet")
inputs = convert_to_tensor(inputs)
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2
pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size")
strides = pool_size if strides is None else strides
strides = standardize_tuple(strides, num_spatial_dims, "strides")

# Create a pool by applying the sum function in each window
pooled = _pool(
inputs, pool_size, strides, padding, 0.0, data_format, mx.sum
)
if padding == "valid":
# No padding needed. Divide by the size of the pool which gives
# the average
return pooled / math.prod(pool_size)
else:
# Create a tensor of ones of the same shape of inputs.
# Then create a pool, padding by zero and using sum as function.
# This will create a tensor of the smae dimensions as pooled tensor
# with values being the sum.
# By dividing pooled by windows_counts, we get the average while
# skipping the padded values.
window_counts = _pool(
mx.ones(inputs.shape, inputs.dtype),
pool_size,
strides,
padding,
0.0,
data_format,
mx.sum,
)
return pooled / window_counts


def conv(
Expand Down
Loading