Skip to content

Commit

Permalink
feat: Text detection in auto strategy (#209)
Browse files Browse the repository at this point in the history
* add: Layout detection in auto strategy

* fix: strategy

* add: fix ample sie for strategy estimation

* fix: auto strategy based on text detection

* add: Megaparse config

* add: remove print
  • Loading branch information
chloedia authored Jan 3, 2025
1 parent 7b7fb40 commit 03c7ada
Show file tree
Hide file tree
Showing 13 changed files with 336 additions and 41 deletions.
2 changes: 1 addition & 1 deletion libs/megaparse/src/megaparse/examples/parse_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def main():
parser = UnstructuredParser()
megaparse = MegaParse(parser=parser)

file_path = "./tests/pdf/ocr/0168123.pdf"
file_path = "./tests/pdf/ocr/0168126.pdf"

parsed_file = megaparse.load(file_path)
print(f"\n----- File Response : {file_path} -----\n")
Expand Down
15 changes: 13 additions & 2 deletions libs/megaparse/src/megaparse/megaparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pathlib import Path
from typing import IO, BinaryIO

from megaparse_sdk.config import MegaParseConfig
from megaparse_sdk.schema.extensions import FileExtension
from megaparse_sdk.schema.parser_config import StrategyEnum

Expand All @@ -18,6 +19,8 @@


class MegaParse:
config: MegaParseConfig = MegaParseConfig()

def __init__(
self,
parser: BaseParser = UnstructuredParser(strategy=StrategyEnum.FAST),
Expand Down Expand Up @@ -129,9 +132,17 @@ def _select_parser(
if self.strategy != StrategyEnum.AUTO or file_extension != FileExtension.PDF:
return self.parser
if file:
local_strategy = determine_strategy(file=file)
local_strategy = determine_strategy(
file=file,
threshold_pages_ocr=self.config.auto_document_threshold,
threshold_per_page=self.config.auto_page_threshold,
)
if file_path:
local_strategy = determine_strategy(file=file_path)
local_strategy = determine_strategy(
file=file_path,
threshold_pages_ocr=self.config.auto_document_threshold,
threshold_per_page=self.config.auto_page_threshold,
)

if local_strategy == StrategyEnum.HI_RES:
return self.ocr_parser
Expand Down
3 changes: 2 additions & 1 deletion libs/megaparse/src/megaparse/parser/doctr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import onnxruntime as rt
from megaparse_sdk.schema.extensions import FileExtension
from onnxtr.io import DocumentFile
from onnxtr.models import EngineConfig, ocr_predictor
from onnxtr.models import ocr_predictor
from onnxtr.models.engine import EngineConfig

from megaparse.parser.base import BaseParser

Expand Down
96 changes: 67 additions & 29 deletions libs/megaparse/src/megaparse/parser/strategy.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,38 @@
import logging
import random
from pathlib import Path
from typing import BinaryIO

import numpy as np
import pypdfium2 as pdfium
from megaparse_sdk.schema.parser_config import StrategyEnum
from onnxtr.io import DocumentFile
from onnxtr.models import detection_predictor
from pypdfium2._helpers.page import PdfPage
from pypdfium2._helpers.pageobjects import PdfImage
from pypdfium2._helpers.textpage import PdfTextPage

from megaparse.predictor.doctr_layout_detector import LayoutPredictor
from megaparse.predictor.models.base import PageLayout

logger = logging.getLogger("megaparse")


def get_strategy_page(page: PdfPage, threshold_image_page: float) -> StrategyEnum:
total_page_area = page.get_width() * page.get_height()
total_image_area = 0
images_coords = []
def get_strategy_page(
pdfium_page: PdfPage, onnxtr_page: PageLayout, threshold: float
) -> StrategyEnum:
# assert (
# p_width == onnxtr_page.dimensions[1]
# and p_height == onnxtr_page.dimensions[0]
# ), "Page dimensions do not match"
text_coords = []
# Get all the images in the page
for obj in page.get_objects():
if isinstance(obj, PdfImage):
images_coords.append(obj.get_pos())
elif obj.type == 2:
images_coords.append(obj.get_pos())

canva = np.zeros((int(page.get_height()), int(page.get_width())))
for coords in images_coords:
for obj in pdfium_page.get_objects():
if obj.type == 1:
text_coords.append(obj.get_pos())

p_width, p_height = int(pdfium_page.get_width()), int(pdfium_page.get_height())

pdfium_canva = np.zeros((int(p_height), int(p_width)))

for coords in text_coords:
# (left,bottom,right, top)
# 0---l--------------R-> y
# |
Expand All @@ -33,35 +41,65 @@ def get_strategy_page(page: PdfPage, threshold_image_page: float) -> StrategyEnu
# T (x1,y1)
# ^
# x
x0, y0, x1, y1 = coords[1], coords[0], coords[3], coords[2]
p_width, p_height = int(page.get_width()), int(page.get_height())
x0, y0, x1, y1 = (
p_height - coords[3],
coords[0],
p_height - coords[1],
coords[2],
)
x0 = max(0, min(p_height, int(x0)))
y0 = max(0, min(p_width, int(y0)))
x1 = max(0, min(p_height, int(x1)))
y1 = max(0, min(p_width, int(y1)))
canva[x0:x1, y0:y1] = 1
# Get the total area of the images
total_image_area = np.sum(canva)
pdfium_canva[x0:x1, y0:y1] = 1

if total_image_area / total_page_area > threshold_image_page:
onnxtr_canva = np.zeros((int(p_height), int(p_width)))
for block in onnxtr_page.bboxes:
x0, y0 = block.bbox[0]
x1, y1 = block.bbox[1]
x0 = max(0, min(int(x0 * p_width), int(p_width)))
y0 = max(0, min(int(y0 * p_height), int(p_height)))
x1 = max(0, min(int(x1 * p_width), int(p_width)))
y1 = max(0, min(int(y1 * p_height), int(p_height)))
onnxtr_canva[y0:y1, x0:x1] = 1

intersection = np.logical_and(pdfium_canva, onnxtr_canva)
union = np.logical_or(pdfium_canva, onnxtr_canva)
iou = np.sum(intersection) / np.sum(union)
if iou < threshold:
return StrategyEnum.HI_RES
return StrategyEnum.FAST


def determine_strategy(
file: str | Path | bytes | BinaryIO,
threshold_pages_ocr: float = 0.2,
threshold_image_page: float = 0.4,
file: str
| Path
| bytes, # FIXME : Careful here on removing BinaryIO (not handled by onnxtr)
threshold_pages_ocr: float,
threshold_per_page: float,
) -> StrategyEnum:
logger.info("Determining strategy...")
need_ocr = 0
document = pdfium.PdfDocument(file)
for page in document:
strategy = get_strategy_page(page, threshold_image_page=threshold_image_page)

onnxtr_document = DocumentFile.from_pdf(file)
det_predictor = detection_predictor()
layout_predictor = LayoutPredictor(det_predictor)

pdfium_document = pdfium.PdfDocument(file)

onnxtr_document_layout = layout_predictor(onnxtr_document)

for pdfium_page, onnxtr_page in zip(
pdfium_document, onnxtr_document_layout, strict=True
):
strategy = get_strategy_page(
pdfium_page, onnxtr_page, threshold=threshold_per_page
)
need_ocr += strategy == StrategyEnum.HI_RES

doc_need_ocr = (need_ocr / len(document)) > threshold_pages_ocr
document.close()
doc_need_ocr = (need_ocr / len(pdfium_document)) > threshold_pages_ocr
if isinstance(pdfium_document, pdfium.PdfDocument):
pdfium_document.close()

if doc_need_ocr:
logger.info("Using HI_RES strategy")
Expand Down
138 changes: 138 additions & 0 deletions libs/megaparse/src/megaparse/predictor/doctr_layout_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import Any, List

import numpy as np
from megaparse.predictor.models.base import (
BlockLayout,
PageLayout,
BBOX,
Point2D,
BlockType,
)
from onnxtr.models.detection.predictor import DetectionPredictor
from onnxtr.models.engine import EngineConfig
from onnxtr.models.predictor.base import _OCRPredictor
from onnxtr.utils.geometry import detach_scores
from onnxtr.utils.repr import NestedObject


class LayoutPredictor(NestedObject, _OCRPredictor):
"""Implements an object able to localize and identify text elements in a set of documents
Args:
det_predictor: detection module
reco_predictor: recognition module
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
without rotated textual elements.
straighten_pages: if True, estimates the page general orientation based on the median line orientation.
Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped
accordingly. Doing so will improve performances for documents with page-uniform rotations.
detect_orientation: if True, the estimated general page orientation will be added to the predictions for each
page. Doing so will slightly deteriorate the overall latency.
detect_language: if True, the language prediction will be added to the predictions for each
page. Doing so will slightly deteriorate the overall latency.
clf_engine_cfg: configuration of the orientation classification engine
**kwargs: keyword args of `DocumentBuilder`
"""

def __init__(
self,
det_predictor: DetectionPredictor,
assume_straight_pages: bool = True,
straighten_pages: bool = False,
preserve_aspect_ratio: bool = True,
symmetric_pad: bool = True,
detect_orientation: bool = False,
clf_engine_cfg: EngineConfig | None = None,
**kwargs: Any,
):
self.det_predictor = det_predictor
_OCRPredictor.__init__(
self,
assume_straight_pages,
straighten_pages,
preserve_aspect_ratio,
symmetric_pad,
detect_orientation,
clf_engine_cfg=clf_engine_cfg,
**kwargs,
)
self.detect_orientation = detect_orientation

def __call__(
self,
pages: list[np.ndarray],
**kwargs: Any,
) -> List[PageLayout]: # FIXME : Create new LayoutDocument class
"""Localize and identify text elements in a set of documents
Args:
pages: list of pages to be processed
Returns:
Document: the document object containing the text elements
"""
# Dimension check
if any(page.ndim != 3 for page in pages):
raise ValueError(
"incorrect input shape: all pages are expected to be multi-channel 2D images."
)

# Localize text elements
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)

# Detect document rotation and rotate pages
seg_maps = [
np.where(
out_map > self.det_predictor.model.postprocessor.bin_thresh,
255,
0,
).astype(np.uint8)
for out_map in out_maps
]
if self.detect_orientation:
general_pages_orientations, origin_pages_orientations = (
self._get_orientations(pages, seg_maps)
)
else:
general_pages_orientations = None
origin_pages_orientations = None
if self.straighten_pages:
pages = self._straighten_pages(
pages, seg_maps, general_pages_orientations, origin_pages_orientations
)

# forward again to get predictions on straight pages
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]

# Detach objectness scores from loc_preds
loc_preds, objectness_scores = detach_scores(loc_preds) # type: ignore[arg-type]

# Apply hooks to loc_preds if any
for hook in self.hooks:
loc_preds = hook(loc_preds)

all_pages_layouts = []
for page_index, (page, loc_pred, objectness_score) in enumerate(
zip(pages, loc_preds, objectness_scores, strict=True)
):
block_layouts = []
for bbox, score in zip(loc_pred, objectness_score, strict=True):
block_layouts.append(
BlockLayout(
bbox=BBOX(bbox[:2].tolist(), bbox[2:].tolist()),
objectness_score=score,
block_type=BlockType.TEXT,
)
)
all_pages_layouts.append(
PageLayout(
bboxes=block_layouts,
page_index=page_index,
dimensions=page.shape[:2],
orientation=general_pages_orientations[page_index]
if general_pages_orientations is not None
else 0,
)
)

return all_pages_layouts
Loading

0 comments on commit 03c7ada

Please sign in to comment.