Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Differential Binarization model from PaddleOCR to Keras3 #1739

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
49f6bb1
Add `DifferentialBinarization` model
gowthamkpr Sep 13, 2024
5b4e011
Added tests for `DifferentialBinarization` losses
gowthamkpr Oct 22, 2024
12ab81c
Moved `DifferentialBinarization` to keras_hub
gowthamkpr Oct 22, 2024
e68512c
Renamed to `differential_binarization.py`
gowthamkpr Oct 22, 2024
0c3235c
Refactorings for `DifferentialBinarization`
gowthamkpr Oct 22, 2024
6797231
More refactorings
gowthamkpr Oct 22, 2024
4845b6a
Fix tests
gowthamkpr Oct 22, 2024
83edf9a
Add preprocessor and image converter
gowthamkpr Oct 29, 2024
f15b7b9
Add presets
gowthamkpr Oct 29, 2024
392dbff
Run formatting script
gowthamkpr Oct 29, 2024
db70eb5
Impl additional tests
gowthamkpr Oct 29, 2024
18fcbfb
Fixed formatting
gowthamkpr Oct 29, 2024
898235d
Removed copyright statements
gowthamkpr Oct 29, 2024
eaec868
Fix tests, run `api_gen.sh`
gowthamkpr Oct 29, 2024
21b6312
Merge branch 'master' into diffbin
gowthamkpr Oct 29, 2024
9fb6e65
Addressed comments
gowthamkpr Nov 11, 2024
83b66ed
Merge with local branch
gowthamkpr Nov 11, 2024
e4a334d
Fixed torch and jax tests
gowthamkpr Nov 13, 2024
49d6f6d
Improved code readability
gowthamkpr Nov 14, 2024
d96b899
Improved/added docstrings
gowthamkpr Nov 22, 2024
2f27981
Added `ImageTextDetector` task
gowthamkpr Nov 25, 2024
66afeb9
Run `api_gen.sh`
gowthamkpr Nov 25, 2024
1d91e76
Fix tensor/array usage
gowthamkpr Dec 17, 2024
e999ad9
Sync with master (new linter)
gowthamkpr Dec 17, 2024
3fd3b6f
Shorten docstring
gowthamkpr Dec 17, 2024
af934f5
Renamed to DiffBin
gowthamkpr Dec 18, 2024
c111bd1
Fixed docstring
gowthamkpr Jan 22, 2025
738786c
Rename `DiffBinOCR` -> `DiffBinImageTextDetector`
gowthamkpr Jan 22, 2025
c06d6cb
Merge branch 'keras-hub' into diffbin3
gowthamkpr Jan 27, 2025
f047012
`diffbin_ocr_test.py` -> `diffbin_textdetector_test.py`
gowthamkpr Jan 27, 2025
26c16e4
Added weight conversion script
gowthamkpr Jan 27, 2025
464482c
Fixed formatting
gowthamkpr Jan 27, 2025
ca83afe
Corrected a few comments
gowthamkpr Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
from keras_hub.src.models.densenet.densenet_image_converter import (
DenseNetImageConverter,
)
from keras_hub.src.models.diffbin.diffbin_image_converter import (
DiffBinImageConverter,
)
from keras_hub.src.models.efficientnet.efficientnet_image_converter import (
EfficientNetImageConverter,
)
Expand Down
8 changes: 8 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@
from keras_hub.src.models.densenet.densenet_image_classifier_preprocessor import (
DenseNetImageClassifierPreprocessor,
)
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone
from keras_hub.src.models.diffbin.diffbin_preprocessor import (
DiffBinPreprocessor,
)
from keras_hub.src.models.diffbin.diffbin_textdetector import (
DiffBinImageTextDetector,
)
from keras_hub.src.models.distil_bert.distil_bert_backbone import (
DistilBertBackbone,
)
Expand Down Expand Up @@ -201,6 +208,7 @@
from keras_hub.src.models.image_segmenter_preprocessor import (
ImageSegmenterPreprocessor,
)
from keras_hub.src.models.image_text_detector import ImageTextDetector
from keras_hub.src.models.image_to_image import ImageToImage
from keras_hub.src.models.inpaint import Inpaint
from keras_hub.src.models.llama.llama_backbone import LlamaBackbone
Expand Down
5 changes: 5 additions & 0 deletions keras_hub/src/models/diffbin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone
from keras_hub.src.models.diffbin.diffbin_presets import backbone_presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, DiffBinBackbone)
220 changes: 220 additions & 0 deletions keras_hub/src/models/diffbin/diffbin_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import keras
from keras import layers

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.DiffBinBackbone")
class DiffBinBackbone(Backbone):
"""Differentiable Binarization architecture for scene text detection.

This class implements the Differentiable Binarization architecture for
detecting text in natural images, described in
[Real-time Scene Text Detection with Differentiable Binarization](
https://arxiv.org/abs/1911.08947).

The backbone architecture in this class contains the feature pyramid
network and model heads.

Args:
image_encoder: A `keras_hub.models.ResNetBackbone` instance.
fpn_channels: int. The number of channels to output by the feature
pyramid network. Defaults to 256.
head_kernel_list: list of ints. The kernel sizes of probability map and
threshold map heads. Defaults to [3, 2, 2].
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
to use for the model's computations and weights.
"""

