Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add object detection pipeline #243

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.text_to_speech import TextToSpeechPipeline

return TextToSpeechPipeline(model_id)
case "object-detection":
from app.pipelines.object_detection import ObjectDetectionPipeline

return ObjectDetectionPipeline(model_id)
case _:
raise EnvironmentError(
f"{pipeline} is not a valid pipeline for model {model_id}"
Expand Down Expand Up @@ -121,6 +125,9 @@ def load_route(pipeline: str) -> any:
from app.routes import text_to_speech

return text_to_speech.router
case "object-detection":
from app.routes import object_detection
return object_detection.router
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")

Expand Down
119 changes: 119 additions & 0 deletions runner/app/pipelines/object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import logging
import os

import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from huggingface_hub import file_download
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from typing import List
from PIL import Image, ImageDraw, ImageFont

from app.utils.errors import InferenceError

logger = logging.getLogger(__name__)


def annotate_image(input_image, detections, labels, font_size, font):
draw = ImageDraw.Draw(input_image)
bounding_box_color = (255, 255, 0) # Bright Yellow for bounding box
text_color = (0, 0, 0) # Black for text
for box, label in zip(detections["boxes"], labels):
x1, y1, x2, y2 = map(int, box)
draw.rectangle([x1, y1, x2, y2], outline=bounding_box_color, width=3)
# Place label above the bounding box
draw.text((x1, y1 - font_size - 5), label, fill=text_color, font=font) # Adjust y position
return input_image


class ObjectDetectionPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {}

self.torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
# Load fp16 variant if fp16 safetensors files are found in cache
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if self.torch_device != "cpu" and has_fp16_variant:
logger.info("ObjectDetectionPipeline loading fp16 variant for %s", model_id)

kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

if os.environ.get("BFLOAT16"):
logger.info("ObjectDetectionPipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16

self.object_detection_model = AutoModelForObjectDetection.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
cache_dir=get_model_dir(),
**kwargs,
).to(self.torch_device)

self.image_processor = AutoImageProcessor.from_pretrained(
model_id, cache_dir=get_model_dir()
)

# Load a font (default font is used here; you can specify your own path for a TTF file)
self.font_size = 24
self.font = ImageFont.load_default(size=self.font_size)


def __call__(self, frames: List[Image], confidence_threshold: float = 0.6, **kwargs) -> str:

try:
annotated_frames = []
confidence_scores_all_frames = []
labels_all_frames = []

for frame in frames:
# Process frame and add annotations
inputs = self.image_processor(images=frame, return_tensors="pt").to(self.torch_device)
with torch.no_grad():
outputs = self.object_detection_model(**inputs)

target_sizes = torch.tensor([frame.size[::-1]])
results = self.image_processor.post_process_object_detection(
outputs=outputs,
threshold=confidence_threshold,
target_sizes=target_sizes
)[0]

final_labels = []
confidence_scores = []

detections = {"boxes": results["boxes"].cpu().numpy()}

for label_id, score in zip(results["labels"].cpu().numpy(),results["scores"].cpu().numpy()):
final_labels.append(self.object_detection_model.config.id2label[label_id])
confidence_scores.append(round(score, 3))

annotated_frame = annotate_image(
input_image=frame,
detections=detections,
labels=final_labels,
font_size=self.font_size,
font=self.font
)

annotated_frames.append(annotated_frame)
confidence_scores_all_frames.append(confidence_scores)
labels_all_frames.append(final_labels)

return annotated_frames, confidence_scores_all_frames, labels_all_frames

except Exception as e:
raise InferenceError(original_exception=e)

def __str__(self) -> str:
return f"ObjectDetectionPipeline model_id={self.model_id}"
154 changes: 154 additions & 0 deletions runner/app/routes/object_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import av
import logging
import os
from typing import Annotated, Dict, Tuple, Union
import time

import torch

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.utils import (
HTTPError,
ObjectDetectionResponse,
file_exceeds_max_size,
handle_pipeline_exception,
http_error,
image_to_data_url,
)
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()

logger = logging.getLogger(__name__)

# Pipeline specific error handling configuration.
PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = {
# Specific error types.
"OutOfMemoryError": (
"Out of memory error. Try reducing input image resolution.",
status.HTTP_500_INTERNAL_SERVER_ERROR,
)
}

RESPONSES = {
status.HTTP_200_OK: {
"content": {
"application/json": {
"schema": {
"x-speakeasy-name-override": "data",
}
}
},
},
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


@router.post(
"/object-detection",
response_model=ObjectDetectionResponse,
responses=RESPONSES,
description="Generate annotated video(s) for object detection from the input video(s)",
operation_id="genObjectDetection",
summary="Object Detection",
tags=["generate"],
openapi_extra={"x-speakeasy-name-override": "objectDetection"},
)
@router.post(
"/object-detection/",
response_model=ObjectDetectionResponse,
responses=RESPONSES,
include_in_schema=False,
)
async def object_detection(
video: Annotated[
UploadFile, File(description="Uploaded video to transform with the pipeline.")
],
confidence_threshold: Annotated[
float, Form(description="Score threshold to keep object detection predictions.")
] = 0.6,
model_id: Annotated[
str,
Form(description="Hugging Face model ID used for transformation."),
] = "",
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
auth_token = os.environ.get("AUTH_TOKEN")
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
),
)

if file_exceeds_max_size(video, 50 * 1024 * 1024):
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content=http_error("File size exceeds limit"),
)

