Skip to content

Commit

Permalink
add data_format selection support to ocr
Browse files Browse the repository at this point in the history
  • Loading branch information
PeiyuLau committed Jul 11, 2024
1 parent db0ad17 commit a826e2b
Show file tree
Hide file tree
Showing 17 changed files with 443 additions and 93 deletions.
2 changes: 2 additions & 0 deletions configs/cls/cls_mv3.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ Architecture:
name: MobileNetV3
scale: 0.35
model_name: small
data_format: NHWC
Neck:
Head:
name: ClsHead
class_dim: 2
data_format: NHWC

Loss:
name: ClsLoss
Expand Down
9 changes: 9 additions & 0 deletions configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@ Architecture:
scale: 0.5
model_name: large
disable_se: true
data_format: NHWC
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
data_format: NHWC
Head:
name: DBHead
k: 50
data_format: NHWC
Student2:
pretrained:
model_type: det
Expand All @@ -52,13 +55,16 @@ Architecture:
scale: 0.5
model_name: large
disable_se: true
data_format: NHWC
Neck:
name: RSEFPN
out_channels: 96
shortcut: True
data_format: NHWC
Head:
name: DBHead
k: 50
data_format: NHWC
Teacher:
freeze_params: true
return_all_feats: false
Expand All @@ -68,13 +74,16 @@ Architecture:
name: ResNet_vd
in_channels: 3
layers: 50
data_format: NHWC
Neck:
name: LKPAN
out_channels: 256
data_format: NHWC
Head:
name: DBHead
kernel_list: [7,2,2]
k: 50
data_format: NHWC

Loss:
name: CombinedLoss
Expand Down
7 changes: 5 additions & 2 deletions configs/det/det_mv3_db.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ Architecture:
name: MobileNetV3
scale: 0.5
model_name: large
data_format: NHWC
Neck:
name: DBFPN
out_channels: 256
data_format: NHWC
Head:
name: DBHead
k: 50
data_format: NHWC

Loss:
name: DBLoss
Expand Down Expand Up @@ -64,7 +67,7 @@ Metric:
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
data_dir: ./
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
Expand Down Expand Up @@ -107,7 +110,7 @@ Train:
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
data_dir: ./
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
Expand Down
2 changes: 2 additions & 0 deletions configs/rec/PP-OCRv3/ch_PP-OCRv3_rec.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Architecture:
last_conv_stride: [1, 2]
last_pool_type: avg
last_pool_kernel_size: [2, 2]
data_format: 'NHWC'
Head:
name: MultiHead
head_list:
Expand All @@ -59,6 +60,7 @@ Architecture:
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
data_format: 'NHWC'

Loss:
name: MultiLoss
Expand Down
4 changes: 4 additions & 0 deletions configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Architecture:
last_conv_stride: [1, 2]
last_pool_type: avg
last_pool_kernel_size: [2, 2]
data_format: 'NHWC'
Head:
name: MultiHead
head_list:
Expand All @@ -69,6 +70,7 @@ Architecture:
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
data_format: 'NHWC'
Student:
pretrained:
freeze_params: false
Expand All @@ -82,6 +84,7 @@ Architecture:
last_conv_stride: [1, 2]
last_pool_type: avg
last_pool_kernel_size: [2, 2]
data_format: 'NHWC'
Head:
name: MultiHead
head_list:
Expand All @@ -97,6 +100,7 @@ Architecture:
- SARHead:
enc_dim: 512
max_text_length: *max_text_length
data_format: 'NHWC'
Loss:
name: CombinedLoss
loss_config_list:
Expand Down
32 changes: 27 additions & 5 deletions ppocr/modeling/backbones/det_mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ def make_divisible(v, divisor=8, min_value=None):

class MobileNetV3(nn.Layer):
def __init__(
self, in_channels=3, model_name="large", scale=0.5, disable_se=False, **kwargs
self,
in_channels=3,
model_name="large",
scale=0.5,
disable_se=False,
data_format="NCHW",
**kwargs,
):
"""
the MobilenetV3 backbone network for detection module.
Expand All @@ -46,6 +52,7 @@ def __init__(

self.disable_se = disable_se

self.nchw = data_format == "NCHW"
if model_name == "large":
cfg = [
# k, exp, c, se, nl, s,
Expand Down Expand Up @@ -102,6 +109,7 @@ def __init__(
groups=1,
if_act=True,
act="hardswish",
data_format=data_format,
)

self.stages = []
Expand All @@ -125,6 +133,7 @@ def __init__(
stride=s,
use_se=se,
act=nl,
data_format=data_format,
)
)
inplanes = make_divisible(scale * c)
Expand All @@ -139,6 +148,7 @@ def __init__(
groups=1,
if_act=True,
act="hardswish",
data_format=data_format,
)
)
self.stages.append(nn.Sequential(*block_list))
Expand All @@ -147,6 +157,8 @@ def __init__(
self.add_sublayer(sublayer=stage, name="stage{}".format(i))

def forward(self, x):
if not self.nchw:
x = x.transpose([0, 2, 3, 1])
x = self.conv(x)
out_list = []
for stage in self.stages:
Expand All @@ -166,6 +178,7 @@ def __init__(
groups=1,
if_act=True,
act=None,
data_format="NCHW",
):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
Expand All @@ -178,9 +191,12 @@ def __init__(
padding=padding,
groups=groups,
bias_attr=False,
data_format=data_format,
)

self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
self.bn = nn.BatchNorm(
num_channels=out_channels, act=None, data_layout=data_format
)

def forward(self, x):
x = self.conv(x)
Expand Down Expand Up @@ -210,6 +226,7 @@ def __init__(
stride,
use_se,
act=None,
data_format="NCHW",
):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
Expand All @@ -223,6 +240,7 @@ def __init__(
padding=0,
if_act=True,
act=act,
data_format=data_format,
)
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
Expand All @@ -233,9 +251,10 @@ def __init__(
groups=mid_channels,
if_act=True,
act=act,
data_format=data_format,
)
if self.if_se:
self.mid_se = SEModule(mid_channels)
self.mid_se = SEModule(mid_channels, data_format=data_format)
self.linear_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels,
Expand All @@ -244,6 +263,7 @@ def __init__(
padding=0,
if_act=False,
act=None,
data_format=data_format,
)

def forward(self, inputs):
Expand All @@ -258,22 +278,24 @@ def forward(self, inputs):


class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4):
def __init__(self, in_channels, reduction=4, data_format="NCHW"):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.avg_pool = nn.AdaptiveAvgPool2D(1, data_format=data_format)
self.conv1 = nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
padding=0,
data_format=data_format,
)
self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
padding=0,
data_format=data_format,
)

def forward(self, inputs):
Expand Down
Loading

0 comments on commit a826e2b

Please sign in to comment.