From c2d65f4223850379f6afc35b47d65c93b6d2886e Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 21:06:11 +0800 Subject: [PATCH 1/4] Refactor position + imports --- .pre-commit-config.yaml | 10 + app/common/__init__.py | 36 +++ app/cutout_handler/__init__.py | 3 + app/{ => cutout_handler}/dino.py | 13 +- app/cutout_handler/grounded_cutouts.py | 151 ++++++++++ app/{ => cutout_handler}/s3_handler.py | 24 +- app/{ => cutout_handler}/segment.py | 13 +- app/cutout_handler/server.py | 191 +++++++++++++ app/grounded_cutouts.py | 364 ------------------------- app/s3_handler/app.py | 42 ++- app/s3_handler/s3_handler.py | 10 +- legacy_code/cutouts.py | 20 +- 12 files changed, 481 insertions(+), 396 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 app/common/__init__.py create mode 100644 app/cutout_handler/__init__.py rename app/{ => cutout_handler}/dino.py (89%) create mode 100644 app/cutout_handler/grounded_cutouts.py rename app/{ => cutout_handler}/s3_handler.py (86%) rename app/{ => cutout_handler}/segment.py (79%) create mode 100644 app/cutout_handler/server.py delete mode 100644 app/grounded_cutouts.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..87d2be8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +# repos: +# - repo: https://github.com/psf/black +# hooks: +# - id: black +# language_version: python3.8 # Should match the version of Python you're using + +# - repo: https://github.com/pycqa/isort +# hooks: +# - id: isort +# language_version: python3.8 # Should match the version of Python you're using \ No newline at end of file diff --git a/app/common/__init__.py b/app/common/__init__.py new file mode 100644 index 0000000..1f6ba4d --- /dev/null +++ b/app/common/__init__.py @@ -0,0 +1,36 @@ +from modal import Image, Mount, Stub + +s3_handler_image = Image.debian_slim().pip_install("boto3", "botocore") + +cutout_generator_image = ( + Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3") + .pip_install( + "segment-anything", "opencv-python", "botocore", "boto3", "fastapi", "starlette" + ) + .run_commands( + "apt-get update", + "apt-get install -y git wget libgl1-mesa-glx libglib2.0-0", + "git clone https://github.com/IDEA-Research/GroundingDINO.git", + "pip install -q -e GroundingDINO/", + "mkdir -p /weights", + "mkdir -p /data", + "pip uninstall -y supervision", + "pip uninstall -y opencv-python", + "pip install opencv-python==4.8.0.74", + "pip install -q supervision==0.6.0", + "wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/", + "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/", + "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -P weights/", + "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -P weights/", + ) +) + +local_packages = Mount.from_local_python_packages( + "app.cutout_handler.dino", + "app.cutout_handler.segment", + "app.cutout_handler.s3_handler", + "app.cutout_handler.grounded_cutouts", +) + +cutout_handler_stub = Stub(image=cutout_generator_image, name="cutout_generator") +s3_handler_stub = Stub(image=s3_handler_image, name="s3_handler") diff --git a/app/cutout_handler/__init__.py b/app/cutout_handler/__init__.py new file mode 100644 index 0000000..ed7f30f --- /dev/null +++ b/app/cutout_handler/__init__.py @@ -0,0 +1,3 @@ +from app.common import cutout_handler_stub, s3_handler_stub + +from .server import cutout_app diff --git a/app/dino.py b/app/cutout_handler/dino.py similarity index 89% rename from app/dino.py rename to app/cutout_handler/dino.py index 1e1528f..fb5bbb1 100644 --- a/app/dino.py +++ b/app/cutout_handler/dino.py @@ -1,11 +1,13 @@ -import torch -from groundingdino.util.inference import Model from typing import List +from app.common import cutout_handler_stub + +cutout_handler_stub.cls() + class Dino: - """ A class for object detection using GroundingDINO. - """ + """A class for object detection using GroundingDINO.""" + def __init__( self, classes, @@ -14,6 +16,9 @@ def __init__( model_config_path, model_checkpoint_path, ): + import torch + from groundingdino.util.inference import Model + self.classes = classes self.box_threshold = box_threshold self.text_threshold = text_threshold diff --git a/app/cutout_handler/grounded_cutouts.py b/app/cutout_handler/grounded_cutouts.py new file mode 100644 index 0000000..dd1be39 --- /dev/null +++ b/app/cutout_handler/grounded_cutouts.py @@ -0,0 +1,151 @@ +import io +import logging +import os +from typing import Dict, List + +import cv2 +import numpy as np +import supervision as sv +from fastapi import Body, FastAPI +from fastapi.middleware.cors import CORSMiddleware +from modal import Image, Mount, Secret, Stub, asgi_app, method +from starlette.requests import Request + +from .dino import Dino +from .s3_handler import Boto3Client +from .segment import Segmenter + +# ====================== +# Constants +# ====================== +HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) +GROUNDING_DINO_CONFIG_PATH = os.path.join( + HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" +) +GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( + HOME, "weights", "groundingdino_swint_ogc.pth" +) +SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") +SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth") +SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth") + +# ====================== +# Logging +# ====================== +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +c_handler = logging.StreamHandler() +c_handler.setLevel(logging.DEBUG) + +c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") +c_handler.setFormatter(c_format) + +logger.addHandler(c_handler) + + +class CutoutCreator: + def __init__( + self, + classes: str, + grounding_dino_config_path: str, + grounding_dino_checkpoint_path: str, + encoder_version: str = "vit_b", + ): + self.classes = classes + self.grounding_dino_config_path = grounding_dino_config_path + self.grounding_dino_checkpoint_path = grounding_dino_checkpoint_path + self.encoder_version = encoder_version + self.HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + self.s3 = Boto3Client() + self.dino = Dino( + classes=self.classes, + box_threshold=0.35, + text_threshold=0.25, + model_config_path=self.grounding_dino_config_path, + model_checkpoint_path=self.grounding_dino_checkpoint_path, + ) + self.mask_annotator = sv.MaskAnnotator() + + encoder_checkpoint_paths = { + "vit_b": SAM_CHECKPOINT_PATH_LOW, + "vit_l": SAM_CHECKPOINT_PATH_MID, + "vit_h": SAM_CHECKPOINT_PATH_HIGH, + } + + self.sam_checkpoint_path = encoder_checkpoint_paths.get(self.encoder_version) + self.segment = Segmenter( + sam_encoder_version=self.encoder_version, + sam_checkpoint_path=self.sam_checkpoint_path, + ) + + def create_annotated_image(self, image, image_name, detections: Dict[str, list]): + """Create a highlighted annotated image from an image and detections. + + Args: + image (File): Image to be used for creating the annotated image. + image_name (string): name of image + detections (Dict[str, list]): annotations for the image + """ + annotated_image = self.mask_annotator.annotate( + scene=image, detections=detections + ) + # Convert annotated image to bytes + img_bytes = io.BytesIO() + Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG") + img_bytes.seek(0) + # Upload bytes to S3 + self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png") + + def create_cutouts(self, image_name): + """Create cutouts from an image and upload them to S3. + + Args: + image_name (string): name of image + """ + + # Define paths + data_path = os.path.join(HOME, "data") + cutouts_path = os.path.join(HOME, "cutouts") + + # Download image from s3 + image_path = self.s3.download_from_s3(data_path, image_name) + if image_path is None: + print(f"Failed to download image {image_name} from S3") + return + + # Check if image exists + if not os.path.exists(image_path): + print(f"Image {image_name} not found in folder {image_path}") + return + + # Create cutouts directory if it doesn't exist + os.makedirs(cutouts_path, exist_ok=True) + + # Read image + image = cv2.imread(image_path) + + # Predict and segment image + detections = self.dino.predict(image) + masks = self.segment.segment(image, detections.xyxy) + + # Apply each mask to the image + for i, mask in enumerate(masks): + # Ensure the mask is a boolean array + mask = mask.astype(bool) + + # Apply the mask to create a cutout + cutout = np.zeros_like(image) + cutout[mask] = image[mask] + + # Save the cutout + cutout_name = f"{image_name}_cutout_{i}.png" + cutout_path = os.path.join(cutouts_path, cutout_name) + cv2.imwrite(cutout_path, cutout) + + # Upload the cutout to S3 + with open(cutout_path, "rb") as f: + self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}") + + # Create annotated image + # self.create_annotated_image(image, f"{image_name}_{i}", detections) diff --git a/app/s3_handler.py b/app/cutout_handler/s3_handler.py similarity index 86% rename from app/s3_handler.py rename to app/cutout_handler/s3_handler.py index fc88a43..e2359f5 100644 --- a/app/s3_handler.py +++ b/app/cutout_handler/s3_handler.py @@ -1,11 +1,15 @@ -import os -import boto3 import logging -from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError +import os + +from app.common import s3_handler_stub + +s3_handler_stub.cls() class Boto3Client: def __init__(self): + import boto3 + self.s3 = boto3.client( "s3", aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], @@ -14,6 +18,9 @@ def __init__(self): ) def download_from_s3(self, save_path, image_name): + import boto3 + from botocore.exceptions import ClientError + s3_client = boto3.client("s3") file_path = os.path.join(save_path, image_name) try: @@ -33,6 +40,8 @@ def download_from_s3(self, save_path, image_name): return file_path def upload_to_s3(self, file_body, folder, image_name): + from botocore.exceptions import BotoCoreError, NoCredentialsError + try: self.s3.put_object( Body=file_body, @@ -51,6 +60,8 @@ def upload_to_s3(self, file_body, folder, image_name): raise def generate_presigned_urls(self, folder, expiration=3600): + from botocore.exceptions import ClientError + try: response = self.s3.list_objects_v2( Bucket=os.environ["CUTOUT_BUCKET"], Prefix=folder @@ -75,11 +86,16 @@ def generate_presigned_urls(self, folder, expiration=3600): return urls def generate_presigned_url_with_metadata(self, folder, key, expiration=3600): + from botocore.exceptions import ClientError + try: # Generate presigned URL url = self.s3.generate_presigned_url( "get_object", - Params={"Bucket": os.environ["CUTOUT_BUCKET"], "Key": f"{folder}/{key}"}, + Params={ + "Bucket": os.environ["CUTOUT_BUCKET"], + "Key": f"{folder}/{key}", + }, ExpiresIn=expiration, ) # Get object metadata diff --git a/app/segment.py b/app/cutout_handler/segment.py similarity index 79% rename from app/segment.py rename to app/cutout_handler/segment.py index 33e7a88..3eda724 100644 --- a/app/segment.py +++ b/app/cutout_handler/segment.py @@ -1,14 +1,19 @@ -import numpy as np -import torch -from segment_anything import sam_model_registry, SamPredictor +from app.common import cutout_handler_stub + +cutout_handler_stub.cls() class Segmenter: + import numpy as np + def __init__( self, sam_encoder_version: str, sam_checkpoint_path: str, ): + import torch + from segment_anything import SamPredictor, sam_model_registry + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.sam = sam_model_registry[sam_encoder_version]( checkpoint=sam_checkpoint_path @@ -16,6 +21,8 @@ def __init__( self.sam_predictor = SamPredictor(self.sam) def segment(self, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: + import numpy as np + self.sam_predictor.set_image(image) result_masks = [] for box in xyxy: diff --git a/app/cutout_handler/server.py b/app/cutout_handler/server.py new file mode 100644 index 0000000..a982c17 --- /dev/null +++ b/app/cutout_handler/server.py @@ -0,0 +1,191 @@ +import logging +import os +from typing import Dict, List + +from fastapi import Body, FastAPI +from fastapi.middleware.cors import CORSMiddleware +from modal import Secret, asgi_app +from starlette.requests import Request + +from app.common import cutout_handler_stub, local_packages + +from .grounded_cutouts import CutoutCreator + +# ====================== +# Constants +# ====================== +HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) +GROUNDING_DINO_CONFIG_PATH = os.path.join( + HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" +) +GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( + HOME, "weights", "groundingdino_swint_ogc.pth" +) +SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") +SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth") +SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth") + + +# ====================== +# Logging +# ====================== +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +c_handler = logging.StreamHandler() +c_handler.setLevel(logging.DEBUG) + +c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") +c_handler.setFormatter(c_format) + +logger.addHandler(c_handler) + + +# ====================== +# FastAPI Setup +# ====================== +app = FastAPI() + +# stub = Stub(name="cutout_generator") + +origins = [ + "http://localhost:3000", # local development + "https://cutouts.noahrijkaard.com", # main website +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/warmup") +async def warmup(): + """Warmup the container. + + Returns: + _type_: return message + """ + # Spins up the container and loads the models, this will make it easier to create cutouts later + CutoutCreator( + classes=[], + grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, + grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, + encoder_version="vit_b", + ) + + return "Warmed up!" + + +@app.post("/create-cutouts/{image_name}") +async def create_cutouts(image_name: str, request: Request): + """ + Create cutouts from an image and upload them to S3. + + Args: + image_name (str): Name of image to create cutouts from. + classes (List[str], optional): A list of classes for the AI to detect for. Defaults to Body(...). + + Returns: + _type_: _description_ + """ + from s3_handler import Boto3Client + + try: + # Log the start of the process + logger.info("Creating cutouts for image %s ", image_name) + + # Parse the request body as JSON + data = await request.json() + + # Get the classes and accuracy level from the JSON data + classes = data.get("classes", []) + accuracy_level = data.get("accuracy_level", "low") + logger.info("Classes: %s", classes) + logger.info("Accuracy level: %s", accuracy_level) + + # Select the SAM checkpoint path based on the accuracy level + accuracy_encoder_versions = { + "high": "vit_h", + "mid": "vit_l", + "low": "vit_b", + } + encoder_version = accuracy_encoder_versions.get(accuracy_level, "vit_b") + + # Initialize the S3 client and the CutoutCreator + s3 = Boto3Client() + cutout = CutoutCreator( + classes=classes, + grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, + grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, + encoder_version=encoder_version, + ) + + # Create the cutouts + print(f"CREATING CUTOUTS FOR IMAGE {image_name}") + cutout.create_cutouts(image_name) + logger.info("Cutouts created for image %s", image_name) + + # Generate presigned URLs for the cutouts + urls = s3.generate_presigned_urls(f"cutouts/{image_name}") + logger.info("Presigned URLs generated for cutouts of image %s", image_name) + + # Return the URLs + return urls + except Exception as e: + # Log any errors that occur + logger.error( + "An error occurred while creating cutouts for image %s: %s", image_name, e + ) + raise + + +@app.post("/create-cutouts") +async def create_all_cutouts( + image_names: List[str] = Body(...), classes: List[str] = Body(...) +): + """Create cutouts from multiple images and upload them to S3. + + Args: + image_names (List[str]): List of image names to create cutouts from. + classes (List[str], optional): A list of classes for the AI to detect for. Defaults to Body(...). + + Returns: + Dict[str, List[str]]: A dictionary where the keys are the image names and the values are the lists of presigned URLs for the cutouts. + """ + from s3_handler import Boto3Client + + s3 = Boto3Client() + cutout = CutoutCreator( + classes=classes, + grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, + grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, + encoder_version="vit_b", + ) + + result = {} + for image_name in image_names: + cutout.create_cutouts(image_name) + result[image_name] = s3.generate_presigned_urls(f"cutouts/{image_name}") + + return result + + +@cutout_handler_stub.function( + gpu="T4", + mounts=[local_packages], + secret=Secret.from_name("my-aws-secret"), + container_idle_timeout=300, + retries=1, +) +@asgi_app() +def cutout_app(): + """Create a FastAPI app for creating cutouts. + + Returns: + FastAPI: FastAPI app for creating cutouts. + """ + return app diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py deleted file mode 100644 index 3a9bed5..0000000 --- a/app/grounded_cutouts.py +++ /dev/null @@ -1,364 +0,0 @@ -import os -import io -import logging -from typing import List, Dict -from fastapi import FastAPI, Body -from fastapi.middleware.cors import CORSMiddleware -from starlette.requests import Request -from modal import asgi_app, Secret, Stub, Mount, Image, method - -# ====================== -# Constants -# ====================== -HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) -GROUNDING_DINO_CONFIG_PATH = os.path.join( - HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" -) -GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( - HOME, "weights", "groundingdino_swint_ogc.pth" -) -SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") -SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth") -SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth") - -# ====================== -# Logging -# ====================== -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -c_handler = logging.StreamHandler() -c_handler.setLevel(logging.DEBUG) - -c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") -c_handler.setFormatter(c_format) - -logger.addHandler(c_handler) - -# ====================== -# FastAPI Setup -# ====================== -app = FastAPI() - -stub = Stub(name="cutout_generator") - -origins = [ - "http://localhost:3000", # local development - "https://cutouts.noahrijkaard.com", # main website -] - -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# ====================== -# Modal Image Setup -# ====================== -local_packages = Mount.from_local_python_packages("dino", "segment", "s3_handler") -cutout_generator_image = ( - Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3") - .pip_install( - "segment-anything", "opencv-python", "botocore", "boto3", "fastapi", "starlette" - ) - .run_commands( - "apt-get update", - "apt-get install -y git wget libgl1-mesa-glx libglib2.0-0", - "git clone https://github.com/IDEA-Research/GroundingDINO.git", - "pip install -q -e GroundingDINO/", - "mkdir -p /weights", - "mkdir -p /data", - "pip uninstall -y supervision", - "pip uninstall -y opencv-python", - "pip install opencv-python==4.8.0.74", - "pip install -q supervision==0.6.0", - "wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/", - "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/", - "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -P weights/", - "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -P weights/", - ) -) - - -@stub.cls( - image=cutout_generator_image, - gpu="T4", - mounts=[local_packages], - secret=Secret.from_name("my-aws-secret"), - container_idle_timeout=300, -) -class CutoutCreator: - import cv2 - import numpy as np - - def __init__( - self, - classes: str, - grounding_dino_config_path: str, - grounding_dino_checkpoint_path: str, - encoder_version: str = "vit_b", - ): - self.classes = classes - self.grounding_dino_config_path = grounding_dino_config_path - self.grounding_dino_checkpoint_path = grounding_dino_checkpoint_path - self.encoder_version = encoder_version - self.HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) - - def __enter__(self): - from s3_handler import Boto3Client - from dino import Dino - from segment import Segmenter - import supervision as sv - - self.dino = Dino( - classes=self.classes, - box_threshold=0.35, - text_threshold=0.25, - model_config_path=self.grounding_dino_config_path, - model_checkpoint_path=self.grounding_dino_checkpoint_path, - ) - self.s3 = Boto3Client() - self.mask_annotator = sv.MaskAnnotator() - - encoder_checkpoint_paths = { - "vit_b": SAM_CHECKPOINT_PATH_LOW, - "vit_l": SAM_CHECKPOINT_PATH_MID, - "vit_h": SAM_CHECKPOINT_PATH_HIGH, - } - - self.sam_checkpoint_path = encoder_checkpoint_paths.get(self.encoder_version) - self.segment = Segmenter( - sam_encoder_version=self.encoder_version, - sam_checkpoint_path=self.sam_checkpoint_path, - ) - - @method() - def create_annotated_image(self, image, image_name, detections: Dict[str, list]): - """Create a highlighted annotated image from an image and detections. - - Args: - image (File): Image to be used for creating the annotated image. - image_name (string): name of image - detections (Dict[str, list]): annotations for the image - """ - annotated_image = self.mask_annotator.annotate( - scene=image, detections=detections - ) - # Convert annotated image to bytes - img_bytes = io.BytesIO() - Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG") - img_bytes.seek(0) - # Upload bytes to S3 - self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png") - - @method() - def create_cutouts(self, image_name): - """Create cutouts from an image and upload them to S3. - - Args: - image_name (string): name of image - """ - import cv2 - import numpy as np - - # Define paths - data_path = os.path.join(HOME, "data") - cutouts_path = os.path.join(HOME, "cutouts") - - # Download image from s3 - image_path = self.s3.download_from_s3(data_path, image_name) - if image_path is None: - print(f"Failed to download image {image_name} from S3") - return - - # Check if image exists - if not os.path.exists(image_path): - print(f"Image {image_name} not found in folder {image_path}") - return - - # Create cutouts directory if it doesn't exist - os.makedirs(cutouts_path, exist_ok=True) - - # Read image - image = cv2.imread(image_path) - - # Predict and segment image - detections = self.dino.predict(image) - masks = self.segment.segment(image, detections.xyxy) - - # Apply each mask to the image - for i, mask in enumerate(masks): - # Ensure the mask is a boolean array - mask = mask.astype(bool) - - # Apply the mask to create a cutout - cutout = np.zeros_like(image) - cutout[mask] = image[mask] - - # Save the cutout - cutout_name = f"{image_name}_cutout_{i}.png" - cutout_path = os.path.join(cutouts_path, cutout_name) - cv2.imwrite(cutout_path, cutout) - - # Upload the cutout to S3 - with open(cutout_path, "rb") as f: - self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}") - - # Create annotated image - # self.create_annotated_image(image, f"{image_name}_{i}", detections) - - -@stub.local_entrypoint() -def main( - classes: str, - grounding_dino_config_path: str, - grounding_dino_checkpoint_path: str, - encoder_version: str, -): - return CutoutCreator( - classes, - grounding_dino_config_path, - grounding_dino_checkpoint_path, - encoder_version, - ) - - -@app.get("/warmup") -async def warmup(): - """Warmup the container. - - Returns: - _type_: return message - """ - # Spins up the container and loads the models, this will make it easier to create cutouts later - CutoutCreator( - classes=[], - grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, - grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - encoder_version="vit_b", - ) - - return "Warmed up!" - - -@app.post("/create-cutouts/{image_name}") -async def create_cutouts(image_name: str, request: Request): - """ - Create cutouts from an image and upload them to S3. - - Args: - image_name (str): Name of image to create cutouts from. - classes (List[str], optional): A list of classes for the AI to detect for. Defaults to Body(...). - - Returns: - _type_: _description_ - """ - from s3_handler import Boto3Client - - try: - # Log the start of the process - logger.info("Creating cutouts for image %s ", image_name) - - # Parse the request body as JSON - data = await request.json() - - # Get the classes and accuracy level from the JSON data - classes = data.get("classes", []) - accuracy_level = data.get("accuracy_level", "low") - logger.info("Classes: %s", classes) - logger.info("Accuracy level: %s", accuracy_level) - - # Select the SAM checkpoint path based on the accuracy level - accuracy_encoder_versions = { - "high": "vit_h", - "mid": "vit_l", - "low": "vit_b", - } - encoder_version = accuracy_encoder_versions.get(accuracy_level, "vit_b") - - # Initialize the S3 client and the CutoutCreator - s3 = Boto3Client() - """ - Create cutouts for an image. - - :param classes: The classes for the cutout - :param grounding_dino_config_path: The path to the DINO configuration - :param grounding_dino_checkpoint_path: The path to the DINO checkpoint - :param encoder_version: The version of the encoder based on the accuracy level - """ - cutout = CutoutCreator( - classes=classes, - grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, - grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - encoder_version=encoder_version, - ) - - # Create the cutouts - print(f"CREATING CUTOUTS FOR IMAGE {image_name}") - cutout.create_cutouts.remote(image_name) - logger.info("Cutouts created for image %s", image_name) - - # Generate presigned URLs for the cutouts - urls = s3.generate_presigned_urls(f"cutouts/{image_name}") - logger.info("Presigned URLs generated for cutouts of image %s", image_name) - - # Return the URLs - return urls - except Exception as e: - # Log any errors that occur - logger.error( - "An error occurred while creating cutouts for image %s: %s", image_name, e - ) - raise - - -@app.post("/create-cutouts") -async def create_all_cutouts( - image_names: List[str] = Body(...), classes: List[str] = Body(...) -): - """Create cutouts from multiple images and upload them to S3. - - Args: - image_names (List[str]): List of image names to create cutouts from. - classes (List[str], optional): A list of classes for the AI to detect for. Defaults to Body(...). - - Returns: - Dict[str, List[str]]: A dictionary where the keys are the image names and the values are the lists of presigned URLs for the cutouts. - """ - from s3_handler import Boto3Client - - s3 = Boto3Client() - cutout = CutoutCreator( - classes=classes, - grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, - grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - encoder_version="vit_b", - ) - - result = {} - for image_name in image_names: - cutout.create_cutouts(image_name) - result[image_name] = s3.generate_presigned_urls(f"cutouts/{image_name}") - - return result - - -@stub.function( - image=cutout_generator_image, - gpu="T4", - mounts=[local_packages], - secret=Secret.from_name("my-aws-secret"), - container_idle_timeout=300, - retries=1, -) -@asgi_app() -def cutout_app(): - """Create a FastAPI app for creating cutouts. - - Returns: - FastAPI: FastAPI app for creating cutouts. - """ - return app diff --git a/app/s3_handler/app.py b/app/s3_handler/app.py index f4656ce..e3081d9 100644 --- a/app/s3_handler/app.py +++ b/app/s3_handler/app.py @@ -1,9 +1,9 @@ -from modal import Mount, Image, Secret, Stub, asgi_app -import os import logging -from fastapi import FastAPI, File, UploadFile, Body, HTTPException -from fastapi.middleware.cors import CORSMiddleware +import os +from fastapi import Body, FastAPI, File, HTTPException, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from modal import Image, Mount, Secret, Stub, asgi_app stub = Stub(name="s3_handler") @@ -22,6 +22,7 @@ allow_headers=["*"], ) + # ================================================ # API Endpoints # ================================================ @@ -35,9 +36,9 @@ async def upload_image_to_s3(image: UploadFile = File(...)): Returns: str: Message indicating whether the upload was successful. """ - from s3_handler import Boto3Client from botocore.exceptions import BotoCoreError, NoCredentialsError - + from s3_handler import Boto3Client + s3_client = Boto3Client() try: s3_client.upload_to_s3(image.file, "images", image.filename) @@ -46,9 +47,12 @@ async def upload_image_to_s3(image: UploadFile = File(...)): except BotoCoreError as e: raise HTTPException(status_code=500, detail=str(e)) from e except Exception as e: - raise HTTPException(status_code=500, detail="An error occurred while uploading the image") from e + raise HTTPException( + status_code=500, detail="An error occurred while uploading the image" + ) from e return {"message": "Image uploaded successfully", "status_code": 200} + @app.get("/generate-presigned-urls/{image_name}") async def generate_presigned_urls(image_name: str): """Generate presigned urls for the cutouts of an image. @@ -65,7 +69,7 @@ async def generate_presigned_urls(image_name: str): return s3_client.generate_presigned_urls(f"cutouts/{image_name}") -@app.get('/get-image/{image_name}') +@app.get("/get-image/{image_name}") async def get_image(image_name: str): """Get an image from S3. @@ -75,8 +79,8 @@ async def get_image(image_name: str): Returns: FileResponse: FileResponse object containing the image. """ - from s3_handler import Boto3Client from fastapi.responses import FileResponse + from s3_handler import Boto3Client s3_client = Boto3Client() data = s3_client.generate_presigned_url_with_metadata("images", image_name) @@ -86,13 +90,27 @@ async def get_image(image_name: str): @stub.function( - image=Image.debian_slim().pip_install("boto3", "fastapi", "starlette", "uvicorn", "python-multipart", "pydantic", "requests", "httpx", "httpcore", "httpx[http2]", "httpx[http1]"), mounts=[Mount.from_local_python_packages("s3_handler")], secret=Secret.from_name("my-aws-secret") + image=Image.debian_slim().pip_install( + "boto3", + "fastapi", + "starlette", + "uvicorn", + "python-multipart", + "pydantic", + "requests", + "httpx", + "httpcore", + "httpx[http2]", + "httpx[http1]", + ), + mounts=[Mount.from_local_python_packages("s3_handler")], + secret=Secret.from_name("my-aws-secret"), ) - @asgi_app() def main(): return app + # ================================= # Modal s3 functions # ================================= @@ -104,7 +122,7 @@ def main(): # from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError # s3_client = Boto3Client() - + # try: # s3_client.upload_to_s3(file_body, folder, image_name) # logging.info(f"Successfully uploaded {image_name} to {folder}") diff --git a/app/s3_handler/s3_handler.py b/app/s3_handler/s3_handler.py index 2cd5bb5..e989e4f 100644 --- a/app/s3_handler/s3_handler.py +++ b/app/s3_handler/s3_handler.py @@ -1,7 +1,8 @@ +import logging import os + import boto3 -import logging -from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError +from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError class Boto3Client: @@ -76,7 +77,10 @@ def generate_presigned_url_with_metadata(self, folder, key, expiration=3600): # Generate presigned URL url = self.s3.generate_presigned_url( "get_object", - Params={"Bucket": os.environ["CUTOUT_BUCKET"], "Key": f"{folder}/{key}"}, + Params={ + "Bucket": os.environ["CUTOUT_BUCKET"], + "Key": f"{folder}/{key}", + }, ExpiresIn=expiration, ) # Get object metadata diff --git a/legacy_code/cutouts.py b/legacy_code/cutouts.py index 054da01..1e66844 100644 --- a/legacy_code/cutouts.py +++ b/legacy_code/cutouts.py @@ -1,16 +1,18 @@ -from typing import Dict -import os import io +import os +from typing import Dict + import cv2 import numpy as np -from s3_handler import Boto3Client +import supervision as sv from dino import Dino -from segment import Segmenter from PIL import Image -import supervision as sv +from s3_handler import Boto3Client +from segment import Segmenter HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + class CutoutCreator: """ A class for creating cutouts from an image and uploading them to S3. @@ -20,7 +22,13 @@ class CutoutCreator: s3: A Boto3Client object for uploading to S3. mask_annotator: A MaskAnnotator object for annotating images with masks. """ - def __init__(self, classes: str, grounding_dino_config_path: str, grounding_dino_checkpoint_path: str): + + def __init__( + self, + classes: str, + grounding_dino_config_path: str, + grounding_dino_checkpoint_path: str, + ): self.dino = Dino( classes=classes, box_threshold=0.35, From 4e5226233b85443590a67979b7cac0858865c9b7 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 21:21:00 +0800 Subject: [PATCH 2/4] Clean up imports --- app/cutout_handler/dino.py | 8 +++++--- app/cutout_handler/s3_handler.py | 24 ++++++------------------ app/cutout_handler/segment.py | 14 ++++++++------ 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/app/cutout_handler/dino.py b/app/cutout_handler/dino.py index fb5bbb1..63aa479 100644 --- a/app/cutout_handler/dino.py +++ b/app/cutout_handler/dino.py @@ -1,6 +1,10 @@ from typing import List -from app.common import cutout_handler_stub +from app.common import cutout_handler_stub, cutout_generator_image + +with cutout_generator_image.imports(): + import torch + from groundingdino.util.inference import Model cutout_handler_stub.cls() @@ -16,8 +20,6 @@ def __init__( model_config_path, model_checkpoint_path, ): - import torch - from groundingdino.util.inference import Model self.classes = classes self.box_threshold = box_threshold diff --git a/app/cutout_handler/s3_handler.py b/app/cutout_handler/s3_handler.py index e2359f5..bb239f9 100644 --- a/app/cutout_handler/s3_handler.py +++ b/app/cutout_handler/s3_handler.py @@ -1,14 +1,14 @@ -import logging import os +import logging +from app.common import s3_handler_stub, s3_handler_image -from app.common import s3_handler_stub +with s3_handler_image.imports(): + import boto3 + from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError s3_handler_stub.cls() - - class Boto3Client: def __init__(self): - import boto3 self.s3 = boto3.client( "s3", @@ -18,9 +18,6 @@ def __init__(self): ) def download_from_s3(self, save_path, image_name): - import boto3 - from botocore.exceptions import ClientError - s3_client = boto3.client("s3") file_path = os.path.join(save_path, image_name) try: @@ -40,8 +37,6 @@ def download_from_s3(self, save_path, image_name): return file_path def upload_to_s3(self, file_body, folder, image_name): - from botocore.exceptions import BotoCoreError, NoCredentialsError - try: self.s3.put_object( Body=file_body, @@ -60,8 +55,6 @@ def upload_to_s3(self, file_body, folder, image_name): raise def generate_presigned_urls(self, folder, expiration=3600): - from botocore.exceptions import ClientError - try: response = self.s3.list_objects_v2( Bucket=os.environ["CUTOUT_BUCKET"], Prefix=folder @@ -86,16 +79,11 @@ def generate_presigned_urls(self, folder, expiration=3600): return urls def generate_presigned_url_with_metadata(self, folder, key, expiration=3600): - from botocore.exceptions import ClientError - try: # Generate presigned URL url = self.s3.generate_presigned_url( "get_object", - Params={ - "Bucket": os.environ["CUTOUT_BUCKET"], - "Key": f"{folder}/{key}", - }, + Params={"Bucket": os.environ["CUTOUT_BUCKET"], "Key": f"{folder}/{key}"}, ExpiresIn=expiration, ) # Get object metadata diff --git a/app/cutout_handler/segment.py b/app/cutout_handler/segment.py index 3eda724..802ab5a 100644 --- a/app/cutout_handler/segment.py +++ b/app/cutout_handler/segment.py @@ -1,4 +1,11 @@ -from app.common import cutout_handler_stub +from app.common import cutout_handler_stub, cutout_generator_image + + +with cutout_generator_image.imports(): + import torch + from segment_anything import SamPredictor, sam_model_registry + import numpy as np + cutout_handler_stub.cls() @@ -11,9 +18,6 @@ def __init__( sam_encoder_version: str, sam_checkpoint_path: str, ): - import torch - from segment_anything import SamPredictor, sam_model_registry - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.sam = sam_model_registry[sam_encoder_version]( checkpoint=sam_checkpoint_path @@ -21,8 +25,6 @@ def __init__( self.sam_predictor = SamPredictor(self.sam) def segment(self, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: - import numpy as np - self.sam_predictor.set_image(image) result_masks = [] for box in xyxy: From e920dc74be33ddfcd41cc9e76aff2ae65e606949 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 21:29:45 +0800 Subject: [PATCH 3/4] Fix logging statements in cutout_handler --- app/cutout_handler/grounded_cutouts.py | 4 ++-- app/cutout_handler/server.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/app/cutout_handler/grounded_cutouts.py b/app/cutout_handler/grounded_cutouts.py index dd1be39..b9286f4 100644 --- a/app/cutout_handler/grounded_cutouts.py +++ b/app/cutout_handler/grounded_cutouts.py @@ -111,12 +111,12 @@ def create_cutouts(self, image_name): # Download image from s3 image_path = self.s3.download_from_s3(data_path, image_name) if image_path is None: - print(f"Failed to download image {image_name} from S3") + logger.error(f"Failed to download image {image_name} from S3") return # Check if image exists if not os.path.exists(image_path): - print(f"Image {image_name} not found in folder {image_path}") + logger.error(f"Image {image_name} not found in folder {image_path}") return # Create cutouts directory if it doesn't exist diff --git a/app/cutout_handler/server.py b/app/cutout_handler/server.py index a982c17..66a605e 100644 --- a/app/cutout_handler/server.py +++ b/app/cutout_handler/server.py @@ -7,7 +7,7 @@ from modal import Secret, asgi_app from starlette.requests import Request -from app.common import cutout_handler_stub, local_packages +from app.common import cutout_handler_stub, local_packages, cutout_generator_image from .grounded_cutouts import CutoutCreator @@ -125,7 +125,7 @@ async def create_cutouts(image_name: str, request: Request): ) # Create the cutouts - print(f"CREATING CUTOUTS FOR IMAGE {image_name}") + logger.info(f"CREATING CUTOUTS FOR IMAGE {image_name}") cutout.create_cutouts(image_name) logger.info("Cutouts created for image %s", image_name) From 19da2948c8ad296e585d2193b7773135ce99dcc4 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 21:29:59 +0800 Subject: [PATCH 4/4] Hopefully fix deployment --- .github/workflows/ci-cd.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index 3b4673a..816e1f6 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -39,9 +39,9 @@ jobs: - name: Deploy cutout generator job run: | cd app - modal deploy --env=${{ steps.vars.outputs.environment }} grounded_cutouts.py + modal deploy --env=${{ steps.vars.outputs.environment }} app.cutout_handler::cutout_handler_stub - name: Deploy s3_handler job run: | cd app/s3_handler - modal deploy --env=${{ steps.vars.outputs.environment }} app.py + modal deploy --env=${{ steps.vars.outputs.environment }} app.s3_handler.app \ No newline at end of file