-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Develop #23
Develop #23
Changes from all commits
28fc3c3
becc0ad
6fbe6a3
781f571
fd4c4e4
4ad6323
5bd62ee
b4a1469
8adeab4
52e8ca2
169bcff
7e00025
6f0ef3c
7dfced8
1991005
79feed8
bb26ae0
5f019b0
7ce526d
97e44ef
16acc83
abc7bae
c2d65f4
4e52262
e920dc7
19da294
b6c22db
3b65f39
cb1e222
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# repos: | ||
# - repo: https://github.com/psf/black | ||
# hooks: | ||
# - id: black | ||
# language_version: python3.8 # Should match the version of Python you're using | ||
|
||
# - repo: https://github.com/pycqa/isort | ||
# hooks: | ||
# - id: isort | ||
# language_version: python3.8 # Should match the version of Python you're using |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[MESSAGES CONTROL] | ||
disable= | ||
missing-docstring, | ||
import-outside-toplevel, | ||
line-too-long, | ||
attribute-defined-outside-init, | ||
no-member |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from modal import Image, Mount, Stub | ||
|
||
s3_handler_image = Image.debian_slim().pip_install("boto3", "botocore") | ||
|
||
cutout_generator_image = ( | ||
Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3") | ||
.pip_install( | ||
"segment-anything", "opencv-python", "botocore", "boto3", "fastapi", "starlette" | ||
) | ||
.run_commands( | ||
"apt-get update", | ||
"apt-get install -y git wget libgl1-mesa-glx libglib2.0-0", | ||
"git clone https://github.com/IDEA-Research/GroundingDINO.git", | ||
"pip install -q -e GroundingDINO/", | ||
"mkdir -p /weights", | ||
"mkdir -p /data", | ||
"pip uninstall -y supervision", | ||
"pip uninstall -y opencv-python", | ||
"pip install opencv-python==4.8.0.74", | ||
"pip install -q supervision==0.6.0", | ||
"wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/", | ||
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/", | ||
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -P weights/", | ||
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -P weights/", | ||
) | ||
) | ||
|
||
local_packages = Mount.from_local_python_packages( | ||
"app.cutout_handler.dino", | ||
"app.cutout_handler.segment", | ||
"app.cutout_handler.s3_handler", | ||
"app.cutout_handler.grounded_cutouts", | ||
) | ||
|
||
cutout_handler_stub = Stub(image=cutout_generator_image, name="cutout_generator") | ||
s3_handler_stub = Stub(image=s3_handler_image, name="s3_handler") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from app.common import cutout_handler_stub, s3_handler_stub | ||
|
||
from .server import cutout_app |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import io | ||
import logging | ||
import os | ||
from typing import Dict, List | ||
|
||
import cv2 | ||
import numpy as np | ||
import supervision as sv | ||
from fastapi import Body, FastAPI | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from modal import Image, Mount, Secret, Stub, asgi_app, method | ||
from starlette.requests import Request | ||
|
||
from .dino import Dino | ||
from .s3_handler import Boto3Client | ||
from .segment import Segmenter | ||
|
||
# ====================== | ||
# Constants | ||
# ====================== | ||
HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) | ||
GROUNDING_DINO_CONFIG_PATH = os.path.join( | ||
HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" | ||
) | ||
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( | ||
HOME, "weights", "groundingdino_swint_ogc.pth" | ||
) | ||
SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") | ||
SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth") | ||
SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth") | ||
|
||
# ====================== | ||
# Logging | ||
# ====================== | ||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.DEBUG) | ||
|
||
c_handler = logging.StreamHandler() | ||
c_handler.setLevel(logging.DEBUG) | ||
|
||
c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") | ||
c_handler.setFormatter(c_format) | ||
|
||
logger.addHandler(c_handler) | ||
|
||
|
||
class CutoutCreator: | ||
def __init__( | ||
self, | ||
classes: str, | ||
grounding_dino_config_path: str, | ||
grounding_dino_checkpoint_path: str, | ||
encoder_version: str = "vit_b", | ||
): | ||
self.classes = classes | ||
self.grounding_dino_config_path = grounding_dino_config_path | ||
self.grounding_dino_checkpoint_path = grounding_dino_checkpoint_path | ||
self.encoder_version = encoder_version | ||
self.HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) | ||
self.s3 = Boto3Client() | ||
self.dino = Dino( | ||
classes=self.classes, | ||
box_threshold=0.35, | ||
text_threshold=0.25, | ||
model_config_path=self.grounding_dino_config_path, | ||
model_checkpoint_path=self.grounding_dino_checkpoint_path, | ||
) | ||
self.mask_annotator = sv.MaskAnnotator() | ||
|
||
encoder_checkpoint_paths = { | ||
"vit_b": SAM_CHECKPOINT_PATH_LOW, | ||
"vit_l": SAM_CHECKPOINT_PATH_MID, | ||
"vit_h": SAM_CHECKPOINT_PATH_HIGH, | ||
} | ||
|
||
self.sam_checkpoint_path = encoder_checkpoint_paths.get(self.encoder_version) | ||
self.segment = Segmenter( | ||
sam_encoder_version=self.encoder_version, | ||
sam_checkpoint_path=self.sam_checkpoint_path, | ||
) | ||
|
||
def create_annotated_image(self, image, image_name, detections: Dict[str, list]): | ||
"""Create a highlighted annotated image from an image and detections. | ||
|
||
Args: | ||
image (File): Image to be used for creating the annotated image. | ||
image_name (string): name of image | ||
detections (Dict[str, list]): annotations for the image | ||
""" | ||
annotated_image = self.mask_annotator.annotate( | ||
scene=image, detections=detections | ||
) | ||
# Convert annotated image to bytes | ||
img_bytes = io.BytesIO() | ||
Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG") | ||
img_bytes.seek(0) | ||
# Upload bytes to S3 | ||
self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png") | ||
|
||
def create_cutouts(self, image_name): | ||
"""Create cutouts from an image and upload them to S3. | ||
|
||
Args: | ||
image_name (string): name of image | ||
""" | ||
|
||
# Define paths | ||
data_path = os.path.join(HOME, "data") | ||
cutouts_path = os.path.join(HOME, "cutouts") | ||
|
||
# Download image from s3 | ||
image_path = self.s3.download_from_s3(data_path, image_name) | ||
if image_path is None: | ||
logger.error(f"Failed to download image {image_name} from S3") | ||
return | ||
|
||
# Check if image exists | ||
if not os.path.exists(image_path): | ||
logger.error(f"Image {image_name} not found in folder {image_path}") | ||
return | ||
|
||
# Create cutouts directory if it doesn't exist | ||
os.makedirs(cutouts_path, exist_ok=True) | ||
|
||
# Read image | ||
image = cv2.imread(image_path) | ||
|
||
# Predict and segment image | ||
detections = self.dino.predict(image) | ||
masks = self.segment.segment(image, detections.xyxy) | ||
|
||
# Apply each mask to the image | ||
for i, mask in enumerate(masks): | ||
# Ensure the mask is a boolean array | ||
mask = mask.astype(bool) | ||
|
||
# Apply the mask to create a cutout | ||
cutout = np.zeros_like(image) | ||
cutout[mask] = image[mask] | ||
|
||
# Save the cutout | ||
cutout_name = f"{image_name}_cutout_{i}.png" | ||
cutout_path = os.path.join(cutouts_path, cutout_name) | ||
cv2.imwrite(cutout_path, cutout) | ||
|
||
# Upload the cutout to S3 | ||
with open(cutout_path, "rb") as f: | ||
self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}") | ||
|
||
# Create annotated image | ||
# self.create_annotated_image(image, f"{image_name}_{i}", detections) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,9 +1,18 @@ | ||||||
import numpy as np | ||||||
import torch | ||||||
from segment_anything import sam_model_registry, SamPredictor | ||||||
from app.common import cutout_handler_stub, cutout_generator_image | ||||||
|
||||||
|
||||||
with cutout_generator_image.imports(): | ||||||
import torch | ||||||
from segment_anything import SamPredictor, sam_model_registry | ||||||
import numpy as np | ||||||
|
||||||
|
||||||
cutout_handler_stub.cls() | ||||||
|
||||||
|
||||||
class Segmenter: | ||||||
import numpy as np | ||||||
|
||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The import statement for - import numpy as np
+import numpy as np Committable suggestion
Suggested change
|
||||||
def __init__( | ||||||
self, | ||||||
sam_encoder_version: str, | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The addition of print statements for debugging is useful for development, but consider using logging instead of print statements for better control over the output and its format.
Committable suggestion