-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #20 from OriginalByteMe/19-move-cutout-class-back-…
…into-seperate-file 19 move cutout class back into seperate file
- Loading branch information
Showing
13 changed files
with
473 additions
and
396 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# repos: | ||
# - repo: https://github.com/psf/black | ||
# hooks: | ||
# - id: black | ||
# language_version: python3.8 # Should match the version of Python you're using | ||
|
||
# - repo: https://github.com/pycqa/isort | ||
# hooks: | ||
# - id: isort | ||
# language_version: python3.8 # Should match the version of Python you're using |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from modal import Image, Mount, Stub | ||
|
||
s3_handler_image = Image.debian_slim().pip_install("boto3", "botocore") | ||
|
||
cutout_generator_image = ( | ||
Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3") | ||
.pip_install( | ||
"segment-anything", "opencv-python", "botocore", "boto3", "fastapi", "starlette" | ||
) | ||
.run_commands( | ||
"apt-get update", | ||
"apt-get install -y git wget libgl1-mesa-glx libglib2.0-0", | ||
"git clone https://github.com/IDEA-Research/GroundingDINO.git", | ||
"pip install -q -e GroundingDINO/", | ||
"mkdir -p /weights", | ||
"mkdir -p /data", | ||
"pip uninstall -y supervision", | ||
"pip uninstall -y opencv-python", | ||
"pip install opencv-python==4.8.0.74", | ||
"pip install -q supervision==0.6.0", | ||
"wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/", | ||
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/", | ||
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -P weights/", | ||
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -P weights/", | ||
) | ||
) | ||
|
||
local_packages = Mount.from_local_python_packages( | ||
"app.cutout_handler.dino", | ||
"app.cutout_handler.segment", | ||
"app.cutout_handler.s3_handler", | ||
"app.cutout_handler.grounded_cutouts", | ||
) | ||
|
||
cutout_handler_stub = Stub(image=cutout_generator_image, name="cutout_generator") | ||
s3_handler_stub = Stub(image=s3_handler_image, name="s3_handler") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from app.common import cutout_handler_stub, s3_handler_stub | ||
|
||
from .server import cutout_app |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import io | ||
import logging | ||
import os | ||
from typing import Dict, List | ||
|
||
import cv2 | ||
import numpy as np | ||
import supervision as sv | ||
from fastapi import Body, FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from modal import Image, Mount, Secret, Stub, asgi_app, method | ||
from starlette.requests import Request | ||
|
||
from .dino import Dino | ||
from .s3_handler import Boto3Client | ||
from .segment import Segmenter | ||
|
||
# ====================== | ||
# Constants | ||
# ====================== | ||
HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) | ||
GROUNDING_DINO_CONFIG_PATH = os.path.join( | ||
HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" | ||
) | ||
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( | ||
HOME, "weights", "groundingdino_swint_ogc.pth" | ||
) | ||
SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") | ||
SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth") | ||
SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth") | ||
|
||
# ====================== | ||
# Logging | ||
# ====================== | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
c_handler = logging.StreamHandler() | ||
c_handler.setLevel(logging.DEBUG) | ||
|
||
c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") | ||
c_handler.setFormatter(c_format) | ||
|
||
logger.addHandler(c_handler) | ||
|
||
|
||
class CutoutCreator: | ||
def __init__( | ||
self, | ||
classes: str, | ||
grounding_dino_config_path: str, | ||
grounding_dino_checkpoint_path: str, | ||
encoder_version: str = "vit_b", | ||
): | ||
self.classes = classes | ||
self.grounding_dino_config_path = grounding_dino_config_path | ||
self.grounding_dino_checkpoint_path = grounding_dino_checkpoint_path | ||
self.encoder_version = encoder_version | ||
self.HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) | ||
self.s3 = Boto3Client() | ||
self.dino = Dino( | ||
classes=self.classes, | ||
box_threshold=0.35, | ||
text_threshold=0.25, | ||
model_config_path=self.grounding_dino_config_path, | ||
model_checkpoint_path=self.grounding_dino_checkpoint_path, | ||
) | ||
self.mask_annotator = sv.MaskAnnotator() | ||
|
||
encoder_checkpoint_paths = { | ||
"vit_b": SAM_CHECKPOINT_PATH_LOW, | ||
"vit_l": SAM_CHECKPOINT_PATH_MID, | ||
"vit_h": SAM_CHECKPOINT_PATH_HIGH, | ||
} | ||
|
||
self.sam_checkpoint_path = encoder_checkpoint_paths.get(self.encoder_version) | ||
self.segment = Segmenter( | ||
sam_encoder_version=self.encoder_version, | ||
sam_checkpoint_path=self.sam_checkpoint_path, | ||
) | ||
|
||
def create_annotated_image(self, image, image_name, detections: Dict[str, list]): | ||
"""Create a highlighted annotated image from an image and detections. | ||
Args: | ||
image (File): Image to be used for creating the annotated image. | ||
image_name (string): name of image | ||
detections (Dict[str, list]): annotations for the image | ||
""" | ||
annotated_image = self.mask_annotator.annotate( | ||
scene=image, detections=detections | ||
) | ||
# Convert annotated image to bytes | ||
img_bytes = io.BytesIO() | ||
Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG") | ||
img_bytes.seek(0) | ||
# Upload bytes to S3 | ||
self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png") | ||
|
||
def create_cutouts(self, image_name): | ||
"""Create cutouts from an image and upload them to S3. | ||
Args: | ||
image_name (string): name of image | ||
""" | ||
|
||
# Define paths | ||
data_path = os.path.join(HOME, "data") | ||
cutouts_path = os.path.join(HOME, "cutouts") | ||
|
||
# Download image from s3 | ||
image_path = self.s3.download_from_s3(data_path, image_name) | ||
if image_path is None: | ||
logger.error(f"Failed to download image {image_name} from S3") | ||
return | ||
|
||
# Check if image exists | ||
if not os.path.exists(image_path): | ||
logger.error(f"Image {image_name} not found in folder {image_path}") | ||
return | ||
|
||
# Create cutouts directory if it doesn't exist | ||
os.makedirs(cutouts_path, exist_ok=True) | ||
|
||
# Read image | ||
image = cv2.imread(image_path) | ||
|
||
# Predict and segment image | ||
detections = self.dino.predict(image) | ||
masks = self.segment.segment(image, detections.xyxy) | ||
|
||
# Apply each mask to the image | ||
for i, mask in enumerate(masks): | ||
# Ensure the mask is a boolean array | ||
mask = mask.astype(bool) | ||
|
||
# Apply the mask to create a cutout | ||
cutout = np.zeros_like(image) | ||
cutout[mask] = image[mask] | ||
|
||
# Save the cutout | ||
cutout_name = f"{image_name}_cutout_{i}.png" | ||
cutout_path = os.path.join(cutouts_path, cutout_name) | ||
cv2.imwrite(cutout_path, cutout) | ||
|
||
# Upload the cutout to S3 | ||
with open(cutout_path, "rb") as f: | ||
self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}") | ||
|
||
# Create annotated image | ||
# self.create_annotated_image(image, f"{image_name}_{i}", detections) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.