diff --git a/README.md b/README.md index 5bb32d7..9b0bbd2 100644 --- a/README.md +++ b/README.md @@ -7,4 +7,10 @@ CUDA_VISIBLE_DEVICES= KERAS_BACKEND=tensorflow pytest ``` +## Work in Progress + +- Test pretrained weights +- Add `MobileNet100V3SmallMinimal` pretrained weights +- Add `MobileNet100V3LargeMinimal` pretrained weights + ## Acknowledgments diff --git a/kimm/blocks/base_block.py b/kimm/blocks/base_block.py index 13ae79d..c38c897 100644 --- a/kimm/blocks/base_block.py +++ b/kimm/blocks/base_block.py @@ -18,8 +18,8 @@ def apply_activation(x, activation=None, name="activation"): def apply_conv2d_block( inputs, - filters, - kernel_size, + filters=None, + kernel_size=None, strides=1, groups=1, activation=None, @@ -28,6 +28,10 @@ def apply_conv2d_block( bn_epsilon=1e-5, name="conv2d_block", ): + if kernel_size is None: + raise ValueError( + f"kernel_size must be passed. Received: kernel_size={kernel_size}" + ) x = inputs padding = "same" @@ -80,12 +84,12 @@ def apply_se_block( x = inputs x = layers.GlobalAveragePooling2D(keepdims=True, name=f"{name}_mean")(x) x = layers.Conv2D( - se_channels, 1, use_bias=True, name=f"{name}_reduce_conv2d" + se_channels, 1, use_bias=True, name=f"{name}_conv_reduce" )(x) - x = apply_activation(x, activation, name=f"{name}_act") + x = apply_activation(x, activation, name=f"{name}_act1") x = layers.Conv2D( - input_channels, 1, use_bias=True, name=f"{name}_expand_conv2d" + input_channels, 1, use_bias=True, name=f"{name}_conv_expand" )(x) - x = apply_activation(x, gate_activation, name=f"{name}_gate_act") + x = apply_activation(x, gate_activation, name=f"{name}_gate") out = layers.Multiply(name=name)([ori_x, x]) return out diff --git a/kimm/models/ghostnet.py b/kimm/models/ghostnet.py index e50f50f..6d93757 100644 --- a/kimm/models/ghostnet.py +++ b/kimm/models/ghostnet.py @@ -290,42 +290,40 @@ def __init__( x = apply_conv2d_block( x, stem_channels, 3, 2, activation="relu", name="conv_stem" ) - features["S2"] = x + features["STEM_S2"] = x # blocks total_layer_idx = 0 - block_idx = 0 - net_stride = 2 - for cfg in config: - layer_idx = 0 - strides = 1 - for kernel_size, expand_size, channels, se_ratio, strides in cfg: - output_channels = make_divisible(channels * width, 4) - hidden_channels = make_divisible(expand_size * width, 4) + current_stride = 2 + for current_block_idx, cfg in enumerate(config): + for current_layer_idx, (k, e, c, se, s) in enumerate(cfg): + output_channels = make_divisible(c * width, 4) + hidden_channels = make_divisible(e * width, 4) use_attention = False if version == "v2" and total_layer_idx > 1: use_attention = True + name = f"blocks_{current_block_idx}_{current_layer_idx}" x = apply_ghost_bottleneck( x, hidden_channels, output_channels, - kernel_size, - strides, - se_ratio=se_ratio, + k, + s, + se_ratio=se, use_attention=use_attention, - name=f"blocks{block_idx}_{layer_idx}", + name=name, ) - layer_idx += 1 total_layer_idx += 1 - if strides > 1: - net_stride *= strides - # add feature - features[f"S{net_stride}"] = x - block_idx += 1 + current_stride *= s + features[f"BLOCK{current_block_idx}_S{current_stride}"] = x # post stages conv block - output_channels = make_divisible(expand_size * width, 4) + output_channels = make_divisible(e * width, 4) x = apply_conv2d_block( - x, output_channels, 1, activation="relu", name=f"blocks{block_idx}" + x, + output_channels, + 1, + activation="relu", + name=f"blocks_{current_block_idx+1}", ) if include_top: @@ -366,8 +364,14 @@ def __init__( @staticmethod def available_feature_keys(): - # predefined for better UX - return [f"S{2**i}" for i in range(1, 6)] + feature_keys = ["STEM_S2"] + feature_keys.extend( + [ + f"BLOCK{i}_S{j}" + for i, j in zip(range(9), [2, 4, 4, 8, 8, 16, 16, 32, 32]) + ] + ) + return feature_keys def get_config(self): config = super().get_config() diff --git a/kimm/models/ghostnet_test.py b/kimm/models/ghostnet_test.py index 5771c57..4c687da 100644 --- a/kimm/models/ghostnet_test.py +++ b/kimm/models/ghostnet_test.py @@ -25,11 +25,14 @@ def test_ghostnet_feature_extractor(self, model_class): y = model.predict(x) self.assertIsInstance(y, dict) - self.assertEqual(list(y["S2"].shape), [1, 112, 112, 16]) - self.assertEqual(list(y["S4"].shape), [1, 56, 56, 24]) - self.assertEqual(list(y["S8"].shape), [1, 28, 28, 40]) - self.assertEqual(list(y["S16"].shape), [1, 14, 14, 80]) - self.assertEqual(list(y["S32"].shape), [1, 7, 7, 160]) + self.assertAllEqual( + list(y.keys()), model_class.available_feature_keys() + ) + self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 16]) + self.assertEqual(list(y["BLOCK1_S4"].shape), [1, 56, 56, 24]) + self.assertEqual(list(y["BLOCK3_S8"].shape), [1, 28, 28, 40]) + self.assertEqual(list(y["BLOCK5_S16"].shape), [1, 14, 14, 80]) + self.assertEqual(list(y["BLOCK7_S32"].shape), [1, 7, 7, 160]) @parameterized.named_parameters([(GhostNet100V2.__name__, GhostNet100V2)]) def test_ghostnetv2_base(self, model_class): @@ -49,8 +52,11 @@ def test_ghostnetv2_feature_extractor(self, model_class): y = model.predict(x) self.assertIsInstance(y, dict) - self.assertEqual(list(y["S2"].shape), [1, 112, 112, 16]) - self.assertEqual(list(y["S4"].shape), [1, 56, 56, 24]) - self.assertEqual(list(y["S8"].shape), [1, 28, 28, 40]) - self.assertEqual(list(y["S16"].shape), [1, 14, 14, 80]) - self.assertEqual(list(y["S32"].shape), [1, 7, 7, 160]) + self.assertAllEqual( + list(y.keys()), model_class.available_feature_keys() + ) + self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 16]) + self.assertEqual(list(y["BLOCK1_S4"].shape), [1, 56, 56, 24]) + self.assertEqual(list(y["BLOCK3_S8"].shape), [1, 28, 28, 40]) + self.assertEqual(list(y["BLOCK5_S16"].shape), [1, 14, 14, 80]) + self.assertEqual(list(y["BLOCK7_S32"].shape), [1, 7, 7, 160]) diff --git a/kimm/models/mobilenet_v2.py b/kimm/models/mobilenet_v2.py new file mode 100644 index 0000000..e0a6db8 --- /dev/null +++ b/kimm/models/mobilenet_v2.py @@ -0,0 +1,446 @@ +import math +import typing + +import keras +from keras import backend +from keras import layers +from keras import utils +from keras.src.applications import imagenet_utils + +from kimm.blocks import apply_conv2d_block +from kimm.models.feature_extractor import FeatureExtractor +from kimm.utils import make_divisible + +DEFAULT_CONFIG = [ + # type, repeat, kernel_size, strides, expansion_ratio, channels + ["ds", 1, 3, 1, 1, 16], + ["ir", 2, 3, 2, 6, 24], + ["ir", 3, 3, 2, 6, 32], + ["ir", 4, 3, 2, 6, 64], + ["ir", 3, 3, 1, 6, 96], + ["ir", 3, 3, 2, 6, 160], + ["ir", 1, 3, 1, 6, 320], +] + + +def apply_depthwise_separation_block( + inputs, + output_channels, + depthwise_kernel_size=3, + pointwise_kernel_size=1, + strides=1, + activation="relu6", + name="depthwise_separation_block", +): + input_channels = inputs.shape[-1] + has_skip = strides == 1 and input_channels == output_channels + + x = inputs + x = apply_conv2d_block( + x, + kernel_size=depthwise_kernel_size, + strides=strides, + activation=activation, + use_depthwise=True, + name=f"{name}_conv_dw", + ) + x = apply_conv2d_block( + x, + output_channels, + pointwise_kernel_size, + 1, + activation=None, + name=f"{name}_conv_pw", + ) + if has_skip: + x = layers.Add()([x, inputs]) + return x + + +def apply_inverted_residual_block( + inputs, + output_channels, + depthwise_kernel_size=3, + expansion_kernel_size=1, + pointwise_kernel_size=1, + strides=1, + expansion_ratio=1.0, + activation="relu6", + name="inverted_residual_block", +): + input_channels = inputs.shape[-1] + hidden_channels = make_divisible(input_channels * expansion_ratio) + has_skip = strides == 1 and input_channels == output_channels + + x = inputs + + # Point-wise expansion + x = apply_conv2d_block( + x, + hidden_channels, + expansion_kernel_size, + 1, + activation=activation, + name=f"{name}_conv_pw", + ) + # Depth-wise convolution + x = apply_conv2d_block( + x, + kernel_size=depthwise_kernel_size, + strides=strides, + activation=activation, + use_depthwise=True, + name=f"{name}_conv_dw", + ) + # Point-wise linear projection + x = apply_conv2d_block( + x, + output_channels, + pointwise_kernel_size, + 1, + activation=None, + name=f"{name}_conv_pwl", + ) + if has_skip: + x = layers.Add()([x, inputs]) + return x + + +class MobileNetV2(FeatureExtractor): + def __init__( + self, + width: float = 1.0, + depth: float = 1.0, + fix_stem_and_head_channels: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "default", + **kwargs, + ): + if config == "default": + config = DEFAULT_CONFIG + + # Prepare feature extraction + features = {} + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=224, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + x = img_input + + # [0, 255] to [0, 1] and apply ImageNet mean and variance + if include_preprocessing: + x = layers.Rescaling(scale=1.0 / 255.0)(x) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + + # stem + stem_channel = ( + 32 if fix_stem_and_head_channels else make_divisible(32 * width) + ) + x = apply_conv2d_block( + x, + stem_channel, + 3, + 2, + activation="relu6", + name="conv_stem", + ) + features["STEM_S2"] = x + + # blocks + current_stride = 2 + for current_block_idx, cfg in enumerate(config): + block_type, r, k, s, e, c = cfg + c = make_divisible(c * width) + # no depth multiplier at first and last block + if current_block_idx not in (0, len(config) - 1): + r = int(math.ceil(r * depth)) + for current_layer_idx in range(r): + s = s if current_layer_idx == 0 else 1 + name = f"blocks_{current_block_idx}_{current_layer_idx}" + if block_type == "ds": + x = apply_depthwise_separation_block( + x, c, k, 1, s, name=name + ) + elif block_type == "ir": + x = apply_inverted_residual_block( + x, c, k, 1, 1, s, e, name=name + ) + current_stride *= s + features[f"BLOCK{current_block_idx}_S{current_stride}"] = x + + # last conv + if fix_stem_and_head_channels: + head_channels = 1280 + else: + head_channels = max(1280, make_divisible(1280 * width)) + x = apply_conv2d_block( + x, head_channels, 1, 1, activation="relu6", name="conv_head" + ) + + if include_top: + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.width = width + self.depth = depth + self.fix_stem_and_head_channels = fix_stem_and_head_channels + self.include_preprocessing = include_preprocessing + self.include_top = include_top + self.pooling = pooling + self.dropout_rate = dropout_rate + self.classes = classes + self.classifier_activation = classifier_activation + self._weights = weights # `self.weights` is been used internally + self.config = config + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S2"] + feature_keys.extend( + [ + f"BLOCK{i}_S{j}" + for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) + ] + ) + return feature_keys + + def get_config(self): + config = super().get_config() + config.update( + { + "width": self.width, + "input_shape": self.input_shape[1:], + "include_preprocessing": self.include_preprocessing, + "include_top": self.include_top, + "pooling": self.pooling, + "dropout_rate": self.dropout_rate, + "classes": self.classes, + "classifier_activation": self.classifier_activation, + "weights": self._weights, + "config": self.config, + } + ) + return config + + +""" +Model Definition +""" + + +class MobileNet050V2(MobileNetV2): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "default", + name: str = "MobileNet050V2", + **kwargs, + ): + super().__init__( + 0.5, + 1.0, + False, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + name=name, + **kwargs, + ) + + +class MobileNet100V2(MobileNetV2): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "default", + name: str = "MobileNet100V2", + **kwargs, + ): + super().__init__( + 1.0, + 1.0, + False, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + name=name, + **kwargs, + ) + + +class MobileNet110V2(MobileNetV2): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "default", + name: str = "MobileNet110V2", + **kwargs, + ): + super().__init__( + 1.1, + 1.2, + True, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + name=name, + **kwargs, + ) + + +class MobileNet120V2(MobileNetV2): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "default", + name: str = "MobileNet120V2", + **kwargs, + ): + super().__init__( + 1.2, + 1.4, + True, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + name=name, + **kwargs, + ) + + +class MobileNet140V2(MobileNetV2): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "default", + name: str = "MobileNet140V2", + **kwargs, + ): + super().__init__( + 1.4, + 1.0, + False, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + name=name, + **kwargs, + ) diff --git a/kimm/models/mobilenet_v2_test.py b/kimm/models/mobilenet_v2_test.py new file mode 100644 index 0000000..5045068 --- /dev/null +++ b/kimm/models/mobilenet_v2_test.py @@ -0,0 +1,56 @@ +from absl.testing import parameterized +from keras import random +from keras.src import testing + +from kimm.models.mobilenet_v2 import MobileNet050V2 +from kimm.models.mobilenet_v2 import MobileNet100V2 +from kimm.utils import make_divisible + + +class MobileNetV2Test(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + [ + (MobileNet050V2.__name__, MobileNet050V2), + (MobileNet100V2.__name__, MobileNet100V2), + ] + ) + def test_mobilenet_v2_base(self, model_class): + # TODO: test the correctness of the real image + x = random.uniform([1, 224, 224, 3]) * 255.0 + model = model_class() + + y = model.predict(x) + + self.assertEqual(y.shape, (1, 1000)) + + @parameterized.named_parameters( + [ + (MobileNet050V2.__name__, MobileNet050V2, 0.5), + (MobileNet100V2.__name__, MobileNet100V2, 1.0), + ] + ) + def test_mobilenet_v2_feature_extractor(self, model_class, width): + x = random.uniform([1, 224, 224, 3]) * 255.0 + model = model_class(as_feature_extractor=True) + + y = model.predict(x) + + self.assertIsInstance(y, dict) + self.assertAllEqual( + list(y.keys()), model_class.available_feature_keys() + ) + self.assertEqual( + list(y["STEM_S2"].shape), [1, 112, 112, make_divisible(32 * width)] + ) + self.assertEqual( + list(y["BLOCK1_S4"].shape), [1, 56, 56, make_divisible(24 * width)] + ) + self.assertEqual( + list(y["BLOCK2_S8"].shape), [1, 28, 28, make_divisible(32 * width)] + ) + self.assertEqual( + list(y["BLOCK3_S16"].shape), [1, 14, 14, make_divisible(64 * width)] + ) + self.assertEqual( + list(y["BLOCK5_S32"].shape), [1, 7, 7, make_divisible(160 * width)] + ) diff --git a/kimm/models/mobilenet_v3.py b/kimm/models/mobilenet_v3.py new file mode 100644 index 0000000..997d229 --- /dev/null +++ b/kimm/models/mobilenet_v3.py @@ -0,0 +1,629 @@ +import math +import typing + +import keras +from keras import backend +from keras import layers +from keras import utils +from keras.src.applications import imagenet_utils + +from kimm.blocks import apply_conv2d_block +from kimm.blocks import apply_se_block +from kimm.models.feature_extractor import FeatureExtractor +from kimm.utils import make_divisible + +DEFAULT_SMALL_CONFIG = [ + # type, repeat, kernel_size, strides, expansion_ratio, channels, se_ratio, + # activation + # stage0 + [["ds", 1, 3, 2, 1.0, 16, 0.25, "relu"]], + # stage1 + [ + ["ir", 1, 3, 2, 4.5, 24, 0.0, "relu"], + ["ir", 1, 3, 1, 3.67, 24, 0.0, "relu"], + ], + # stage2 + [ + ["ir", 1, 5, 2, 4.0, 40, 0.25, "hard_swish"], + ["ir", 2, 5, 1, 6.0, 40, 0.25, "hard_swish"], + ], + # stage3 + [["ir", 2, 5, 1, 3.0, 48, 0.25, "hard_swish"]], + # stage4 + [["ir", 3, 5, 2, 6.0, 96, 0.25, "hard_swish"]], + # stage5 + [["cn", 1, 1, 1, 1.0, 576, 0.0, "hard_swish"]], +] +DEFAULT_LARGE_CONFIG = [ + # type, repeat, kernel_size, strides, expansion_ratio, channels, se_ratio, + # activation + # stage0 + [["ds", 1, 3, 1, 1.0, 16, 0.0, "relu"]], + # stage1 + [ + ["ir", 1, 3, 2, 4.0, 24, 0.0, "relu"], + ["ir", 1, 3, 1, 3.0, 24, 0.0, "relu"], + ], + # stage2 + [["ir", 3, 5, 2, 3.0, 40, 0.25, "relu"]], + # stage3 + [ + ["ir", 1, 3, 2, 6.0, 80, 0.0, "hard_swish"], + ["ir", 1, 3, 1, 2.5, 80, 0.0, "hard_swish"], + ["ir", 2, 3, 1, 2.3, 80, 0.0, "hard_swish"], + ], + # stage4 + [["ir", 2, 3, 1, 6.0, 112, 0.25, "hard_swish"]], + # stage5 + [["ir", 3, 5, 2, 6.0, 160, 0.25, "hard_swish"]], + # stage6 + [["cn", 1, 1, 1, 1.0, 960, 0.0, "hard_swish"]], +] + + +def apply_depthwise_separation_block( + inputs, + output_channels, + depthwise_kernel_size=3, + pointwise_kernel_size=1, + strides=1, + se_ratio=0.0, + activation="relu", + name="depthwise_separation_block", +): + input_channels = inputs.shape[-1] + has_skip = strides == 1 and input_channels == output_channels + + x = inputs + x = apply_conv2d_block( + x, + kernel_size=depthwise_kernel_size, + strides=strides, + activation=activation, + use_depthwise=True, + name=f"{name}_conv_dw", + ) + if se_ratio > 0: + x = apply_se_block( + x, + se_ratio, + activation="relu", + gate_activation="hard_sigmoid", + make_divisible_number=8, + name=f"{name}_se", + ) + x = apply_conv2d_block( + x, + output_channels, + pointwise_kernel_size, + 1, + activation=None, + name=f"{name}_conv_pw", + ) + if has_skip: + x = layers.Add()([x, inputs]) + return x + + +def apply_inverted_residual_block( + inputs, + output_channels, + depthwise_kernel_size=3, + expansion_kernel_size=1, + pointwise_kernel_size=1, + strides=1, + expansion_ratio=1.0, + se_ratio=0.0, + activation="relu", + name="inverted_residual_block", +): + input_channels = inputs.shape[-1] + hidden_channels = make_divisible(input_channels * expansion_ratio) + has_skip = strides == 1 and input_channels == output_channels + + x = inputs + + # Point-wise expansion + x = apply_conv2d_block( + x, + hidden_channels, + expansion_kernel_size, + 1, + activation=activation, + name=f"{name}_conv_pw", + ) + # Depth-wise convolution + x = apply_conv2d_block( + x, + kernel_size=depthwise_kernel_size, + strides=strides, + activation=activation, + use_depthwise=True, + name=f"{name}_conv_dw", + ) + # Squeeze-and-excitation + if se_ratio > 0: + x = apply_se_block( + x, + se_ratio, + activation="relu", + gate_activation="hard_sigmoid", + make_divisible_number=8, + name=f"{name}_se", + ) + # Point-wise linear projection + x = apply_conv2d_block( + x, + output_channels, + pointwise_kernel_size, + 1, + activation=None, + name=f"{name}_conv_pwl", + ) + if has_skip: + x = layers.Add()([x, inputs]) + return x + + +class MobileNetV3(FeatureExtractor): + def __init__( + self, + width: float = 1.0, + depth: float = 1.0, + fix_stem_and_head_channels: bool = False, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "large", + minimal: bool = False, + **kwargs, + ): + if config == "small": + config = DEFAULT_SMALL_CONFIG + conv_head_channels = 1024 + elif config == "large": + config = DEFAULT_LARGE_CONFIG + conv_head_channels = 1280 + if minimal: + force_activation = "relu" + force_kernel_size = 3 + no_se = True + else: + force_activation = None + force_kernel_size = None + no_se = False + + # Prepare feature extraction + features = {} + + # Determine proper input shape + input_shape = imagenet_utils.obtain_input_shape( + input_shape, + default_size=224, + min_size=32, + data_format=backend.image_data_format(), + require_flatten=include_top, + weights=weights, + ) + + if input_tensor is None: + img_input = layers.Input(shape=input_shape) + else: + if not backend.is_keras_tensor(input_tensor): + img_input = layers.Input(tensor=input_tensor, shape=input_shape) + else: + img_input = input_tensor + + x = img_input + + # [0, 255] to [0, 1] and apply ImageNet mean and variance + if include_preprocessing: + x = layers.Rescaling(scale=1.0 / 255.0)(x) + x = layers.Normalization( + mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225] + )(x) + + # stem + stem_channel = ( + 16 if fix_stem_and_head_channels else make_divisible(16 * width) + ) + x = apply_conv2d_block( + x, + stem_channel, + 3, + 2, + activation=force_activation or "hard_swish", + name="conv_stem", + ) + features["STEM_S2"] = x + + # blocks + current_stride = 2 + for current_stage_idx, cfg in enumerate(config): + for current_block_idx, sub_cfg in enumerate(cfg): + block_type, r, k, s, e, c, se, act = sub_cfg + + # override default config + if force_activation is not None: + act = force_activation + if force_kernel_size is not None: + k = force_kernel_size if k > force_kernel_size else k + if no_se: + se = 0.0 + + c = make_divisible(c * width) + # no depth multiplier at first and last block + if current_block_idx not in (0, len(config) - 1): + r = int(math.ceil(r * depth)) + for current_layer_idx in range(r): + s = s if current_layer_idx == 0 else 1 + name = ( + f"blocks_{current_stage_idx}_" + f"{current_block_idx + current_layer_idx}" + ) + if block_type == "ds": + x = apply_depthwise_separation_block( + x, c, k, 1, s, se, act, name=name + ) + elif block_type == "ir": + x = apply_inverted_residual_block( + x, c, k, 1, 1, s, e, se, act, name=name + ) + elif block_type == "cn": + x = apply_conv2d_block( + x, + filters=c, + kernel_size=k, + strides=s, + activation=act, + name=name, + ) + current_stride *= s + features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x + + if include_top: + x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)(x) + if fix_stem_and_head_channels: + conv_head_channels = conv_head_channels + else: + conv_head_channels = max( + conv_head_channels, + make_divisible(conv_head_channels * width), + ) + x = layers.Conv2D( + conv_head_channels, 1, 1, use_bias=True, name="conv_head" + )(x) + x = layers.Activation( + force_activation or "hard_swish", name="act2" + )(x) + x = layers.Flatten()(x) + x = layers.Dropout(rate=dropout_rate, name="conv_head_dropout")(x) + x = layers.Dense( + classes, activation=classifier_activation, name="classifier" + )(x) + else: + if pooling == "avg": + x = layers.GlobalAveragePooling2D(name="avg_pool")(x) + elif pooling == "max": + x = layers.GlobalMaxPooling2D(name="max_pool")(x) + + # Ensure that the model takes into account + # any potential predecessors of `input_tensor`. + if input_tensor is not None: + inputs = utils.get_source_inputs(input_tensor) + else: + inputs = img_input + + super().__init__(inputs=inputs, outputs=x, features=features, **kwargs) + + # All references to `self` below this line + self.width = width + self.depth = depth + self.fix_stem_and_head_channels = fix_stem_and_head_channels + self.include_preprocessing = include_preprocessing + self.include_top = include_top + self.pooling = pooling + self.dropout_rate = dropout_rate + self.classes = classes + self.classifier_activation = classifier_activation + self._weights = weights # `self.weights` is been used internally + self.config = config + + @staticmethod + def available_feature_keys(): + raise NotImplementedError() + + def get_config(self): + config = super().get_config() + config.update( + { + "width": self.width, + "input_shape": self.input_shape[1:], + "include_preprocessing": self.include_preprocessing, + "include_top": self.include_top, + "pooling": self.pooling, + "dropout_rate": self.dropout_rate, + "classes": self.classes, + "classifier_activation": self.classifier_activation, + "weights": self._weights, + "config": self.config, + } + ) + return config + + +""" +Model Definition +""" + + +class MobileNet050V3Small(MobileNetV3): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "small", + name: str = "MobileNet050V3Small", + **kwargs, + ): + super().__init__( + 0.5, + 1.0, + True, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + name=name, + **kwargs, + ) + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S2"] + feature_keys.extend( + [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])] + ) + return feature_keys + + +class MobileNet075V3Small(MobileNetV3): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "small", + name: str = "MobileNet075V3Small", + **kwargs, + ): + super().__init__( + 0.75, + 1.0, + False, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + name=name, + **kwargs, + ) + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S2"] + feature_keys.extend( + [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])] + ) + return feature_keys + + +class MobileNet100V3Small(MobileNetV3): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "small", + name: str = "MobileNet100V3Small", + **kwargs, + ): + super().__init__( + 1.0, + 1.0, + False, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + name=name, + **kwargs, + ) + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S2"] + feature_keys.extend( + [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])] + ) + return feature_keys + + +class MobileNet100V3SmallMinimal(MobileNetV3): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "small", + name: str = "MobileNet100V3SmallMinimal", + **kwargs, + ): + super().__init__( + 1.0, + 1.0, + False, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + minimal=True, + name=name, + **kwargs, + ) + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S2"] + feature_keys.extend( + [f"BLOCK{i}_S{j}" for i, j in zip(range(6), [4, 8, 16, 16, 32, 32])] + ) + return feature_keys + + +class MobileNet100V3Large(MobileNetV3): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "large", + name: str = "MobileNet100V3Large", + **kwargs, + ): + super().__init__( + 1.0, + 1.0, + False, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + name=name, + **kwargs, + ) + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S2"] + feature_keys.extend( + [ + f"BLOCK{i}_S{j}" + for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) + ] + ) + return feature_keys + + +class MobileNet100V3LargeMinimal(MobileNetV3): + def __init__( + self, + input_tensor: keras.KerasTensor = None, + input_shape: typing.Optional[typing.Sequence[int]] = None, + include_preprocessing: bool = True, + include_top: bool = True, + pooling: typing.Optional[str] = None, + dropout_rate: float = 0.0, + classes: int = 1000, + classifier_activation: str = "softmax", + weights: typing.Optional[str] = None, # TODO: imagenet + config: typing.Union[str, typing.List] = "large", + name: str = "MobileNet100V3LargeMinimal", + **kwargs, + ): + super().__init__( + 1.0, + 1.0, + False, + input_tensor, + input_shape, + include_preprocessing, + include_top, + pooling, + dropout_rate, + classes, + classifier_activation, + weights, + config, + minimal=True, + name=name, + **kwargs, + ) + + @staticmethod + def available_feature_keys(): + feature_keys = ["STEM_S2"] + feature_keys.extend( + [ + f"BLOCK{i}_S{j}" + for i, j in zip(range(7), [2, 4, 8, 16, 16, 32, 32]) + ] + ) + return feature_keys diff --git a/kimm/models/mobilenet_v3_test.py b/kimm/models/mobilenet_v3_test.py new file mode 100644 index 0000000..9d40fa8 --- /dev/null +++ b/kimm/models/mobilenet_v3_test.py @@ -0,0 +1,83 @@ +from absl.testing import parameterized +from keras import random +from keras.src import testing + +from kimm.models.mobilenet_v3 import MobileNet100V3Large +from kimm.models.mobilenet_v3 import MobileNet100V3Small +from kimm.utils import make_divisible + + +class MobileNetV2Test(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + [ + (MobileNet100V3Small.__name__, MobileNet100V3Small), + (MobileNet100V3Large.__name__, MobileNet100V3Large), + ] + ) + def test_mobilenet_v3_base(self, model_class): + # TODO: test the correctness of the real image + x = random.uniform([1, 224, 224, 3]) * 255.0 + model = model_class() + + y = model.predict(x) + + self.assertEqual(y.shape, (1, 1000)) + + @parameterized.named_parameters( + [ + (MobileNet100V3Small.__name__, MobileNet100V3Small, 1.0), + (MobileNet100V3Large.__name__, MobileNet100V3Large, 1.0), + ] + ) + def test_mobilenet_v3_feature_extractor(self, model_class, width): + x = random.uniform([1, 224, 224, 3]) * 255.0 + model = model_class(as_feature_extractor=True) + + y = model.predict(x) + + self.assertIsInstance(y, dict) + self.assertAllEqual( + list(y.keys()), model_class.available_feature_keys() + ) + if "Small" in model_class.__name__: + self.assertEqual( + list(y["STEM_S2"].shape), + [1, 112, 112, make_divisible(16 * width)], + ) + self.assertEqual( + list(y["BLOCK0_S4"].shape), + [1, 56, 56, make_divisible(16 * width)], + ) + self.assertEqual( + list(y["BLOCK1_S8"].shape), + [1, 28, 28, make_divisible(24 * width)], + ) + self.assertEqual( + list(y["BLOCK2_S16"].shape), + [1, 14, 14, make_divisible(40 * width)], + ) + self.assertEqual( + list(y["BLOCK4_S32"].shape), + [1, 7, 7, make_divisible(96 * width)], + ) + else: + self.assertEqual( + list(y["STEM_S2"].shape), + [1, 112, 112, make_divisible(16 * width)], + ) + self.assertEqual( + list(y["BLOCK1_S4"].shape), + [1, 56, 56, make_divisible(24 * width)], + ) + self.assertEqual( + list(y["BLOCK2_S8"].shape), + [1, 28, 28, make_divisible(40 * width)], + ) + self.assertEqual( + list(y["BLOCK3_S16"].shape), + [1, 14, 14, make_divisible(80 * width)], + ) + self.assertEqual( + list(y["BLOCK5_S32"].shape), + [1, 7, 7, make_divisible(160 * width)], + ) diff --git a/kimm/models/resnet.py b/kimm/models/resnet.py index 17f2b16..8464feb 100644 --- a/kimm/models/resnet.py +++ b/kimm/models/resnet.py @@ -161,7 +161,7 @@ def __init__( x = apply_conv2d_block( x, stem_channels, 7, 2, activation="relu", name="conv_stem" ) - features["S2"] = x + features["STEM_S2"] = x # max pooling x = layers.ZeroPadding2D(padding=1)(x) @@ -169,27 +169,24 @@ def __init__( # stages output_channels = [64, 128, 256, 512] - net_stride = 4 - stage_idx = 0 - for c, n in zip(output_channels, num_blocks): - stride = 1 if stage_idx == 0 else 2 - net_stride *= stride + current_stride = 4 + for current_stage_idx, (c, n) in enumerate( + zip(output_channels, num_blocks) + ): + stride = 1 if current_stage_idx == 0 else 2 + current_stride *= stride # blocks - for block_idx in range(n): - stride = stride if block_idx == 0 else 1 + for current_block_idx in range(n): + stride = stride if current_block_idx == 0 else 1 + name = f"layer{current_stage_idx + 1}_{current_block_idx}" if block_fn == "basic": - x = apply_basic_block( - x, c, stride, name=f"layer{stage_idx + 1}_{block_idx}" - ) + x = apply_basic_block(x, c, stride, name=name) elif block_fn == "bottleneck": - x = apply_bottleneck_block( - x, c, stride, name=f"layer{stage_idx + 1}_{block_idx}" - ) + x = apply_bottleneck_block(x, c, stride, name=name) else: raise NotImplementedError # add feature - features[f"S{net_stride}"] = x - stage_idx += 1 + features[f"BLOCK{current_stage_idx}_S{current_stride}"] = x if include_top: x = layers.GlobalAveragePooling2D(name="avg_pool", keepdims=True)(x) @@ -226,8 +223,11 @@ def __init__( @staticmethod def available_feature_keys(): - # predefined for better UX - return [f"S{2**i}" for i in range(1, 6)] + feature_keys = ["STEM_S2"] + feature_keys.extend( + [f"BLOCK{i}_S{j}" for i, j in zip(range(4), [4, 8, 16, 32])] + ) + return feature_keys def get_config(self): config = super().get_config() diff --git a/kimm/models/resnet_test.py b/kimm/models/resnet_test.py index 1bc9962..e2be162 100644 --- a/kimm/models/resnet_test.py +++ b/kimm/models/resnet_test.py @@ -29,8 +29,19 @@ def test_resnet_feature_extractor(self, model_class, expansion): y = model.predict(x) self.assertIsInstance(y, dict) - self.assertEqual(list(y["S2"].shape), [1, 112, 112, 64]) - self.assertEqual(list(y["S4"].shape), [1, 56, 56, 64 * expansion]) - self.assertEqual(list(y["S8"].shape), [1, 28, 28, 128 * expansion]) - self.assertEqual(list(y["S16"].shape), [1, 14, 14, 256 * expansion]) - self.assertEqual(list(y["S32"].shape), [1, 7, 7, 512 * expansion]) + self.assertAllEqual( + list(y.keys()), model_class.available_feature_keys() + ) + self.assertEqual(list(y["STEM_S2"].shape), [1, 112, 112, 64]) + self.assertEqual( + list(y["BLOCK0_S4"].shape), [1, 56, 56, 64 * expansion] + ) + self.assertEqual( + list(y["BLOCK1_S8"].shape), [1, 28, 28, 128 * expansion] + ) + self.assertEqual( + list(y["BLOCK2_S16"].shape), [1, 14, 14, 256 * expansion] + ) + self.assertEqual( + list(y["BLOCK3_S32"].shape), [1, 7, 7, 512 * expansion] + ) diff --git a/kimm/models/vision_transformer.py b/kimm/models/vision_transformer.py index 6da7323..6762e21 100644 --- a/kimm/models/vision_transformer.py +++ b/kimm/models/vision_transformer.py @@ -158,7 +158,7 @@ def __init__( )(x) x = layers.Reshape((-1, embed_dim))(x) x = kimm_layers.PositionEmbedding(name="postition_embedding")(x) - features["Depth0"] = x + features["EMBEDDING"] = x x = layers.Dropout(pos_dropout_rate, name="pos_dropout")(x) for i in range(depth): @@ -175,7 +175,7 @@ def __init__( ), name=f"blocks_{i}", ) - features[f"Depth{i + 1}"] = x + features[f"BLOCK{i}"] = x x = layers.LayerNormalization(epsilon=1e-6, name="norm")(x) if include_top: @@ -216,7 +216,7 @@ def __init__( @staticmethod def available_feature_keys(): - raise NotImplementedError + raise NotImplementedError() def get_config(self): config = super().get_config() @@ -288,6 +288,12 @@ def __init__( **kwargs, ) + @staticmethod + def available_feature_keys(): + feature_keys = ["EMBEDDING"] + feature_keys.extend([f"BLOCK{i}" for i in range(12)]) + return feature_keys + class VisionTransformerTiny32(VisionTransformer): def __init__( @@ -330,6 +336,12 @@ def __init__( **kwargs, ) + @staticmethod + def available_feature_keys(): + feature_keys = ["EMBEDDING"] + feature_keys.extend([f"BLOCK{i}" for i in range(12)]) + return feature_keys + class VisionTransformerSmall16(VisionTransformer): def __init__( @@ -372,6 +384,12 @@ def __init__( **kwargs, ) + @staticmethod + def available_feature_keys(): + feature_keys = ["EMBEDDING"] + feature_keys.extend([f"BLOCK{i}" for i in range(12)]) + return feature_keys + class VisionTransformerSmall32(VisionTransformer): def __init__( @@ -414,6 +432,12 @@ def __init__( **kwargs, ) + @staticmethod + def available_feature_keys(): + feature_keys = ["EMBEDDING"] + feature_keys.extend([f"BLOCK{i}" for i in range(12)]) + return feature_keys + class VisionTransformerBase16(VisionTransformer): def __init__( @@ -456,6 +480,12 @@ def __init__( **kwargs, ) + @staticmethod + def available_feature_keys(): + feature_keys = ["EMBEDDING"] + feature_keys.extend([f"BLOCK{i}" for i in range(12)]) + return feature_keys + class VisionTransformerBase32(VisionTransformer): def __init__( @@ -498,6 +528,12 @@ def __init__( **kwargs, ) + @staticmethod + def available_feature_keys(): + feature_keys = ["EMBEDDING"] + feature_keys.extend([f"BLOCK{i}" for i in range(12)]) + return feature_keys + class VisionTransformerLarge16(VisionTransformer): def __init__( @@ -540,6 +576,12 @@ def __init__( **kwargs, ) + @staticmethod + def available_feature_keys(): + feature_keys = ["EMBEDDING"] + feature_keys.extend([f"BLOCK{i}" for i in range(24)]) + return feature_keys + class VisionTransformerLarge32(VisionTransformer): def __init__( @@ -581,3 +623,9 @@ def __init__( name=name, **kwargs, ) + + @staticmethod + def available_feature_keys(): + feature_keys = ["EMBEDDING"] + feature_keys.extend([f"BLOCK{i}" for i in range(24)]) + return feature_keys diff --git a/kimm/models/vision_transformer_test.py b/kimm/models/vision_transformer_test.py index 8206113..54f4807 100644 --- a/kimm/models/vision_transformer_test.py +++ b/kimm/models/vision_transformer_test.py @@ -37,11 +37,14 @@ def test_vision_transformer_feature_extractor( y = model.predict(x) self.assertIsInstance(y, dict) + self.assertAllEqual( + list(y.keys()), model_class.available_feature_keys() + ) if patch_size == 16: - self.assertEqual(list(y["Depth0"].shape), [1, 577, 192]) + self.assertEqual(list(y["BLOCK0"].shape), [1, 577, 192]) elif patch_size == 32: - self.assertEqual(list(y["Depth0"].shape), [1, 145, 192]) + self.assertEqual(list(y["BLOCK0"].shape), [1, 145, 192]) if patch_size == 16: - self.assertEqual(list(y["Depth5"].shape), [1, 577, 192]) + self.assertEqual(list(y["BLOCK5"].shape), [1, 577, 192]) elif patch_size == 32: - self.assertEqual(list(y["Depth5"].shape), [1, 145, 192]) + self.assertEqual(list(y["BLOCK5"].shape), [1, 145, 192]) diff --git a/tools/convert_ghostnet_from_timm.py b/tools/convert_ghostnet_from_timm.py index b50dd35..05aaddf 100644 --- a/tools/convert_ghostnet_from_timm.py +++ b/tools/convert_ghostnet_from_timm.py @@ -50,6 +50,16 @@ keras_model ) + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # exit() + """ Assign weights """ @@ -61,7 +71,6 @@ torch_name = torch_name.replace("conv.stem.conv2d", "conv_stem") torch_name = torch_name.replace("conv.stem.bn", "bn1") # blocks - torch_name = torch_name.replace("blocks", "blocks.") torch_name = torch_name.replace("primary.conv.conv2d", "primary_conv.0") torch_name = torch_name.replace("primary.conv.bn", "primary_conv.1") torch_name = torch_name.replace( @@ -77,8 +86,8 @@ torch_name = torch_name.replace("shortcut2.conv2d", "shortcut.2") torch_name = torch_name.replace("shortcut2.bn", "shortcut.3") # se - torch_name = torch_name.replace("se.reduce.conv2d", "se.conv_reduce") - torch_name = torch_name.replace("se.expand.conv2d", "se.conv_expand") + torch_name = torch_name.replace("se.conv.reduce", "se.conv_reduce") + torch_name = torch_name.replace("se.conv.expand", "se.conv_expand") # short conv (GhostNetV2) torch_name = torch_name.replace("short.conv1.conv2d", "short_conv.0") torch_name = torch_name.replace("short.conv1.bn", "short_conv.1") @@ -116,7 +125,9 @@ raise ValueError( "Can't find the corresponding torch weights. The shape is " f"mismatched. Got keras_name={keras_name}, " - f"torch_name={torch_name}" + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" ) """ diff --git a/tools/convert_mobilenet_v2_from_timm.py b/tools/convert_mobilenet_v2_from_timm.py new file mode 100644 index 0000000..d0f702f --- /dev/null +++ b/tools/convert_mobilenet_v2_from_timm.py @@ -0,0 +1,137 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" +import keras +import numpy as np +import timm +import torch + +from kimm.models import mobilenet_v2 +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "mobilenetv2_050.lamb_in1k", + "mobilenetv2_100.ra_in1k", + "mobilenetv2_110d.ra_in1k", + "mobilenetv2_120d.ra_in1k", + "mobilenetv2_140.ra_in1k", +] +keras_model_classes = [ + mobilenet_v2.MobileNet050V2, + mobilenet_v2.MobileNet100V2, + mobilenet_v2.MobileNet110V2, + mobilenet_v2.MobileNet120V2, + mobilenet_v2.MobileNet140V2, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [224, 224, 3] + torch_model = timm.create_model(timm_model_name, pretrained=True) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # stem + torch_name = torch_name.replace("conv.stem.conv2d", "conv_stem") + torch_name = torch_name.replace("conv.stem.bn", "bn1") + # blocks + if "blocks.0.0" in torch_name: + # depthwise separation block + torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") + torch_name = torch_name.replace("conv.dw.bn", "bn1") + torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") + torch_name = torch_name.replace("conv.pw.bn", "bn2") + else: + # inverted residual block + torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") + torch_name = torch_name.replace("conv.pw.bn", "bn1") + torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") + torch_name = torch_name.replace("conv.dw.bn", "bn2") + torch_name = torch_name.replace("conv.pwl.conv2d", "conv_pwl") + torch_name = torch_name.replace("conv.pwl.bn", "bn3") + # conv head + torch_name = torch_name.replace("conv.head.conv2d", "conv_head") + torch_name = torch_name.replace("conv.head.bn", "bn2") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=2e-5) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + keras_model.save(export_path) + print(f"Export to {export_path}") diff --git a/tools/convert_mobilenet_v3_from_timm.py b/tools/convert_mobilenet_v3_from_timm.py new file mode 100644 index 0000000..45063ae --- /dev/null +++ b/tools/convert_mobilenet_v3_from_timm.py @@ -0,0 +1,150 @@ +""" +pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +pip install timm +""" +import keras +import numpy as np +import timm +import torch + +from kimm.models import mobilenet_v3 +from kimm.utils.timm_utils import assign_weights +from kimm.utils.timm_utils import is_same_weights +from kimm.utils.timm_utils import separate_keras_weights +from kimm.utils.timm_utils import separate_torch_state_dict + +timm_model_names = [ + "mobilenetv3_small_050.lamb_in1k", + "mobilenetv3_small_075.lamb_in1k", + # "tf_mobilenetv3_small_minimal_100.in1k", + "mobilenetv3_small_100.lamb_in1k", + "mobilenetv3_large_100.miil_in21k_ft_in1k", + # "tf_mobilenetv3_large_minimal_100.in1k", +] +keras_model_classes = [ + mobilenet_v3.MobileNet050V3Small, + mobilenet_v3.MobileNet075V3Small, + # mobilenet_v3.MobileNet100V3SmallMinimal, + mobilenet_v3.MobileNet100V3Small, + mobilenet_v3.MobileNet100V3Large, + # mobilenet_v3.MobileNet100V3LargeMinimal, +] + +for timm_model_name, keras_model_class in zip( + timm_model_names, keras_model_classes +): + """ + Prepare timm model and keras model + """ + input_shape = [224, 224, 3] + torch_model = timm.create_model(timm_model_name, pretrained=True) + torch_model = torch_model.eval() + trainable_state_dict, non_trainable_state_dict = separate_torch_state_dict( + torch_model.state_dict() + ) + keras_model = keras_model_class( + input_shape=input_shape, + include_preprocessing=False, + classifier_activation="linear", + ) + trainable_weights, non_trainable_weights = separate_keras_weights( + keras_model + ) + + # for torch_name, (_, keras_name) in zip( + # trainable_state_dict.keys(), trainable_weights + # ): + # print(f"{torch_name} {keras_name}") + + # print(len(trainable_state_dict.keys())) + # print(len(trainable_weights)) + + # exit() + + """ + Assign weights + """ + for keras_weight, keras_name in trainable_weights + non_trainable_weights: + keras_name: str + torch_name = keras_name + torch_name = torch_name.replace("_", ".") + # stem + torch_name = torch_name.replace("conv.stem.conv2d", "conv_stem") + torch_name = torch_name.replace("conv.stem.bn", "bn1") + # blocks + if "blocks.0.0" in torch_name: + # depthwise separation block + torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") + torch_name = torch_name.replace("conv.dw.bn", "bn1") + torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") + torch_name = torch_name.replace("conv.pw.bn", "bn2") + else: + # inverted residual block + torch_name = torch_name.replace("conv.pw.conv2d", "conv_pw") + torch_name = torch_name.replace("conv.pw.bn", "bn1") + torch_name = torch_name.replace("conv.dw.dwconv2d", "conv_dw") + torch_name = torch_name.replace("conv.dw.bn", "bn2") + torch_name = torch_name.replace("conv.pwl.conv2d", "conv_pwl") + torch_name = torch_name.replace("conv.pwl.bn", "bn3") + # se + torch_name = torch_name.replace("se.conv.reduce", "se.conv_reduce") + torch_name = torch_name.replace("se.conv.expand", "se.conv_expand") + # last conv block + if "Small" in keras_model_class.__name__: + if "blocks.5.0" in torch_name: + torch_name = torch_name.replace("conv2d", "conv") + torch_name = torch_name.replace("bn", "bn1") + if "Large" in keras_model_class.__name__: + if "blocks.6.0" in torch_name: + torch_name = torch_name.replace("conv2d", "conv") + torch_name = torch_name.replace("bn", "bn1") + # conv head + torch_name = torch_name.replace("conv.head", "conv_head") + + # weights naming mapping + torch_name = torch_name.replace("kernel", "weight") # conv2d + torch_name = torch_name.replace("gamma", "weight") # bn + torch_name = torch_name.replace("beta", "bias") # bn + torch_name = torch_name.replace("moving.mean", "running_mean") # bn + torch_name = torch_name.replace("moving.variance", "running_var") # bn + + # assign weights + if torch_name in trainable_state_dict: + torch_weights = trainable_state_dict[torch_name].numpy() + elif torch_name in non_trainable_state_dict: + torch_weights = non_trainable_state_dict[torch_name].numpy() + else: + raise ValueError( + "Can't find the corresponding torch weights. " + f"Got keras_name={keras_name}, torch_name={torch_name}" + ) + if is_same_weights(keras_name, keras_weight, torch_name, torch_weights): + assign_weights(keras_name, keras_weight, torch_weights) + else: + raise ValueError( + "Can't find the corresponding torch weights. The shape is " + f"mismatched. Got keras_name={keras_name}, " + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" + ) + + """ + Verify model outputs + """ + np.random.seed(2023) + keras_data = np.random.uniform(size=[1] + input_shape).astype("float32") + torch_data = torch.from_numpy(np.transpose(keras_data, [0, 3, 1, 2])) + torch_y = torch_model(torch_data) + keras_y = keras_model(keras_data, training=False) + torch_y = torch_y.detach().cpu().numpy() + keras_y = keras.ops.convert_to_numpy(keras_y) + np.testing.assert_allclose(torch_y, keras_y, atol=2e-5) + print(f"{keras_model_class.__name__}: output matched!") + + """ + Save converted model + """ + export_path = f"exported/{keras_model.name.lower()}_imagenet.keras" + keras_model.save(export_path) + print(f"Export to {export_path}") diff --git a/tools/convert_resnet_from_timm.py b/tools/convert_resnet_from_timm.py index 980565d..eadd28c 100644 --- a/tools/convert_resnet_from_timm.py +++ b/tools/convert_resnet_from_timm.py @@ -102,7 +102,9 @@ raise ValueError( "Can't find the corresponding torch weights. The shape is " f"mismatched. Got keras_name={keras_name}, " - f"torch_name={torch_name}" + f"keras_weight shape={keras_weight.shape}, " + f"torch_name={torch_name}, " + f"torch_weights shape={torch_weights.shape}" ) """