Skip to content

Commit

Permalink
Add LCNet and fix model serialization (#10)
Browse files Browse the repository at this point in the history
* Merge `apply_depthwise_separation_block`

* Add `LCNet`

* Speed up gpu test

* Cleanup

* Update `add_model_to_registry`

* Fix model serialization
  • Loading branch information
james77777778 authored Jan 15, 2024
1 parent 7a0f2e7 commit ce979af
Show file tree
Hide file tree
Showing 22 changed files with 776 additions and 201 deletions.
7 changes: 7 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@


def pytest_configure():
import tensorflow as tf

# disable tensorflow gpu memory preallocation
physical_devices = tf.config.list_physical_devices("GPU")
for device in physical_devices:
tf.config.experimental.set_memory_growth(device, True)

# disable jax gpu memory preallocation
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
3 changes: 3 additions & 0 deletions kimm/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from kimm.blocks.base_block import apply_activation
from kimm.blocks.base_block import apply_conv2d_block
from kimm.blocks.base_block import apply_se_block
from kimm.blocks.depthwise_separation_block import (
apply_depthwise_separation_block,
)
from kimm.blocks.inverted_residual_block import apply_inverted_residual_block
from kimm.blocks.transformer_block import apply_mlp_block
from kimm.blocks.transformer_block import apply_transformer_block
59 changes: 59 additions & 0 deletions kimm/blocks/depthwise_separation_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from keras import layers

from kimm.blocks.base_block import apply_conv2d_block
from kimm.blocks.base_block import apply_se_block


def apply_depthwise_separation_block(
inputs,
output_channels,
depthwise_kernel_size=3,
pointwise_kernel_size=1,
strides=1,
se_ratio=0.0,
activation="swish",
se_activation="relu",
se_gate_activation="sigmoid",
se_make_divisible_number=None,
pw_activation=None,
skip=True,
bn_epsilon=1e-5,
padding=None,
name="depthwise_separation_block",
):
input_channels = inputs.shape[-1]
has_skip = skip and (strides == 1 and input_channels == output_channels)

x = inputs
x = apply_conv2d_block(
x,
kernel_size=depthwise_kernel_size,
strides=strides,
activation=activation,
use_depthwise=True,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_dw",
)
if se_ratio > 0:
x = apply_se_block(
x,
se_ratio,
activation=se_activation,
gate_activation=se_gate_activation,
make_divisible_number=se_make_divisible_number,
name=f"{name}_se",
)
x = apply_conv2d_block(
x,
output_channels,
pointwise_kernel_size,
1,
activation=pw_activation,
bn_epsilon=bn_epsilon,
padding=padding,
name=f"{name}_conv_pw",
)
if has_skip:
x = layers.Add()([x, inputs])
return x
Loading

0 comments on commit ce979af

Please sign in to comment.