Skip to content

Commit

Permalink
Merge pull request #20 from OriginalByteMe/19-move-cutout-class-back-…
Browse files Browse the repository at this point in the history
…into-seperate-file

19 move cutout class back into seperate file
  • Loading branch information
OriginalByteMe authored Dec 21, 2023
2 parents abc7bae + 19da294 commit b6c22db
Show file tree
Hide file tree
Showing 13 changed files with 473 additions and 396 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ jobs:
- name: Deploy cutout generator job
run: |
cd app
modal deploy --env=${{ steps.vars.outputs.environment }} grounded_cutouts.py
modal deploy --env=${{ steps.vars.outputs.environment }} app.cutout_handler::cutout_handler_stub
- name: Deploy s3_handler job
run: |
cd app/s3_handler
modal deploy --env=${{ steps.vars.outputs.environment }} app.py
modal deploy --env=${{ steps.vars.outputs.environment }} app.s3_handler.app
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# repos:
# - repo: https://github.com/psf/black
# hooks:
# - id: black
# language_version: python3.8 # Should match the version of Python you're using

# - repo: https://github.com/pycqa/isort
# hooks:
# - id: isort
# language_version: python3.8 # Should match the version of Python you're using
36 changes: 36 additions & 0 deletions app/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from modal import Image, Mount, Stub

s3_handler_image = Image.debian_slim().pip_install("boto3", "botocore")

cutout_generator_image = (
Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3")
.pip_install(
"segment-anything", "opencv-python", "botocore", "boto3", "fastapi", "starlette"
)
.run_commands(
"apt-get update",
"apt-get install -y git wget libgl1-mesa-glx libglib2.0-0",
"git clone https://github.com/IDEA-Research/GroundingDINO.git",
"pip install -q -e GroundingDINO/",
"mkdir -p /weights",
"mkdir -p /data",
"pip uninstall -y supervision",
"pip uninstall -y opencv-python",
"pip install opencv-python==4.8.0.74",
"pip install -q supervision==0.6.0",
"wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/",
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/",
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -P weights/",
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -P weights/",
)
)

local_packages = Mount.from_local_python_packages(
"app.cutout_handler.dino",
"app.cutout_handler.segment",
"app.cutout_handler.s3_handler",
"app.cutout_handler.grounded_cutouts",
)

cutout_handler_stub = Stub(image=cutout_generator_image, name="cutout_generator")
s3_handler_stub = Stub(image=s3_handler_image, name="s3_handler")
3 changes: 3 additions & 0 deletions app/cutout_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from app.common import cutout_handler_stub, s3_handler_stub

from .server import cutout_app
15 changes: 11 additions & 4 deletions app/dino.py → app/cutout_handler/dino.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import torch
from groundingdino.util.inference import Model
from typing import List

from app.common import cutout_handler_stub, cutout_generator_image

with cutout_generator_image.imports():
import torch
from groundingdino.util.inference import Model

cutout_handler_stub.cls()


class Dino:
""" A class for object detection using GroundingDINO.
"""
"""A class for object detection using GroundingDINO."""

