Skip to content

Latest commit

 

History

History
79 lines (69 loc) · 3.3 KB

添加新算法.md

File metadata and controls

79 lines (69 loc) · 3.3 KB

添加新算法

Dataset

检测算法

不同的检测算法会有不同的图片预处理和label制作方式,添加新dataset的步骤如下

  1. torchocr/datasets/det_modules下添加算法的图片预处理和label制作方式, 每个处理步骤(module)用一个文件存储,module的形式如下
class ModuleName:
    def __init__(self, *args,**kwargs):
        pass
    def __call__(self, data: dict) -> dict:
        im = data['img']
        text_polys = data['text_polys']
        # 执行你的处理
        data['img'] = im
        data['text_polys'] = text_polys
        return data

算法的所有处理步骤由不同的module顺序执行而成,在config文件中按照列表的形式组合并执行。如:

'pre_processes': [{'type': 'IaaAugment', 'args': [{'type': 'Fliplr', 'args': {'p': 0.5}},
                                                  {'type': 'Affine', 'args': {'rotate': [-10, 10]}},
                                                  {'type': 'Resize', 'args': {'size': [0.5, 3]}}]},
                  {'type': 'EastRandomCropData', 'args': {'size': [640, 640], 'max_tries': 50, 'keep_ratio': True}},
                  {'type': 'MakeBorderMap', 'args': {'shrink_ratio': 0.4, 'thresh_min': 0.3, 'thresh_max': 0.7}},
                  {'type': 'MakeShrinkMap', 'args': {'shrink_ratio': 0.4, 'min_text_size': 8}}]

识别算法

对于attention和ctc系列算法,我们已经提供了内置的dataset,其他类型的需要在torchocr/datasets/RecDataSet.py 文件里添加一个dataset并在config文件中使用

网络

PytorchOCR将网络划分为三部分

  • backbone: 从图片中提取特征,如Resnet,MobileNetV3
  • neck: 对backbone输出的特征进行强化,如FPN,CRNN的RNN部分
  • head: 在neck输出特征的基础上进行完成算法的输出 backboneneck均需要out_channels属性以便后续组件构造网络。 若PytorchOCR已提供的组件中没有算法所需组件,就需要在对应的文件夹内实现新组件,一个文件夹存放一个组件, 然后将新组建在torchocr/networks/architectures/DetModel.pytorchocr/networks/architectures/RecModel.py进行导入并添加到对应的dict

各组件对应文件如下:

  • backbone: torchocr/networks/backbones
  • necks: torchocr/networks/necks
  • heads: torchocr/networks/heads

损失函数

损失函数的存文件夹为torchocr/networks/losses,损失函数的输出应该是一个dict,格式如下

{
    'loss':loss_value, # 总的loss,由l1,l2,l3,...,ln加权组成
    '其他的loss': value # 组成总loss的子loss
}

loss module 的形式如下

class ModuleName(nn.Module):
    def __init__(self, *args,**kwargs):
        pass

    def forward(self, pred, batch):
        """

        :param pred:
        :param batch: bach为一个dict{
                                    '其他计算loss所需的输入':'vaue'
                                    }
        :return:
        """
        # 计算loss
        loss_dict = {'loss':loss,'other_sub_loss':value}
        return loss_dict

配置文件

将配置文件里的对应地方换成新增的组件,那么新的网络就添加完成了,在测试性能无误后就可推送到PytorchOCR仓库