-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
11 changed files
with
804 additions
and
280 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,6 @@ | ||
from kimm.blocks.base_block import apply_activation | ||
from kimm.blocks.base_block import apply_conv2d_block | ||
from kimm.blocks.base_block import apply_se_block | ||
from kimm.blocks.inverted_residual_block import apply_inverted_residual_block | ||
from kimm.blocks.transformer_block import apply_mlp_block | ||
from kimm.blocks.transformer_block import apply_transformer_block |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from keras import layers | ||
|
||
from kimm.blocks.base_block import apply_conv2d_block | ||
from kimm.blocks.base_block import apply_se_block | ||
from kimm.utils import make_divisible | ||
|
||
|
||
def apply_inverted_residual_block( | ||
inputs, | ||
output_channels, | ||
depthwise_kernel_size=3, | ||
expansion_kernel_size=1, | ||
pointwise_kernel_size=1, | ||
strides=1, | ||
expansion_ratio=1.0, | ||
se_ratio=0.0, | ||
activation="swish", | ||
se_input_channels=None, | ||
se_activation=None, | ||
se_gate_activation="sigmoid", | ||
se_make_divisible_number=None, | ||
bn_epsilon=1e-5, | ||
padding=None, | ||
name="inverted_residual_block", | ||
): | ||
input_channels = inputs.shape[-1] | ||
hidden_channels = make_divisible(input_channels * expansion_ratio) | ||
has_skip = strides == 1 and input_channels == output_channels | ||
|
||
x = inputs | ||
# Point-wise expansion | ||
x = apply_conv2d_block( | ||
x, | ||
hidden_channels, | ||
expansion_kernel_size, | ||
1, | ||
activation=activation, | ||
bn_epsilon=bn_epsilon, | ||
padding=padding, | ||
name=f"{name}_conv_pw", | ||
) | ||
# Depth-wise convolution | ||
x = apply_conv2d_block( | ||
x, | ||
kernel_size=depthwise_kernel_size, | ||
strides=strides, | ||
activation=activation, | ||
use_depthwise=True, | ||
bn_epsilon=bn_epsilon, | ||
padding=padding, | ||
name=f"{name}_conv_dw", | ||
) | ||
# Squeeze-and-excitation | ||
if se_ratio > 0: | ||
x = apply_se_block( | ||
x, | ||
se_ratio, | ||
activation=se_activation or activation, | ||
gate_activation=se_gate_activation, | ||
se_input_channels=se_input_channels, | ||
make_divisible_number=se_make_divisible_number, | ||
name=f"{name}_se", | ||
) | ||
# Point-wise linear projection | ||
x = apply_conv2d_block( | ||
x, | ||
output_channels, | ||
pointwise_kernel_size, | ||
1, | ||
activation=None, | ||
bn_epsilon=bn_epsilon, | ||
padding=padding, | ||
name=f"{name}_conv_pwl", | ||
) | ||
if has_skip: | ||
x = layers.Add()([x, inputs]) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from keras import layers | ||
|
||
from kimm import layers as kimm_layers | ||
|
||
|
||
def apply_mlp_block( | ||
inputs, | ||
hidden_dim, | ||
output_dim=None, | ||
activation="gelu", | ||
normalization=None, | ||
use_bias=True, | ||
dropout_rate=0.0, | ||
name="mlp_block", | ||
): | ||
input_dim = inputs.shape[-1] | ||
output_dim = output_dim or input_dim | ||
|
||
x = inputs | ||
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x) | ||
x = layers.Activation(activation, name=f"{name}_act")(x) | ||
x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x) | ||
if normalization is not None: | ||
x = normalization(name=f"{name}_norm")(x) | ||
x = layers.Dense(output_dim, use_bias=use_bias, name=f"{name}_fc2")(x) | ||
x = layers.Dropout(dropout_rate, name=f"{name}_drop2")(x) | ||
return x | ||
|
||
|
||
def apply_transformer_block( | ||
inputs, | ||
dim, | ||
num_heads, | ||
mlp_ratio=4.0, | ||
use_qkv_bias=False, | ||
use_qk_norm=False, | ||
projection_dropout_rate=0.0, | ||
attention_dropout_rate=0.0, | ||
activation="gelu", | ||
name="transformer_block", | ||
): | ||
x = inputs | ||
residual_1 = x | ||
|
||
x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm1")(x) | ||
x = kimm_layers.Attention( | ||
dim, | ||
num_heads, | ||
use_qkv_bias, | ||
use_qk_norm, | ||
attention_dropout_rate, | ||
projection_dropout_rate, | ||
name=f"{name}_attn", | ||
)(x) | ||
x = layers.Add()([residual_1, x]) | ||
|
||
residual_2 = x | ||
x = layers.LayerNormalization(epsilon=1e-6, name=f"{name}_norm2")(x) | ||
x = apply_mlp_block( | ||
x, | ||
int(dim * mlp_ratio), | ||
activation=activation, | ||
dropout_rate=projection_dropout_rate, | ||
name=f"{name}_mlp", | ||
) | ||
x = layers.Add()([residual_2, x]) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.