From e2e228838b91145c97c3c999809a9814c401a147 Mon Sep 17 00:00:00 2001 From: acsweet Date: Thu, 23 Jan 2025 17:34:53 -0800 Subject: [PATCH] `mlx` - more conv (#20807) * depthwise_conv implementation * implementation * clean and implemented --- keras/src/backend/common/backend_utils.py | 82 +++++++++++ keras/src/backend/mlx/nn.py | 157 +++++++++++++++++++--- 2 files changed, 217 insertions(+), 22 deletions(-) diff --git a/keras/src/backend/common/backend_utils.py b/keras/src/backend/common/backend_utils.py index 4be0d75d5f27..240d40a169e9 100644 --- a/keras/src/backend/common/backend_utils.py +++ b/keras/src/backend/common/backend_utils.py @@ -187,6 +187,88 @@ def compute_conv_transpose_padding_args_for_torch( return torch_paddings, torch_output_paddings +def _convert_conv_tranpose_padding_args_from_keras_to_mlx( + kernel_size, stride, dilation_rate, padding, output_padding +): + effective_k_size = (kernel_size - 1) * dilation_rate + 1 + if padding == "valid": + output_padding = ( + max(effective_k_size, stride) - effective_k_size + if output_padding is None + else output_padding + ) + pad_left = effective_k_size - 1 + pad_right = effective_k_size - 1 + output_padding + elif padding == "same": + if output_padding is None: + total_pad = stride + effective_k_size - 2 + else: + total_pad = ( + effective_k_size + effective_k_size % 2 - 2 + output_padding + ) + pad_left = min(total_pad // 2 + total_pad % 2, effective_k_size - 1) + pad_right = total_pad - pad_left + else: + raise ValueError(f"Invalid padding value: {padding}") + return pad_left, pad_right + + +def compute_conv_transpose_padding_args_for_mlx( + padding, + num_spatial_dims, + kernel_spatial_shape, + dilation_rate, + strides, + output_padding, +): + start_paddings = [] + end_paddings = [] + for i in range(num_spatial_dims): + kernel_size_i = kernel_spatial_shape[i] + stride_i = strides[i] + dilation_rate_i = dilation_rate[i] + output_padding_i = None if output_padding is None else output_padding[i] + pad_left, pad_right = ( + _convert_conv_tranpose_padding_args_from_keras_to_mlx( + kernel_size_i, + stride_i, + dilation_rate_i, + padding, + output_padding_i, + ) + ) + start_paddings.append(pad_left) + end_paddings.append(pad_right) + return (start_paddings, end_paddings) + + +def compute_transpose_padding_args_for_mlx( + padding, + input_spatial_shape, + kernel_spatial_shape, + dilation_rate, + strides, +): + if padding == "valid": + return 0 + elif padding == "same": + start_paddings = [] + end_paddings = [] + for dim_size, k_size, d_rate, s in zip( + input_spatial_shape, kernel_spatial_shape, dilation_rate, strides + ): + out_size = (dim_size + s - 1) // s + effective_k_size = (k_size - 1) * d_rate + 1 + total_pad = max(0, (out_size - 1) * s + effective_k_size - dim_size) + pad_start = total_pad // 2 + pad_end = total_pad - pad_start + start_paddings.append(pad_start) + end_paddings.append(pad_end) + return (start_paddings, end_paddings) + else: + raise ValueError(f"Invalid padding value: {padding}") + + def _get_output_shape_given_tf_padding( input_size, kernel_size, strides, padding, output_padding, dilation_rate ): diff --git a/keras/src/backend/mlx/nn.py b/keras/src/backend/mlx/nn.py index abcf315c990d..f4235f7a0f4b 100644 --- a/keras/src/backend/mlx/nn.py +++ b/keras/src/backend/mlx/nn.py @@ -3,6 +3,12 @@ from keras.src.backend import standardize_data_format from keras.src.backend import standardize_dtype +from keras.src.backend.common.backend_utils import ( + compute_conv_transpose_padding_args_for_mlx, +) +from keras.src.backend.common.backend_utils import ( + compute_transpose_padding_args_for_mlx, +) from keras.src.backend.config import epsilon from keras.src.backend.mlx.core import convert_to_tensor from keras.src.backend.mlx.core import to_mlx_dtype @@ -148,25 +154,15 @@ def conv( # mlx expects kernel with (out_channels, spatial..., in_channels) kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2) - if padding == "valid": - mlx_padding = 0 - elif padding == "same": - kernel_spatial_shape = kernel.shape[1:-1] - start_paddings = [] - end_paddings = [] - for dim_size, k_size, d_rate, s in zip( - inputs.shape[1:-1], kernel_spatial_shape, dilation_rate, strides - ): - out_size = (dim_size + s - 1) // s - effective_k_size = (k_size - 1) * d_rate + 1 - total_pad = max(0, (out_size - 1) * s + effective_k_size - dim_size) - pad_start = total_pad // 2 - pad_end = total_pad - pad_start - start_paddings.append(pad_start) - end_paddings.append(pad_end) - mlx_padding = (start_paddings, end_paddings) - else: - raise ValueError(f"Invalid padding value: {padding}") + kernel_spatial_shape = kernel.shape[1:-1] + input_spatial_shape = inputs.shape[1:-1] + mlx_padding = compute_transpose_padding_args_for_mlx( + padding, + input_spatial_shape, + kernel_spatial_shape, + dilation_rate, + strides, + ) channels = inputs.shape[-1] kernel_in_channels = kernel.shape[-1] @@ -202,7 +198,53 @@ def depthwise_conv( data_format=None, dilation_rate=1, ): - raise NotImplementedError("MLX backend doesn't support depthwise conv yet") + inputs = convert_to_tensor(inputs) + kernel = convert_to_tensor(kernel) + data_format = standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + + strides = standardize_tuple(strides, num_spatial_dims, "strides") + dilation_rate = standardize_tuple( + dilation_rate, num_spatial_dims, "dilation_rate" + ) + + if data_format == "channels_first": + # mlx expects channels_last + inputs = inputs.transpose(0, *range(2, inputs.ndim), 1) + + feature_group_count = inputs.shape[-1] + + # reshape first for depthwise conv, then transpose to expected mlx format + kernel = kernel.reshape( + *iter(kernel.shape[:-2]), 1, feature_group_count * kernel.shape[-1] + ) + # mlx expects kernel with (out_channels, spatial..., in_channels) + kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2) + + kernel_spatial_shape = kernel.shape[1:-1] + input_spatial_shape = inputs.shape[1:-1] + mlx_padding = compute_transpose_padding_args_for_mlx( + padding, + input_spatial_shape, + kernel_spatial_shape, + dilation_rate, + strides, + ) + + result = mx.conv_general( + inputs, + kernel, + stride=strides, + padding=mlx_padding, + kernel_dilation=dilation_rate, + input_dilation=1, + groups=feature_group_count, + flip=False, + ) + if data_format == "channels_first": + result = result.transpose(0, -1, *range(1, result.ndim - 1)) + + return result def separable_conv( @@ -214,7 +256,23 @@ def separable_conv( data_format=None, dilation_rate=1, ): - raise NotImplementedError("MLX backend doesn't support separable conv yet") + data_format = standardize_data_format(data_format) + depthwise_conv_output = depthwise_conv( + inputs, + depthwise_kernel, + strides, + padding, + data_format, + dilation_rate, + ) + return conv( + depthwise_conv_output, + pointwise_kernel, + strides=1, + padding="valid", + data_format=data_format, + dilation_rate=dilation_rate, + ) def conv_transpose( @@ -226,7 +284,62 @@ def conv_transpose( data_format=None, dilation_rate=1, ): - raise NotImplementedError("MLX backend doesn't support conv transpose yet") + inputs = convert_to_tensor(inputs) + kernel = convert_to_tensor(kernel) + data_format = standardize_data_format(data_format) + num_spatial_dims = inputs.ndim - 2 + + strides = standardize_tuple(strides, num_spatial_dims, "strides") + dilation_rate = standardize_tuple( + dilation_rate, num_spatial_dims, "dilation_rate" + ) + if output_padding is not None: + output_padding = standardize_tuple( + output_padding, num_spatial_dims, "output_padding" + ) + + if data_format == "channels_first": + # mlx expects channels_last + inputs = inputs.transpose(0, *range(2, inputs.ndim), 1) + + # mlx expects kernel with (out_channels, spatial..., in_channels) + kernel = kernel.transpose(-2, *range(kernel.ndim - 2), -1) + kernel_spatial_shape = kernel.shape[1:-1] + + mlx_padding = compute_conv_transpose_padding_args_for_mlx( + padding, + num_spatial_dims, + kernel_spatial_shape, + dilation_rate, + strides, + output_padding, + ) + + channels = inputs.shape[-1] + kernel_in_channels = kernel.shape[-1] + if channels % kernel_in_channels > 0: + raise ValueError( + "The number of input channels must be evenly divisible by " + f"kernel's in_channels. Received input channels {channels} and " + f"kernel in_channels {kernel_in_channels}. " + ) + groups = channels // kernel_in_channels + + result = mx.conv_general( + inputs, + kernel, + stride=1, # stride is handled by input_dilation + padding=mlx_padding, + kernel_dilation=dilation_rate, + input_dilation=strides, + groups=groups, + flip=True, + ) + + if data_format == "channels_first": + result = result.transpose(0, -1, *range(1, result.ndim - 1)) + + return result def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):