Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlx - more conv #20807

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions keras/src/backend/common/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,88 @@ def compute_conv_transpose_padding_args_for_torch(
return torch_paddings, torch_output_paddings


def _convert_conv_tranpose_padding_args_from_keras_to_mlx(
kernel_size, stride, dilation_rate, padding, output_padding
):
effective_k_size = (kernel_size - 1) * dilation_rate + 1
if padding == "valid":
output_padding = (
max(effective_k_size, stride) - effective_k_size
if output_padding is None
else output_padding
)
pad_left = effective_k_size - 1
pad_right = effective_k_size - 1 + output_padding
elif padding == "same":
if output_padding is None:
total_pad = stride + effective_k_size - 2
else:
total_pad = (
effective_k_size + effective_k_size % 2 - 2 + output_padding
)
pad_left = min(total_pad // 2 + total_pad % 2, effective_k_size - 1)
pad_right = total_pad - pad_left
else:
raise ValueError(f"Invalid padding value: {padding}")
return pad_left, pad_right


def compute_conv_transpose_padding_args_for_mlx(
padding,
num_spatial_dims,
kernel_spatial_shape,
dilation_rate,
strides,
output_padding,
):
start_paddings = []
end_paddings = []
for i in range(num_spatial_dims):
kernel_size_i = kernel_spatial_shape[i]
stride_i = strides[i]
dilation_rate_i = dilation_rate[i]
output_padding_i = None if output_padding is None else output_padding[i]
pad_left, pad_right = (
_convert_conv_tranpose_padding_args_from_keras_to_mlx(
kernel_size_i,
stride_i,
dilation_rate_i,
padding,
output_padding_i,
)
)
start_paddings.append(pad_left)
end_paddings.append(pad_right)
return (start_paddings, end_paddings)


def compute_transpose_padding_args_for_mlx(
padding,
input_spatial_shape,
kernel_spatial_shape,
dilation_rate,
strides,
):
if padding == "valid":
return 0
elif padding == "same":
start_paddings = []
end_paddings = []
for dim_size, k_size, d_rate, s in zip(
input_spatial_shape, kernel_spatial_shape, dilation_rate, strides
):
out_size = (dim_size + s - 1) // s
effective_k_size = (k_size - 1) * d_rate + 1
total_pad = max(0, (out_size - 1) * s + effective_k_size - dim_size)
pad_start = total_pad // 2
pad_end = total_pad - pad_start
start_paddings.append(pad_start)
end_paddings.append(pad_end)
return (start_paddings, end_paddings)
else:
raise ValueError(f"Invalid padding value: {padding}")


def _get_output_shape_given_tf_padding(
input_size, kernel_size, strides, padding, output_padding, dilation_rate
):
Expand Down
157 changes: 135 additions & 22 deletions keras/src/backend/mlx/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

from keras.src.backend import standardize_data_format
from keras.src.backend import standardize_dtype
from keras.src.backend.common.backend_utils import (
compute_conv_transpose_padding_args_for_mlx,
)
from keras.src.backend.common.backend_utils import (
compute_transpose_padding_args_for_mlx,
)
from keras.src.backend.config import epsilon
from keras.src.backend.mlx.core import convert_to_tensor
from keras.src.backend.mlx.core import to_mlx_dtype
Expand Down Expand Up @@ -148,25 +154,15 @@ def conv(
# mlx expects kernel with (out_channels, spatial..., in_channels)
kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2)

if padding == "valid":
mlx_padding = 0
elif padding == "same":
kernel_spatial_shape = kernel.shape[1:-1]
start_paddings = []
end_paddings = []
for dim_size, k_size, d_rate, s in zip(
inputs.shape[1:-1], kernel_spatial_shape, dilation_rate, strides
):
out_size = (dim_size + s - 1) // s
effective_k_size = (k_size - 1) * d_rate + 1
total_pad = max(0, (out_size - 1) * s + effective_k_size - dim_size)
pad_start = total_pad // 2
pad_end = total_pad - pad_start
start_paddings.append(pad_start)
end_paddings.append(pad_end)
mlx_padding = (start_paddings, end_paddings)
else:
raise ValueError(f"Invalid padding value: {padding}")
kernel_spatial_shape = kernel.shape[1:-1]
input_spatial_shape = inputs.shape[1:-1]
mlx_padding = compute_transpose_padding_args_for_mlx(
padding,
input_spatial_shape,
kernel_spatial_shape,
dilation_rate,
strides,
)

channels = inputs.shape[-1]
kernel_in_channels = kernel.shape[-1]
Expand Down Expand Up @@ -202,7 +198,53 @@ def depthwise_conv(
data_format=None,
dilation_rate=1,
):
raise NotImplementedError("MLX backend doesn't support depthwise conv yet")
inputs = convert_to_tensor(inputs)
kernel = convert_to_tensor(kernel)
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2

strides = standardize_tuple(strides, num_spatial_dims, "strides")
dilation_rate = standardize_tuple(
dilation_rate, num_spatial_dims, "dilation_rate"
)

if data_format == "channels_first":
# mlx expects channels_last
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1)

feature_group_count = inputs.shape[-1]

# reshape first for depthwise conv, then transpose to expected mlx format
kernel = kernel.reshape(
*iter(kernel.shape[:-2]), 1, feature_group_count * kernel.shape[-1]
)
# mlx expects kernel with (out_channels, spatial..., in_channels)
kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2)

