Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Add] GOT OCR 2.0 inference pipeline #831

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions paddlemix/examples/GOT_OCR_2_0/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# GOT-OCR2.0

## 1. 模型介绍

[GOT-OCR2.0](https://qwenlm.github.io/blog/qwen2-vl/) 是大规模视觉语言模型。可以以图像、文本、检测框、视频作为输入,并以文本和检测框作为输出。本仓库提供paddle版本的`GOT-OCR2.0`模型。


## 2 环境准备
- **python >= 3.10**
- **paddlepaddle-gpu 要求版本develop**
```
# 安装示例
python -m pip install paddlepaddle-gpu==0.0.0.post118 -f https://www.paddlepaddle.org.cn/whl/linux/gpu/develop.html
```

- paddlenlp >= 3.0.0(默认开启flash_attn,推荐源码编译安装)

> 注:
* 请确保安装了以上依赖,否则无法运行。同时,需要安装 paddlemix/external_ops 下的自定义OP, `python setup.py install`。如果安装后仍然找不到算子,需要额外设置PYTHONPATH

## 3 推理预测

1. plain texts OCR:
```bash
python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type ocr
```

2. format texts OCR:
```bash
python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format
```

3. fine-grained OCR:
```bash
python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format/ocr --box [x1,y1,x2,y2]
```
```bash
python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format/ocr --color red/green/blue
```

4. multi-crop OCR:
```bash
python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --multi_crop --ocr_type format/ocr
```

4. render the formatted OCR results:
```bash
python paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py --model_name_or_path /GOT_weights/ --image_file /an/image/file.png --ocr_type format --render
```

## 参考文献
```BibTeX
@article{wei2024general,
title={General OCR Theory: Towards OCR-2.0 via a Unified End-to-end Model},
author={Wei, Haoran and Liu, Chenglong and Chen, Jinyue and Wang, Jia and Kong, Lingyu and Xu, Yanming and Ge, Zheng and Zhao, Liang and Sun, Jianjian and Peng, Yuang and others},
journal={arXiv preprint arXiv:2409.01704},
year={2024}
}
```
6 changes: 6 additions & 0 deletions paddlemix/examples/GOT_OCR_2_0/configs/demo_dataset.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"synthdog_en": {
"images": "playground/data/synthdog-en/",
"annotations": "playground/opensource/synthdog_en.jsonl"
}
}
91 changes: 91 additions & 0 deletions paddlemix/examples/GOT_OCR_2_0/got_ocr2_0_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import paddle
from paddlenlp.transformers import QWenTokenizer

from paddlemix.models.GOT.model import GOTQwenForCausalLM

parser = argparse.ArgumentParser()

parser.add_argument("--model_name_or_path", type=str, default="GOT-OCR2_0_pd", help="pretrained ckpt and tokenizer")
parser.add_argument("--image_file", type=str, default="yiyuan.jpeg")
parser.add_argument("--multi_crop", action="store_true")
parser.add_argument("--ocr_type", type=str, default="plain", choices=["ocr", "format"])
parser.add_argument("--box", type=str, default="")
parser.add_argument("--color", type=str, default="")
parser.add_argument("--render", action="store_true")

args = parser.parse_args()
model_name_or_path = args.model_name_or_path

tokenizer = QWenTokenizer.from_pretrained(model_name_or_path)
# print('tokenizer:\n', tokenizer)
# print('tokenizer.added_tokens_encoder:\n', tokenizer.added_tokens_encoder)
# print('tokenizer.added_tokens_decoder:\n', tokenizer.added_tokens_decoder)
# PretrainedTokenizer(name_or_path='',
# vocab_size=151851, model_max_len=8000, padding_side='right',
# truncation_side='right', special_tokens={
# 'pad_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=False)})
model = GOTQwenForCausalLM.from_pretrained(
model_name_or_path, dtype=paddle.bfloat16, pad_token_id=tokenizer.eos_token_id
).eval()
# print('tokenizer:\n', tokenizer)


# input test image
image_file = args.image_file
with paddle.no_grad():
if args.multi_crop:
# multi-crop OCR:
res = model.chat_crop(
tokenizer, image_file, ocr_type=args.ocr_type, render=args.render, save_render_file="./demo.html"
)
else:
# plain texts OCR
# format texts OCR
# fine-grained OCR
# render the formatted OCR results
res = model.chat(
tokenizer,
image_file,
ocr_type=args.ocr_type,
ocr_box=args.box,
ocr_color=args.color,
render=args.render,
save_render_file="./demo.html",
)

# plain texts OCR
# res = model.chat(tokenizer, image_file, ocr_type='ocr')

# format texts OCR:
# res = model.chat(tokenizer, image_file, ocr_type='format')

# fine-grained OCR:
# res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_box='')
# res = model.chat(tokenizer, image_file, ocr_type='format', ocr_box='')
# res = model.chat(tokenizer, image_file, ocr_type='ocr', ocr_color='')
# res = model.chat(tokenizer, image_file, ocr_type='format', ocr_color='')

# multi-crop OCR:
# res = model.chat_crop(tokenizer, image_file, ocr_type='ocr')
# res = model.chat_crop(tokenizer, image_file, ocr_type='format')

