Skip to content

UNet 3+ model added #385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Following models are supported:
| vgg_unet | VGG 16 | U-Net |
| resnet50_unet | Resnet-50 | U-Net |
| mobilenet_unet | MobileNet | U-Net |
| unet3_plus | Vanilla CNN | U-Net 3+ |
| segnet | Vanilla CNN | Segnet |
| vgg_segnet | VGG 16 | Segnet |
| resnet50_segnet | Resnet-50 | Segnet |
Expand Down
3 changes: 3 additions & 0 deletions keras_segmentation/models/all_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import pspnet
from . import unet
from . import unet3_plus
from . import segnet
from . import fcn
model_from_name = {}
Expand Down Expand Up @@ -35,6 +36,8 @@
model_from_name["resnet50_unet"] = unet.resnet50_unet
model_from_name["mobilenet_unet"] = unet.mobilenet_unet

model_from_name["unet3_plus"] = unet3_plus.unet3_plus


model_from_name["segnet"] = segnet.segnet
model_from_name["vgg_segnet"] = segnet.vgg_segnet
Expand Down
203 changes: 203 additions & 0 deletions keras_segmentation/models/unet3_plus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""
@Author: Hamid Ali
@Date: 10/18/2022
@GitHub: https://github.com/hamidriasat
@Gmail: [email protected]
"""
import tensorflow as tf
import keras as k
from keras.layers import *
from .model_utils import get_segmentation_model
from .config import IMAGE_ORDERING

if IMAGE_ORDERING == 'channels_first':
MERGE_AXIS = 1
elif IMAGE_ORDERING == 'channels_last':
MERGE_AXIS = -1

""" # Model Architecture """


def __conv_block(x, kernels, kernel_size=(3, 3), strides=(1, 1), padding='same',
is_bn=True, is_relu=True, n=2):
""" Custom function for conv2d:
Apply 3*3 convolutions with BN and relu.
"""
for i in range(1, n + 1):
x = k.layers.Conv2D(filters=kernels, kernel_size=kernel_size,
padding=padding, strides=strides,
kernel_regularizer=tf.keras.regularizers.l2(1e-4),
kernel_initializer=k.initializers.he_normal(seed=5))(x)
if is_bn:
x = k.layers.BatchNormalization()(x)
if is_relu:
x = k.activations.relu(x)

return x


def __dotProduct(seg, cls):
B, H, W, N = k.backend.int_shape(seg)
seg = tf.reshape(seg, [-1, H * W, N])
final = tf.einsum("ijk,ik->ijk", seg, cls)
final = tf.reshape(final, [-1, H, W, N])
return final


""" UNet_3Plus """


def __unet3_plus(n_classes, input_height=416, input_width=608, channels=3):
"""
Create model and pass it to segmentation head.
:param n_classes: number of output classes
:param input_height: input image height
:param input_width: input image width
:param channels: number of input channels
:return: image-segmentation-keras library compatible model
"""
assert input_height % 32 == 0
assert input_width % 32 == 0

if IMAGE_ORDERING == 'channels_first':
img_input = Input(shape=(channels, input_height, input_width), name="img_input")
elif IMAGE_ORDERING == 'channels_last':
img_input = Input(shape=(input_height, input_width, channels), name="img_input")

filters = [64, 128, 256, 512, 1024]

""" Encoder"""
# block 1
e1 = __conv_block(img_input, filters[0]) # 320*320*64

# block 2
e2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 160*160*64
e2 = __conv_block(e2, filters[1]) # 160*160*128

# block 3
e3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 80*80*128
e3 = __conv_block(e3, filters[2]) # 80*80*256

# block 4
e4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 40*40*256
e4 = __conv_block(e4, filters[3]) # 40*40*512

# block 5
# bottleneck layer
e5 = k.layers.MaxPool2D(pool_size=(2, 2))(e4) # 20*20*512
e5 = __conv_block(e5, filters[4]) # 20*20*1024

""" Decoder """
cat_channels = filters[0]
cat_blocks = len(filters)
upsample_channels = cat_blocks * cat_channels

""" d4 """
e1_d4 = k.layers.MaxPool2D(pool_size=(8, 8))(e1) # 320*320*64 --> 40*40*64
e1_d4 = __conv_block(e1_d4, cat_channels, n=1) # 320*320*64 --> 40*40*64

e2_d4 = k.layers.MaxPool2D(pool_size=(4, 4))(e2) # 160*160*128 --> 40*40*128
e2_d4 = __conv_block(e2_d4, cat_channels, n=1) # 160*160*128 --> 40*40*64

e3_d4 = k.layers.MaxPool2D(pool_size=(2, 2))(e3) # 80*80*256 --> 40*40*256
e3_d4 = __conv_block(e3_d4, cat_channels, n=1) # 80*80*256 --> 40*40*64

