Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Develop #23

Merged
merged 29 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
28fc3c3
Update cutout creation and annotation process
OriginalByteMe Nov 18, 2023
becc0ad
Add print statements to s3_handler.py
OriginalByteMe Nov 18, 2023
6fbe6a3
Refactor create_cutouts method
OriginalByteMe Nov 18, 2023
781f571
Add environment variable for CI/CD workflow
OriginalByteMe Nov 18, 2023
fd4c4e4
Cleanup functions a bit
OriginalByteMe Nov 18, 2023
4ad6323
Add .pylintrc file to disable pylint warnings
OriginalByteMe Nov 18, 2023
5bd62ee
move cutouts file to legacy code
OriginalByteMe Nov 18, 2023
b4a1469
Refactor cutout creation and add SAM checkpoint
OriginalByteMe Nov 18, 2023
8adeab4
Add import statement for Boto3Client
OriginalByteMe Nov 20, 2023
52e8ca2
Remove keep_warm parameter from create_all_cutouts
OriginalByteMe Nov 21, 2023
169bcff
Fix indentation in ci-cd.yml file
OriginalByteMe Nov 21, 2023
7e00025
Fix formatting and add warmup endpoint
OriginalByteMe Nov 21, 2023
6f0ef3c
Update SAM checkpoint paths and add accuracy level
OriginalByteMe Nov 21, 2023
7dfced8
Merge pull request #14 from OriginalByteMe/feature/optimizations
OriginalByteMe Nov 22, 2023
1991005
Update deployment order in ci-cd.yml
OriginalByteMe Nov 22, 2023
79feed8
Merge pull request #15 from OriginalByteMe/feature/optimizations
OriginalByteMe Nov 22, 2023
bb26ae0
Fix deployment command in ci-cd.yml
OriginalByteMe Nov 22, 2023
5f019b0
Merge pull request #16 from OriginalByteMe/feature/optimizations
OriginalByteMe Nov 22, 2023
7ce526d
Update SAM checkpoint path based on encoder
OriginalByteMe Dec 21, 2023
97e44ef
Improve code quality and readability
OriginalByteMe Dec 21, 2023
16acc83
Update app/grounded_cutouts.py
OriginalByteMe Dec 21, 2023
abc7bae
Merge pull request #18 from OriginalByteMe/fix/incorrect_encoder_version
OriginalByteMe Dec 21, 2023
c2d65f4
Refactor position + imports
OriginalByteMe Dec 21, 2023
4e52262
Clean up imports
OriginalByteMe Dec 21, 2023
e920dc7
Fix logging statements in cutout_handler
OriginalByteMe Dec 21, 2023
19da294
Hopefully fix deployment
OriginalByteMe Dec 21, 2023
b6c22db
Merge pull request #20 from OriginalByteMe/19-move-cutout-class-back-…
OriginalByteMe Dec 21, 2023
3b65f39
Update deployment configuration
OriginalByteMe Dec 21, 2023
cb1e222
Merge pull request #22 from OriginalByteMe/21-fix-github-action-failu…
OriginalByteMe Dec 21, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches:
- main
- develop

jobs:
deploy:
Expand All @@ -27,12 +28,18 @@ jobs:
python -m pip install --upgrade pip
pip install modal

- name: Set environment
id: vars
run: |
if [[ "${{ github.ref }}" == "refs/heads/main" ]]; then
echo "::set-output name=environment::main"
else
echo "::set-output name=environment::dev"
fi
- name: Deploy cutout generator job
run: |
cd app
modal deploy grounded_cutouts.py
modal deploy --env=${{ steps.vars.outputs.environment }} app.cutout_handler::cutout_handler_stub
- name: Deploy s3_handler job
run: |
cd app/s3_handler
modal deploy app.py
modal deploy --env=${{ steps.vars.outputs.environment }} app.s3_handler.app

10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# repos:
# - repo: https://github.com/psf/black
# hooks:
# - id: black
# language_version: python3.8 # Should match the version of Python you're using

# - repo: https://github.com/pycqa/isort
# hooks:
# - id: isort
# language_version: python3.8 # Should match the version of Python you're using
7 changes: 7 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[MESSAGES CONTROL]
disable=
missing-docstring,
import-outside-toplevel,
line-too-long,
attribute-defined-outside-init,
no-member
36 changes: 36 additions & 0 deletions app/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from modal import Image, Mount, Stub

