Skip to content

Commit

Permalink
Basnet keras-3 migration compatibility with keras-hub (#2038)
Browse files Browse the repository at this point in the history
  • Loading branch information
laxmareddyp authored Jan 28, 2025
1 parent ba30e94 commit 6681f9e
Show file tree
Hide file tree
Showing 5 changed files with 1,428 additions and 809 deletions.
43 changes: 26 additions & 17 deletions examples/vision/basnet_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from glob import glob
import matplotlib.pyplot as plt

import keras_cv
import keras_hub
import tensorflow as tf
import keras
from keras import layers, ops
Expand Down Expand Up @@ -228,15 +228,19 @@ def segmentation_head(x_input, out_classes, final_size):
return x


def get_resnet_block(_resnet, block_num):
"""Extract and return ResNet-34 block."""
resnet_layers = [3, 4, 6, 3] # ResNet-34 layer sizes at different block.
def get_resnet_block(resnet, block_num):
"""Extract and return a ResNet-34 block."""
extractor_levels = ["P2", "P3", "P4", "P5"]
num_blocks = resnet.stackwise_num_blocks
if block_num == 0:
x = resnet.get_layer("pool1_pool").output
else:
x = resnet.pyramid_outputs[extractor_levels[block_num - 1]]
y = resnet.get_layer(f"stack{block_num}_block{num_blocks[block_num]-1}_add").output
return keras.models.Model(
inputs=_resnet.get_layer(f"v2_stack_{block_num}_block1_1_conv").input,
outputs=_resnet.get_layer(
f"v2_stack_{block_num}_block{resnet_layers[block_num]}_add"
).output,
name=f"resnet34_block{block_num + 1}",
inputs=x,
outputs=y,
name=f"resnet_block{block_num + 1}",
)


Expand All @@ -262,8 +266,13 @@ def basnet_predict(input_shape, out_classes):
# -------------Encoder--------------
x = layers.Conv2D(filters, kernel_size=(3, 3), padding="same")(x_input)

resnet = keras_cv.models.ResNet34Backbone(
include_rescaling=False,
resnet = keras_hub.models.ResNetBackbone(
input_conv_filters=[64],
input_conv_kernel_sizes=[7],
stackwise_num_filters=[64, 128, 256, 512],
stackwise_num_blocks=[3, 4, 6, 3],
stackwise_num_strides=[1, 2, 2, 2],
block_type="basic_block",
)

encoder_blocks = []
Expand Down Expand Up @@ -307,7 +316,7 @@ def basnet_predict(input_shape, out_classes):
for decoder_block in decoder_blocks
]

return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)
return keras.models.Model(inputs=x_input, outputs=decoder_blocks)


"""
Expand Down Expand Up @@ -352,7 +361,7 @@ def basnet_rrm(base_model, out_classes):
# ------------- refined = coarse + residual
x = layers.Add()([x_input, x]) # Add prediction + refinement output

return keras.models.Model(inputs=base_model.input[0], outputs=x)
return keras.models.Model(inputs=[base_model.input], outputs=[x])


"""
Expand All @@ -375,7 +384,7 @@ def __init__(self, input_shape, out_classes):

# Activations.
output = [layers.Activation("sigmoid")(x) for x in output]
super().__init__(inputs=predict_model.input[0], outputs=output)
super().__init__(inputs=predict_model.input, outputs=output)