# render the formatted OCR results:
# res = model.chat(tokenizer, image_file, ocr_type='format', render=True, save_render_file = './demo.html')

print(res)
13 changes: 13 additions & 0 deletions paddlemix/models/GOT/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
121 changes: 121 additions & 0 deletions paddlemix/models/GOT/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from functools import partial
from sys import meta_path
from typing import List, Union

import paddle
import paddlenlp
from paddle import Tensor

from paddlemix.models.GOT.data.conversation_dataset_qwen import ConversationDataset

from ..utils.constants import *

IGNORE_INDEX = -100


# helpers
def pad_sequence_paddle(sequences, padding_value=0):
"""
Implement a function similar to PyTorch's pad_sequence in PaddlePaddle.

Args:
- sequences (list of Tensor): The list of sequences to be padded.
- padding_value (float, optional): The value used for padding, default is 0.

Returns:
- Tensor: The result of padding all sequences to the same length.
"""
# Calculate the maximum length
max_len = max([seq.shape[0] for seq in sequences])

# Pad sequences
padded_sequences = []
for seq in sequences:
# Calculate the length to pad
padding_len = max_len - seq.shape[0]

# Create a padding tensor
if padding_len > 0:
padding_tensor = paddle.full([padding_len] + list(seq.shape[1:]), padding_value, dtype=seq.dtype)
# Concatenate the original sequence and the padding tensor
padded_seq = paddle.concat([seq, padding_tensor], axis=0)
else:
padded_seq = seq

padded_sequences.append(padded_seq)

# Stack the padded sequences to form a batch
padded_batch = paddle.stack(padded_sequences, axis=0)
return padded_batch


def orig_pad_sequence(
sequences: Union[Tensor, List[Tensor]],
batch_first: bool = False,
padding_value: float = 0.0,
) -> Tensor:
if batch_first:
return pad_sequence_paddle(sequences, padding_value)
else:
assert False, "Not implemented"


@dataclass
class DataCollatorForSupervisedDataset(object):
tokenizer: paddlenlp.transformers.PretrainedTokenizer

def __call__(self, instances):
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
images = [paddle.stack(instance["image"]) for instance in instances]
images_high = [paddle.stack(instance["image_high"]) for instance in instances]
images = list(zip(images, images_high))

pad_sequence = partial(orig_pad_sequence, batch_first=True)

input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)

labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)

batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.not_equal(paddle.to_tensor(self.tokenizer.pad_token_id)),
images=images,
)
return batch


def make_supervised_data_module(interleave, with_box, tokenizer, data_args):
assert data_args.conversation_version == "mpt"

train_dataset = ConversationDataset(
tokenizer=tokenizer,
# datasets=data_args.datasets,
meta_path=data_args.meta_path,
multimodal_cfg=dict(
sep_image_conv_front=data_args.sep_image_conv_front,
image_token_len=data_args.image_token_len,
image_aspect_ratio=data_args.image_aspect_ratio,
use_im_start_end=data_args.use_im_start_end,
image_processor=data_args.image_processor,
image_processor_high=data_args.image_processor_high,
box_limit=data_args.box_limit,
),
)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
82 changes: 82 additions & 0 deletions paddlemix/models/GOT/data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# import copy
# import io
# import json
import logging

# from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Dict

import paddle
import paddlenlp
from paddle.io import Dataset
from PIL import ImageFile # , Image

ImageFile.LOAD_TRUNCATED_IMAGES = True
# from ..utils.constants import *


class BaseDataset(Dataset):
def __init__(self, datasets: str, tokenizer: paddlenlp.transformers.PretrainedTokenizer, multimodal_cfg: dict):
super(BaseDataset, self).__init__()
self.tokenizer = tokenizer
self.multimodal_cfg = multimodal_cfg

logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image")

def image_processor(self, image):
# processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit)
processor_high = self.multimodal_cfg[
"image_processor_high"
] # the second processor, usually is the designed image encoder (sam/swin/cnn)
image_high = image.copy()

# Vary old codes

# # TODO the 'keep', 'padding' only used for the first processor
# if self.multimodal_cfg['image_aspect_ratio'] == 'keep':
# max_hw, min_hw = max(image.size), min(image.size)
# aspect_ratio = max_hw / min_hw
# max_len, min_len = 448, 224
# shortest_edge = int(min(max_len / aspect_ratio, min_len))
# image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
# elif self.multimodal_cfg['image_aspect_ratio'] == 'pad':
# def expand2square(pil_img, background_color):
# width, height = pil_img.size
# if width == height:
# return pil_img
# elif width > height:
# result = Image.new(pil_img.mode, (width, width), background_color)
# result.paste(pil_img) # for simpler box processing
# return result
# else:
# result = Image.new(pil_img.mode, (height, height), background_color)
# result.paste(pil_img) # for simpler box processing
# return result
# image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
# image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": 224})['pixel_values'][0]
# else:
# image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

image_high = processor_high(image_high)

return image_high

def __len__(self):
return len(self.list_data_dict)

def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
pass
Loading