s3_handler_image = Image.debian_slim().pip_install("boto3", "botocore")

cutout_generator_image = (
Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3")
.pip_install(
"segment-anything", "opencv-python", "botocore", "boto3", "fastapi", "starlette"
)
.run_commands(
"apt-get update",
"apt-get install -y git wget libgl1-mesa-glx libglib2.0-0",
"git clone https://github.com/IDEA-Research/GroundingDINO.git",
"pip install -q -e GroundingDINO/",
"mkdir -p /weights",
"mkdir -p /data",
"pip uninstall -y supervision",
"pip uninstall -y opencv-python",
"pip install opencv-python==4.8.0.74",
"pip install -q supervision==0.6.0",
"wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/",
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/",
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -P weights/",
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -P weights/",
)
)

local_packages = Mount.from_local_python_packages(
"app.cutout_handler.dino",
"app.cutout_handler.segment",
"app.cutout_handler.s3_handler",
"app.cutout_handler.grounded_cutouts",
)

cutout_handler_stub = Stub(image=cutout_generator_image, name="cutout_generator")
s3_handler_stub = Stub(image=s3_handler_image, name="s3_handler")
3 changes: 3 additions & 0 deletions app/cutout_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from app.common import cutout_handler_stub, s3_handler_stub

from .server import cutout_app
15 changes: 11 additions & 4 deletions app/dino.py → app/cutout_handler/dino.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import torch
from groundingdino.util.inference import Model
from typing import List

from app.common import cutout_handler_stub, cutout_generator_image

with cutout_generator_image.imports():
import torch
from groundingdino.util.inference import Model

cutout_handler_stub.cls()


class Dino:
""" A class for object detection using GroundingDINO.
"""
"""A class for object detection using GroundingDINO."""

def __init__(
self,
classes,
Expand All @@ -14,6 +20,7 @@ def __init__(
model_config_path,
model_checkpoint_path,
):

self.classes = classes
self.box_threshold = box_threshold
self.text_threshold = text_threshold
Expand Down
151 changes: 151 additions & 0 deletions app/cutout_handler/grounded_cutouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import io
import logging
import os
from typing import Dict, List

import cv2
import numpy as np
import supervision as sv
from fastapi import Body, FastAPI
from fastapi.middleware.cors import CORSMiddleware
from modal import Image, Mount, Secret, Stub, asgi_app, method
from starlette.requests import Request

from .dino import Dino
from .s3_handler import Boto3Client
from .segment import Segmenter

# ======================
# Constants
# ======================
HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
GROUNDING_DINO_CONFIG_PATH = os.path.join(
HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
)
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(
HOME, "weights", "groundingdino_swint_ogc.pth"
)
SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth")
SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth")

# ======================
# Logging
# ======================
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

c_handler = logging.StreamHandler()
c_handler.setLevel(logging.DEBUG)

c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s")
c_handler.setFormatter(c_format)

logger.addHandler(c_handler)


class CutoutCreator:
def __init__(
self,
classes: str,
grounding_dino_config_path: str,
grounding_dino_checkpoint_path: str,
encoder_version: str = "vit_b",
):
self.classes = classes
self.grounding_dino_config_path = grounding_dino_config_path
self.grounding_dino_checkpoint_path = grounding_dino_checkpoint_path
self.encoder_version = encoder_version
self.HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
self.s3 = Boto3Client()
self.dino = Dino(
classes=self.classes,
box_threshold=0.35,
text_threshold=0.25,
model_config_path=self.grounding_dino_config_path,
model_checkpoint_path=self.grounding_dino_checkpoint_path,
)
self.mask_annotator = sv.MaskAnnotator()

encoder_checkpoint_paths = {
"vit_b": SAM_CHECKPOINT_PATH_LOW,
"vit_l": SAM_CHECKPOINT_PATH_MID,
"vit_h": SAM_CHECKPOINT_PATH_HIGH,
}

self.sam_checkpoint_path = encoder_checkpoint_paths.get(self.encoder_version)
self.segment = Segmenter(
sam_encoder_version=self.encoder_version,
sam_checkpoint_path=self.sam_checkpoint_path,
)

def create_annotated_image(self, image, image_name, detections: Dict[str, list]):
"""Create a highlighted annotated image from an image and detections.

Args:
image (File): Image to be used for creating the annotated image.
image_name (string): name of image
detections (Dict[str, list]): annotations for the image
"""
annotated_image = self.mask_annotator.annotate(
scene=image, detections=detections
)
# Convert annotated image to bytes
img_bytes = io.BytesIO()
Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG")
img_bytes.seek(0)
# Upload bytes to S3
self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png")

def create_cutouts(self, image_name):
"""Create cutouts from an image and upload them to S3.

Args:
image_name (string): name of image
"""

# Define paths
data_path = os.path.join(HOME, "data")
cutouts_path = os.path.join(HOME, "cutouts")

# Download image from s3
image_path = self.s3.download_from_s3(data_path, image_name)
if image_path is None:
logger.error(f"Failed to download image {image_name} from S3")
return

# Check if image exists
if not os.path.exists(image_path):
logger.error(f"Image {image_name} not found in folder {image_path}")
return

# Create cutouts directory if it doesn't exist
os.makedirs(cutouts_path, exist_ok=True)

# Read image
image = cv2.imread(image_path)

# Predict and segment image
detections = self.dino.predict(image)
masks = self.segment.segment(image, detections.xyxy)

# Apply each mask to the image
for i, mask in enumerate(masks):
# Ensure the mask is a boolean array
mask = mask.astype(bool)

# Apply the mask to create a cutout
cutout = np.zeros_like(image)
cutout[mask] = image[mask]

# Save the cutout
cutout_name = f"{image_name}_cutout_{i}.png"
cutout_path = os.path.join(cutouts_path, cutout_name)
cv2.imwrite(cutout_path, cutout)

# Upload the cutout to S3
with open(cutout_path, "rb") as f:
self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}")

