Skip to content

Commit

Permalink
Add DenseNet, InceptionV3 and refactor BaseModel (#11)
Browse files Browse the repository at this point in the history
* Fix export name

* Add `DenseNet`

* Cleanup

* Add `InceptionV3`

* Refactor `BaseModel`

* Refactor `BaseModel`

* Simplify `build_preprocessing` and `build_top`

* Simplify code

* Format

* Mark serialization and skip them by default
  • Loading branch information
james77777778 authored Jan 17, 2024
1 parent ce979af commit d7804ac
Show file tree
Hide file tree
Showing 40 changed files with 2,142 additions and 1,245 deletions.
28 changes: 27 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import os

import pytest

def pytest_configure():

def pytest_addoption(parser):
parser.addoption(
"--run_serialization",
action="store_true",
default=False,
help="run serialization tests",
)


def pytest_configure(config):
import tensorflow as tf

# disable tensorflow gpu memory preallocation
Expand All @@ -12,3 +23,18 @@ def pytest_configure():
# disable jax gpu memory preallocation
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

config.addinivalue_line(
"markers", "serialization: mark test as a serialization test"
)


def pytest_collection_modifyitems(config, items):
run_serialization_tests = config.getoption("--run_serialization")
skip_serialization = pytest.mark.skipif(
not run_serialization_tests,
reason="need --run_serialization option to run",
)
for item in items:
if "serialization" in item.name:
item.add_marker(skip_serialization)
6 changes: 5 additions & 1 deletion kimm/blocks/base_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def apply_conv2d_block(
raise ValueError(
f"kernel_size must be passed. Received: kernel_size={kernel_size}"
)
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
input_channels = inputs.shape[-1]
has_skip = add_skip and strides == 1 and input_channels == filters
x = inputs
Expand All @@ -42,7 +44,9 @@ def apply_conv2d_block(
padding = "same"
if strides > 1:
padding = "valid"
x = layers.ZeroPadding2D(kernel_size // 2, name=f"{name}_pad")(x)
x = layers.ZeroPadding2D(
(kernel_size[0] // 2, kernel_size[1] // 2), name=f"{name}_pad"
)(x)

if not use_depthwise:
x = layers.Conv2D(
Expand Down
4 changes: 2 additions & 2 deletions kimm/blocks/inverted_residual_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def apply_inverted_residual_block(
expansion_ratio=1.0,
se_ratio=0.0,
activation="swish",
se_input_channels=None,
se_channels=None,
se_activation=None,
se_gate_activation="sigmoid",
se_make_divisible_number=None,
Expand Down Expand Up @@ -57,7 +57,7 @@ def apply_inverted_residual_block(
se_ratio,
activation=se_activation or activation,
gate_activation=se_gate_activation,
se_input_channels=se_input_channels,
se_input_channels=se_channels,
make_divisible_number=se_make_divisible_number,
name=f"{name}_se",
)
Expand Down
15 changes: 0 additions & 15 deletions kimm/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,3 @@ def get_config(self):
}
)
return config


if __name__ == "__main__":
from keras import models
from keras import random

inputs = layers.Input(shape=[197, 768])
outputs = Attention(768)(inputs)

model = models.Model(inputs, outputs)
model.summary()

inputs = random.uniform([1, 197, 768])
outputs = model(inputs)
print(outputs.shape)
15 changes: 0 additions & 15 deletions kimm/layers/layer_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,3 @@ def get_config(self):
}
)
return config


if __name__ == "__main__":
from keras import models
from keras import random

inputs = layers.Input(shape=[197, 768])
outputs = LayerScale(768)(inputs)

model = models.Model(inputs, outputs)
model.summary()

inputs = random.uniform([1, 197, 768])
outputs = model(inputs)
print(outputs.shape)
22 changes: 0 additions & 22 deletions kimm/layers/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,25 +38,3 @@ def compute_output_shape(self, input_shape):

def get_config(self):
return super().get_config()


if __name__ == "__main__":
from keras import models
from keras import random

inputs = layers.Input([224, 224, 3])
x = layers.Conv2D(
768,
16,
16,
use_bias=True,
)(inputs)
x = layers.Reshape((-1, 768))(x)
outputs = PositionEmbedding()(x)

model = models.Model(inputs, outputs)
model.summary()

inputs = random.uniform([1, 224, 224, 3])
outputs = model(inputs)
print(outputs.shape)
2 changes: 1 addition & 1 deletion kimm/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from kimm.models.base_model import BaseModel
from kimm.models.efficientnet import * # noqa:F403
from kimm.models.feature_extractor import FeatureExtractor
from kimm.models.ghostnet import * # noqa:F403
from kimm.models.mobilenet_v2 import * # noqa:F403
from kimm.models.mobilenet_v3 import * # noqa:F403
Expand Down
157 changes: 157 additions & 0 deletions kimm/models/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import abc
import typing

from keras import KerasTensor
from keras import backend
from keras import layers
from keras import models
from keras.src.applications import imagenet_utils


class BaseModel(models.Model):
def __init__(
self,
inputs,
outputs,
features: typing.Optional[typing.Dict[str, KerasTensor]] = None,
feature_keys: typing.Optional[typing.List[str]] = None,
**kwargs,
):
self.as_feature_extractor = kwargs.pop("as_feature_extractor", False)
self.feature_keys = feature_keys
if self.as_feature_extractor:
if features is None:
raise ValueError(
"`features` must be set when "
f"`as_feature_extractor=True`. Got features={features}"
)
if self.feature_keys is None:
self.feature_keys = list(features.keys())
filtered_features = {}
for k in self.feature_keys:
if k not in features:
raise KeyError(
f"'{k}' is not a key of `features`. Available keys "
f"are: {list(features.keys())}"
)
filtered_features[k] = features[k]
super().__init__(inputs=inputs, outputs=filtered_features, **kwargs)
else:
del features
super().__init__(inputs=inputs, outputs=outputs, **kwargs)

def parse_kwargs(
self, kwargs: typing.Dict[str, typing.Any], default_size: int = 224
):
result = {
"input_tensor": kwargs.pop("input_tensor", None),
"input_shape": kwargs.pop("input_shape", None),
"include_preprocessing": kwargs.pop("include_preprocessing", True),
"include_top": kwargs.pop("include_top", True),
"pooling": kwargs.pop("pooling", None),
"dropout_rate": kwargs.pop("dropout_rate", 0.0),
"classes": kwargs.pop("classes", 1000),
"classifier_activation": kwargs.pop(
"classifier_activation", "softmax"
),
"weights": kwargs.pop("weights", "imagenet"),
"default_size": kwargs.pop("default_size", default_size),
}
return result

def determine_input_tensor(
self,
input_tensor=None,
input_shape=None,
default_size=224,
min_size=32,
require_flatten=False,
static_shape=False,
):
"""Determine the input tensor by the arguments."""
input_shape = imagenet_utils.obtain_input_shape(
input_shape,
default_size=default_size,
min_size=min_size,
data_format="channels_last", # always channels_last
require_flatten=require_flatten or static_shape,
weights=None,
)

if input_tensor is None:
x = layers.Input(shape=input_shape)
else:
if not backend.is_keras_tensor(input_tensor):
x = layers.Input(tensor=input_tensor, shape=input_shape)
else:
x = input_tensor
return x

def build_preprocessing(self, inputs, mode="imagenet"):
if mode == "imagenet":
# [0, 255] to [0, 1] and apply ImageNet mean and variance
x = layers.Rescaling(scale=1.0 / 255.0)(inputs)
x = layers.Normalization(
mean=[0.485, 0.456, 0.406], variance=[0.229, 0.224, 0.225]
)(x)
elif mode == "0_1":
# [0, 255] to [-1, 1]
x = layers.Rescaling(scale=1.0 / 255.0)(inputs)
elif mode == "-1_1":
# [0, 255] to [-1, 1]
x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(inputs)
else:
raise ValueError(
"`mode` must be one of ('imagenet', '0_1', '-1_1'). "
f"Received: mode={mode}"
)
return x

def build_top(self, inputs, classes, classifier_activation, dropout_rate):
x = layers.GlobalAveragePooling2D(name="avg_pool")(inputs)
x = layers.Dropout(rate=dropout_rate, name="head_dropout")(x)
x = layers.Dense(
classes, activation=classifier_activation, name="classifier"
)(x)
return x

def add_references(self, parsed_kwargs: typing.Dict[str, typing.Any]):
self.include_preprocessing = parsed_kwargs["include_preprocessing"]
self.include_top = parsed_kwargs["include_top"]
self.pooling = parsed_kwargs["pooling"]
self.dropout_rate = parsed_kwargs["dropout_rate"]
self.classes = parsed_kwargs["classes"]
self.classifier_activation = parsed_kwargs["classifier_activation"]
# `self.weights` is been used internally
self._weights = parsed_kwargs["weights"]

@staticmethod
@abc.abstractmethod
def available_feature_keys():
# TODO: add docstring
raise NotImplementedError

def get_config(self):
# Don't chain to super here. The default `get_config()` for functional
# models is nested and cannot be passed to BaseModel.
config = {
# models.Model
"name": self.name,
"trainable": self.trainable,
# feature extractor
"as_feature_extractor": self.as_feature_extractor,
"feature_keys": self.feature_keys,
# common
"input_shape": self.input_shape[1:],
"include_preprocessing": self.include_preprocessing,
"include_top": self.include_top,
"pooling": self.pooling,
"dropout_rate": self.dropout_rate,
"classes": self.classes,
"classifier_activation": self.classifier_activation,
"weights": self._weights,
}
return config

def fix_config(self, config: typing.Dict):
return config
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from keras import random
from keras.src import testing

from kimm.models.feature_extractor import FeatureExtractor
from kimm.models.base_model import BaseModel


class SampleModel(FeatureExtractor):
class SampleModel(BaseModel):
def __init__(self, **kwargs):
inputs = layers.Input(shape=[224, 224, 3])

Expand Down Expand Up @@ -34,7 +34,7 @@ def get_config(self):
return super().get_config()


class GhostNetTest(testing.TestCase, parameterized.TestCase):
class BaseModelTest(testing.TestCase, parameterized.TestCase):
def test_feature_extractor(self):
x = random.uniform([1, 224, 224, 3])

Expand Down
Loading

0 comments on commit d7804ac

Please sign in to comment.