self.smooth = 1.0e-9
# Binary Cross Entropy loss.
Expand Down Expand Up @@ -453,9 +462,9 @@ def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):
trainings parameters please check given link.
"""

"""shell
!gdown 1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg
"""
import gdown

gdown.download(id="1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg", output="basnet_weights.h5")


def normalize_output(prediction):
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
177 changes: 88 additions & 89 deletions examples/vision/ipynb/basnet_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Hamid Ali](https://github.com/hamidriasat)<br>\n",
"**Date created:** 2023/05/30<br>\n",
"**Last modified:** 2024/10/02<br>\n",
"**Last modified:** 2025/01/24<br>\n",
"**Description:** Boundaries aware segmentation model trained on the DUTS dataset."
]
},
Expand Down Expand Up @@ -68,10 +68,12 @@
"from glob import glob\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import keras_cv\n",
"import keras_hub\n",
"import tensorflow as tf\n",
"import keras\n",
"from keras import layers, ops"
"from keras import layers, ops\n",
"\n",
"keras.config.disable_traceback_filtering()"
]
},
{
Expand Down Expand Up @@ -117,10 +119,11 @@
},
"outputs": [],
"source": [
"DATA_DIR = keras.utils.get_file(\n",
"data_dir = keras.utils.get_file(\n",
" origin=\"http://saliencydetection.net/duts/download/DUTS-TE.zip\",\n",
" extract=True,\n",
")\n",
"data_dir = os.path.join(data_dir, \"DUTS-TE\")\n",
"\n",
"\n",
"def load_paths(path, split_ratio):\n",
Expand Down Expand Up @@ -159,7 +162,9 @@
" batch_x, batch_y = [], []\n",
" for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):\n",
" x, y = self.preprocess(\n",
" self.image_paths[i], self.mask_paths[i], self.img_size, self.out_classes\n",
" self.image_paths[i],\n",
" self.mask_paths[i],\n",
" self.img_size,\n",
" )\n",
" batch_x.append(x)\n",
" batch_y.append(y)\n",
Expand All @@ -173,13 +178,13 @@
" x = (x / 255.0).astype(np.float32)\n",
" return x\n",
"\n",
" def preprocess(self, x_batch, y_batch, img_size, out_classes):\n",
" def preprocess(self, x_batch, y_batch, img_size):\n",
" images = self.read_image(x_batch, (img_size, img_size), mode=\"rgb\") # image\n",
" masks = self.read_image(y_batch, (img_size, img_size), mode=\"grayscale\") # mask\n",
" return images, masks\n",
"\n",
"\n",
"train_paths, val_paths = load_paths(DATA_DIR, TRAIN_SPLIT_RATIO)\n",
"train_paths, val_paths = load_paths(data_dir, TRAIN_SPLIT_RATIO)\n",
"\n",
"train_dataset = Dataset(\n",
" train_paths[0], train_paths[1], IMAGE_SIZE, OUT_CLASSES, BATCH_SIZE, shuffle=True\n",
Expand Down Expand Up @@ -318,17 +323,20 @@
" return x\n",
"\n",
"\n",
"def get_resnet_block(_resnet, block_num):\n",
" \"\"\"Extract and return ResNet-34 block.\"\"\"\n",
" resnet_layers = [3, 4, 6, 3] # ResNet-34 layer sizes at different block.\n",
"def get_resnet_block(resnet, block_num):\n",
" \"\"\"Extract and return a ResNet-34 block.\"\"\"\n",
" extractor_levels = [\"P2\", \"P3\", \"P4\", \"P5\"]\n",
" num_blocks = resnet.stackwise_num_blocks\n",
" if block_num == 0:\n",
" x = resnet.get_layer(\"pool1_pool\").output\n",
" else:\n",
" x = resnet.pyramid_outputs[extractor_levels[block_num - 1]]\n",
" y = resnet.get_layer(f\"stack{block_num}_block{num_blocks[block_num]-1}_add\").output\n",
" return keras.models.Model(\n",
" inputs=_resnet.get_layer(f\"v2_stack_{block_num}_block1_1_conv\").input,\n",
" outputs=_resnet.get_layer(\n",
" f\"v2_stack_{block_num}_block{resnet_layers[block_num]}_add\"\n",
" ).output,\n",
" name=f\"resnet34_block{block_num + 1}\",\n",
" )\n",
""
" inputs=x,\n",
" outputs=y,\n",
" name=f\"resnet_block{block_num + 1}\",\n",
" )\n"
]
},
{
Expand Down Expand Up @@ -366,8 +374,13 @@
" # -------------Encoder--------------\n",
" x = layers.Conv2D(filters, kernel_size=(3, 3), padding=\"same\")(x_input)\n",
"\n",
" resnet = keras_cv.models.ResNet34Backbone(\n",
" include_rescaling=False,\n",
" resnet = keras_hub.models.ResNetBackbone(\n",
" input_conv_filters=[64],\n",
" input_conv_kernel_sizes=[7],\n",
" stackwise_num_filters=[64, 128, 256, 512],\n",
" stackwise_num_blocks=[3, 4, 6, 3],\n",
" stackwise_num_strides=[1, 2, 2, 2],\n",
" block_type=\"basic_block\",\n",
" )\n",
"\n",
" encoder_blocks = []\n",
Expand Down Expand Up @@ -411,8 +424,7 @@
" for decoder_block in decoder_blocks\n",
" ]\n",
"\n",
" return keras.models.Model(inputs=[x_input], outputs=decoder_blocks)\n",
""
" return keras.models.Model(inputs=x_input, outputs=decoder_blocks)\n"
]
},
{
Expand Down Expand Up @@ -470,8 +482,7 @@
" # ------------- refined = coarse + residual\n",
" x = layers.Add()([x_input, x]) # Add prediction + refinement output\n",
"\n",
" return keras.models.Model(inputs=[base_model.input], outputs=[x])\n",
""
" return keras.models.Model(inputs=[base_model.input], outputs=[x])\n"
]
},
{
Expand All @@ -492,22 +503,56 @@
"outputs": [],
"source": [
"\n",
"def basnet(input_shape, out_classes):\n",
" \"\"\"BASNet, it's a combination of two modules\n",
" Prediction Module and Residual Refinement Module(RRM).\"\"\"\n",
"class BASNet(keras.Model):\n",
" def __init__(self, input_shape, out_classes):\n",
" \"\"\"BASNet, it's a combination of two modules\n",
" Prediction Module and Residual Refinement Module(RRM).\"\"\"\n",
"\n",
" # Prediction model.\n",
" predict_model = basnet_predict(input_shape, out_classes)\n",
" # Refinement model.\n",
" refine_model = basnet_rrm(predict_model, out_classes)\n",
"\n",
" output = refine_model.outputs # Combine outputs.\n",
" output.extend(predict_model.output)\n",
"\n",
" # Activations.\n",
" output = [layers.Activation(\"sigmoid\")(x) for x in output]\n",
" super().__init__(inputs=predict_model.input, outputs=output)\n",
"\n",
" self.smooth = 1.0e-9\n",
" # Binary Cross Entropy loss.\n",
" self.cross_entropy_loss = keras.losses.BinaryCrossentropy()\n",
" # Structural Similarity Index value.\n",
" self.ssim_value = tf.image.ssim\n",
" # Jaccard / IoU loss.\n",
" self.iou_value = self.calculate_iou\n",
"\n",
" def calculate_iou(\n",
" self,\n",
" y_true,\n",
" y_pred,\n",
" ):\n",
" \"\"\"Calculate intersection over union (IoU) between images.\"\"\"\n",
" intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])\n",
" union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])\n",
" union = union - intersection\n",
" return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)\n",
"\n",
" # Prediction model.\n",
" predict_model = basnet_predict(input_shape, out_classes)\n",
" # Refinement model.\n",
" refine_model = basnet_rrm(predict_model, out_classes)\n",
" def compute_loss(self, x, y_true, y_pred, sample_weight=None, training=False):\n",
" total = 0.0\n",
" for y_pred_i in y_pred: # y_pred = refine_model.outputs + predict_model.output\n",
" cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred_i)\n",
"\n",
" output = refine_model.outputs # Combine outputs.\n",
" output.extend(predict_model.output)\n",
" ssim_value = self.ssim_value(y_true, y_pred, max_val=1)\n",
" ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)\n",
"\n",
" output = [layers.Activation(\"sigmoid\")(_) for _ in output] # Activations.\n",
" iou_value = self.iou_value(y_true, y_pred)\n",
" iou_loss = 1 - iou_value\n",
"\n",
" return keras.models.Model(inputs=[predict_model.input], outputs=output)\n",
""
" # Add all three losses.\n",
" total += cross_entropy_loss + ssim_loss + iou_loss\n",
" return total\n"
]
},
{
Expand All @@ -532,53 +577,14 @@
"outputs": [],
"source": [
"\n",
"class BasnetLoss(keras.losses.Loss):\n",
" \"\"\"BASNet hybrid loss.\"\"\"\n",
"\n",
" def __init__(self, **kwargs):\n",
" super().__init__(name=\"basnet_loss\", **kwargs)\n",
" self.smooth = 1.0e-9\n",
"\n",
" # Binary Cross Entropy loss.\n",
" self.cross_entropy_loss = keras.losses.BinaryCrossentropy()\n",
" # Structural Similarity Index value.\n",
" self.ssim_value = tf.image.ssim\n",
" # Jaccard / IoU loss.\n",
" self.iou_value = self.calculate_iou\n",
"\n",
" def calculate_iou(\n",
" self,\n",
" y_true,\n",
" y_pred,\n",
" ):\n",
" \"\"\"Calculate intersection over union (IoU) between images.\"\"\"\n",
" intersection = ops.sum(ops.abs(y_true * y_pred), axis=[1, 2, 3])\n",
" union = ops.sum(y_true, [1, 2, 3]) + ops.sum(y_pred, [1, 2, 3])\n",
" union = union - intersection\n",
" return ops.mean((intersection + self.smooth) / (union + self.smooth), axis=0)\n",
"\n",
" def call(self, y_true, y_pred):\n",
" cross_entropy_loss = self.cross_entropy_loss(y_true, y_pred)\n",
"\n",
" ssim_value = self.ssim_value(y_true, y_pred, max_val=1)\n",
" ssim_loss = ops.mean(1 - ssim_value + self.smooth, axis=0)\n",
"\n",
" iou_value = self.iou_value(y_true, y_pred)\n",
" iou_loss = 1 - iou_value\n",
"\n",
" # Add all three losses.\n",
" return cross_entropy_loss + ssim_loss + iou_loss\n",
"\n",
"\n",
"basnet_model = basnet(\n",
"basnet_model = BASNet(\n",
" input_shape=[IMAGE_SIZE, IMAGE_SIZE, 3], out_classes=OUT_CLASSES\n",
") # Create model.\n",
"basnet_model.summary() # Show model summary.\n",
"\n",
"optimizer = keras.optimizers.Adam(learning_rate=1e-4, epsilon=1e-8)\n",
"# Compile model.\n",
"basnet_model.compile(\n",
" loss=BasnetLoss(),\n",
" optimizer=optimizer,\n",
" metrics=[keras.metrics.MeanAbsoluteError(name=\"mae\") for _ in basnet_model.outputs],\n",
")"
Expand Down Expand Up @@ -631,17 +637,10 @@
},
"outputs": [],
"source": [
"!!gdown 1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import gdown\n",
"\n",
"gdown.download(id=\"1OWKouuAQ7XpXZbWA3mmxDPrFGW71Axrg\", output=\"basnet_weights.h5\")\n",
"\n",
"\n",
"def normalize_output(prediction):\n",
" max_value = np.max(prediction)\n",
Expand Down Expand Up @@ -686,7 +685,7 @@
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "evn1",
"language": "python",
"name": "python3"
},
Expand All @@ -700,9 +699,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading

0 comments on commit 6681f9e

Please sign in to comment.