def __init__(
self,
image_encoder,
fpn_channels=256,
head_kernel_list=[3, 2, 2],
dtype=None,
**kwargs,
):
# === Functional Model ===
inputs = image_encoder.input
x = image_encoder.pyramid_outputs
x = diffbin_fpn_model(x, out_channels=fpn_channels, dtype=dtype)

probability_maps = diffbin_head(
x,
in_channels=fpn_channels,
kernel_list=head_kernel_list,
name="head_prob",
)
threshold_maps = diffbin_head(
x,
in_channels=fpn_channels,
kernel_list=head_kernel_list,
name="head_thresh",
)

outputs = {
"probability_maps": probability_maps,
"threshold_maps": threshold_maps,
}

super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs)

# === Config ===
self.image_encoder = image_encoder
self.fpn_channels = fpn_channels
self.head_kernel_list = head_kernel_list

def get_config(self):
config = super().get_config()
config["fpn_channels"] = self.fpn_channels
config["head_kernel_list"] = self.head_kernel_list
config["image_encoder"] = keras.layers.serialize(self.image_encoder)
return config

@classmethod
def from_config(cls, config):
config["image_encoder"] = keras.layers.deserialize(
config["image_encoder"]
)
return cls(**config)


def diffbin_fpn_model(inputs, out_channels, dtype=None):
# lateral layers composing the FPN's bottom-up pathway using
# pointwise convolutions of ResNet's pyramid outputs
lateral_p2 = layers.Conv2D(
out_channels,
kernel_size=1,
use_bias=False,
name="neck_lateral_p2",
dtype=dtype,
)(inputs["P2"])
lateral_p3 = layers.Conv2D(
out_channels,
kernel_size=1,
use_bias=False,
name="neck_lateral_p3",
dtype=dtype,
)(inputs["P3"])
lateral_p4 = layers.Conv2D(
out_channels,
kernel_size=1,
use_bias=False,
name="neck_lateral_p4",
dtype=dtype,
)(inputs["P4"])
lateral_p5 = layers.Conv2D(
out_channels,
kernel_size=1,
use_bias=False,
name="neck_lateral_p5",
dtype=dtype,
)(inputs["P5"])
# top-down fusion pathway consisting of upsampling layers with
# skip connections
topdown_p5 = lateral_p5
topdown_p4 = layers.Add(name="neck_topdown_p4")(
[
layers.UpSampling2D(dtype=dtype)(topdown_p5),
lateral_p4,
]
)
topdown_p3 = layers.Add(name="neck_topdown_p3")(
[
layers.UpSampling2D(dtype=dtype)(topdown_p4),
lateral_p3,
]
)
topdown_p2 = layers.Add(name="neck_topdown_p2")(
[
layers.UpSampling2D(dtype=dtype)(topdown_p3),
lateral_p2,
]
)
# construct merged feature maps for each pyramid level
featuremap_p5 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_featuremap_p5",
dtype=dtype,
)(topdown_p5)
featuremap_p4 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_featuremap_p4",
dtype=dtype,
)(topdown_p4)
featuremap_p3 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_featuremap_p3",
dtype=dtype,
)(topdown_p3)
featuremap_p2 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_featuremap_p2",
dtype=dtype,
)(topdown_p2)
featuremap_p5 = layers.UpSampling2D((8, 8), dtype=dtype)(featuremap_p5)
featuremap_p4 = layers.UpSampling2D((4, 4), dtype=dtype)(featuremap_p4)
featuremap_p3 = layers.UpSampling2D((2, 2), dtype=dtype)(featuremap_p3)
featuremap = layers.Concatenate(axis=-1, dtype=dtype)(
[featuremap_p5, featuremap_p4, featuremap_p3, featuremap_p2]
)
return featuremap