# Create annotated image
# self.create_annotated_image(image, f"{image_name}_{i}", detections)
11 changes: 9 additions & 2 deletions app/s3_handler.py → app/cutout_handler/s3_handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import os
import boto3
import logging
from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError
from app.common import s3_handler_stub, s3_handler_image

with s3_handler_image.imports():
import boto3
from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError

s3_handler_stub.cls()
class Boto3Client:
def __init__(self):

self.s3 = boto3.client(
"s3",
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
Expand All @@ -20,6 +24,9 @@ def download_from_s3(self, save_path, image_name):
s3_client.download_file(
os.environ["CUTOUT_BUCKET"], f"images/{image_name}", file_path
)
print(f"Successfully downloaded {image_name} to {file_path}")
print("Directory contents:")
print(os.listdir(save_path))
Comment on lines +27 to +29
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The addition of print statements for debugging is useful for development, but consider using logging instead of print statements for better control over the output and its format.

-            print(f"Successfully downloaded {image_name} to {file_path}")
-            print("Directory contents:")
-            print(os.listdir(save_path))
+            logger.info(f"Successfully downloaded {image_name} to {file_path}")
+            logger.debug("Directory contents: %s", os.listdir(save_path))

Committable suggestion

IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
print(f"Successfully downloaded {image_name} to {file_path}")
print("Directory contents:")
print(os.listdir(save_path))
logger.info(f"Successfully downloaded {image_name} to {file_path}")
logger.debug("Directory contents: %s", os.listdir(save_path))

except ClientError as e:
print("BOTO error: ", e)
print(
Expand Down
15 changes: 12 additions & 3 deletions app/segment.py → app/cutout_handler/segment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import numpy as np
import torch
from segment_anything import sam_model_registry, SamPredictor
from app.common import cutout_handler_stub, cutout_generator_image


with cutout_generator_image.imports():
import torch
from segment_anything import SamPredictor, sam_model_registry
import numpy as np


cutout_handler_stub.cls()


class Segmenter:
import numpy as np

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import statement for numpy should be at the module level, not within the class definition, to follow Python's convention and improve readability.

-    import numpy as np
+import numpy as np

Committable suggestion

IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
import numpy as np
import numpy as np

def __init__(
self,
sam_encoder_version: str,
Expand Down
Loading
Loading