Skip to content

Commit

Permalink
Add MobileViT (#8)
Browse files Browse the repository at this point in the history
* Add `MobileViT`

* Improve `MobileViT`
  • Loading branch information
james77777778 authored Jan 12, 2024
1 parent 5fe9c62 commit ca90fa4
Show file tree
Hide file tree
Showing 11 changed files with 804 additions and 280 deletions.
3 changes: 3 additions & 0 deletions kimm/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from kimm.blocks.base_block import apply_activation
from kimm.blocks.base_block import apply_conv2d_block
from kimm.blocks.base_block import apply_se_block
from kimm.blocks.inverted_residual_block import apply_inverted_residual_block
from kimm.blocks.transformer_block import apply_mlp_block
from kimm.blocks.transformer_block import apply_transformer_block
77 changes: 77 additions & 0 deletions kimm/blocks/inverted_residual_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from keras import layers

from kimm.blocks.base_block import apply_conv2d_block
from kimm.blocks.base_block import apply_se_block
from kimm.utils import make_divisible


def apply_inverted_residual_block(
inputs,
output_channels,
depthwise_kernel_size=3,
expansion_kernel_size=1,
pointwise_kernel_size=1,
strides=1,
expansion_ratio=1.0,
se_ratio=0.0,
activation="swish",
se_input_channels=None,
se_activation=None,
se_gate_activation="sigmoid",
se_make_divisible_number=None,
bn_epsilon=1e-5,
padding=None,
name="inverted_residual_block",
):
input_channels = inputs.shape[-1]
hidden_channels = make_divisible(input_channels * expansion_ratio)
has_skip = strides == 1 and input_channels == output_channels

x = inputs
# Point-wise expansion
x = apply_conv2d_block(
x,
hidden_channels,
expansion_kernel_size,
1,
activation=activation,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_pw",
)
# Depth-wise convolution
x = apply_conv2d_block(
x,
kernel_size=depthwise_kernel_size,
strides=strides,
activation=activation,
use_depthwise=True,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_dw",
)
# Squeeze-and-excitation
if se_ratio > 0:
x = apply_se_block(
x,
se_ratio,
activation=se_activation or activation,
gate_activation=se_gate_activation,
se_input_channels=se_input_channels,
make_divisible_number=se_make_divisible_number,
name=f"{name}_se",
)
# Point-wise linear projection
x = apply_conv2d_block(
x,
output_channels,
pointwise_kernel_size,
1,
activation=None,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_pwl",
)
if has_skip:
x = layers.Add()([x, inputs])
return x
67 changes: 67 additions & 0 deletions kimm/blocks/transformer_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from keras import layers

from kimm import layers as kimm_layers


def apply_mlp_block(
inputs,
hidden_dim,
output_dim=None,
activation="gelu",
normalization=None,
use_bias=True,
dropout_rate=0.0,
name="mlp_block",
):
input_dim = inputs.shape[-1]
output_dim = output_dim or input_dim

x = inputs
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
x = layers.Activation(activation, name=f"{name}_act")(x)
x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x)
if normalization is not None:
x = normalization(name=f"{name}_norm")(x)
x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x)
x = layers.Dropout(dropout_rate, name=f"{name}_drop2")(x)
return x


def apply_transformer_block(
inputs,
dim,
num_heads,
mlp_ratio=4.0,
use_qkv_bias=False,
use_qk_norm=False,
projection_dropout_rate=0.0,
attention_dropout_rate=0.0,
activation="gelu",
name="transformer_block",
):
x = inputs
residual_1 = x

x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm1")(x)
x = kimm_layers.Attention(
dim,
num_heads,
use_qkv_bias,
use_qk_norm,
attention_dropout_rate,
projection_dropout_rate,
name=f"{name}_attn",
)(x)
x = layers.Add()([residual_1, x])

residual_2 = x
x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm2")(x)
x = apply_mlp_block(
x,
int(dim * mlp_ratio),
activation=activation,
dropout_rate=projection_dropout_rate,
name=f"{name}_mlp",
)
x = layers.Add()([residual_2, x])
return x
70 changes: 2 additions & 68 deletions kimm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.applications import imagenet_utils