def diffbin_head(inputs, in_channels, kernel_list, name):
x = layers.Conv2D(
in_channels // 4,
kernel_size=kernel_list[0],
padding="same",
use_bias=False,
name=f"{name}_conv0_weights",
)(inputs)
x = layers.BatchNormalization(
beta_initializer=keras.initializers.Constant(1e-4),
gamma_initializer=keras.initializers.Constant(1.0),
name=f"{name}_conv0_bn",
)(x)
x = layers.ReLU(name=f"{name}_conv0_relu")(x)
x = layers.Conv2DTranspose(
in_channels // 4,
kernel_size=kernel_list[1],
strides=2,
padding="valid",
bias_initializer=keras.initializers.RandomUniform(
minval=-1.0 / (in_channels // 4 * 1.0) ** 0.5,
maxval=1.0 / (in_channels // 4 * 1.0) ** 0.5,
),
name=f"{name}_conv1_weights",
)(x)
x = layers.BatchNormalization(
beta_initializer=keras.initializers.Constant(1e-4),
gamma_initializer=keras.initializers.Constant(1.0),
name=f"{name}_conv1_bn",
)(x)
x = layers.ReLU(name=f"{name}_conv1_relu")(x)
x = layers.Conv2DTranspose(
1,
kernel_size=kernel_list[2],
strides=2,
padding="valid",
activation="sigmoid",
bias_initializer=keras.initializers.RandomUniform(
minval=-1.0 / (in_channels // 4 * 1.0) ** 0.5,
maxval=1.0 / (in_channels // 4 * 1.0) ** 0.5,
),
name=f"{name}_conv2_weights",
)(x)
return x
42 changes: 42 additions & 0 deletions keras_hub/src/models/diffbin/diffbin_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from keras import ops

from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone
from keras_hub.src.models.diffbin.diffbin_preprocessor import (
DiffBinPreprocessor,
)
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
from keras_hub.src.tests.test_case import TestCase


class DiffBinTest(TestCase):
def setUp(self):
self.images = ops.ones((2, 32, 32, 3))
self.image_encoder = ResNetBackbone(
input_conv_filters=[4],
input_conv_kernel_sizes=[7],
stackwise_num_filters=[64, 4, 4, 4],
stackwise_num_blocks=[3, 4, 6, 3],
stackwise_num_strides=[1, 2, 2, 2],
block_type="bottleneck_block",
image_shape=(32, 32, 3),
)
self.preprocessor = DiffBinPreprocessor()
self.init_kwargs = {
"image_encoder": self.image_encoder,
"fpn_channels": 16,
"head_kernel_list": [3, 2, 2],
}

def test_backbone_basics(self):
expected_output_shape = {
"probability_maps": (2, 32, 32, 1),
"threshold_maps": (2, 32, 32, 1),
}
self.run_backbone_test(
cls=DiffBinBackbone,
init_kwargs=self.init_kwargs,
input_data=self.images,
expected_output_shape=expected_output_shape,
run_mixed_precision_check=False,
run_quantization_check=False,
)
8 changes: 8 additions & 0 deletions keras_hub/src/models/diffbin/diffbin_image_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone


@keras_hub_export("keras_hub.layers.DiffBinImageConverter")
class DiffBinImageConverter(ImageConverter):
backbone_cls = DiffBinBackbone
14 changes: 14 additions & 0 deletions keras_hub/src/models/diffbin/diffbin_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone
from keras_hub.src.models.diffbin.diffbin_image_converter import (
DiffBinImageConverter,
)
from keras_hub.src.models.image_segmenter_preprocessor import (
ImageSegmenterPreprocessor,
)


@keras_hub_export("keras_hub.models.DiffBinPreprocessor")
class DiffBinPreprocessor(ImageSegmenterPreprocessor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this model is going to be returning polygons - we might need to change the parent class

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. The functionality that we require here is pretty similar, though (resize the output mask in addition to the model's output). So basically, we'd have to copy ImageSegmenterPreprocessor.

backbone_cls = DiffBinBackbone
image_converter_cls = DiffBinImageConverter
17 changes: 17 additions & 0 deletions keras_hub/src/models/diffbin/diffbin_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Differentiable Binarization preset configurations."""

backbone_presets = {
"diffbin_r50vd_icdar2015": {
"metadata": {
"description": (
"Differentiable Binarization using 50-layer"
"ResNetVD trained on the ICDAR2015 dataset."
),
"params": 25482722,
"official_name": "DifferentiableBinarization",
"path": "diffbin",
"model_card": "https://arxiv.org/abs/1911.08947",
},
"kaggle_handle": "", # TODO
}
}
Loading
Loading