e4_d4 = __conv_block(e4, cat_channels, n=1) # 40*40*512 --> 40*40*64

e5_d4 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(e5) # 80*80*256 --> 40*40*256
e5_d4 = __conv_block(e5_d4, cat_channels, n=1) # 20*20*1024 --> 20*20*64

d4 = k.layers.concatenate([e1_d4, e2_d4, e3_d4, e4_d4, e5_d4])
d4 = __conv_block(d4, upsample_channels, n=1) # 40*40*320 --> 40*40*320

""" d3 """
e1_d3 = k.layers.MaxPool2D(pool_size=(4, 4))(e1) # 320*320*64 --> 80*80*64
e1_d3 = __conv_block(e1_d3, cat_channels, n=1) # 80*80*64 --> 80*80*64

e2_d3 = k.layers.MaxPool2D(pool_size=(2, 2))(e2) # 160*160*256 --> 80*80*256
e2_d3 = __conv_block(e2_d3, cat_channels, n=1) # 80*80*256 --> 80*80*64

e3_d3 = __conv_block(e3, cat_channels, n=1) # 80*80*512 --> 80*80*64

e4_d3 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d4) # 40*40*320 --> 80*80*320
e4_d3 = __conv_block(e4_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64

e5_d3 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(e5) # 20*20*320 --> 80*80*320
e5_d3 = __conv_block(e5_d3, cat_channels, n=1) # 80*80*320 --> 80*80*64

d3 = k.layers.concatenate([e1_d3, e2_d3, e3_d3, e4_d3, e5_d3])
d3 = __conv_block(d3, upsample_channels, n=1) # 80*80*320 --> 80*80*320

""" d2 """
e1_d2 = k.layers.MaxPool2D(pool_size=(2, 2))(e1) # 320*320*64 --> 160*160*64
e1_d2 = __conv_block(e1_d2, cat_channels, n=1) # 160*160*64 --> 160*160*64

e2_d2 = __conv_block(e2, cat_channels, n=1) # 160*160*256 --> 160*160*64

d3_d2 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d3) # 80*80*320 --> 160*160*320
d3_d2 = __conv_block(d3_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64

d4_d2 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d4) # 40*40*320 --> 160*160*320
d4_d2 = __conv_block(d4_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64

e5_d2 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(e5) # 20*20*320 --> 160*160*320
e5_d2 = __conv_block(e5_d2, cat_channels, n=1) # 160*160*320 --> 160*160*64

d2 = k.layers.concatenate([e1_d2, e2_d2, d3_d2, d4_d2, e5_d2])
d2 = __conv_block(d2, upsample_channels, n=1) # 160*160*320 --> 160*160*320

""" d1 """
e1_d1 = __conv_block(e1, cat_channels, n=1) # 320*320*64 --> 320*320*64

d2_d1 = k.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')(d2) # 160*160*320 --> 320*320*320
d2_d1 = __conv_block(d2_d1, cat_channels, n=1) # 160*160*320 --> 160*160*64

d3_d1 = k.layers.UpSampling2D(size=(4, 4), interpolation='bilinear')(d3) # 80*80*320 --> 320*320*320
d3_d1 = __conv_block(d3_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64

d4_d1 = k.layers.UpSampling2D(size=(8, 8), interpolation='bilinear')(d4) # 40*40*320 --> 320*320*320
d4_d1 = __conv_block(d4_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64

e5_d1 = k.layers.UpSampling2D(size=(16, 16), interpolation='bilinear')(e5) # 20*20*320 --> 320*320*320
e5_d1 = __conv_block(e5_d1, cat_channels, n=1) # 320*320*320 --> 320*320*64

d1 = k.layers.concatenate([e1_d1, d2_d1, d3_d1, d4_d1, e5_d1, ])
d1 = __conv_block(d1, upsample_channels, n=1) # 320*320*320 --> 320*320*320

# last layer does not have batchnorm and relu
d = __conv_block(d1, n_classes, n=1, is_bn=False, is_relu=False)

model = get_segmentation_model(img_input, d)

return model


def unet3_plus(n_classes: int, input_height: int = 416, input_width: int = 608, channels: int = 3):
"""
Create UNet3+ model based on image-segmentation-keras requirements
:param n_classes: number of output classes
:param input_height: input image height
:param input_width: input image width
:param channels: number of input channels
:return: image-segmentation-keras library compatible model
"""
model = __unet3_plus(
n_classes,
input_height=input_height,
input_width=input_width,
channels=channels
)
model.model_name = "unet3_plus"
return model


if __name__ == "__main__":
"""## Model Compilation"""
OUTPUT_CHANNELS = 50

__unet_3P = unet3_plus(OUTPUT_CHANNELS)
__unet_3P.summary()

# tf.keras.utils.plot_model(__unet_3P, show_layer_names=True, show_shapes=True)

# __unet_3P.save("unet_3P.hdf5")
Loading