def __init__(
self,
classes,
Expand All @@ -14,6 +20,7 @@ def __init__(
model_config_path,
model_checkpoint_path,
):

self.classes = classes
self.box_threshold = box_threshold
self.text_threshold = text_threshold
Expand Down
151 changes: 151 additions & 0 deletions app/cutout_handler/grounded_cutouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import io
import logging
import os
from typing import Dict, List

import cv2
import numpy as np
import supervision as sv
from fastapi import Body, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from modal import Image, Mount, Secret, Stub, asgi_app, method
from starlette.requests import Request

from .dino import Dino
from .s3_handler import Boto3Client
from .segment import Segmenter

# ======================
# Constants
# ======================
HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
GROUNDING_DINO_CONFIG_PATH = os.path.join(
HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
)
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(
HOME, "weights", "groundingdino_swint_ogc.pth"
)
SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth")
SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth")

# ======================
# Logging
# ======================
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

c_handler = logging.StreamHandler()
c_handler.setLevel(logging.DEBUG)

c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
c_handler.setFormatter(c_format)

logger.addHandler(c_handler)


class CutoutCreator:
def __init__(
self,
classes: str,
grounding_dino_config_path: str,
grounding_dino_checkpoint_path: str,
encoder_version: str = "vit_b",
):
self.classes = classes
self.grounding_dino_config_path = grounding_dino_config_path
self.grounding_dino_checkpoint_path = grounding_dino_checkpoint_path
self.encoder_version = encoder_version
self.HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
self.s3 = Boto3Client()
self.dino = Dino(
classes=self.classes,
box_threshold=0.35,
text_threshold=0.25,
model_config_path=self.grounding_dino_config_path,
model_checkpoint_path=self.grounding_dino_checkpoint_path,
)
self.mask_annotator = sv.MaskAnnotator()

encoder_checkpoint_paths = {
"vit_b": SAM_CHECKPOINT_PATH_LOW,
"vit_l": SAM_CHECKPOINT_PATH_MID,
"vit_h": SAM_CHECKPOINT_PATH_HIGH,
}

self.sam_checkpoint_path = encoder_checkpoint_paths.get(self.encoder_version)
self.segment = Segmenter(
sam_encoder_version=self.encoder_version,
sam_checkpoint_path=self.sam_checkpoint_path,
)

def create_annotated_image(self, image, image_name, detections: Dict[str, list]):
"""Create a highlighted annotated image from an image and detections.
Args:
image (File): Image to be used for creating the annotated image.
image_name (string): name of image
detections (Dict[str, list]): annotations for the image
"""
annotated_image = self.mask_annotator.annotate(
scene=image, detections=detections
)
# Convert annotated image to bytes
img_bytes = io.BytesIO()
Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG")
img_bytes.seek(0)
# Upload bytes to S3
self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png")

def create_cutouts(self, image_name):
"""Create cutouts from an image and upload them to S3.
Args:
image_name (string): name of image
"""

# Define paths
data_path = os.path.join(HOME, "data")
cutouts_path = os.path.join(HOME, "cutouts")

# Download image from s3
image_path = self.s3.download_from_s3(data_path, image_name)
if image_path is None:
logger.error(f"Failed to download image {image_name} from S3")
return

# Check if image exists
if not os.path.exists(image_path):
logger.error(f"Image {image_name} not found in folder {image_path}")
return

# Create cutouts directory if it doesn't exist
os.makedirs(cutouts_path, exist_ok=True)

# Read image
image = cv2.imread(image_path)

# Predict and segment image
detections = self.dino.predict(image)
masks = self.segment.segment(image, detections.xyxy)

# Apply each mask to the image
for i, mask in enumerate(masks):
# Ensure the mask is a boolean array
mask = mask.astype(bool)

# Apply the mask to create a cutout
cutout = np.zeros_like(image)
cutout[mask] = image[mask]

# Save the cutout
cutout_name = f"{image_name}_cutout_{i}.png"
cutout_path = os.path.join(cutouts_path, cutout_name)
cv2.imwrite(cutout_path, cutout)

# Upload the cutout to S3
with open(cutout_path, "rb") as f:
self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}")

# Create annotated image
# self.create_annotated_image(image, f"{image_name}_{i}", detections)
8 changes: 6 additions & 2 deletions app/s3_handler.py → app/cutout_handler/s3_handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import boto3
import logging
from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError
from app.common import s3_handler_stub, s3_handler_image

with s3_handler_image.imports():
import boto3
from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError

s3_handler_stub.cls()
class Boto3Client:
def __init__(self):

self.s3 = boto3.client(
"s3",
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
Expand Down
15 changes: 12 additions & 3 deletions app/segment.py → app/cutout_handler/segment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import numpy as np
import torch
from segment_anything import sam_model_registry, SamPredictor
from app.common import cutout_handler_stub, cutout_generator_image


with cutout_generator_image.imports():
import torch
from segment_anything import SamPredictor, sam_model_registry
import numpy as np


cutout_handler_stub.cls()


class Segmenter:
import numpy as np

def __init__(
self,
sam_encoder_version: str,
Expand Down
Loading

0 comments on commit b6c22db

Please sign in to comment.