frames = []
try:
container = av.open(video.file)

start = time.time()
for frame in container.decode(video=0): # Decode video frames
frames.append(frame.to_image()) # Convert each frame to PIL image and add to list

container.close()
logger.info(f"Decoded video in {time.time() - start:.2f} seconds")

start = time.time()
annotated_frames, confidence_scores_all_frames, labels_all_frames = pipeline(
frames=frames,
confidence_threshold=confidence_threshold,
)
logger.info(f"Detections processed in {time.time() - start:.2f} seconds")
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
logger.error(f"ObjectDetectionPipeline error: {e}")
return handle_pipeline_exception(
e,
default_error_message="Object-detection pipeline error.",
custom_error_config=PIPELINE_ERROR_CONFIG,
)
start = time.time()
output_frames = []
for frame in annotated_frames:
output_frames.append(
{
"url": image_to_data_url(frame),
"seed": 0,
"nsfw": False,
}
)

logger.info(f"Annotated frames converted to data URLs in {time.time() - start:.2f} seconds, frame count: {len(output_frames)}")
frames = []
frames.append(output_frames)
return {
"frames": frames,
"confidence_scores": str(confidence_scores_all_frames),
"labels": str(labels_all_frames),
}
12 changes: 12 additions & 0 deletions runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ class LiveVideoToVideoResponse(BaseModel):
..., description="URL for updating the live video-to-video generation"
)

class ObjectDetectionResponse(BaseModel):
"""Response model for object detection."""

frames: List[List[Media]] = Field(..., description="The generated annotated video frames.")
confidence_scores: str = Field(
..., description="The model's confidence scores for each detected object in each frame."
)
labels: str = Field(
..., description="The model's labels for each detected object in each frame."
)


class APIError(BaseModel):
"""API error response model."""

Expand Down
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ function download_all_models() {
# Download image-to-text models.
huggingface-cli download Salesforce/blip-image-captioning-large --include "*.safetensors" "*.json" --cache-dir models

# Download object-detection models.
huggingface-cli download PekingU/rtdetr_r50vd --include "*.safetensors" "*.json" --cache-dir models

# Custom pipeline models.
huggingface-cli download facebook/sam2-hiera-large --include "*.pt" "*.yaml" --cache-dir models

Expand Down
Loading