diff --git a/keras/src/backend/mlx/nn.py b/keras/src/backend/mlx/nn.py index b40d16ebb1a..cfb79401925 100644 --- a/keras/src/backend/mlx/nn.py +++ b/keras/src/backend/mlx/nn.py @@ -1,4 +1,7 @@ import builtins +import math +import operator +from itertools import accumulate import mlx.core as mx import mlx.nn as nn @@ -122,16 +125,172 @@ def log_softmax(x, axis=-1): return x - mx.logsumexp(x, axis=axis, keepdims=True) +def _calculate_padding(input_shape, pool_size, strides): + ndim = len(input_shape) + + padding = () + for d in range(ndim): + pad = max(0, (pool_size[d] - 1) - ((input_shape[d] - 1) % strides[d])) + padding = padding + (pad,) + + return [(p // 2, (p + 1) // 2) for p in padding] + + +def _non_overlapping_sliding_windows(x, shape, window_shape): + # Compute the intermediate shape + new_shape = [shape[0]] + for s, w in zip(shape[1:], window_shape): + new_shape.append(s // w) + new_shape.append(w) + new_shape.append(shape[-1]) + + last_axis = len(new_shape) - 1 + axis_order = [ + 0, + *range(1, last_axis, 2), + *range(2, last_axis, 2), + last_axis, + ] + + x = x.reshape(new_shape) + x = x.transpose(axis_order) + return x + + +def _sliding_windows(x, window_shape, window_strides): + if x.ndim < 3: + raise ValueError( + "To extract sliding windows at least 1 spatial dimension " + f"(3 total) is needed but the input only has {x.ndim} dimension(s)." + ) + + spatial_dims = x.shape[1:-1] + if not (len(spatial_dims) == len(window_shape) == len(window_strides)): + raise ValueError( + "To extract sliding windows, the lengths of window_shape and " + "window_strides must be equal to the signal's spatial dimensions. " + f"However, the signal has spatial_dims={spatial_dims} while " + f"window_shape={window_shape} and window_strides={window_strides}." + ) + + shape = x.shape + if all( + window == stride and size % window == 0 + for size, window, stride in zip( + spatial_dims, window_shape, window_strides + ) + ): + return _non_overlapping_sliding_windows(x, shape, window_shape) + + strides = list( + reversed(list(accumulate(reversed(shape + (1,)), operator.mul))) + )[1:] + + # Compute the output shape + final_shape = [shape[0]] + final_shape += [ + (size - window) // stride + 1 + for size, window, stride in zip( + spatial_dims, window_shape, window_strides + ) + ] + final_shape += window_shape + final_shape += [shape[-1]] + + # Compute the output strides + final_strides = strides[:1] + final_strides += [ + og_stride * stride + for og_stride, stride in zip(strides[1:-1], window_strides) + ] + final_strides += strides[1:-1] + final_strides += strides[-1:] # should always be [1] + + return mx.as_strided(x, final_shape, final_strides) + + +def _pool( + inputs, pool_size, strides, padding, padding_value, data_format, pooling_fn +): + if padding not in ("same", "valid"): + raise ValueError( + f"Invalid padding '{padding}', must be 'same' or 'valid'." + ) + + if data_format == "channels_first": + # mlx expects channels_last + inputs = inputs.transpose(0, *range(2, inputs.ndim), 1) + + if padding == "same": + pads = _calculate_padding(inputs.shape[1:-1], pool_size, strides) + + if any(p[1] > 0 for p in pads): + inputs = mx.pad( + inputs, + [(0, 0)] + pads + [(0, 0)], + constant_values=padding_value, + ) + + inputs = _sliding_windows(inputs, pool_size, strides) + + axes = tuple(range(-len(pool_size) - 1, -1, 1)) + result = pooling_fn(inputs, axes) + + if data_format == "channels_first": + result = result.transpose(0, -1, *range(1, result.ndim - 1)) + return result + + def max_pool( inputs, pool_size, strides=None, padding="valid", data_format=None ): - raise NotImplementedError("MLX backend doesn't support max pooling yet") + inputs = convert_to_tensor(inputs) + data_format = standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size") + strides = pool_size if strides is None else strides + strides = standardize_tuple(strides, num_spatial_dims, "strides") + + return _pool( + inputs, pool_size, strides, padding, -mx.inf, data_format, mx.max + ) def average_pool( inputs, pool_size, strides=None, padding="valid", data_format=None ): - raise NotImplementedError("MLX backend doesn't support average pooling yet") + inputs = convert_to_tensor(inputs) + data_format = standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + pool_size = standardize_tuple(pool_size, num_spatial_dims, "pool_size") + strides = pool_size if strides is None else strides + strides = standardize_tuple(strides, num_spatial_dims, "strides") + + # Create a pool by applying the sum function in each window + pooled = _pool( + inputs, pool_size, strides, padding, 0.0, data_format, mx.sum + ) + if padding == "valid": + # No padding needed. Divide by the size of the pool which gives + # the average + return pooled / math.prod(pool_size) + else: + # Create a tensor of ones of the same shape of inputs. + # Then create a pool, padding by zero and using sum as function. + # This will create a tensor of the smae dimensions as pooled tensor + # with values being the sum. + # By dividing pooled by windows_counts, we get the average while + # skipping the padded values. + window_counts = _pool( + mx.ones(inputs.shape, inputs.dtype), + pool_size, + strides, + padding, + 0.0, + data_format, + mx.sum, + ) + return pooled / window_counts def conv(