Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop ultralytics dependency #231

Merged
merged 13 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/builds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: [3.10]
python: ['3.10']

steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"ultralytics==8.2.50",
"onnxruntime==1.18.1",
"ncnn==1.0.20240410",
"pyroclient @ git+https://github.com/pyronear/pyro-api.git@767be30a781b52b29d68579d543e3f45ac8c4713#egg=pyroclient&subdirectory=client",
"requests>=2.20.0,<3.0.0",
"tqdm>=4.62.0",
"huggingface_hub==0.23.1",
"pillow==11.0.0",
]

[project.optional-dependencies]
Expand Down
55 changes: 54 additions & 1 deletion pyroengine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,63 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.


import cv2
import numpy as np
from tqdm import tqdm # type: ignore[import-untyped]

__all__ = ["nms", "DownloadProgressBar"]
__all__ = ["xywh2xyxy", "letterbox", "nms", "DownloadProgressBar"]


def xywh2xyxy(x: np.ndarray):
y = np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x
y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y
return y


def letterbox(
im: np.ndarray,
new_shape: tuple = (1024, 1024),
color: tuple = (114, 114, 114),
auto: bool = False,
stride: int = 32,
):
"""Letterbox image transform for yolo models
Args:
im (np.ndarray): Input image
new_shape (tuple, optional): Image size. Defaults to (1024, 1024).
color (tuple, optional): Pixel fill value for the area outside the transformed image.
Defaults to (114, 114, 114).
auto (bool, optional): auto padding. Defaults to False.
stride (int, optional): padding stride. Defaults to 32.
Returns:
np.ndarray: Output image
"""
# Resize and pad image while meeting stride-multiple constraints
im = np.array(im)
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
# Compute padding
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
# add border
h, w = im.shape[:2]
im_b = np.zeros((h + top + bottom, w + left + right, 3)) + color
im_b[top : top + h, left : left + w, :] = im
return im_b.astype("uint8"), (left, top)


def box_iou(box1: np.ndarray, box2: np.ndarray, eps: float = 1e-7):
Expand Down
113 changes: 94 additions & 19 deletions pyroengine/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@
import os
import platform
import shutil
from typing import Optional
from typing import Optional, Tuple
from urllib.request import urlretrieve

import ncnn
import numpy as np
import onnxruntime
from huggingface_hub import HfApi # type: ignore[import-untyped]
from PIL import Image
from ultralytics import YOLO # type: ignore[import-untyped]

from .utils import DownloadProgressBar
from .utils import DownloadProgressBar, letterbox, nms, xywh2xyxy

__all__ = ["Classifier"]

Expand Down Expand Up @@ -48,14 +49,17 @@ class Classifier:

def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0, format="ncnn", model_path=None) -> None:
if model_path is None:

if format == "ncnn":
if self.is_arm_architecture():
model = "yolov8s_ncnn_model.zip"
else:
logging.info("NCNN format is optimized for arm architecture only, switching to onnx")
model = "yolov8s.onnx"
elif format in ["onnx", "pt"]:
model = f"yolov8s.{format}"
if not self.is_arm_architecture():
logging.info("NCNN format is optimized for arm architecture only, switching to onnx is recommended")

model = "yolov8s_ncnn_model.zip"
self.format = "ncnn"

elif format == "onnx":
model = "yolov8s.onnx"
self.format = "onnx"

model_path = os.path.join(model_folder, model)
metadata_path = os.path.join(model_folder, METADATA_NAME)
Expand Down Expand Up @@ -88,7 +92,14 @@ def __init__(self, model_folder="data", imgsz=1024, conf=0.15, iou=0, format="nc
shutil.unpack_archive(model_path, model_folder)
model_path = file_name

self.model = YOLO(model_path, task="detect")
if self.format == "ncnn":
self.model = ncnn.Net()
self.model.load_param(os.path.join(model_path, "model.ncnn.param"))
self.model.load_model(os.path.join(model_path, "model.ncnn.bin"))

else:
self.ort_session = onnxruntime.InferenceSession(model_path)

self.imgsz = imgsz
self.conf = conf
self.iou = iou
Expand Down Expand Up @@ -126,20 +137,84 @@ def load_metadata(self, metadata_path):
return json.load(f)
return None

def prep_process(self, pil_img: Image.Image) -> Tuple[np.ndarray, Tuple[int, int]]:
"""Preprocess an image for inference

Args:
pil_img: A valid PIL image.

Returns:
A tuple containing:
- The resized and normalized image of shape (1, C, H, W).
- Padding information as a tuple of integers (pad_height, pad_width).
"""

np_img, pad = letterbox(np.array(pil_img), self.imgsz) # Applies letterbox resize with padding

if self.format == "ncnn":
np_img = ncnn.Mat.from_pixels(np_img, ncnn.Mat.PixelType.PIXEL_BGR, np_img.shape[1], np_img.shape[0])
mean = [0, 0, 0]
std = [1 / 255, 1 / 255, 1 / 255]
np_img.substract_mean_normalize(mean=mean, norm=std)

else:

np_img = np.expand_dims(np_img.astype("float"), axis=0) # Add batch dimension
np_img = np.ascontiguousarray(np_img.transpose((0, 3, 1, 2))) # Convert from BHWC to BCHW format
np_img = np_img.astype("float32") / 255 # Normalize to [0, 1]

return np_img, pad

def post_process(self, pred: np.ndarray, pad: int) -> Tuple[np.ndarray, Tuple[int, int]]:

# Drop low conf for speed-up
pred = pred[:, pred[-1, :] > self.conf]
# Post processing
pred = np.transpose(pred)
pred = xywh2xyxy(pred)
# Sort by confidence
pred = pred[pred[:, 4].argsort()]
pred = nms(pred)
pred = pred[::-1]

# Normalize preds
if len(pred) > 0:
# Remove padding
left_pad, top_pad = pad
pred[:, :4:2] -= left_pad
pred[:, 1:4:2] -= top_pad
pred[:, :4:2] /= self.imgsz - 2 * left_pad
pred[:, 1:4:2] /= self.imgsz - 2 * top_pad
pred = np.clip(pred, 0, 1)
pred = np.reshape(pred, (-1, 5))
else:
pred = np.zeros((0, 5)) # normalize output

return pred

def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] = None) -> np.ndarray:

results = self.model(pil_img, imgsz=self.imgsz, conf=self.conf, iou=self.iou, verbose=False)
y = np.concatenate(
(results[0].boxes.xyxyn.cpu().numpy(), results[0].boxes.conf.cpu().numpy().reshape((-1, 1))), axis=1
)
np_img, pad = self.prep_process(pil_img)

if self.format == "ncnn":

extractor = self.model.create_extractor()
extractor.set_light_mode(True)
extractor.input("in0", np_img)
pred = ncnn.Mat()
extractor.extract("out0", pred)
pred = np.asarray(pred)

else:
pred = self.ort_session.run(["output0"], {"images": np_img})[0][0]

y = np.reshape(y, (-1, 5))
pred = self.post_process(pred, pad)

# Remove prediction in occlusion mask
if occlusion_mask is not None:
hm, wm = occlusion_mask.shape
keep = []
for p in y.copy():
for p in pred.copy():
p[:4:2] *= wm
p[1:4:2] *= hm
p[:4:2] = np.clip(p[:4:2], 0, wm)
Expand All @@ -150,6 +225,6 @@ def __call__(self, pil_img: Image.Image, occlusion_mask: Optional[np.ndarray] =
else:
keep.append(False)

y = y[keep]
pred = pred[keep]

return y
return pred
Loading
Loading