Skip to content

Commit

Permalink
Add type hints (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Jan 20, 2024
1 parent a50003c commit c608f1c
Show file tree
Hide file tree
Showing 17 changed files with 135 additions and 98 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
<!-- markdownlint-disable MD033 -->
<!-- markdownlint-disable MD041 -->

# Keras Image Models

<div align="center">
<img width="50%" src="https://github.com/james77777778/kimm/assets/20734616/b21db8f2-307b-4791-b93d-e913e45fb238" alt="KIMM">

Expand All @@ -11,6 +9,8 @@
[![codecov](https://codecov.io/gh/james77777778/kimm/graph/badge.svg?token=eEha1SR80D)](https://codecov.io/gh/james77777778/kimm)
</div>

# Keras Image Models

## Description

**K**eras **Im**age **M**odels (`kimm`) is a collection of image models, blocks and layers written in Keras 3. The goal is to offer SOTA models with pretrained weights in a user-friendly manner.
Expand Down
46 changes: 26 additions & 20 deletions kimm/blocks/base_block.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import typing

from keras import layers

from kimm.utils import make_divisible


def apply_activation(x, activation=None, name="activation"):
def apply_activation(
inputs, activation: typing.Optional[str] = None, name: str = "activation"
):
x = inputs
if activation is not None:
if isinstance(activation, str):
x = layers.Activation(activation, name=name)(x)
Expand All @@ -18,16 +23,18 @@ def apply_activation(x, activation=None, name="activation"):

def apply_conv2d_block(
inputs,
filters=None,
kernel_size=None,
strides=1,
groups=1,
activation=None,
use_depthwise=False,
add_skip=False,
bn_momentum=0.9,
bn_epsilon=1e-5,
padding=None,
filters: typing.Optional[int] = None,
kernel_size: typing.Optional[
typing.Union[int, typing.Sequence[int]]
] = None,
strides: int = 1,
groups: int = 1,
activation: typing.Optional[str] = None,
use_depthwise: bool = False,
add_skip: bool = False,
bn_momentum: float = 0.9,
bn_epsilon: float = 1e-5,
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name="conv2d_block",
):
if kernel_size is None:
Expand Down Expand Up @@ -77,12 +84,12 @@ def apply_conv2d_block(

def apply_se_block(
inputs,
se_ratio=0.25,
activation="relu",
gate_activation="sigmoid",
make_divisible_number=None,
se_input_channels=None,
name="se_block",
se_ratio: float = 0.25,
activation: typing.Optional[str] = "relu",
gate_activation: typing.Optional[str] = "sigmoid",
make_divisible_number: typing.Optional[int] = None,
se_input_channels: typing.Optional[int] = None,
name: str = "se_block",
):
input_channels = inputs.shape[-1]
if se_input_channels is None:
Expand All @@ -94,7 +101,6 @@ def apply_se_block(
se_input_channels * se_ratio, make_divisible_number
)

ori_x = inputs
x = inputs
x = layers.GlobalAveragePooling2D(keepdims=True, name=f"{name}_mean")(x)
x = layers.Conv2D(
Expand All @@ -105,5 +111,5 @@ def apply_se_block(
input_channels, 1, use_bias=True, name=f"{name}_conv_expand"
)(x)
x = apply_activation(x, gate_activation, name=f"{name}_gate")
out = layers.Multiply(name=name)([ori_x, x])
return out
x = layers.Multiply(name=name)([inputs, x])
return x
30 changes: 16 additions & 14 deletions kimm/blocks/depthwise_separation_block.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

from keras import layers

from kimm.blocks.base_block import apply_conv2d_block
Expand All @@ -6,20 +8,20 @@

def apply_depthwise_separation_block(
inputs,
output_channels,
depthwise_kernel_size=3,
pointwise_kernel_size=1,
strides=1,
se_ratio=0.0,
activation="swish",
se_activation="relu",
se_gate_activation="sigmoid",
se_make_divisible_number=None,
pw_activation=None,
skip=True,
bn_epsilon=1e-5,
padding=None,
name="depthwise_separation_block",
output_channels: int,
depthwise_kernel_size: int = 3,
pointwise_kernel_size: int = 1,
strides: int = 1,
se_ratio: float = 0.0,
activation: typing.Optional[str] = "swish",
se_activation: typing.Optional[str] = "relu",
se_gate_activation: typing.Optional[str] = "sigmoid",
se_make_divisible_number: typing.Optional[int] = None,
pw_activation: typing.Optional[str] = None,
skip: bool = True,
bn_epsilon: float = 1e-5,
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name: str = "depthwise_separation_block",
):
input_channels = inputs.shape[-1]
has_skip = skip and (strides == 1 and input_channels == output_channels)
Expand Down
32 changes: 17 additions & 15 deletions kimm/blocks/inverted_residual_block.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

from keras import layers

from kimm.blocks.base_block import apply_conv2d_block
Expand All @@ -7,21 +9,21 @@

def apply_inverted_residual_block(
inputs,
output_channels,
depthwise_kernel_size=3,
expansion_kernel_size=1,
pointwise_kernel_size=1,
strides=1,
expansion_ratio=1.0,
se_ratio=0.0,
activation="swish",
se_channels=None,
se_activation=None,
se_gate_activation="sigmoid",
se_make_divisible_number=None,
bn_epsilon=1e-5,
padding=None,
name="inverted_residual_block",
output_channels: int,
depthwise_kernel_size: int = 3,
expansion_kernel_size: int = 1,
pointwise_kernel_size: int = 1,
strides: int = 1,
expansion_ratio: float = 1.0,
se_ratio: float = 0.0,
activation: str = "swish",
se_channels: typing.Optional[int] = None,
se_activation: typing.Optional[str] = None,
se_gate_activation: typing.Optional[str] = "sigmoid",
se_make_divisible_number: typing.Optional[int] = None,
bn_epsilon: float = 1e-5,
padding: typing.Optional[typing.Literal["same", "valid"]] = None,
name: str = "inverted_residual_block",
):
input_channels = inputs.shape[-1]
hidden_channels = make_divisible(input_channels * expansion_ratio)
Expand Down
37 changes: 18 additions & 19 deletions kimm/blocks/transformer_block.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import typing

from keras import layers

from kimm import layers as kimm_layers


def apply_mlp_block(
inputs,
hidden_dim,
output_dim=None,
activation="gelu",
normalization=None,
use_bias=True,
dropout_rate=0.0,
use_conv_mlp=False,
name="mlp_block",
hidden_dim: int,
output_dim: typing.Optional[int] = None,
activation: str = "gelu",
use_bias: bool = True,
dropout_rate: float = 0.0,
use_conv_mlp: bool = False,
name: str = "mlp_block",
):
input_dim = inputs.shape[-1]
output_dim = output_dim or input_dim
Expand All @@ -26,8 +27,6 @@ def apply_mlp_block(
x = layers.Dense(hidden_dim, use_bias=use_bias, name=f"{name}_fc1")(x)
x = layers.Activation(activation, name=f"{name}_act")(x)
x = layers.Dropout(dropout_rate, name=f"{name}_drop1")(x)
if normalization is not None:
x = normalization(name=f"{name}_norm")(x)
if use_conv_mlp:
x = layers.Conv2D(
output_dim, 1, use_bias=use_bias, name=f"{name}_fc2_conv2d"
Expand All @@ -40,15 +39,15 @@ def apply_mlp_block(

def apply_transformer_block(
inputs,
dim,
num_heads,
mlp_ratio=4.0,
use_qkv_bias=False,
use_qk_norm=False,
projection_dropout_rate=0.0,
attention_dropout_rate=0.0,
activation="gelu",
name="transformer_block",
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
use_qkv_bias: bool = False,
use_qk_norm: bool = False,
projection_dropout_rate: float = 0.0,
attention_dropout_rate: float = 0.0,
activation: str = "gelu",
name: str = "transformer_block",
):
x = inputs
residual_1 = x
Expand Down
14 changes: 7 additions & 7 deletions kimm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
class Attention(layers.Layer):
def __init__(
self,
hidden_dim,
num_heads=8,
use_qkv_bias=False,
use_qk_norm=False,
attention_dropout_rate=0.0,
projection_dropout_rate=0.0,
name="attention",
hidden_dim: int,
num_heads: int = 8,
use_qkv_bias: bool = False,
use_qk_norm: bool = False,
attention_dropout_rate: float = 0.0,
projection_dropout_rate: float = 0.0,
name: str = "attention",
**kwargs,
):
super().__init__(**kwargs)
Expand Down
7 changes: 4 additions & 3 deletions kimm/layers/layer_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
from keras import initializers
from keras import layers
from keras import ops
from keras.initializers import Initializer


@keras.saving.register_keras_serializable(package="kimm")
class LayerScale(layers.Layer):
def __init__(
self,
hidden_size,
initializer=initializers.Constant(1e-5),
name="layer_scale",
hidden_size: int,
initializer: Initializer = initializers.Constant(1e-5),
name: str = "layer_scale",
**kwargs,
):
super().__init__(**kwargs)
Expand Down
26 changes: 18 additions & 8 deletions kimm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def set_properties(

def determine_input_tensor(
self,
input_tensor=None,
input_shape=None,
default_size=224,
min_size=32,
require_flatten=False,
static_shape=False,
input_tensor: typing.Optional[KerasTensor] = None,
input_shape: typing.Optional[typing.Sequence[int]] = None,
default_size: int = 224,
min_size: int = 32,
require_flatten: bool = False,
static_shape: bool = False,
):
"""Determine the input tensor by the arguments."""
input_shape = imagenet_utils.obtain_input_shape(
Expand All @@ -118,7 +118,11 @@ def determine_input_tensor(
x = utils.get_source_inputs(input_tensor)
return x

def build_preprocessing(self, inputs, mode="imagenet"):
def build_preprocessing(
self,
inputs,
mode: typing.Literal["imagenet", "0_1", "-1_1"] = "imagenet",
):
if self._include_preprocessing is False:
return inputs
if mode == "imagenet":
Expand All @@ -140,7 +144,13 @@ def build_preprocessing(self, inputs, mode="imagenet"):
)
return x

def build_top(self, inputs, classes, classifier_activation, dropout_rate):
def build_top(
self,
inputs,
classes: int,
classifier_activation: str,
dropout_rate: float,
):
x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs)
x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x)
x = layers.Dense(
Expand Down
2 changes: 1 addition & 1 deletion kimm/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(
fix_stem_and_head_channels: bool = False,
fix_first_and_last_blocks: bool = False,
activation="swish",
config: typing.Union[str, typing.List] = "v1",
config: str = "v1",
**kwargs,
):
_available_configs = [
Expand Down
2 changes: 1 addition & 1 deletion kimm/models/ghostnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def __init__(
self,
width: float = 1.0,
config: typing.Union[str, typing.List] = "default",
version: str = "v1",
version: typing.Literal["v1", "v2"] = "v1",
**kwargs,
):
_available_configs = ["default"]
Expand Down
2 changes: 1 addition & 1 deletion kimm/models/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def apply_inception_aux_block(inputs, classes, name="inception_aux_block"):

@keras.saving.register_keras_serializable(package="kimm")
class InceptionV3Base(BaseModel):
def __init__(self, has_aux_logits=False, **kwargs):
def __init__(self, has_aux_logits: bool = False, **kwargs):
input_tensor = kwargs.pop("input_tensor", None)
self.set_properties(kwargs, 299)
inputs = self.determine_input_tensor(
Expand Down
2 changes: 1 addition & 1 deletion kimm/models/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
width: float = 1.0,
depth: float = 1.0,
fix_stem_and_head_channels: bool = False,
config: typing.Union[str, typing.List] = "default",
config: typing.Literal["default"] = "default",
**kwargs,
):
_available_configs = ["default"]
Expand Down
2 changes: 1 addition & 1 deletion kimm/models/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
width: float = 1.0,
depth: float = 1.0,
fix_stem_and_head_channels: bool = False,
config: typing.Union[str, typing.List] = "large",
config: typing.Literal["small", "large", "lcnet"] = "large",
minimal: bool = False,
**kwargs,
):
Expand Down
2 changes: 1 addition & 1 deletion kimm/models/mobilevit.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __init__(
self,
stem_channels: int = 16,
head_channels: int = 640,
activation="swish",
activation: str = "swish",
config: str = "v1_s",
**kwargs,
):
Expand Down
5 changes: 4 additions & 1 deletion kimm/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def apply_bottleneck_block(
@keras.saving.register_keras_serializable(package="kimm")
class ResNet(BaseModel):
def __init__(
self, block_fn: str, num_blocks: typing.Sequence[int], **kwargs
self,
block_fn: typing.Literal["basic", "bottleneck"],
num_blocks: typing.Sequence[int],
**kwargs,
):
if block_fn not in ("basic", "bottleneck"):
raise ValueError(
Expand Down
Loading

0 comments on commit c608f1c

Please sign in to comment.