Skip to content

Commit

Permalink
mlx - more conv (#20807)
Browse files Browse the repository at this point in the history
* depthwise_conv implementation

* implementation

* clean and  implemented
  • Loading branch information
acsweet authored Jan 24, 2025
1 parent b8338f7 commit e2e2288
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 22 deletions.
82 changes: 82 additions & 0 deletions keras/src/backend/common/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,88 @@ def compute_conv_transpose_padding_args_for_torch(
return torch_paddings, torch_output_paddings


def _convert_conv_tranpose_padding_args_from_keras_to_mlx(
kernel_size, stride, dilation_rate, padding, output_padding
):
effective_k_size = (kernel_size - 1) * dilation_rate + 1
if padding == "valid":
output_padding = (
max(effective_k_size, stride) - effective_k_size
if output_padding is None
else output_padding
)
pad_left = effective_k_size - 1
pad_right = effective_k_size - 1 + output_padding
elif padding == "same":
if output_padding is None:
total_pad = stride + effective_k_size - 2
else:
total_pad = (
effective_k_size + effective_k_size % 2 - 2 + output_padding
)
pad_left = min(total_pad // 2 + total_pad % 2, effective_k_size - 1)
pad_right = total_pad - pad_left
else:
raise ValueError(f"Invalid padding value: {padding}")
return pad_left, pad_right


def compute_conv_transpose_padding_args_for_mlx(
padding,
num_spatial_dims,
kernel_spatial_shape,
dilation_rate,
strides,
output_padding,
):
start_paddings = []
end_paddings = []
for i in range(num_spatial_dims):
kernel_size_i = kernel_spatial_shape[i]
stride_i = strides[i]
dilation_rate_i = dilation_rate[i]
output_padding_i = None if output_padding is None else output_padding[i]
pad_left, pad_right = (
_convert_conv_tranpose_padding_args_from_keras_to_mlx(
kernel_size_i,
stride_i,
dilation_rate_i,
padding,
output_padding_i,
)
)
start_paddings.append(pad_left)
end_paddings.append(pad_right)
return (start_paddings, end_paddings)


def compute_transpose_padding_args_for_mlx(
padding,
input_spatial_shape,
kernel_spatial_shape,
dilation_rate,
strides,
):
if padding == "valid":
return 0
elif padding == "same":
start_paddings = []
end_paddings = []
for dim_size, k_size, d_rate, s in zip(
input_spatial_shape, kernel_spatial_shape, dilation_rate, strides
):
out_size = (dim_size + s - 1) // s
effective_k_size = (k_size - 1) * d_rate + 1
total_pad = max(0, (out_size - 1) * s + effective_k_size - dim_size)
pad_start = total_pad // 2
pad_end = total_pad - pad_start
start_paddings.append(pad_start)
end_paddings.append(pad_end)
return (start_paddings, end_paddings)
else:
raise ValueError(f"Invalid padding value: {padding}")


def _get_output_shape_given_tf_padding(
input_size, kernel_size, strides, padding, output_padding, dilation_rate
):
Expand Down
157 changes: 135 additions & 22 deletions keras/src/backend/mlx/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

from keras.src.backend import standardize_data_format
from keras.src.backend import standardize_dtype
from keras.src.backend.common.backend_utils import (
compute_conv_transpose_padding_args_for_mlx,
)
from keras.src.backend.common.backend_utils import (
compute_transpose_padding_args_for_mlx,
)
from keras.src.backend.config import epsilon
from keras.src.backend.mlx.core import convert_to_tensor
from keras.src.backend.mlx.core import to_mlx_dtype
Expand Down Expand Up @@ -148,25 +154,15 @@ def conv(
# mlx expects kernel with (out_channels, spatial..., in_channels)
kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2)

if padding == "valid":
mlx_padding = 0
elif padding == "same":
kernel_spatial_shape = kernel.shape[1:-1]
start_paddings = []
end_paddings = []
for dim_size, k_size, d_rate, s in zip(
inputs.shape[1:-1], kernel_spatial_shape, dilation_rate, strides
):
out_size = (dim_size + s - 1) // s
effective_k_size = (k_size - 1) * d_rate + 1
total_pad = max(0, (out_size - 1) * s + effective_k_size - dim_size)
pad_start = total_pad // 2
pad_end = total_pad - pad_start
start_paddings.append(pad_start)
end_paddings.append(pad_end)
mlx_padding = (start_paddings, end_paddings)
else:
raise ValueError(f"Invalid padding value: {padding}")
kernel_spatial_shape = kernel.shape[1:-1]
input_spatial_shape = inputs.shape[1:-1]
mlx_padding = compute_transpose_padding_args_for_mlx(
padding,
input_spatial_shape,
kernel_spatial_shape,
dilation_rate,
strides,
)

channels = inputs.shape[-1]
kernel_in_channels = kernel.shape[-1]
Expand Down Expand Up @@ -202,7 +198,53 @@ def depthwise_conv(
data_format=None,
dilation_rate=1,
):
raise NotImplementedError("MLX backend doesn't support depthwise conv yet")
inputs = convert_to_tensor(inputs)
kernel = convert_to_tensor(kernel)
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2

strides = standardize_tuple(strides, num_spatial_dims, "strides")
dilation_rate = standardize_tuple(
dilation_rate, num_spatial_dims, "dilation_rate"
)

if data_format == "channels_first":
# mlx expects channels_last
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1)

feature_group_count = inputs.shape[-1]

# reshape first for depthwise conv, then transpose to expected mlx format
kernel = kernel.reshape(
*iter(kernel.shape[:-2]), 1, feature_group_count * kernel.shape[-1]
)
# mlx expects kernel with (out_channels, spatial..., in_channels)
kernel = kernel.transpose(-1, *range(kernel.ndim - 2), -2)

kernel_spatial_shape = kernel.shape[1:-1]
input_spatial_shape = inputs.shape[1:-1]
mlx_padding = compute_transpose_padding_args_for_mlx(
padding,
input_spatial_shape,
kernel_spatial_shape,
dilation_rate,
strides,
)

result = mx.conv_general(
inputs,
kernel,
stride=strides,
padding=mlx_padding,
kernel_dilation=dilation_rate,
input_dilation=1,
groups=feature_group_count,
flip=False,
)
if data_format == "channels_first":
result = result.transpose(0, -1, *range(1, result.ndim - 1))

return result


def separable_conv(
Expand All @@ -214,7 +256,23 @@ def separable_conv(
data_format=None,
dilation_rate=1,
):
raise NotImplementedError("MLX backend doesn't support separable conv yet")
data_format = standardize_data_format(data_format)
depthwise_conv_output = depthwise_conv(
inputs,
depthwise_kernel,
strides,
padding,
data_format,
dilation_rate,
)
return conv(
depthwise_conv_output,
pointwise_kernel,
strides=1,
padding="valid",
data_format=data_format,
dilation_rate=dilation_rate,
)


def conv_transpose(
Expand All @@ -226,7 +284,62 @@ def conv_transpose(
data_format=None,
dilation_rate=1,
):
raise NotImplementedError("MLX backend doesn't support conv transpose yet")
inputs = convert_to_tensor(inputs)
kernel = convert_to_tensor(kernel)
data_format = standardize_data_format(data_format)
num_spatial_dims = inputs.ndim - 2

strides = standardize_tuple(strides, num_spatial_dims, "strides")
dilation_rate = standardize_tuple(
dilation_rate, num_spatial_dims, "dilation_rate"
)
if output_padding is not None:
output_padding = standardize_tuple(
output_padding, num_spatial_dims, "output_padding"
)

if data_format == "channels_first":
# mlx expects channels_last
inputs = inputs.transpose(0, *range(2, inputs.ndim), 1)

# mlx expects kernel with (out_channels, spatial..., in_channels)
kernel = kernel.transpose(-2, *range(kernel.ndim - 2), -1)
kernel_spatial_shape = kernel.shape[1:-1]

mlx_padding = compute_conv_transpose_padding_args_for_mlx(
padding,
num_spatial_dims,
kernel_spatial_shape,
dilation_rate,
strides,
output_padding,
)

channels = inputs.shape[-1]
kernel_in_channels = kernel.shape[-1]
if channels % kernel_in_channels > 0:
raise ValueError(
"The number of input channels must be evenly divisible by "
f"kernel's in_channels. Received input channels {channels} and "
f"kernel in_channels {kernel_in_channels}. "
)
groups = channels // kernel_in_channels

result = mx.conv_general(
inputs,
kernel,
stride=1, # stride is handled by input_dilation
padding=mlx_padding,
kernel_dilation=dilation_rate,
input_dilation=strides,
groups=groups,
flip=True,
)

if data_format == "channels_first":
result = result.transpose(0, -1, *range(1, result.ndim - 1))

return result


def one_hot(x, num_classes, axis=-1, dtype="float32", sparse=False):
Expand Down

0 comments on commit e2e2288

Please sign in to comment.