-
Notifications
You must be signed in to change notification settings - Fork 254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Differential Binarization model from PaddleOCR to Keras3 #1739
Open
gowthamkpr
wants to merge
33
commits into
keras-team:master
Choose a base branch
from
gowthamkpr:diffbin
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
49f6bb1
Add `DifferentialBinarization` model
gowthamkpr 5b4e011
Added tests for `DifferentialBinarization` losses
gowthamkpr 12ab81c
Moved `DifferentialBinarization` to keras_hub
gowthamkpr e68512c
Renamed to `differential_binarization.py`
gowthamkpr 0c3235c
Refactorings for `DifferentialBinarization`
gowthamkpr 6797231
More refactorings
gowthamkpr 4845b6a
Fix tests
gowthamkpr 83edf9a
Add preprocessor and image converter
gowthamkpr f15b7b9
Add presets
gowthamkpr 392dbff
Run formatting script
gowthamkpr db70eb5
Impl additional tests
gowthamkpr 18fcbfb
Fixed formatting
gowthamkpr 898235d
Removed copyright statements
gowthamkpr eaec868
Fix tests, run `api_gen.sh`
gowthamkpr 21b6312
Merge branch 'master' into diffbin
gowthamkpr 9fb6e65
Addressed comments
gowthamkpr 83b66ed
Merge with local branch
gowthamkpr e4a334d
Fixed torch and jax tests
gowthamkpr 49d6f6d
Improved code readability
gowthamkpr d96b899
Improved/added docstrings
gowthamkpr 2f27981
Added `ImageTextDetector` task
gowthamkpr 66afeb9
Run `api_gen.sh`
gowthamkpr 1d91e76
Fix tensor/array usage
gowthamkpr e999ad9
Sync with master (new linter)
gowthamkpr 3fd3b6f
Shorten docstring
gowthamkpr af934f5
Renamed to DiffBin
gowthamkpr c111bd1
Fixed docstring
gowthamkpr 738786c
Rename `DiffBinOCR` -> `DiffBinImageTextDetector`
gowthamkpr c06d6cb
Merge branch 'keras-hub' into diffbin3
gowthamkpr f047012
`diffbin_ocr_test.py` -> `diffbin_textdetector_test.py`
gowthamkpr 26c16e4
Added weight conversion script
gowthamkpr 464482c
Fixed formatting
gowthamkpr ca83afe
Corrected a few comments
gowthamkpr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone | ||
from keras_hub.src.models.diffbin.diffbin_presets import backbone_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(backbone_presets, DiffBinBackbone) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import keras | ||
from keras import layers | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
|
||
|
||
@keras_hub_export("keras_hub.models.DiffBinBackbone") | ||
class DiffBinBackbone(Backbone): | ||
"""Differentiable Binarization architecture for scene text detection. | ||
|
||
This class implements the Differentiable Binarization architecture for | ||
detecting text in natural images, described in | ||
[Real-time Scene Text Detection with Differentiable Binarization]( | ||
https://arxiv.org/abs/1911.08947). | ||
|
||
The backbone architecture in this class contains the feature pyramid | ||
network and model heads. | ||
|
||
Args: | ||
image_encoder: A `keras_hub.models.ResNetBackbone` instance. | ||
fpn_channels: int. The number of channels to output by the feature | ||
pyramid network. Defaults to 256. | ||
head_kernel_list: list of ints. The kernel sizes of probability map and | ||
threshold map heads. Defaults to [3, 2, 2]. | ||
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype | ||
to use for the model's computations and weights. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
image_encoder, | ||
fpn_channels=256, | ||
head_kernel_list=[3, 2, 2], | ||
dtype=None, | ||
**kwargs, | ||
): | ||
# === Functional Model === | ||
inputs = image_encoder.input | ||
x = image_encoder.pyramid_outputs | ||
x = diffbin_fpn_model(x, out_channels=fpn_channels, dtype=dtype) | ||
|
||
probability_maps = diffbin_head( | ||
x, | ||
in_channels=fpn_channels, | ||
kernel_list=head_kernel_list, | ||
name="head_prob", | ||
) | ||
threshold_maps = diffbin_head( | ||
x, | ||
in_channels=fpn_channels, | ||
kernel_list=head_kernel_list, | ||
name="head_thresh", | ||
) | ||
|
||
outputs = { | ||
"probability_maps": probability_maps, | ||
"threshold_maps": threshold_maps, | ||
} | ||
|
||
super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs) | ||
|
||
# === Config === | ||
self.image_encoder = image_encoder | ||
self.fpn_channels = fpn_channels | ||
self.head_kernel_list = head_kernel_list | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config["fpn_channels"] = self.fpn_channels | ||
config["head_kernel_list"] = self.head_kernel_list | ||
config["image_encoder"] = keras.layers.serialize(self.image_encoder) | ||
return config | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
config["image_encoder"] = keras.layers.deserialize( | ||
config["image_encoder"] | ||
) | ||
return cls(**config) | ||
|
||
|
||
def diffbin_fpn_model(inputs, out_channels, dtype=None): | ||
# lateral layers composing the FPN's bottom-up pathway using | ||
# pointwise convolutions of ResNet's pyramid outputs | ||
lateral_p2 = layers.Conv2D( | ||
out_channels, | ||
kernel_size=1, | ||
use_bias=False, | ||
name="neck_lateral_p2", | ||
dtype=dtype, | ||
)(inputs["P2"]) | ||
lateral_p3 = layers.Conv2D( | ||
out_channels, | ||
kernel_size=1, | ||
use_bias=False, | ||
name="neck_lateral_p3", | ||
dtype=dtype, | ||
)(inputs["P3"]) | ||
lateral_p4 = layers.Conv2D( | ||
out_channels, | ||
kernel_size=1, | ||
use_bias=False, | ||
name="neck_lateral_p4", | ||
dtype=dtype, | ||
)(inputs["P4"]) | ||
lateral_p5 = layers.Conv2D( | ||
out_channels, | ||
kernel_size=1, | ||
use_bias=False, | ||
name="neck_lateral_p5", | ||
dtype=dtype, | ||
)(inputs["P5"]) | ||
# top-down fusion pathway consisting of upsampling layers with | ||
# skip connections | ||
topdown_p5 = lateral_p5 | ||
topdown_p4 = layers.Add(name="neck_topdown_p4")( | ||
[ | ||
layers.UpSampling2D(dtype=dtype)(topdown_p5), | ||
lateral_p4, | ||
] | ||
) | ||
topdown_p3 = layers.Add(name="neck_topdown_p3")( | ||
[ | ||
layers.UpSampling2D(dtype=dtype)(topdown_p4), | ||
lateral_p3, | ||
] | ||
) | ||
topdown_p2 = layers.Add(name="neck_topdown_p2")( | ||
[ | ||
layers.UpSampling2D(dtype=dtype)(topdown_p3), | ||
lateral_p2, | ||
] | ||
) | ||
# construct merged feature maps for each pyramid level | ||
featuremap_p5 = layers.Conv2D( | ||
out_channels // 4, | ||
kernel_size=3, | ||
padding="same", | ||
use_bias=False, | ||
name="neck_featuremap_p5", | ||
dtype=dtype, | ||
)(topdown_p5) | ||
featuremap_p4 = layers.Conv2D( | ||
out_channels // 4, | ||
kernel_size=3, | ||
padding="same", | ||
use_bias=False, | ||
name="neck_featuremap_p4", | ||
dtype=dtype, | ||
)(topdown_p4) | ||
featuremap_p3 = layers.Conv2D( | ||
out_channels // 4, | ||
kernel_size=3, | ||
padding="same", | ||
use_bias=False, | ||
name="neck_featuremap_p3", | ||
dtype=dtype, | ||
)(topdown_p3) | ||
featuremap_p2 = layers.Conv2D( | ||
out_channels // 4, | ||
kernel_size=3, | ||
padding="same", | ||
use_bias=False, | ||
name="neck_featuremap_p2", | ||
dtype=dtype, | ||
)(topdown_p2) | ||
featuremap_p5 = layers.UpSampling2D((8, 8), dtype=dtype)(featuremap_p5) | ||
featuremap_p4 = layers.UpSampling2D((4, 4), dtype=dtype)(featuremap_p4) | ||
featuremap_p3 = layers.UpSampling2D((2, 2), dtype=dtype)(featuremap_p3) | ||
featuremap = layers.Concatenate(axis=-1, dtype=dtype)( | ||
[featuremap_p5, featuremap_p4, featuremap_p3, featuremap_p2] | ||
) | ||
return featuremap | ||
|
||
|
||
def diffbin_head(inputs, in_channels, kernel_list, name): | ||
x = layers.Conv2D( | ||
in_channels // 4, | ||
kernel_size=kernel_list[0], | ||
padding="same", | ||
use_bias=False, | ||
name=f"{name}_conv0_weights", | ||
)(inputs) | ||
x = layers.BatchNormalization( | ||
beta_initializer=keras.initializers.Constant(1e-4), | ||
gamma_initializer=keras.initializers.Constant(1.0), | ||
name=f"{name}_conv0_bn", | ||
)(x) | ||
x = layers.ReLU(name=f"{name}_conv0_relu")(x) | ||
x = layers.Conv2DTranspose( | ||
in_channels // 4, | ||
kernel_size=kernel_list[1], | ||
strides=2, | ||
padding="valid", | ||
bias_initializer=keras.initializers.RandomUniform( | ||
minval=-1.0 / (in_channels // 4 * 1.0) ** 0.5, | ||
maxval=1.0 / (in_channels // 4 * 1.0) ** 0.5, | ||
), | ||
name=f"{name}_conv1_weights", | ||
)(x) | ||
x = layers.BatchNormalization( | ||
beta_initializer=keras.initializers.Constant(1e-4), | ||
gamma_initializer=keras.initializers.Constant(1.0), | ||
name=f"{name}_conv1_bn", | ||
)(x) | ||
x = layers.ReLU(name=f"{name}_conv1_relu")(x) | ||
x = layers.Conv2DTranspose( | ||
1, | ||
kernel_size=kernel_list[2], | ||
strides=2, | ||
padding="valid", | ||
activation="sigmoid", | ||
bias_initializer=keras.initializers.RandomUniform( | ||
minval=-1.0 / (in_channels // 4 * 1.0) ** 0.5, | ||
maxval=1.0 / (in_channels // 4 * 1.0) ** 0.5, | ||
), | ||
name=f"{name}_conv2_weights", | ||
)(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from keras import ops | ||
|
||
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone | ||
from keras_hub.src.models.diffbin.diffbin_preprocessor import ( | ||
DiffBinPreprocessor, | ||
) | ||
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class DiffBinTest(TestCase): | ||
def setUp(self): | ||
self.images = ops.ones((2, 32, 32, 3)) | ||
self.image_encoder = ResNetBackbone( | ||
input_conv_filters=[4], | ||
input_conv_kernel_sizes=[7], | ||
stackwise_num_filters=[64, 4, 4, 4], | ||
stackwise_num_blocks=[3, 4, 6, 3], | ||
stackwise_num_strides=[1, 2, 2, 2], | ||
block_type="bottleneck_block", | ||
image_shape=(32, 32, 3), | ||
) | ||
self.preprocessor = DiffBinPreprocessor() | ||
self.init_kwargs = { | ||
"image_encoder": self.image_encoder, | ||
"fpn_channels": 16, | ||
"head_kernel_list": [3, 2, 2], | ||
} | ||
|
||
def test_backbone_basics(self): | ||
expected_output_shape = { | ||
"probability_maps": (2, 32, 32, 1), | ||
"threshold_maps": (2, 32, 32, 1), | ||
} | ||
self.run_backbone_test( | ||
cls=DiffBinBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.images, | ||
expected_output_shape=expected_output_shape, | ||
run_mixed_precision_check=False, | ||
run_quantization_check=False, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter | ||
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.DiffBinImageConverter") | ||
class DiffBinImageConverter(ImageConverter): | ||
backbone_cls = DiffBinBackbone |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone | ||
from keras_hub.src.models.diffbin.diffbin_image_converter import ( | ||
DiffBinImageConverter, | ||
) | ||
from keras_hub.src.models.image_segmenter_preprocessor import ( | ||
ImageSegmenterPreprocessor, | ||
) | ||
|
||
|
||
@keras_hub_export("keras_hub.models.DiffBinPreprocessor") | ||
class DiffBinPreprocessor(ImageSegmenterPreprocessor): | ||
backbone_cls = DiffBinBackbone | ||
image_converter_cls = DiffBinImageConverter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""Differentiable Binarization preset configurations.""" | ||
|
||
backbone_presets = { | ||
"diffbin_r50vd_icdar2015": { | ||
"metadata": { | ||
"description": ( | ||
"Differentiable Binarization using 50-layer" | ||
"ResNetVD trained on the ICDAR2015 dataset." | ||
), | ||
"params": 25482722, | ||
"official_name": "DifferentiableBinarization", | ||
"path": "diffbin", | ||
"model_card": "https://arxiv.org/abs/1911.08947", | ||
}, | ||
"kaggle_handle": "", # TODO | ||
} | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this model is going to be returning polygons - we might need to change the parent class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. The functionality that we require here is pretty similar, though (resize the output mask in addition to the model's output). So basically, we'd have to copy
ImageSegmenterPreprocessor
.