kernel_spatial_shape = kernel.shape[1:-1]
input_spatial_shape = inputs.shape[1:-1]
mlx_padding = compute_transpose_padding_args_for_mlx(
padding,
input_spatial_shape,
kernel_spatial_shape,
dilation_rate,
strides,
)

result = mx.conv_general(
inputs,
kernel,
stride=strides,
padding=mlx_padding,
kernel_dilation=dilation_rate,
input_dilation=1,
groups=feature_group_count,
flip=False,
)
if data_format == "channels_first":
result = result.transpose(0, -1, *range(1, result.ndim - 1))

return result


def separable_conv(
Expand All @@ -214,7 +256,23 @@ def separable_conv(
data_format=None,
dilation_rate=1,
):
raise NotImplementedError("MLX backend doesn't support separable conv yet")
data_format = standardize_data_format(data_format)
depthwise_conv_output = depthwise_conv(
inputs,
depthwise_kernel,
strides,
padding,
data_format,
dilation_rate,
)
return conv(
depthwise_conv_output,
pointwise_kernel,
strides=1,
padding="valid",
data_format=data_format,
dilation_rate=dilation_rate,
)


def conv_transpose(
Expand All @@ -226,7 +284,62 @@ def conv_transpose(
data_format=None,
dilation_rate=1,
):
raise NotImplementedError("MLX backend doesn't support conv transpose yet")
inputs = convert_to_tensor(inputs)
kernel = convert_to_tensor(kernel)
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2

strides = standardize_tuple(strides, num_spatial_dims, "strides")
dilation_rate = standardize_tuple(
dilation_rate, num_spatial_dims, "dilation_rate"
)
if output_padding is not None:
output_padding = standardize_tuple(
output_padding, num_spatial_dims, "output_padding"
)

if data_format == "channels_first":
# mlx expects channels_last
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1)

# mlx expects kernel with (out_channels, spatial..., in_channels)
kernel = kernel.transpose(-2, *range(kernel.ndim - 2), -1)
kernel_spatial_shape = kernel.shape[1:-1]

mlx_padding = compute_conv_transpose_padding_args_for_mlx(
padding,
num_spatial_dims,
kernel_spatial_shape,
dilation_rate,
strides,
output_padding,
)

channels = inputs.shape[-1]
kernel_in_channels = kernel.shape[-1]
if channels % kernel_in_channels > 0:
raise ValueError(
"The number of input channels must be evenly divisible by "
f"kernel's in_channels. Received input channels {channels} and "
f"kernel in_channels {kernel_in_channels}. "
)
groups = channels // kernel_in_channels

result = mx.conv_general(
inputs,
kernel,
stride=1, # stride is handled by input_dilation
padding=mlx_padding,
kernel_dilation=dilation_rate,
input_dilation=strides,
groups=groups,
flip=True,
)

if data_format == "channels_first":
result = result.transpose(0, -1, *range(1, result.ndim - 1))

return result


def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
Expand Down
Loading