from kimm.blocks import apply_conv2d_block
from kimm.blocks import apply_inverted_residual_block
from kimm.blocks import apply_se_block
from kimm.models.feature_extractor import FeatureExtractor
from kimm.utils import make_divisible
Expand Down Expand Up @@ -130,73 +131,6 @@ def apply_depthwise_separation_block(
return x


def apply_inverted_residual_block(
inputs,
output_channels,
depthwise_kernel_size=3,
expansion_kernel_size=1,
pointwise_kernel_size=1,
strides=1,
expansion_ratio=1.0,
se_ratio=0.0,
activation="swish",
bn_epsilon=1e-5,
padding=None,
name="inverted_residual_block",
):
input_channels = inputs.shape[-1]
hidden_channels = make_divisible(input_channels * expansion_ratio)
has_skip = strides == 1 and input_channels == output_channels

x = inputs
# Point-wise expansion
x = apply_conv2d_block(
x,
hidden_channels,
expansion_kernel_size,
1,
activation=activation,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_pw",
)
# Depth-wise convolution
x = apply_conv2d_block(
x,
kernel_size=depthwise_kernel_size,
strides=strides,
activation=activation,
use_depthwise=True,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_dw",
)
# Squeeze-and-excitation
if se_ratio > 0:
x = apply_se_block(
x,
se_ratio,
activation=activation,
gate_activation="sigmoid",
se_input_channels=input_channels,
name=f"{name}_se",
)
# Point-wise linear projection
x = apply_conv2d_block(
x,
output_channels,
pointwise_kernel_size,
1,
activation=None,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_pwl",
)
if has_skip:
x = layers.Add()([x, inputs])
return x


def apply_edge_residual_block(
inputs,
output_channels,
Expand Down Expand Up @@ -271,7 +205,7 @@ def __init__(
classes: int = 1000,
classifier_activation: str = "softmax",
weights: typing.Optional[str] = None, # TODO: imagenet
config: typing.Union[str, typing.List] = "default",
config: typing.Union[str, typing.List] = "v1",
**kwargs,
):
_available_configs = [
Expand Down
52 changes: 2 additions & 50 deletions kimm/models/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.applications import imagenet_utils

from kimm.blocks import apply_conv2d_block
from kimm.blocks import apply_inverted_residual_block
from kimm.models.feature_extractor import FeatureExtractor
from kimm.utils import make_divisible
from kimm.utils.model_registry import add_model_to_registry
Expand Down Expand Up @@ -58,55 +59,6 @@ def apply_depthwise_separation_block(
return x


def apply_inverted_residual_block(
inputs,
output_channels,
depthwise_kernel_size=3,
expansion_kernel_size=1,
pointwise_kernel_size=1,
strides=1,
expansion_ratio=1.0,
activation="relu6",
name="inverted_residual_block",
):
input_channels = inputs.shape[-1]
hidden_channels = make_divisible(input_channels * expansion_ratio)
has_skip = strides == 1 and input_channels == output_channels

x = inputs

# Point-wise expansion
x = apply_conv2d_block(
x,
hidden_channels,
expansion_kernel_size,
1,
activation=activation,
name=f"{name}_conv_pw",
)
# Depth-wise convolution
x = apply_conv2d_block(
x,
kernel_size=depthwise_kernel_size,
strides=strides,
activation=activation,
use_depthwise=True,
name=f"{name}_conv_dw",
)
# Point-wise linear projection
x = apply_conv2d_block(
x,
output_channels,
pointwise_kernel_size,
1,
activation=None,
name=f"{name}_conv_pwl",
)
if has_skip:
x = layers.Add()([x, inputs])
return x


class MobileNetV2(FeatureExtractor):
def __init__(
self,
Expand Down Expand Up @@ -189,7 +141,7 @@ def __init__(
)
elif block_type == "ir":
x = apply_inverted_residual_block(
x, c, k, 1, 1, s, e, name=name
x, c, k, 1, 1, s, e, activation="relu6", name=name
)
current_stride *= s
features[f"BLOCK{current_block_idx}_S{current_stride}"] = x
Expand Down
Loading

0 comments on commit ca90fa4

Please sign in to comment.