diff --git a/conftest.py b/conftest.py index 0127de4..fce612b 100644 --- a/conftest.py +++ b/conftest.py @@ -1,7 +1,18 @@ import os +import pytest -def pytest_configure(): + +def pytest_addoption(parser): + parser.addoption( + "--run_serialization", + action="store_true", + default=False, + help="run serialization tests", + ) + + +def pytest_configure(config): import tensorflow as tf # disable tensorflow gpu memory preallocation @@ -12,3 +23,18 @@ def pytest_configure(): # disable jax gpu memory preallocation # https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" + + config.addinivalue_line( + "markers", "serialization: mark test as a serialization test" + ) + + +def pytest_collection_modifyitems(config, items): + run_serialization_tests = config.getoption("--run_serialization") + skip_serialization = pytest.mark.skipif( + not run_serialization_tests, + reason="need --run_serialization option to run", + ) + for item in items: + if "serialization" in item.name: + item.add_marker(skip_serialization) diff --git a/kimm/blocks/base_block.py b/kimm/blocks/base_block.py index 2a428c6..54e3672 100644 --- a/kimm/blocks/base_block.py +++ b/kimm/blocks/base_block.py @@ -34,6 +34,8 @@ def apply_conv2d_block( raise ValueError( f"kernel_size must be passed. Received: kernel_size={kernel_size}" ) + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] input_channels = inputs.shape[-1] has_skip = add_skip and strides == 1 and input_channels == filters x = inputs @@ -42,7 +44,9 @@ def apply_conv2d_block( padding = "same" if strides > 1: padding = "valid" - x = layers.ZeroPadding2D(kernel_size // 2, name=f"{name}_pad")(x) + x = layers.ZeroPadding2D( + (kernel_size[0] // 2, kernel_size[1] // 2), name=f"{name}_pad" + )(x) if not use_depthwise: x = layers.Conv2D( diff --git a/kimm/blocks/inverted_residual_block.py b/kimm/blocks/inverted_residual_block.py index 1d58267..1043a90 100644 --- a/kimm/blocks/inverted_residual_block.py +++ b/kimm/blocks/inverted_residual_block.py @@ -15,7 +15,7 @@ def apply_inverted_residual_block( expansion_ratio=1.0, se_ratio=0.0, activation="swish", - se_input_channels=None, + se_channels=None, se_activation=None, se_gate_activation="sigmoid", se_make_divisible_number=None, @@ -57,7 +57,7 @@ def apply_inverted_residual_block( se_ratio, activation=se_activation or activation, gate_activation=se_gate_activation, - se_input_channels=se_input_channels, + se_input_channels=se_channels, make_divisible_number=se_make_divisible_number, name=f"{name}_se", ) diff --git a/kimm/layers/attention.py b/kimm/layers/attention.py index 1797020..7eb90af 100644 --- a/kimm/layers/attention.py +++ b/kimm/layers/attention.py @@ -118,18 +118,3 @@ def get_config(self): } ) return config - - -if __name__ == "__main__": - from keras import models - from keras import random - - inputs = layers.Input(shape=[197, 768]) - outputs = Attention(768)(inputs) - - model = models.Model(inputs, outputs) - model.summary() - - inputs = random.uniform([1, 197, 768]) - outputs = model(inputs) - print(outputs.shape) diff --git a/kimm/layers/layer_scale.py b/kimm/layers/layer_scale.py index 9f030c7..0afce2d 100644 --- a/kimm/layers/layer_scale.py +++ b/kimm/layers/layer_scale.py @@ -35,18 +35,3 @@ def get_config(self): } ) return config - - -if __name__ == "__main__": - from keras import models - from keras import random - - inputs = layers.Input(shape=[197, 768]) - outputs = LayerScale(768)(inputs) - - model = models.Model(inputs, outputs) - model.summary() - - inputs = random.uniform([1, 197, 768]) - outputs = model(inputs) - print(outputs.shape) diff --git a/kimm/layers/position_embedding.py b/kimm/layers/position_embedding.py index 167b9ae..82670f3 100644 --- a/kimm/layers/position_embedding.py +++ b/kimm/layers/position_embedding.py @@ -38,25 +38,3 @@ def compute_output_shape(self, input_shape): def get_config(self): return super().get_config() - - -if __name__ == "__main__": - from keras import models - from keras import random - - inputs = layers.Input([224, 224, 3]) - x = layers.Conv2D( - 768, - 16, - 16, - use_bias=True, - )(inputs) - x = layers.Reshape((-1, 768))(x) - outputs = PositionEmbedding()(x) - - model = models.Model(inputs, outputs) - model.summary() - - inputs = random.uniform([1, 224, 224, 3]) - outputs = model(inputs) - print(outputs.shape) diff --git a/kimm/models/__init__.py b/kimm/models/__init__.py index 2b60641..6be22a2 100644 --- a/kimm/models/__init__.py +++ b/kimm/models/__init__.py @@ -1,5 +1,5 @@ +from kimm.models.base_model import BaseModel from kimm.models.efficientnet import * # noqa:F403 -from kimm.models.feature_extractor import FeatureExtractor from kimm.models.ghostnet import * # noqa:F403 from kimm.models.mobilenet_v2 import * # noqa:F403 from kimm.models.mobilenet_v3 import * # noqa:F403 diff --git a/kimm/models/base_model.py b/kimm/models/base_model.py new file mode 100644 index 0000000..45204fd --- /dev/null +++ b/kimm/models/base_model.py @@ -0,0 +1,157 @@ +import abc +import typing + +from keras import KerasTensor +from keras import backend +from keras import layers +from keras import models +from keras.src.applications import imagenet_utils + + +class BaseModel(models.Model): + def __init__( + self, + inputs, + outputs, + features: typing.Optional[typing.Dict[str, KerasTensor]] = None, + feature_keys: typing.Optional[typing.List[str]] = None, + **kwargs, + ): + self.as_feature_extractor = kwargs.pop("as_feature_extractor", False) + self.feature_keys = feature_keys + if self.as_feature_extractor: + if features is None: + raise ValueError( + "`features` must be set when " + f"`as_feature_extractor=True`. Got features={features}" + ) + if self.feature_keys is None: + self.feature_keys = list(features.keys()) + filtered_features = {} + for k in self.feature_keys: + if k not in features: + raise KeyError( + f"'{k}' is not a key of `features`. Available keys " + f"are: {list(features.keys())}" + ) + filtered_features[k] = features[k] + super().__init__(inputs=inputs, outputs=filtered_features, **kwargs) + else: + del features + super().__init__(inputs=inputs, outputs=outputs, **kwargs) + + def parse_kwargs( + self, kwargs: typing.Dict[str, typing.Any], default_size: int = 224 + ): + result = { + "input_tensor": kwargs.pop("input_tensor", None), + "input_shape": kwargs.pop("input_shape", None), + "include_preprocessing": kwargs.pop("include_preprocessing", True), + "include_top": kwargs.pop("include_top", True), + "pooling": kwargs.pop("pooling", None), + "dropout_rate": kwargs.pop("dropout_rate", 0.0), + "classes": kwargs.pop("classes", 1000), + "classifier_activation": kwargs.pop( + "classifier_activation", "softmax" + ), + "weights": kwargs.pop("weights", "imagenet"), + "default_size": kwargs.pop("default_size", default_size), + } + return result + + def determine_input_tensor( + self, + input_tensor=None, + input_shape=None, + default_size=224, + min_size=32, + require_flatten=False, + static_shape=False, + ): + """Determine the input tensor by the arguments.""" + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=default_size, + min_size=min_size, + data_format="channels_last", # always channels_last + require_flatten=require_flatten or static_shape, + weights=None, + ) + + if input_tensor is None: + x = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + x = layers.Input(tensor=input_tensor, shape=input_shape) + else: + x = input_tensor + return x + + def build_preprocessing(self, inputs, mode="imagenet"): + if mode == "imagenet": + # [0, 255] to [0, 1] and apply ImageNet mean and variance + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + elif mode == "0_1": + # [0, 255] to [-1, 1] + x = layers.Rescaling(scale=1.0 / 255.0)(inputs) + elif mode == "-1_1": + # [0, 255] to [-1, 1] + x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(inputs) + else: + raise ValueError( + "`mode` must be one of ('imagenet', '0_1', '-1_1'). " + f"Received: mode={mode}" + ) + return x + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs) + x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + + def add_references(self, parsed_kwargs: typing.Dict[str, typing.Any]): + self.include_preprocessing = parsed_kwargs["include_preprocessing"] + self.include_top = parsed_kwargs["include_top"] + self.pooling = parsed_kwargs["pooling"] + self.dropout_rate = parsed_kwargs["dropout_rate"] + self.classes = parsed_kwargs["classes"] + self.classifier_activation = parsed_kwargs["classifier_activation"] + # `self.weights` is been used internally + self._weights = parsed_kwargs["weights"] + + @staticmethod + @abc.abstractmethod + def available_feature_keys(): + # TODO: add docstring + raise NotImplementedError + + def get_config(self): + # Don't chain to super here. The default `get_config()` for functional + # models is nested and cannot be passed to BaseModel. + config = { + # models.Model + "name": self.name, + "trainable": self.trainable, + # feature extractor + "as_feature_extractor": self.as_feature_extractor, + "feature_keys": self.feature_keys, + # common + "input_shape": self.input_shape[1:], + "include_preprocessing": self.include_preprocessing, + "include_top": self.include_top, + "pooling": self.pooling, + "dropout_rate": self.dropout_rate, + "classes": self.classes, + "classifier_activation": self.classifier_activation, + "weights": self._weights, + } + return config + + def fix_config(self, config: typing.Dict): + return config diff --git a/kimm/models/feature_extractor_test.py b/kimm/models/base_model_test.py similarity index 93% rename from kimm/models/feature_extractor_test.py rename to kimm/models/base_model_test.py index 0987340..a0977ba 100644 --- a/kimm/models/feature_extractor_test.py +++ b/kimm/models/base_model_test.py @@ -3,10 +3,10 @@ from keras import random from keras.src import testing -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel -class SampleModel(FeatureExtractor): +class SampleModel(BaseModel): def __init__(self, **kwargs): inputs = layers.Input(shape=[224, 224, 3]) @@ -34,7 +34,7 @@ def get_config(self): return super().get_config() -class GhostNetTest(testing.TestCase, parameterized.TestCase): +class BaseModelTest(testing.TestCase, parameterized.TestCase): def test_feature_extractor(self): x = random.uniform([1, 224, 224, 3]) diff --git a/kimm/models/densenet.py b/kimm/models/densenet.py new file mode 100644 index 0000000..479a110 --- /dev/null +++ b/kimm/models/densenet.py @@ -0,0 +1,321 @@ +import typing + +import keras +from keras import layers +from keras import utils + +from kimm.blocks import apply_conv2d_block +from kimm.models import BaseModel +from kimm.utils import add_model_to_registry + + +def apply_dense_layer( + inputs, growth_rate, expansion_ratio=4.0, name="dense_layer" +): + x = inputs + x = layers.BatchNormalization( + momentum=0.9, epsilon=1e-5, name=f"{name}_norm1" + )(x) + x = layers.ReLU()(x) + x = apply_conv2d_block( + x, + int(growth_rate * expansion_ratio), + 1, + 1, + activation="relu", + name=f"{name}_conv1", + ) + x = layers.Conv2D( + growth_rate, 3, 1, padding="same", use_bias=False, name=f"{name}_conv2" + )(x) + return x + + +def apply_dense_block( + inputs, num_layers, growth_rate, expansion_ratio=4.0, name="dense_block" +): + x = inputs + + features = [x] + for i in range(num_layers): + new_features = layers.Concatenate()(features) + new_features = apply_dense_layer( + new_features, + growth_rate, + expansion_ratio, + name=f"{name}_denselayer{i + 1}", + ) + features.append(new_features) + x = layers.Concatenate()(features) + return x + + +def apply_dense_transition_block( + inputs, output_channels, name="dense_transition_block" +): + x = inputs + x = layers.BatchNormalization( + momentum=0.9, epsilon=1e-5, name=f"{name}_norm" + )(x) + x = layers.ReLU()(x) + x = layers.Conv2D( + output_channels, 1, 1, "same", use_bias=False, name=f"{name}_conv" + )(x) + x = layers.AveragePooling2D(2, 2, name=f"{name}_pool")(x) + return x + + +class DenseNet(BaseModel): + def __init__( + self, + growth_rate: float = 32, + num_blocks: typing.Sequence[int] = [6, 12, 24, 16], + **kwargs, + ): + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], + ) + x = img_input + + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} + + # Stem block + stem_channel = growth_rate * 2 + x = apply_conv2d_block( + x, stem_channel, 7, 2, activation="relu", name="features_conv0" + ) + x = layers.ZeroPadding2D(1, name="features_pad0")(x) + x = layers.MaxPooling2D(3, 2, name="features_pool0")(x) + features["STEM_S4"] = x + + # Blocks + current_stride = 4 + input_channels = stem_channel + for current_block_idx, num_layers in enumerate(num_blocks): + x = apply_dense_block( + x, + num_layers, + growth_rate, + expansion_ratio=4.0, + name=f"features_denseblock{current_block_idx + 1}", + ) + input_channels = input_channels + num_layers * growth_rate + if current_block_idx != len(num_blocks) - 1: + current_stride *= 2 + x = apply_dense_transition_block( + x, + input_channels // 2, + name=f"features_transition{current_block_idx + 1}", + ) + input_channels = input_channels // 2 + + features[f"BLOCK{current_block_idx}_S{current_stride}"] = x + + # Final batch norm + x = layers.BatchNormalization( + momentum=0.9, epsilon=1e-5, name="features_norm5" + )(x) + x = layers.ReLU()(x) + + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) + else: + if parsed_kwargs["pooling"] == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif parsed_kwargs["pooling"] == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) + else: + inputs = img_input + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.add_references(parsed_kwargs) + self.growth_rate = growth_rate + self.num_blocks = num_blocks + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S4"] + feature_keys.extend( + [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [8, 16, 32, 32])] + ) + return feature_keys + + def get_config(self): + config = super().get_config() + config.update( + {"growth_rate": self.growth_rate, "num_blocks": self.num_blocks} + ) + return config + + def fix_config(self, config: typing.Dict): + unused_kwargs = ["growth_rate", "num_blocks"] + for k in unused_kwargs: + config.pop(k, None) + return config + + +""" +Model Definition +""" + + +class DenseNet121(DenseNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + name: str = "DenseNet121", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 32, + [6, 12, 24, 16], + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + default_size=288, + **kwargs, + ) + + +class DenseNet161(DenseNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + name: str = "DenseNet161", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 48, + [6, 12, 36, 24], + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + default_size=224, + **kwargs, + ) + + +class DenseNet169(DenseNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + name: str = "DenseNet169", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 32, + [6, 12, 32, 32], + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + default_size=224, + **kwargs, + ) + + +class DenseNet201(DenseNet): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + name: str = "DenseNet201", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + 32, + [6, 12, 48, 32], + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + default_size=224, + **kwargs, + ) + + +add_model_to_registry(DenseNet121, True) +add_model_to_registry(DenseNet161, True) +add_model_to_registry(DenseNet169, True) +add_model_to_registry(DenseNet201, True) diff --git a/kimm/models/densenet_test.py b/kimm/models/densenet_test.py new file mode 100644 index 0000000..95c50b3 --- /dev/null +++ b/kimm/models/densenet_test.py @@ -0,0 +1,52 @@ +import pytest +from absl.testing import parameterized +from keras import models +from keras import random +from keras.src import testing + +from kimm.models.densenet import DenseNet121 + + +class DenseNetTest(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters([(DenseNet121.__name__, DenseNet121)]) + def test_densenet_base(self, model_class): + # TODO: test the correctness of the real image + x = random.uniform([1, 224, 224, 3]) * 255.0 + model = model_class(input_shape=[224, 224, 3]) + + y = model(x, training=False) + + self.assertEqual(y.shape, (1, 1000)) + + @parameterized.named_parameters([(DenseNet121.__name__, DenseNet121)]) + def test_densenet_feature_extractor(self, model_class): + x = random.uniform([1, 224, 224, 3]) * 255.0 + model = model_class( + input_shape=[224, 224, 3], as_feature_extractor=True + ) + + y = model(x, training=False) + + self.assertIsInstance(y, dict) + self.assertAllEqual( + list(y.keys()), model_class.available_feature_keys() + ) + self.assertEqual(list(y["STEM_S4"].shape), [1, 56, 56, 64]) + self.assertEqual(list(y["BLOCK0_S8"].shape), [1, 28, 28, 128]) + self.assertEqual(list(y["BLOCK1_S16"].shape), [1, 14, 14, 256]) + self.assertEqual(list(y["BLOCK2_S32"].shape), [1, 7, 7, 512]) + self.assertEqual(list(y["BLOCK3_S32"].shape), [1, 7, 7, 1024]) + + @pytest.mark.serialization + @parameterized.named_parameters([(DenseNet121.__name__, DenseNet121, 224)]) + def test_densenet_serialization(self, model_class, image_size): + x = random.uniform([1, image_size, image_size, 3]) * 255.0 + temp_dir = self.get_temp_dir() + model1 = model_class(input_shape=[224, 224, 3]) + y1 = model1(x, training=False) + model1.save(temp_dir + "/model.keras") + + model2 = models.load_model(temp_dir + "/model.keras") + y2 = model2(x, training=False) + + self.assertAllClose(y1, y2) diff --git a/kimm/models/efficientnet.py b/kimm/models/efficientnet.py index 562c71b..bc5c7dc 100644 --- a/kimm/models/efficientnet.py +++ b/kimm/models/efficientnet.py @@ -2,16 +2,13 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_depthwise_separation_block from kimm.blocks import apply_inverted_residual_block -from kimm.blocks import apply_se_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -90,7 +87,6 @@ def apply_edge_residual_block( pointwise_kernel_size=1, strides=1, expansion_ratio=1.0, - se_ratio=0.0, activation="swish", bn_epsilon=1e-5, padding=None, @@ -112,16 +108,6 @@ def apply_edge_residual_block( padding=padding, name=f"{name}_conv_exp", ) - # Squeeze-and-excitation - if se_ratio > 0: - x = apply_se_block( - x, - se_ratio, - activation=activation, - gate_activation="sigmoid", - se_input_channels=input_channels, - name=f"{name}_se", - ) # Point-wise linear projection x = apply_conv2d_block( x, @@ -138,7 +124,7 @@ def apply_edge_residual_block( return x -class EfficientNet(FeatureExtractor): +class EfficientNet(BaseModel): def __init__( self, width: float = 1.0, @@ -148,15 +134,6 @@ def __init__( fix_stem_and_head_channels: bool = False, fix_first_and_last_blocks: bool = False, activation="swish", - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet config: typing.Union[str, typing.List] = "v1", **kwargs, ): @@ -189,7 +166,6 @@ def __init__( f"Received: config={config}" ) # TF default config - default_size = kwargs.pop("default_size", 224) bn_epsilon = kwargs.pop("bn_epsilon", 1e-5) padding = kwargs.pop("padding", None) # EfficientNetV2Base config @@ -197,35 +173,19 @@ def __init__( # TinyNet config round_fn = kwargs.pop("round_fn", math.ceil) - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=default_size, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} # Stem block stem_channel = ( @@ -258,51 +218,25 @@ def __init__( r = int(round_fn(r * depth)) for current_layer_idx in range(r): s = s if current_layer_idx == 0 else 1 - common_kwargs = { + _kwargs = { "bn_epsilon": bn_epsilon, "padding": padding, "name": f"blocks_{current_block_idx}_{current_layer_idx}", + "activation": activation, } if block_type == "ds": x = apply_depthwise_separation_block( - x, - c, - k, - 1, - s, - se, - activation=activation, - se_activation=activation, - **common_kwargs, + x, c, k, 1, s, se, se_activation=activation, **_kwargs ) elif block_type == "ir": + se_c = x.shape[-1] x = apply_inverted_residual_block( - x, - c, - k, - 1, - 1, - s, - e, - se, - activation, - se_input_channels=x.shape[-1], - **common_kwargs, + x, c, k, 1, 1, s, e, se, se_channels=se_c, **_kwargs ) elif block_type == "cn": - x = apply_conv2d_block( - x, - filters=c, - kernel_size=k, - strides=s, - activation=activation, - add_skip=True, - **common_kwargs, - ) + x = apply_conv2d_block(x, c, k, s, add_skip=True, **_kwargs) elif block_type == "er": - x = apply_edge_residual_block( - x, c, k, 1, s, e, se, activation, **common_kwargs - ) + x = apply_edge_residual_block(x, c, k, 1, s, e, **_kwargs) current_stride *= s features[f"BLOCK{current_block_idx}_S{current_stride}"] = x @@ -322,28 +256,30 @@ def __init__( ) # Head - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.width = width self.depth = depth self.stem_channels = stem_channels @@ -351,13 +287,6 @@ def __init__( self.fix_stem_and_head_channels = fix_stem_and_head_channels self.fix_first_and_last_blocks = fix_first_and_last_blocks self.activation = activation - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally self.config = config @staticmethod @@ -384,14 +313,6 @@ def get_config(self): "fix_stem_and_head_channels": self.fix_stem_and_head_channels, "fix_first_and_last_blocks": self.fix_first_and_last_blocks, "activation": self.activation, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, "config": self.config, } ) @@ -443,16 +364,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=224, bn_epsilon=1e-3, @@ -487,16 +408,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=240, bn_epsilon=1e-3, @@ -531,16 +452,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=260, bn_epsilon=1e-3, @@ -575,16 +496,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=300, bn_epsilon=1e-3, @@ -619,16 +540,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=380, bn_epsilon=1e-3, @@ -663,16 +584,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=456, bn_epsilon=1e-3, @@ -707,16 +628,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=528, bn_epsilon=1e-3, @@ -751,16 +672,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=600, bn_epsilon=1e-3, @@ -795,16 +716,16 @@ def __init__( True, True, "relu6", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=224, bn_epsilon=1e-3, @@ -839,16 +760,16 @@ def __init__( True, True, "relu6", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=240, bn_epsilon=1e-3, @@ -883,16 +804,16 @@ def __init__( True, True, "relu6", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=260, bn_epsilon=1e-3, @@ -927,16 +848,16 @@ def __init__( True, True, "relu6", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=300, bn_epsilon=1e-3, @@ -971,16 +892,16 @@ def __init__( True, True, "relu6", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=380, bn_epsilon=1e-3, @@ -1015,16 +936,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=300, bn_epsilon=1e-3, @@ -1067,16 +988,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=384, bn_epsilon=1e-3, @@ -1111,16 +1032,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=384, bn_epsilon=1e-3, @@ -1155,16 +1076,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=384, bn_epsilon=1e-3, @@ -1199,16 +1120,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=192, bn_epsilon=1e-3, @@ -1251,16 +1172,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=192, bn_epsilon=1e-3, @@ -1303,16 +1224,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=208, bn_epsilon=1e-3, @@ -1356,16 +1277,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=240, bn_epsilon=1e-3, @@ -1408,16 +1329,16 @@ def __init__( False, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=192, round_fn=round, # tinynet config @@ -1450,16 +1371,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=192, round_fn=round, # tinynet config @@ -1492,16 +1413,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=188, round_fn=round, # tinynet config @@ -1534,16 +1455,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=152, round_fn=round, # tinynet config @@ -1576,16 +1497,16 @@ def __init__( True, False, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, default_size=106, round_fn=round, # tinynet config diff --git a/kimm/models/efficientnet_test.py b/kimm/models/efficientnet_test.py index 10837a9..570c4ed 100644 --- a/kimm/models/efficientnet_test.py +++ b/kimm/models/efficientnet_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -123,6 +124,7 @@ def test_efficentnet_v2_feature_extractor(self, model_class, width): [1, 7, 7, make_divisible(192 * width)], ) + @pytest.mark.serialization @parameterized.named_parameters( [ (EfficientNetB0.__name__, EfficientNetB0, 224), diff --git a/kimm/models/feature_extractor.py b/kimm/models/feature_extractor.py deleted file mode 100644 index 827a5d6..0000000 --- a/kimm/models/feature_extractor.py +++ /dev/null @@ -1,54 +0,0 @@ -import abc -import typing - -from keras import KerasTensor -from keras import models - - -class FeatureExtractor(models.Model): - @staticmethod - @abc.abstractmethod - def available_feature_keys(): - return [] - - def __init__( - self, - inputs, - outputs, - features: typing.Optional[typing.Dict[str, KerasTensor]] = None, - feature_keys: typing.Optional[typing.List[str]] = None, - **kwargs, - ): - self.as_feature_extractor = kwargs.pop("as_feature_extractor", False) - self.feature_keys = feature_keys - if self.as_feature_extractor: - if features is None: - raise ValueError( - "`features` must be set when " - f"`as_feature_extractor=True`. Got features={features}" - ) - if self.feature_keys is None: - self.feature_keys = list(features.keys()) - filtered_features = {} - for k in self.feature_keys: - if k not in features: - raise KeyError( - f"'{k}' is not a key of `features`. Available keys " - f"are: {list(features.keys())}" - ) - filtered_features[k] = features[k] - super().__init__(inputs=inputs, outputs=filtered_features, **kwargs) - else: - del features - super().__init__(inputs=inputs, outputs=outputs, **kwargs) - - def get_config(self): - # Don't chain to super here. The default `get_config()` for functional - # models is nested and cannot be passed to FeatureExtractor. - config = { - "name": self.name, - "trainable": self.trainable, - "as_feature_extractor": self.as_feature_extractor, - "feature_keys": self.feature_keys, - } - return config diff --git a/kimm/models/ghostnet.py b/kimm/models/ghostnet.py index 072b23e..4b0b89b 100644 --- a/kimm/models/ghostnet.py +++ b/kimm/models/ghostnet.py @@ -1,15 +1,13 @@ import typing import keras -from keras import backend from keras import layers from keras import ops from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_se_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -231,19 +229,10 @@ def apply_ghost_bottleneck( return out -class GhostNet(FeatureExtractor): +class GhostNet(BaseModel): def __init__( self, width: float = 1.0, - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.2, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet config: typing.Union[str, typing.List] = "default", version: str = "v1", **kwargs, @@ -262,35 +251,21 @@ def __init__( f"Received version={version}" ) - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=224, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], + require_flatten=parsed_kwargs["include_top"], + static_shape=True if version == "v2" else False, ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} # stem stem_channels = make_divisible(16 * width, 4) @@ -333,42 +308,49 @@ def __init__( name=f"blocks_{current_block_idx+1}", ) - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)(x) - x = layers.Conv2D(1280, 1, 1, use_bias=True, name="conv_head")(x) - x = layers.ReLU(name="conv_head_relu")(x) - x = layers.Flatten()(x) - x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.width = width - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally self.config = config self.version = version + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)( + inputs + ) + x = layers.Conv2D( + 1280, 1, 1, use_bias=True, activation="relu", name="conv_head" + )(x) + x = layers.Flatten()(x) + x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + @staticmethod def available_feature_keys(): feature_keys = ["STEM_S2"] @@ -385,14 +367,6 @@ def get_config(self): config.update( { "width": self.width, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, "config": self.config, "version": self.version, } @@ -400,7 +374,7 @@ def get_config(self): return config def fix_config(self, config): - unused_kwargs = ["width", "version"] + unused_kwargs = ["width", "config", "version"] for k in unused_kwargs: config.pop(k, None) return config @@ -430,17 +404,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 0.5, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v1", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -465,17 +439,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 1.0, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v1", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -500,17 +474,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 1.3, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v1", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -535,17 +509,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 1.0, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v2", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -570,17 +544,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 1.3, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v2", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -605,17 +579,17 @@ def __init__( kwargs = self.fix_config(kwargs) super().__init__( 1.6, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, "v2", + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/ghostnet_test.py b/kimm/models/ghostnet_test.py index c800c8e..6b3881f 100644 --- a/kimm/models/ghostnet_test.py +++ b/kimm/models/ghostnet_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -62,6 +63,7 @@ def test_ghostnetv2_feature_extractor(self, model_class): self.assertEqual(list(y["BLOCK5_S16"].shape), [1, 14, 14, 80]) self.assertEqual(list(y["BLOCK7_S32"].shape), [1, 7, 7, 160]) + @pytest.mark.serialization @parameterized.named_parameters( [ (GhostNet100.__name__, GhostNet100, 224), diff --git a/kimm/models/inception_v3.py b/kimm/models/inception_v3.py new file mode 100644 index 0000000..b37428d --- /dev/null +++ b/kimm/models/inception_v3.py @@ -0,0 +1,336 @@ +import functools +import typing + +import keras +from keras import layers +from keras import utils + +from kimm.blocks import apply_conv2d_block +from kimm.models import BaseModel +from kimm.utils import add_model_to_registry + +_apply_conv2d_block = functools.partial( + apply_conv2d_block, activation="relu", bn_epsilon=1e-3, padding="valid" +) + + +def apply_inception_a_block(inputs, pool_channels, name="inception_a_block"): + x = inputs + + branch1x1 = _apply_conv2d_block(x, 64, 1, 1, name=f"{name}_branch1x1") + + branch5x5 = _apply_conv2d_block(x, 48, 1, 1, name=f"{name}_branch5x5_1") + branch5x5 = _apply_conv2d_block( + branch5x5, 64, 5, 1, padding=None, name=f"{name}_branch5x5_2" + ) + + branch3x3dbl = _apply_conv2d_block( + x, 64, 1, 1, name=f"{name}_branch3x3dbl_1" + ) + branch3x3dbl = _apply_conv2d_block( + branch3x3dbl, 96, 3, 1, padding=None, name=f"{name}_branch3x3dbl_2" + ) + branch3x3dbl = _apply_conv2d_block( + branch3x3dbl, 96, 3, 1, padding=None, name=f"{name}_branch3x3dbl_3" + ) + + branch_pool = layers.ZeroPadding2D(1)(x) + branch_pool = layers.AveragePooling2D(3, 1)(branch_pool) + branch_pool = _apply_conv2d_block( + branch_pool, + pool_channels, + 1, + 1, + activation="relu", + name=f"{name}_branch_pool", + ) + x = layers.Concatenate()([branch1x1, branch5x5, branch3x3dbl, branch_pool]) + return x + + +def apply_inception_b_block(inputs, name="incpetion_b_block"): + x = inputs + + branch3x3 = _apply_conv2d_block(x, 384, 3, 2, name=f"{name}_branch3x3") + + branch3x3dbl = _apply_conv2d_block( + x, 64, 1, 1, name=f"{name}_branch3x3dbl_1" + ) + branch3x3dbl = _apply_conv2d_block( + branch3x3dbl, 96, 3, 1, padding=None, name=f"{name}_branch3x3dbl_2" + ) + branch3x3dbl = _apply_conv2d_block( + branch3x3dbl, 96, 3, 2, name=f"{name}_branch3x3dbl_3" + ) + + branch_pool = layers.MaxPooling2D(3, 2, name=f"{name}_branch_pool")(x) + x = layers.Concatenate()([branch3x3, branch3x3dbl, branch_pool]) + return x + + +def apply_inception_c_block( + inputs, branch7x7_channels, name="inception_c_block" +): + c7 = branch7x7_channels + x = inputs + + branch1x1 = _apply_conv2d_block(x, 192, 1, 1, name=f"{name}_branch1x1") + + branch7x7 = _apply_conv2d_block(x, c7, 1, 1, name=f"{name}_branch7x7_1") + branch7x7 = _apply_conv2d_block( + branch7x7, c7, (1, 7), 1, padding=None, name=f"{name}_branch7x7_2" + ) + branch7x7 = _apply_conv2d_block( + branch7x7, 192, (7, 1), 1, padding=None, name=f"{name}_branch7x7_3" + ) + + branch7x7dbl = _apply_conv2d_block( + x, c7, 1, 1, name=f"{name}_branch7x7dbl_1" + ) + branch7x7dbl = _apply_conv2d_block( + branch7x7dbl, c7, (7, 1), 1, padding=None, name=f"{name}_branch7x7dbl_2" + ) + branch7x7dbl = _apply_conv2d_block( + branch7x7dbl, c7, (1, 7), 1, padding=None, name=f"{name}_branch7x7dbl_3" + ) + branch7x7dbl = _apply_conv2d_block( + branch7x7dbl, c7, (7, 1), 1, padding=None, name=f"{name}_branch7x7dbl_4" + ) + branch7x7dbl = _apply_conv2d_block( + branch7x7dbl, + 192, + (1, 7), + 1, + padding=None, + name=f"{name}_branch7x7dbl_5", + ) + + branch_pool = layers.ZeroPadding2D(1)(x) + branch_pool = layers.AveragePooling2D(3, 1)(branch_pool) + branch_pool = _apply_conv2d_block( + branch_pool, 192, 1, 1, name=f"{name}_branch_pool" + ) + x = layers.Concatenate()([branch1x1, branch7x7, branch7x7dbl, branch_pool]) + return x + + +def apply_inception_d_block(inputs, name="inception_d_block"): + x = inputs + + branch3x3 = _apply_conv2d_block(x, 192, 1, 1, name=f"{name}_branch3x3_1") + branch3x3 = _apply_conv2d_block( + branch3x3, 320, 3, 2, name=f"{name}_branch3x3_2" + ) + + branch7x7x3 = _apply_conv2d_block( + x, 192, 1, 1, name=f"{name}_branch7x7x3_1" + ) + branch7x7x3 = _apply_conv2d_block( + branch7x7x3, 192, (1, 7), 1, padding=None, name=f"{name}_branch7x7x3_2" + ) + branch7x7x3 = _apply_conv2d_block( + branch7x7x3, 192, (7, 1), 1, padding=None, name=f"{name}_branch7x7x3_3" + ) + branch7x7x3 = _apply_conv2d_block( + branch7x7x3, 192, 3, 2, name=f"{name}_branch7x7x3_4" + ) + + branch_pool = layers.MaxPooling2D(3, 2)(x) + x = layers.Concatenate()([branch3x3, branch7x7x3, branch_pool]) + return x + + +def apply_inception_e_block(inputs, name="inception_e_block"): + x = inputs + + branch1x1 = _apply_conv2d_block(x, 320, 1, 1, name=f"{name}_branch1x1") + + branch3x3 = _apply_conv2d_block(x, 384, 1, 1, name=f"{name}_branch3x3_1") + branch3x3 = [ + _apply_conv2d_block( + branch3x3, 384, (1, 3), 1, padding=None, name=f"{name}_branch3x3_2a" + ), + _apply_conv2d_block( + branch3x3, 384, (3, 1), 1, padding=None, name=f"{name}_branch3x3_2b" + ), + ] + branch3x3 = layers.Concatenate()(branch3x3) + + branch3x3dbl = _apply_conv2d_block( + x, 448, 1, 1, name=f"{name}_branch3x3dbl_1" + ) + branch3x3dbl = _apply_conv2d_block( + branch3x3dbl, 384, 3, 1, padding=None, name=f"{name}_branch3x3dbl_2" + ) + branch3x3dbl = [ + _apply_conv2d_block( + branch3x3dbl, + 384, + (1, 3), + 1, + padding=None, + name=f"{name}_branch3x3dbl_3a", + ), + _apply_conv2d_block( + branch3x3dbl, + 384, + (3, 1), + 1, + padding=None, + name=f"{name}_branch3x3dbl_3b", + ), + ] + branch3x3dbl = layers.Concatenate()(branch3x3dbl) + + branch_pool = layers.ZeroPadding2D(1)(x) + branch_pool = layers.AveragePooling2D(3, 1)(branch_pool) + branch_pool = _apply_conv2d_block( + branch_pool, 192, 1, 1, name=f"{name}_branch_pool" + ) + x = layers.Concatenate()([branch1x1, branch3x3, branch3x3dbl, branch_pool]) + return x + + +def apply_inception_aux_block(inputs, classes, name="inception_aux_block"): + x = inputs + + x = layers.AveragePooling2D(5, 3)(x) + x = _apply_conv2d_block(x, 128, 1, 1, name=f"{name}_conv0") + x = _apply_conv2d_block(x, 768, 5, 1, name=f"{name}_conv1") + x = layers.GlobalAveragePooling2D()(x) + x = layers.Dense(classes, use_bias=True, name=f"{name}_fc")(x) + return x + + +class InceptionV3Base(BaseModel): + def __init__(self, has_aux_logits=False, **kwargs): + parsed_kwargs = self.parse_kwargs(kwargs, default_size=299) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], + require_flatten=parsed_kwargs["include_top"], + ) + x = img_input + + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} + + # Stem block + x = _apply_conv2d_block(x, 32, 3, 2, name="Conv2d_1a_3x3") + x = _apply_conv2d_block(x, 32, 3, 1, name="Conv2d_2a_3x3") + x = _apply_conv2d_block(x, 64, 3, 1, padding=None, name="Conv2d_2b_3x3") + features["STEM_S2"] = x + + # Blocks + x = layers.MaxPooling2D(3, 2, name="Pool1")(x) + x = _apply_conv2d_block(x, 80, 1, 1, name="Conv2d_3b_1x1") + x = _apply_conv2d_block(x, 192, 3, 1, name="Conv2d_4a_3x3") + features["BLOCK0_S4"] = x + x = layers.MaxPooling2D(3, 2, name="Pool2")(x) + x = apply_inception_a_block(x, 32, "Mixed_5b") + x = apply_inception_a_block(x, 64, "Mixed_5c") + x = apply_inception_a_block(x, 64, "Mixed_5d") + features["BLOCK1_S8"] = x + + x = apply_inception_b_block(x, "Mixed_6a") + + x = apply_inception_c_block(x, 128, "Mixed_6b") + x = apply_inception_c_block(x, 160, "Mixed_6c") + x = apply_inception_c_block(x, 160, "Mixed_6d") + x = apply_inception_c_block(x, 192, "Mixed_6e") + features["BLOCK2_S16"] = x + + if has_aux_logits: + aux_logits = apply_inception_aux_block( + x, parsed_kwargs["classes"], "AuxLogits" + ) + + x = apply_inception_d_block(x, "Mixed_7a") + x = apply_inception_e_block(x, "Mixed_7b") + x = apply_inception_e_block(x, "Mixed_7c") + features["BLOCK3_S32"] = x + + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) + else: + if parsed_kwargs["pooling"] == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif parsed_kwargs["pooling"] == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) + else: + inputs = img_input + + if has_aux_logits: + x = [x, aux_logits] + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.add_references(parsed_kwargs) + self.has_aux_logits = has_aux_logits + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S2"] + feature_keys.extend( + [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])] + ) + return feature_keys + + def get_config(self): + config = super().get_config() + config.update({"has_aux_logits": self.has_aux_logits}) + return config + + def fix_config(self, config: typing.Dict): + return config + + +class InceptionV3(InceptionV3Base): + def __init__( + self, + has_aux_logits: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + name: str = "InceptionV3", + **kwargs, + ): + kwargs = self.fix_config(kwargs) + super().__init__( + has_aux_logits, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, + name=name, + **kwargs, + ) + + +add_model_to_registry(InceptionV3, True) diff --git a/kimm/models/inception_v3_test.py b/kimm/models/inception_v3_test.py new file mode 100644 index 0000000..af13e2f --- /dev/null +++ b/kimm/models/inception_v3_test.py @@ -0,0 +1,50 @@ +import pytest +from absl.testing import parameterized +from keras import models +from keras import random +from keras.src import testing + +from kimm.models.inception_v3 import InceptionV3 + + +class InceptionV3Test(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters([(InceptionV3.__name__, InceptionV3)]) + def test_inception_v3_base(self, model_class): + # TODO: test the correctness of the real image + x = random.uniform([1, 299, 299, 3]) * 255.0 + model = model_class() + + y = model(x, training=False) + + self.assertEqual(y.shape, (1, 1000)) + + @parameterized.named_parameters([(InceptionV3.__name__, InceptionV3)]) + def test_inception_v3_feature_extractor(self, model_class): + x = random.uniform([1, 299, 299, 3]) * 255.0 + model = model_class(as_feature_extractor=True) + + y = model(x, training=False) + + self.assertIsInstance(y, dict) + self.assertAllEqual( + list(y.keys()), model_class.available_feature_keys() + ) + self.assertEqual(list(y["STEM_S2"].shape), [1, 147, 147, 64]) + self.assertEqual(list(y["BLOCK0_S4"].shape), [1, 71, 71, 192]) + self.assertEqual(list(y["BLOCK1_S8"].shape), [1, 35, 35, 288]) + self.assertEqual(list(y["BLOCK2_S16"].shape), [1, 17, 17, 768]) + self.assertEqual(list(y["BLOCK3_S32"].shape), [1, 8, 8, 2048]) + + @pytest.mark.serialization + @parameterized.named_parameters([(InceptionV3.__name__, InceptionV3, 299)]) + def test_inception_v3_serialization(self, model_class, image_size): + x = random.uniform([1, image_size, image_size, 3]) * 255.0 + temp_dir = self.get_temp_dir() + model1 = model_class() + y1 = model1(x, training=False) + model1.save(temp_dir + "/model.keras") + + model2 = models.load_model(temp_dir + "/model.keras") + y2 = model2(x, training=False) + + self.assertAllClose(y1, y2) diff --git a/kimm/models/mobilenet_v2.py b/kimm/models/mobilenet_v2.py index 3da5f1c..a9ee436 100644 --- a/kimm/models/mobilenet_v2.py +++ b/kimm/models/mobilenet_v2.py @@ -2,15 +2,13 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_depthwise_separation_block from kimm.blocks import apply_inverted_residual_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -26,21 +24,12 @@ ] -class MobileNetV2(FeatureExtractor): +class MobileNetV2(BaseModel): def __init__( self, width: float = 1.0, depth: float = 1.0, fix_stem_and_head_channels: bool = False, - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet config: typing.Union[str, typing.List] = "default", **kwargs, ): @@ -53,35 +42,19 @@ def __init__( f"Received: config={config}" ) - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=224, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} # stem stem_channel = ( @@ -110,13 +83,7 @@ def __init__( name = f"blocks_{current_block_idx}_{current_layer_idx}" if block_type == "ds": x = apply_depthwise_separation_block( - x, - c, - k, - 1, - s, - activation="relu6", - name=name, + x, c, k, 1, s, activation="relu6", name=name ) elif block_type == "ir": x = apply_inverted_residual_block( @@ -134,38 +101,34 @@ def __init__( x, head_channels, 1, 1, activation="relu6", name="conv_head" ) - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.width = width self.depth = depth self.fix_stem_and_head_channels = fix_stem_and_head_channels - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally self.config = config @staticmethod @@ -186,21 +149,18 @@ def get_config(self): "width": self.width, "depth": self.depth, "fix_stem_and_head_channels": self.fix_stem_and_head_channels, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, "config": self.config, } ) return config def fix_config(self, config): - unused_kwargs = ["width", "depth", "fix_stem_and_head_channels"] + unused_kwargs = [ + "width", + "depth", + "fix_stem_and_head_channels", + "config", + ] for k in unused_kwargs: config.pop(k, None) return config @@ -232,16 +192,16 @@ def __init__( 0.5, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -268,16 +228,16 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -304,16 +264,16 @@ def __init__( 1.1, 1.2, True, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -340,16 +300,16 @@ def __init__( 1.2, 1.4, True, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -376,16 +336,16 @@ def __init__( 1.4, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/mobilenet_v2_test.py b/kimm/models/mobilenet_v2_test.py index 859f7e9..05454c0 100644 --- a/kimm/models/mobilenet_v2_test.py +++ b/kimm/models/mobilenet_v2_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -56,6 +57,7 @@ def test_mobilenet_v2_feature_extractor(self, model_class, width): list(y["BLOCK5_S32"].shape), [1, 7, 7, make_divisible(160 * width)] ) + @pytest.mark.serialization @parameterized.named_parameters( [(MobileNet050V2.__name__, MobileNet050V2, 224)] ) diff --git a/kimm/models/mobilenet_v3.py b/kimm/models/mobilenet_v3.py index dfbef04..b55759d 100644 --- a/kimm/models/mobilenet_v3.py +++ b/kimm/models/mobilenet_v3.py @@ -2,15 +2,13 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_depthwise_separation_block from kimm.blocks import apply_inverted_residual_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -82,21 +80,12 @@ ] -class MobileNetV3(FeatureExtractor): +class MobileNetV3(BaseModel): def __init__( self, width: float = 1.0, depth: float = 1.0, fix_stem_and_head_channels: bool = False, - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet config: typing.Union[str, typing.List] = "large", minimal: bool = False, **kwargs, @@ -128,35 +117,19 @@ def __init__( bn_epsilon = kwargs.pop("bn_epsilon", 1e-5) padding = kwargs.pop("padding", None) - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=224, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} # stem stem_channel = ( @@ -194,7 +167,7 @@ def __init__( r = int(math.ceil(r * depth)) for current_layer_idx in range(r): s = s if current_layer_idx == 0 else 1 - common_kwargs = { + _kwargs = { "bn_epsilon": bn_epsilon, "padding": padding, "name": ( @@ -216,7 +189,7 @@ def __init__( se_make_divisible_number=8, pw_activation=act if block_type == "dsa" else None, skip=False if block_type == "dsa" else True, - **common_kwargs, + **_kwargs, ) elif block_type == "ir": x = apply_inverted_residual_block( @@ -229,26 +202,20 @@ def __init__( e, se, act, - se_input_channels=None, se_activation="relu", se_gate_activation="hard_sigmoid", se_make_divisible_number=8, - **common_kwargs, + **_kwargs, ) elif block_type == "cn": x = apply_conv2d_block( - x, - filters=c, - kernel_size=k, - strides=s, - activation=act, - **common_kwargs, + x, c, k, s, activation=act, **_kwargs ) current_stride *= s features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)(x) + # Head + if parsed_kwargs["include_top"]: if fix_stem_and_head_channels: conv_head_channels = conv_head_channels else: @@ -256,46 +223,61 @@ def __init__( conv_head_channels, make_divisible(conv_head_channels * width), ) - x = layers.Conv2D( - conv_head_channels, 1, 1, use_bias=True, name="conv_head" - )(x) - x = layers.Activation( - force_activation or "hard_swish", name="act2" - )(x) - x = layers.Flatten()(x) - x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="classifier" - )(x) + head_activation = force_activation or "hard_swish" + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + conv_head_channels=conv_head_channels, + head_activation=head_activation, + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.width = width self.depth = depth self.fix_stem_and_head_channels = fix_stem_and_head_channels - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally self.config = config self.minimal = minimal + def build_top( + self, + inputs, + classes, + classifier_activation, + dropout_rate, + conv_head_channels, + head_activation, + ): + x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)( + inputs + ) + x = layers.Conv2D( + conv_head_channels, 1, 1, use_bias=True, name="conv_head" + )(x) + x = layers.Activation(head_activation, name="act2")(x) + x = layers.Flatten()(x) + x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + return x + @staticmethod def available_feature_keys(): raise NotImplementedError() @@ -307,14 +289,6 @@ def get_config(self): "width": self.width, "depth": self.depth, "fix_stem_and_head_channels": self.fix_stem_and_head_channels, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, "config": self.config, "minimal": self.minimal, } @@ -326,6 +300,7 @@ def fix_config(self, config): "width", "depth", "fix_stem_and_head_channels", + "config", "minimal", ] for k in unused_kwargs: @@ -359,16 +334,16 @@ def __init__( 0.5, 1.0, True, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -403,16 +378,16 @@ def __init__( 0.75, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -447,16 +422,16 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -492,17 +467,17 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, - minimal=True, + True, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, bn_epsilon=1e-3, padding="same", @@ -539,16 +514,16 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -587,17 +562,17 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, - minimal=True, + True, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, bn_epsilon=1e-3, padding="same", @@ -638,16 +613,16 @@ def __init__( 0.35, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -683,16 +658,16 @@ def __init__( 0.5, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -728,16 +703,16 @@ def __init__( 0.75, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -773,16 +748,16 @@ def __init__( 1.0, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -818,16 +793,16 @@ def __init__( 1.5, 1.0, False, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/mobilenet_v3_test.py b/kimm/models/mobilenet_v3_test.py index fea61a5..f7a31a9 100644 --- a/kimm/models/mobilenet_v3_test.py +++ b/kimm/models/mobilenet_v3_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -6,6 +7,7 @@ from kimm.models.mobilenet_v3 import LCNet100 from kimm.models.mobilenet_v3 import MobileNet100V3Large from kimm.models.mobilenet_v3 import MobileNet100V3Small +from kimm.models.mobilenet_v3 import MobileNet100V3SmallMinimal from kimm.utils import make_divisible @@ -13,6 +15,7 @@ class MobileNetV3Test(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( [ (MobileNet100V3Small.__name__, MobileNet100V3Small), + (MobileNet100V3SmallMinimal.__name__, MobileNet100V3SmallMinimal), (MobileNet100V3Large.__name__, MobileNet100V3Large), (LCNet100.__name__, LCNet100), ] @@ -29,6 +32,11 @@ def test_mobilenet_v3_base(self, model_class): @parameterized.named_parameters( [ (MobileNet100V3Small.__name__, MobileNet100V3Small, 1.0), + ( + MobileNet100V3SmallMinimal.__name__, + MobileNet100V3SmallMinimal, + 1.0, + ), (MobileNet100V3Large.__name__, MobileNet100V3Large, 1.0), ] ) @@ -117,9 +125,15 @@ def test_lcnet_feature_extractor(self, model_class, width): [1, 7, 7, make_divisible(512 * width)], ) + @pytest.mark.serialization @parameterized.named_parameters( [ (MobileNet100V3Small.__name__, MobileNet100V3Small, 224), + ( + MobileNet100V3SmallMinimal.__name__, + MobileNet100V3SmallMinimal, + 224, + ), (MobileNet100V3Large.__name__, MobileNet100V3Large, 224), (LCNet100.__name__, LCNet100, 224), ] diff --git a/kimm/models/mobilevit.py b/kimm/models/mobilevit.py index 813a019..4649462 100644 --- a/kimm/models/mobilevit.py +++ b/kimm/models/mobilevit.py @@ -2,16 +2,14 @@ import typing import keras -from keras import backend from keras import layers from keras import ops from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block from kimm.blocks import apply_inverted_residual_block from kimm.blocks import apply_transformer_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry from kimm.utils import make_divisible @@ -163,21 +161,12 @@ def apply_mobilevit_block( return x -class MobileViT(FeatureExtractor): +class MobileViT(BaseModel): def __init__( self, stem_channels: int = 16, head_channels: int = 640, activation="swish", - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.1, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet config: str = "v1_s", **kwargs, ): @@ -194,35 +183,20 @@ def __init__( f"Received: config={config}" ) - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=256, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs, 256) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], + static_shape=True, ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} # stem x = apply_conv2d_block( @@ -239,7 +213,7 @@ def __init__( k, c, s, - expansion_ratio, + e, transformer_dim, transformer_depth, patch_size, @@ -249,15 +223,7 @@ def __init__( s = s if current_layer_idx == 0 else 1 name = f"stages_{current_block_idx}_{current_layer_idx}" x = apply_inverted_residual_block( - x, - c, - k, - 1, - 1, - s, - expansion_ratio, - activation=activation, - name=name, + x, c, k, 1, 1, s, e, activation=activation, name=name ) current_stride *= s if block_type == "mobilevit": @@ -281,37 +247,34 @@ def __init__( x, head_channels, 1, 1, activation=activation, name="final_conv" ) - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - x = layers.Dropout(dropout_rate, name="head_drop")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="head_fc" - )(x) + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + # All references to `self` below this line + self.add_references(parsed_kwargs) self.stem_channels = stem_channels self.head_channels = head_channels self.activation = activation - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally self.config = config @staticmethod @@ -329,21 +292,18 @@ def get_config(self): "stem_channels": self.stem_channels, "head_channels": self.head_channels, "activation": self.activation, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, "config": self.config, } ) return config def fix_config(self, config): - unused_kwargs = ["stem_channels", "head_channels", "activation"] + unused_kwargs = [ + "stem_channels", + "head_channels", + "activation", + "config", + ] for k in unused_kwargs: config.pop(k, None) return config @@ -370,16 +330,16 @@ def __init__( 16, 640, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -406,16 +366,16 @@ def __init__( 16, 384, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -442,16 +402,16 @@ def __init__( 16, 320, "swish", - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, config, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/mobilevit_test.py b/kimm/models/mobilevit_test.py index fbd2939..1738fe9 100644 --- a/kimm/models/mobilevit_test.py +++ b/kimm/models/mobilevit_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -46,6 +47,7 @@ def test_mobilevit_feature_extractor(self, model_class): self.assertEqual(list(y["BLOCK3_S16"].shape), [1, 16, 16, 80]) self.assertEqual(list(y["BLOCK4_S32"].shape), [1, 8, 8, 96]) + @pytest.mark.serialization @parameterized.named_parameters( [ (MobileViTS.__name__, MobileViTS, 256), diff --git a/kimm/models/resnet.py b/kimm/models/resnet.py index 8354412..7d882d3 100644 --- a/kimm/models/resnet.py +++ b/kimm/models/resnet.py @@ -1,13 +1,11 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm.blocks import apply_conv2d_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry @@ -106,56 +104,29 @@ def apply_bottleneck_block( return x -class ResNet(FeatureExtractor): +class ResNet(BaseModel): def __init__( - self, - block_fn: str, - num_blocks: typing.Sequence[int], - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, - dropout_rate: float = 0.0, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet - **kwargs, + self, block_fn: str, num_blocks: typing.Sequence[int], **kwargs ): if block_fn not in ("basic", "bottleneck"): raise ValueError( "`block_fn` must be one of ('basic', 'bottelneck'). " f"Received: block_fn={block_fn}" ) - # Prepare feature extraction - features = {} - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=224, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [0, 1] and apply ImageNet mean and variance - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 255.0)(x) - x = layers.Normalization( - mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] - )(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x, "imagenet") + + # Prepare feature extraction + features = {} # stem stem_channels = 64 @@ -189,38 +160,33 @@ def __init__( # add feature features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x - if include_top: - x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)(x) - x = layers.Flatten()(x) - x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="fc" - )(x) + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) else: - if pooling == "avg": + if parsed_kwargs["pooling"] == "avg": x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": + elif parsed_kwargs["pooling"] == "max": x = layers.GlobalMaxPooling2D(name="max_pool")(x) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) # All references to `self` below this line + self.add_references(parsed_kwargs) self.block_fn = block_fn self.num_blocks = num_blocks - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally @staticmethod def available_feature_keys(): @@ -233,18 +199,7 @@ def available_feature_keys(): def get_config(self): config = super().get_config() config.update( - { - "block_fn": self.block_fn, - "num_blocks": self.num_blocks, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, - } + {"block_fn": self.block_fn, "num_blocks": self.num_blocks} ) return config @@ -279,15 +234,15 @@ def __init__( super().__init__( "basic", [2, 2, 2, 2], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -312,15 +267,15 @@ def __init__( super().__init__( "basic", [3, 4, 6, 3], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -345,15 +300,15 @@ def __init__( super().__init__( "bottleneck", [3, 4, 6, 3], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -378,15 +333,15 @@ def __init__( super().__init__( "bottleneck", [3, 4, 23, 3], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -411,15 +366,15 @@ def __init__( super().__init__( "bottleneck", [3, 8, 36, 3], - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/resnet_test.py b/kimm/models/resnet_test.py index 478ef6b..47bf01d 100644 --- a/kimm/models/resnet_test.py +++ b/kimm/models/resnet_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -47,6 +48,7 @@ def test_resnet_feature_extractor(self, model_class, expansion): list(y["BLOCK3_S32"].shape), [1, 7, 7, 512 * expansion] ) + @pytest.mark.serialization @parameterized.named_parameters( [(ResNet18.__name__, ResNet18, 224), (ResNet50.__name__, ResNet50, 224)] ) diff --git a/kimm/models/vision_transformer.py b/kimm/models/vision_transformer.py index ad40fa9..0fae02c 100644 --- a/kimm/models/vision_transformer.py +++ b/kimm/models/vision_transformer.py @@ -1,18 +1,16 @@ import typing import keras -from keras import backend from keras import layers from keras import utils -from keras.src.applications import imagenet_utils from kimm import layers as kimm_layers from kimm.blocks import apply_transformer_block -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils import add_model_to_registry -class VisionTransformer(FeatureExtractor): +class VisionTransformer(BaseModel): def __init__( self, patch_size: int, @@ -22,44 +20,28 @@ def __init__( mlp_ratio: float = 4.0, use_qkv_bias: bool = True, use_qk_norm: bool = False, - input_tensor: keras.KerasTensor = None, - input_shape: typing.Optional[typing.Sequence[int]] = None, - include_preprocessing: bool = True, - include_top: bool = True, - pooling: typing.Optional[str] = None, pos_dropout_rate: float = 0.0, - dropout_rate: float = 0.1, - classes: int = 1000, - classifier_activation: str = "softmax", - weights: typing.Optional[str] = None, # TODO: imagenet **kwargs, ): - # Prepare feature extraction - features = {} - - # Determine proper input shape - input_shape = imagenet_utils.obtain_input_shape( - input_shape, - default_size=384, - min_size=32, - data_format=backend.image_data_format(), - require_flatten=include_top, - weights=weights, + parsed_kwargs = self.parse_kwargs(kwargs, 384) + if parsed_kwargs["pooling"] is not None: + raise ValueError( + "`VisionTransformer` doesn't support `pooling`. " + f"Received: pooling={parsed_kwargs['pooling']}" + ) + img_input = self.determine_input_tensor( + parsed_kwargs["input_tensor"], + parsed_kwargs["input_shape"], + parsed_kwargs["default_size"], + static_shape=True, ) - - if input_tensor is None: - img_input = layers.Input(shape=input_shape) - else: - if not backend.is_keras_tensor(input_tensor): - img_input = layers.Input(tensor=input_tensor, shape=input_shape) - else: - img_input = input_tensor - x = img_input - # [0, 255] to [-1, 1] - if include_preprocessing: - x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(x) + if parsed_kwargs["include_preprocessing"]: + x = self.build_preprocessing(x, "-1_1") + + # Prepare feature extraction + features = {} # patch embedding x = layers.Conv2D( @@ -89,27 +71,26 @@ def __init__( features[f"BLOCK{i}"] = x x = layers.LayerNormalization(epsilon=1e-6, name="norm")(x) - if include_top: - x = x[:, 0] # class token - x = layers.Dropout(dropout_rate, name="head_drop")(x) - x = layers.Dense( - classes, activation=classifier_activation, name="head" - )(x) - else: - if pooling == "avg": - x = layers.GlobalAveragePooling2D(name="avg_pool")(x) - elif pooling == "max": - x = layers.GlobalMaxPooling2D(name="max_pool")(x) + # Head + if parsed_kwargs["include_top"]: + x = self.build_top( + x, + parsed_kwargs["classes"], + parsed_kwargs["classifier_activation"], + parsed_kwargs["dropout_rate"], + ) # Ensure that the model takes into account # any potential predecessors of `input_tensor`. - if input_tensor is not None: - inputs = utils.get_source_inputs(input_tensor) + if parsed_kwargs["input_tensor"] is not None: + inputs = utils.get_source_inputs(parsed_kwargs["input_tensor"]) else: inputs = img_input super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + # All references to `self` below this line + self.add_references(parsed_kwargs) self.patch_size = patch_size self.embed_dim = embed_dim self.depth = depth @@ -117,13 +98,15 @@ def __init__( self.mlp_ratio = mlp_ratio self.use_qkv_bias = use_qkv_bias self.use_qk_norm = use_qk_norm - self.include_preprocessing = include_preprocessing - self.include_top = include_top - self.pooling = pooling - self.dropout_rate = dropout_rate - self.classes = classes - self.classifier_activation = classifier_activation - self._weights = weights # `self.weights` is been used internally + self.pos_dropout_rate = pos_dropout_rate + + def build_top(self, inputs, classes, classifier_activation, dropout_rate): + x = inputs[:, 0] # class token + x = layers.Dropout(dropout_rate, name="head_drop")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="head" + )(x) + return x @staticmethod def available_feature_keys(): @@ -140,14 +123,7 @@ def get_config(self): "mlp_ratio": self.mlp_ratio, "use_qkv_bias": self.use_qkv_bias, "use_qk_norm": self.use_qk_norm, - "input_shape": self.input_shape[1:], - "include_preprocessing": self.include_preprocessing, - "include_top": self.include_top, - "pooling": self.pooling, - "dropout_rate": self.dropout_rate, - "classes": self.classes, - "classifier_activation": self.classifier_activation, - "weights": self._weights, + "pos_dropout_rate": self.pos_dropout_rate, } ) return config @@ -161,6 +137,7 @@ def fix_config(self, config): "mlp_ratio", "use_qkv_bias", "use_qk_norm", + "pos_dropout_rate", ] for k in unused_kwargs: config.pop(k, None) @@ -200,16 +177,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -249,16 +226,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -298,16 +275,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -347,16 +324,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -396,16 +373,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -445,16 +422,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -494,16 +471,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) @@ -543,16 +520,16 @@ def __init__( mlp_ratio, use_qkv_bias, use_qk_norm, - input_tensor, - input_shape, - include_preprocessing, - include_top, - pooling, pos_dropout_rate, - dropout_rate, - classes, - classifier_activation, - weights, + input_tensor=input_tensor, + input_shape=input_shape, + include_preprocessing=include_preprocessing, + include_top=include_top, + pooling=pooling, + dropout_rate=dropout_rate, + classes=classes, + classifier_activation=classifier_activation, + weights=weights, name=name, **kwargs, ) diff --git a/kimm/models/vision_transformer_test.py b/kimm/models/vision_transformer_test.py index 7d8ba97..74bdab7 100644 --- a/kimm/models/vision_transformer_test.py +++ b/kimm/models/vision_transformer_test.py @@ -1,3 +1,4 @@ +import pytest from absl.testing import parameterized from keras import models from keras import random @@ -50,6 +51,7 @@ def test_vision_transformer_feature_extractor( elif patch_size == 32: self.assertEqual(list(y["BLOCK5"].shape), [1, 145, 192]) + @pytest.mark.serialization @parameterized.named_parameters( [ (VisionTransformerTiny16.__name__, VisionTransformerTiny16, 384), diff --git a/kimm/utils/model_registry.py b/kimm/utils/model_registry.py index 36ba7d5..15a2517 100644 --- a/kimm/utils/model_registry.py +++ b/kimm/utils/model_registry.py @@ -32,11 +32,11 @@ def clear_registry(): def add_model_to_registry(model_cls, has_pretrained=False): - from kimm.models.feature_extractor import FeatureExtractor + from kimm.models.base_model import BaseModel support_feature = False available_feature_keys = [] - if issubclass(model_cls, FeatureExtractor): + if issubclass(model_cls, BaseModel): support_feature = True available_feature_keys = model_cls.available_feature_keys() for info in MODEL_REGISTRY: diff --git a/kimm/utils/model_registry_test.py b/kimm/utils/model_registry_test.py index 73c5961..f979811 100644 --- a/kimm/utils/model_registry_test.py +++ b/kimm/utils/model_registry_test.py @@ -1,7 +1,7 @@ from keras import models from keras.src import testing -from kimm.models.feature_extractor import FeatureExtractor +from kimm.models.base_model import BaseModel from kimm.utils.model_registry import MODEL_REGISTRY from kimm.utils.model_registry import add_model_to_registry from kimm.utils.model_registry import clear_registry @@ -12,7 +12,7 @@ class DummyModel(models.Model): pass -class DummyFeatureExtractor(FeatureExtractor): +class DummyFeatureExtractor(BaseModel): @staticmethod def available_feature_keys(): return ["A", "B", "C"] diff --git a/shell/export.sh b/shell/export.sh new file mode 100755 index 0000000..27c209a --- /dev/null +++ b/shell/export.sh @@ -0,0 +1,11 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES= +export TF_CPP_MIN_LOG_LEVEL=3 +python3 -m tools.convert_densenet_from_timm && +python3 -m tools.convert_efficientnet_from_timm && +python3 -m tools.convert_ghostnet_from_timm && +python3 -m tools.convert_inception_v3_from_timm && +python3 -m tools.convert_mobilenet_v2_from_timm && +python3 -m tools.convert_mobilenet_v3_from_timm && +python3 -m tools.convert_mobilevit_from_timm && +echo "Export finished successfully!" diff --git a/tools/convert_densenet_from_timm.py b/tools/convert_densenet_from_timm.py new file mode 100644 index 0000000..5d882a7 --- /dev/null +++ b/tools/convert_densenet_from_timm.py @@ -0,0 +1,123 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" +import os + +import keras +import numpy as np +import timm +import torch + +from kimm.models import densenet +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "densenet121.ra_in1k", + "densenet161.tv_in1k", + "densenet169.tv_in1k", + "densenet201.tv_in1k", +] +keras_model_classes = [ + densenet.DenseNet121, + densenet.DenseNet161, + densenet.DenseNet169, + densenet.DenseNet201, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [224, 224, 3] + torch_model = timm.create_model(timm_model_name, pretrained=True) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # stem + torch_name = torch_name.replace("conv0.conv2d", "conv0") + torch_name = torch_name.replace("conv0.bn", "norm0") + # blocks + torch_name = torch_name.replace("conv1.conv2d", "conv1") + torch_name = torch_name.replace("conv1.bn", "norm2") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}") diff --git a/tools/convert_efficientnet_from_timm.py b/tools/convert_efficientnet_from_timm.py index ce014ac..3af0f24 100644 --- a/tools/convert_efficientnet_from_timm.py +++ b/tools/convert_efficientnet_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -188,6 +190,7 @@ """ Save converted model """ - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_ghostnet_from_timm.py b/tools/convert_ghostnet_from_timm.py index 05aaddf..7ffd688 100644 --- a/tools/convert_ghostnet_from_timm.py +++ b/tools/convert_ghostnet_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -146,6 +148,7 @@ """ Save converted model """ - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_inception_v3_from_timm.py b/tools/convert_inception_v3_from_timm.py new file mode 100644 index 0000000..f73d277 --- /dev/null +++ b/tools/convert_inception_v3_from_timm.py @@ -0,0 +1,150 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" +import os + +import keras +import numpy as np +import timm +import torch + +from kimm.models import inception_v3 +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "inception_v3.gluon_in1k", +] +keras_model_classes = [ + inception_v3.InceptionV3, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [299, 299, 3] + torch_model = timm.create_model( + timm_model_name, pretrained=True, aux_logits=True + ) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + has_aux_logits=True, + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # exit() + + """ + Preprocess + """ + new_dict = {} + old_keys = trainable_state_dict.keys() + new_keys = [] + for k in old_keys: + new_key = k.replace("_", ".") + new_key = new_key.replace("running.mean", "running_mean") + new_key = new_key.replace("running.var", "running_var") + new_keys.append(new_key) + for k1, k2 in zip(trainable_state_dict.keys(), new_keys): + new_dict[k2] = trainable_state_dict[k1] + trainable_state_dict = new_dict + + new_dict = {} + old_keys = non_trainable_state_dict.keys() + new_keys = [] + for k in old_keys: + new_key = k.replace("_", ".") + new_key = new_key.replace("running.mean", "running_mean") + new_key = new_key.replace("running.var", "running_var") + new_keys.append(new_key) + for k1, k2 in zip(non_trainable_state_dict.keys(), new_keys): + new_dict[k2] = non_trainable_state_dict[k1] + non_trainable_state_dict = new_dict + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # general + torch_name = torch_name.replace("conv2d", "conv") + # head + torch_name = torch_name.replace("classifier", "fc") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + torch_y, torch_y_aux = torch_y[0], torch_y[1] + keras_y = keras_model(keras_data, training=False) + keras_y, keras_y_aux = keras_y[0], keras_y[1] + torch_y = torch_y.detach().cpu().numpy() + torch_y_aux = torch_y_aux.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + keras_y_aux = keras.ops.convert_to_numpy(keras_y_aux) + np.testing.assert_allclose(torch_y, keras_y, atol=1e-5) + np.testing.assert_allclose(torch_y_aux, keras_y_aux, atol=1e-5) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" + keras_model.save(export_path) + print(f"Export to {export_path}") diff --git a/tools/convert_mobilenet_v2_from_timm.py b/tools/convert_mobilenet_v2_from_timm.py index d0f702f..ec55782 100644 --- a/tools/convert_mobilenet_v2_from_timm.py +++ b/tools/convert_mobilenet_v2_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -132,6 +134,7 @@ """ Save converted model """ - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_mobilenet_v3_from_timm.py b/tools/convert_mobilenet_v3_from_timm.py index 4420913..0ba712c 100644 --- a/tools/convert_mobilenet_v3_from_timm.py +++ b/tools/convert_mobilenet_v3_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -160,6 +162,7 @@ """ Save converted model """ - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_mobilevit_from_timm.py b/tools/convert_mobilevit_from_timm.py index 445d2a7..5000e74 100644 --- a/tools/convert_mobilevit_from_timm.py +++ b/tools/convert_mobilevit_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -86,6 +88,8 @@ # final block torch_name = torch_name.replace("final.conv.conv2d", "final_conv.conv") torch_name = torch_name.replace("final.conv.bn", "final_conv.bn") + # head + torch_name = torch_name.replace("classifier", "head.fc") # weights naming mapping torch_name = torch_name.replace("kernel", "weight") # conv2d @@ -136,6 +140,7 @@ """ Save converted model """ - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_resnet_from_timm.py b/tools/convert_resnet_from_timm.py index eadd28c..4dde8bf 100644 --- a/tools/convert_resnet_from_timm.py +++ b/tools/convert_resnet_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -78,6 +80,8 @@ torch_name = torch_name.replace("conv3.bn", "bn3") torch_name = torch_name.replace("downsample.conv2d", "downsample.0") torch_name = torch_name.replace("downsample.bn", "downsample.1") + # head + torch_name = torch_name.replace("classifier", "fc") # weights naming mapping torch_name = torch_name.replace("kernel", "weight") # conv2d @@ -123,6 +127,7 @@ """ Save converted model """ - export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}") diff --git a/tools/convert_vit_from_timm.py b/tools/convert_vit_from_timm.py index eeb0289..f0fd93d 100644 --- a/tools/convert_vit_from_timm.py +++ b/tools/convert_vit_from_timm.py @@ -2,6 +2,8 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu pip install timm """ +import os + import keras import numpy as np import timm @@ -133,6 +135,7 @@ """ Save converted model """ - export_path = f"exported/{keras_model.name.lower()}_imagenet_384.keras" + os.makedirs("exported", exist_ok=True) + export_path = f"exported/{keras_model.name.lower()}_{timm_model_name}.keras" keras_model.save(export_path) print(f"Export to {export_path}")