From 28fc3c35dded32985f630e0e21dab16d72fda000 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Sat, 18 Nov 2023 12:57:05 +0800 Subject: [PATCH 01/23] Update cutout creation and annotation process --- app/grounded_cutouts.py | 180 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 164 insertions(+), 16 deletions(-) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index 562425d..f58d545 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -1,15 +1,15 @@ import os -from modal import asgi_app, Secret, Stub, Mount, Image -from fastapi import FastAPI, File, UploadFile, Body, HTTPException +from modal import asgi_app, Secret, Stub, Mount, Image, method +from fastapi import FastAPI, Body from fastapi.middleware.cors import CORSMiddleware from typing import List import logging -import json from starlette.requests import Request +from typing import Dict -#====================== +# ====================== # Logging -#====================== +# ====================== # Create a custom logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -19,7 +19,7 @@ c_handler.setLevel(logging.DEBUG) # Create formatters and add it to handlers -c_format = logging.Formatter('%(name)s - %(levelname)s - %(message)s') +c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") c_handler.setFormatter(c_format) # Add handlers to the logger @@ -43,12 +43,12 @@ allow_headers=["*"], ) -local_packages = Mount.from_local_python_packages( - "cutout", "dino", "segment", "s3_handler" -) +local_packages = Mount.from_local_python_packages("dino", "segment", "s3_handler") cutout_generator_image = ( Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3") - .pip_install("segment-anything", "opencv-python","botocore", "boto3", "fastapi", "starlette") + .pip_install( + "segment-anything", "opencv-python", "botocore", "boto3", "fastapi", "starlette" + ) .run_commands( "apt-get update", "apt-get install -y git wget libgl1-mesa-glx libglib2.0-0", @@ -80,6 +80,149 @@ SAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") +@stub.cls( + image=cutout_generator_image, + gpu="T4", + mounts=[local_packages], + secret=Secret.from_name("my-aws-secret"), + container_idle_timeout=300, +) +class CutoutCreator: + import cv2 + import numpy as np + import io + + def __init__( + self, + classes: str, + grounding_dino_config_path: str, + grounding_dino_checkpoint_path: str, + sam_checkpoint_path: str, + ): + self.classes = classes + self.grounding_dino_config_path = grounding_dino_config_path + self.grounding_dino_checkpoint_path = grounding_dino_checkpoint_path + self.sam_checkpoint_path = sam_checkpoint_path + self.HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + + def __enter__(self): + from s3_handler import Boto3Client + from dino import Dino + from segment import Segmenter + import supervision as sv + + self.dino = Dino( + classes=self.classes, + box_threshold=0.35, + text_threshold=0.25, + model_config_path=self.grounding_dino_config_path, + model_checkpoint_path=self.grounding_dino_checkpoint_path, + ) + self.s3 = Boto3Client() + self.mask_annotator = sv.MaskAnnotator() + self.segment = Segmenter( + sam_encoder_version="vit_h", sam_checkpoint_path=self.sam_checkpoint_path + ) + + @method() + def create_annotated_image(self, image, image_name, detections: Dict[str, list]): + """Create a highlighted annotated image from an image and detections. + + Args: + image (File): Image to be used for creating the annotated image. + image_name (string): name of image + detections (Dict[str, list]): annotations for the image + """ + annotated_image = self.mask_annotator.annotate( + scene=image, detections=detections + ) + # Convert annotated image to bytes + img_bytes = io.BytesIO() + Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG") + img_bytes.seek(0) + # Upload bytes to S3 + self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png") + + @method() + def create_cutouts(self, image_name, sam_checkpoint_path): + import cv2 + import numpy as np + import io + + """Create cutouts from an image and upload them to S3. + + Args: + image_name (string): name of image + sam_checkpoint_path (string): path to sam checkpoint + """ + # Download image from s3 + image_path = self.s3.download_from_s3( + os.path.join(self.HOME, "data"), image_name + ) + if not os.path.exists(os.path.join(self.HOME, "cutouts")): + os.mkdir(os.path.join(self.HOME, "cutouts")) + image = cv2.imread(image_path) + # segment = Segmenter( + # sam_encoder_version="vit_h", sam_checkpoint_path=sam_checkpoint_path + # ) + detections = self.dino.predict(image) + + masks = self.segment.segment(image, detections.xyxy) + # Load the image + # image_path = os.path.join(self.image_folder, image_name) + # for item in os.listdir(self.image_folder): + # print("Item: ",item) + if not os.path.exists(image_path): + print(f"Image {image_name} not found in folder {image_path}") + return + + image = cv2.imread(image_path) + + # Apply each mask to the image + for i, mask in enumerate(masks): + # Ensure the mask is a boolean array + mask = mask.astype(bool) + + # Apply the mask to create a cutout + cutout = np.zeros_like(image) + cutout[mask] = image[mask] + + # Save the cutout + cutout_name = f"{image_name}_cutout_{i}.png" + cutout_path = os.path.join(self.HOME, "cutouts", cutout_name) + cv2.imwrite(cutout_path, cutout) + + # Upload the cutout to S3 + with open(cutout_path, "rb") as f: + self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}") + + # Create annotated image + # self.create_annotated_image(image, f"{image_name}_{i}", detections) + + +@stub.local_entrypoint() +def main( + classes: str, + grounding_dino_config_path: str, + grounding_dino_checkpoint_path: str, + sam_checkpoint_path: str, +): + return CutoutCreator( + classes, + grounding_dino_config_path, + grounding_dino_checkpoint_path, + sam_checkpoint_path, + ) + +@app.get("/warmup") +async def warmup(): + """Warmup the container. + + Returns: + _type_: return message + """ + return "Warmed up!" + @app.post("/create-cutouts/{image_name}") async def create_cutouts(image_name: str, request: Request): """Create cutouts from an image and upload them to S3. @@ -95,8 +238,7 @@ async def create_cutouts(image_name: str, request: Request): data = await request.json() # Get the classes from the JSON data - classes = data.get('classes', []) - from cutout import CutoutCreator + classes = data.get("classes", []) from s3_handler import Boto3Client try: @@ -107,21 +249,25 @@ async def create_cutouts(image_name: str, request: Request): classes=classes, grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, + sam_checkpoint_path=SAM_CHECKPOINT_PATH, ) print(f"CREATING CUTOUTS FOR IMAGE {image_name}") - cutout.create_cutouts(image_name, SAM_CHECKPOINT_PATH) + cutout.create_cutouts.remote(image_name, SAM_CHECKPOINT_PATH) logger.info(f"Cutouts created for image {image_name}") urls = s3.generate_presigned_urls(f"cutouts/{image_name}") logger.info(f"Presigned URLs generated for cutouts of image {image_name}") return urls except Exception as e: - logger.error(f"An error occurred while creating cutouts for image {image_name}: {e}") + logger.error( + f"An error occurred while creating cutouts for image {image_name}: {e}" + ) raise - return urls @app.post("/create-cutouts") -async def create_all_cutouts(image_names: List[str] = Body(...), classes: List[str] = Body(...)): +async def create_all_cutouts( + image_names: List[str] = Body(...), classes: List[str] = Body(...) +): """Create cutouts from multiple images and upload them to S3. Args: @@ -148,11 +294,13 @@ async def create_all_cutouts(image_names: List[str] = Body(...), classes: List[s return result + @stub.function( image=cutout_generator_image, gpu="T4", mounts=[local_packages], secret=Secret.from_name("my-aws-secret"), + allow_concurrent_inputs=4 ) @asgi_app() def cutout_app(): From becc0ad3c905b0f85056c7704876daf97d191f58 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Sat, 18 Nov 2023 18:44:55 +0800 Subject: [PATCH 02/23] Add print statements to s3_handler.py --- app/s3_handler.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/app/s3_handler.py b/app/s3_handler.py index 2cd5bb5..fc88a43 100644 --- a/app/s3_handler.py +++ b/app/s3_handler.py @@ -20,6 +20,9 @@ def download_from_s3(self, save_path, image_name): s3_client.download_file( os.environ["CUTOUT_BUCKET"], f"images/{image_name}", file_path ) + print(f"Successfully downloaded {image_name} to {file_path}") + print("Directory contents:") + print(os.listdir(save_path)) except ClientError as e: print("BOTO error: ", e) print( From 6fbe6a3fcbea534f022c8d95c11aab468b06d9d2 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Sat, 18 Nov 2023 18:45:21 +0800 Subject: [PATCH 03/23] Refactor create_cutouts method --- app/grounded_cutouts.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index f58d545..5fe239a 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -6,6 +6,7 @@ import logging from starlette.requests import Request from typing import Dict +import io # ====================== # Logging @@ -144,7 +145,7 @@ def create_annotated_image(self, image, image_name, detections: Dict[str, list]) self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png") @method() - def create_cutouts(self, image_name, sam_checkpoint_path): + def create_cutouts(self, image_name): import cv2 import numpy as np import io @@ -157,27 +158,25 @@ def create_cutouts(self, image_name, sam_checkpoint_path): """ # Download image from s3 image_path = self.s3.download_from_s3( - os.path.join(self.HOME, "data"), image_name + os.path.join(HOME, "data"), image_name ) - if not os.path.exists(os.path.join(self.HOME, "cutouts")): - os.mkdir(os.path.join(self.HOME, "cutouts")) - image = cv2.imread(image_path) - # segment = Segmenter( - # sam_encoder_version="vit_h", sam_checkpoint_path=sam_checkpoint_path - # ) - detections = self.dino.predict(image) - - masks = self.segment.segment(image, detections.xyxy) - # Load the image - # image_path = os.path.join(self.image_folder, image_name) - # for item in os.listdir(self.image_folder): - # print("Item: ",item) + if image_path is None: + print(f"Failed to download image {image_name} from S3") + return if not os.path.exists(image_path): print(f"Image {image_name} not found in folder {image_path}") return + + if not os.path.exists(os.path.join(HOME, "cutouts")): + os.mkdir(os.path.join(HOME, "cutouts")) image = cv2.imread(image_path) + + detections = self.dino.predict(image) + + masks = self.segment.segment(image, detections.xyxy) + # Apply each mask to the image for i, mask in enumerate(masks): # Ensure the mask is a boolean array @@ -252,7 +251,7 @@ async def create_cutouts(image_name: str, request: Request): sam_checkpoint_path=SAM_CHECKPOINT_PATH, ) print(f"CREATING CUTOUTS FOR IMAGE {image_name}") - cutout.create_cutouts.remote(image_name, SAM_CHECKPOINT_PATH) + cutout.create_cutouts.remote(image_name) logger.info(f"Cutouts created for image {image_name}") urls = s3.generate_presigned_urls(f"cutouts/{image_name}") logger.info(f"Presigned URLs generated for cutouts of image {image_name}") @@ -300,7 +299,8 @@ async def create_all_cutouts( gpu="T4", mounts=[local_packages], secret=Secret.from_name("my-aws-secret"), - allow_concurrent_inputs=4 + container_idle_timeout=300, + keep_warm=1 ) @asgi_app() def cutout_app(): From 781f5710c36f411ef0cda1b8894f52340824fe2b Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Sat, 18 Nov 2023 18:47:35 +0800 Subject: [PATCH 04/23] Add environment variable for CI/CD workflow --- .github/workflows/ci-cd.yml | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index c713da1..09cd70b 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -4,6 +4,7 @@ on: push: branches: - main + - develop jobs: deploy: @@ -27,12 +28,20 @@ jobs: python -m pip install --upgrade pip pip install modal + - name: Set environment + id: vars + run: | + if [[ "${{ github.ref }}" == "refs/heads/main" ]]; then + echo "::set-output name=environment::prod" + else + echo "::set-output name=environment::dev" + fi - name: Deploy cutout generator job run: | cd app - modal deploy grounded_cutouts.py + modal deploy grounded_cutouts.py --env ${{ steps.vars.outputs.environment }} - name: Deploy s3_handler job run: | cd app/s3_handler - modal deploy app.py + modal deploy app.py --env ${{ steps.vars.outputs.environment }} \ No newline at end of file From fd4c4e45bec313eea2ac7f0b8067dd015185ffdd Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Sat, 18 Nov 2023 18:55:46 +0800 Subject: [PATCH 05/23] Cleanup functions a bit --- app/grounded_cutouts.py | 108 +++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 47 deletions(-) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index 5fe239a..64670b1 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -1,38 +1,47 @@ import os -from modal import asgi_app, Secret, Stub, Mount, Image, method +import io +import logging +from typing import List, Dict from fastapi import FastAPI, Body from fastapi.middleware.cors import CORSMiddleware -from typing import List -import logging from starlette.requests import Request -from typing import Dict -import io +from modal import asgi_app, Secret, Stub, Mount, Image, method + +# ====================== +# Constants +# ====================== +HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) +GROUNDING_DINO_CONFIG_PATH = os.path.join( + HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" +) +GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( + HOME, "weights", "groundingdino_swint_ogc.pth" +) +SAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") # ====================== # Logging # ====================== -# Create a custom logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) -# Create handlers c_handler = logging.StreamHandler() c_handler.setLevel(logging.DEBUG) -# Create formatters and add it to handlers c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") c_handler.setFormatter(c_format) -# Add handlers to the logger logger.addHandler(c_handler) - +# ====================== +# FastAPI Setup +# ====================== app = FastAPI() stub = Stub(name="cutout_generator") origins = [ - "http://localhost:3000", # localdevelopment + "http://localhost:3000", # local development "https://cutouts.noahrijkaard.com", # main website ] @@ -44,6 +53,9 @@ allow_headers=["*"], ) +# ====================== +# Modal Image Setup +# ====================== local_packages = Mount.from_local_python_packages("dino", "segment", "s3_handler") cutout_generator_image = ( Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3") @@ -71,16 +83,6 @@ ) ) -HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) -GROUNDING_DINO_CONFIG_PATH = os.path.join( - HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" -) -GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( - HOME, "weights", "groundingdino_swint_ogc.pth" -) -SAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") - - @stub.cls( image=cutout_generator_image, gpu="T4", @@ -146,35 +148,37 @@ def create_annotated_image(self, image, image_name, detections: Dict[str, list]) @method() def create_cutouts(self, image_name): - import cv2 - import numpy as np - import io - """Create cutouts from an image and upload them to S3. Args: image_name (string): name of image - sam_checkpoint_path (string): path to sam checkpoint """ + import cv2 + import numpy as np + + # Define paths + data_path = os.path.join(HOME, "data") + cutouts_path = os.path.join(HOME, "cutouts") + # Download image from s3 - image_path = self.s3.download_from_s3( - os.path.join(HOME, "data"), image_name - ) + image_path = self.s3.download_from_s3(data_path, image_name) if image_path is None: print(f"Failed to download image {image_name} from S3") return + + # Check if image exists if not os.path.exists(image_path): print(f"Image {image_name} not found in folder {image_path}") return - - if not os.path.exists(os.path.join(HOME, "cutouts")): - os.mkdir(os.path.join(HOME, "cutouts")) - image = cv2.imread(image_path) + # Create cutouts directory if it doesn't exist + os.makedirs(cutouts_path, exist_ok=True) + # Read image + image = cv2.imread(image_path) + # Predict and segment image detections = self.dino.predict(image) - masks = self.segment.segment(image, detections.xyxy) # Apply each mask to the image @@ -188,7 +192,7 @@ def create_cutouts(self, image_name): # Save the cutout cutout_name = f"{image_name}_cutout_{i}.png" - cutout_path = os.path.join(self.HOME, "cutouts", cutout_name) + cutout_path = os.path.join(cutouts_path, cutout_name) cv2.imwrite(cutout_path, cutout) # Upload the cutout to S3 @@ -224,7 +228,8 @@ async def warmup(): @app.post("/create-cutouts/{image_name}") async def create_cutouts(image_name: str, request: Request): - """Create cutouts from an image and upload them to S3. + """ + Create cutouts from an image and upload them to S3. Args: image_name (str): Name of image to create cutouts from. @@ -233,16 +238,18 @@ async def create_cutouts(image_name: str, request: Request): Returns: _type_: _description_ """ - # Parse the request body as JSON - data = await request.json() + try: + # Log the start of the process + logger.info("Creating cutouts for image %s ", image_name) - # Get the classes from the JSON data - classes = data.get("classes", []) - from s3_handler import Boto3Client + # Parse the request body as JSON + data = await request.json() - try: - logger.info(f"Creating cutouts for image {image_name}") - logger.info(f"Classes: {classes}") + # Get the classes from the JSON data + classes = data.get("classes", []) + logger.info("Classes: %s", classes) + + # Initialize the S3 client and the CutoutCreator s3 = Boto3Client() cutout = CutoutCreator( classes=classes, @@ -250,15 +257,22 @@ async def create_cutouts(image_name: str, request: Request): grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, sam_checkpoint_path=SAM_CHECKPOINT_PATH, ) + + # Create the cutouts print(f"CREATING CUTOUTS FOR IMAGE {image_name}") cutout.create_cutouts.remote(image_name) - logger.info(f"Cutouts created for image {image_name}") + logger.info("Cutouts created for image %s", image_name) + + # Generate presigned URLs for the cutouts urls = s3.generate_presigned_urls(f"cutouts/{image_name}") - logger.info(f"Presigned URLs generated for cutouts of image {image_name}") + logger.info("Presigned URLs generated for cutouts of image %s", image_name) + + # Return the URLs return urls except Exception as e: + # Log any errors that occur logger.error( - f"An error occurred while creating cutouts for image {image_name}: {e}" + "An error occurred while creating cutouts for image %s: %s", image_name, e ) raise From 4ad6323042ca393d1ba7c4dd85de85beaef03361 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Sat, 18 Nov 2023 19:02:44 +0800 Subject: [PATCH 06/23] Add .pylintrc file to disable pylint warnings --- .pylintrc | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .pylintrc diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..7ab83fb --- /dev/null +++ b/.pylintrc @@ -0,0 +1,7 @@ +[MESSAGES CONTROL] +disable= + missing-docstring, + import-outside-toplevel, + line-too-long, + attribute-defined-outside-init, + no-member \ No newline at end of file From 5bd62ee5196712787b53ba1a4accf66aa40b696d Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Sat, 18 Nov 2023 19:03:02 +0800 Subject: [PATCH 07/23] move cutouts file to legacy code --- app/cutout.py => legacy_code/cutouts.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename app/cutout.py => legacy_code/cutouts.py (100%) diff --git a/app/cutout.py b/legacy_code/cutouts.py similarity index 100% rename from app/cutout.py rename to legacy_code/cutouts.py From b4a1469a00d07d007506efb99c2fef9c2b073e9d Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Sat, 18 Nov 2023 19:03:09 +0800 Subject: [PATCH 08/23] Refactor cutout creation and add SAM checkpoint path --- app/grounded_cutouts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index 64670b1..cf0776d 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -93,7 +93,6 @@ class CutoutCreator: import cv2 import numpy as np - import io def __init__( self, @@ -290,7 +289,6 @@ async def create_all_cutouts( Returns: Dict[str, List[str]]: A dictionary where the keys are the image names and the values are the lists of presigned URLs for the cutouts. """ - from cutout import CutoutCreator from s3_handler import Boto3Client s3 = Boto3Client() @@ -298,11 +296,12 @@ async def create_all_cutouts( classes=classes, grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, + sam_checkpoint_path = SAM_CHECKPOINT_PATH, ) result = {} for image_name in image_names: - cutout.create_cutouts(image_name, SAM_CHECKPOINT_PATH) + cutout.create_cutouts(image_name) result[image_name] = s3.generate_presigned_urls(f"cutouts/{image_name}") return result From 8adeab4923fa5c84eb2eb4eb8eeb69b972f00ebb Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Mon, 20 Nov 2023 11:27:51 +0800 Subject: [PATCH 09/23] Add import statement for Boto3Client --- app/grounded_cutouts.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index cf0776d..91183b4 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -237,6 +237,7 @@ async def create_cutouts(image_name: str, request: Request): Returns: _type_: _description_ """ + from s3_handler import Boto3Client try: # Log the start of the process logger.info("Creating cutouts for image %s ", image_name) From 52e8ca2454875b342be92c97438324165aa8f07c Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Tue, 21 Nov 2023 19:59:49 +0800 Subject: [PATCH 10/23] Remove keep_warm parameter from create_all_cutouts function --- app/grounded_cutouts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index 91183b4..0cf2913 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -314,7 +314,6 @@ async def create_all_cutouts( mounts=[local_packages], secret=Secret.from_name("my-aws-secret"), container_idle_timeout=300, - keep_warm=1 ) @asgi_app() def cutout_app(): From 169bcff760bea631879cd84a422ad4e46ba3a97c Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Wed, 22 Nov 2023 01:01:32 +0800 Subject: [PATCH 11/23] Fix indentation in ci-cd.yml file --- .github/workflows/ci-cd.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index 09cd70b..093369e 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -28,14 +28,14 @@ jobs: python -m pip install --upgrade pip pip install modal - - name: Set environment - id: vars - run: | - if [[ "${{ github.ref }}" == "refs/heads/main" ]]; then - echo "::set-output name=environment::prod" - else - echo "::set-output name=environment::dev" - fi + - name: Set environment + id: vars + run: | + if [[ "${{ github.ref }}" == "refs/heads/main" ]]; then + echo "::set-output name=environment::prod" + else + echo "::set-output name=environment::dev" + fi - name: Deploy cutout generator job run: | cd app From 7e00025e5a5c77601af0b489e9ff4c82d058e48f Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Wed, 22 Nov 2023 01:14:10 +0800 Subject: [PATCH 12/23] Fix formatting and add warmup endpoint --- app/grounded_cutouts.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index 0cf2913..6fdaf80 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -83,6 +83,7 @@ ) ) + @stub.cls( image=cutout_generator_image, gpu="T4", @@ -216,6 +217,7 @@ def main( sam_checkpoint_path, ) + @app.get("/warmup") async def warmup(): """Warmup the container. @@ -223,8 +225,17 @@ async def warmup(): Returns: _type_: return message """ + # Spins up the container and loads the models, this will make it easier to create cutouts later + CutoutCreator( + classes=[], + grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, + grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, + sam_checkpoint_path=SAM_CHECKPOINT_PATH, + ) + return "Warmed up!" + @app.post("/create-cutouts/{image_name}") async def create_cutouts(image_name: str, request: Request): """ @@ -238,6 +249,7 @@ async def create_cutouts(image_name: str, request: Request): _type_: _description_ """ from s3_handler import Boto3Client + try: # Log the start of the process logger.info("Creating cutouts for image %s ", image_name) @@ -297,7 +309,7 @@ async def create_all_cutouts( classes=classes, grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - sam_checkpoint_path = SAM_CHECKPOINT_PATH, + sam_checkpoint_path=SAM_CHECKPOINT_PATH, ) result = {} From 6f0ef3c3bff4f16f675c16cf82c3ef2a9e80bf41 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Wed, 22 Nov 2023 01:30:54 +0800 Subject: [PATCH 13/23] Update SAM checkpoint paths and add accuracy level parameter --- app/grounded_cutouts.py | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index 6fdaf80..445dbce 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -17,7 +17,9 @@ GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( HOME, "weights", "groundingdino_swint_ogc.pth" ) -SAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") +SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") +SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth") +SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth") # ====================== # Logging @@ -65,7 +67,6 @@ .run_commands( "apt-get update", "apt-get install -y git wget libgl1-mesa-glx libglib2.0-0", - "echo $CUDA_HOME", "git clone https://github.com/IDEA-Research/GroundingDINO.git", "pip install -q -e GroundingDINO/", "mkdir -p /weights", @@ -76,10 +77,8 @@ "pip install -q supervision==0.6.0", "wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/", "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/", - "wget -q https://media.roboflow.com/notebooks/examples/dog.jpeg -P images/", - "ls -F", - "ls -F GroundingDINO/groundingdino/config", - "ls -F GroundingDINO/groundingdino/models/GroundingDINO/", + "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -P weights/", + "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -P weights/", ) ) @@ -230,7 +229,7 @@ async def warmup(): classes=[], grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - sam_checkpoint_path=SAM_CHECKPOINT_PATH, + sam_checkpoint_path=SAM_CHECKPOINT_PATH_LOW, ) return "Warmed up!" @@ -257,9 +256,19 @@ async def create_cutouts(image_name: str, request: Request): # Parse the request body as JSON data = await request.json() - # Get the classes from the JSON data + # Get the classes and accuracy level from the JSON data classes = data.get("classes", []) + accuracy_level = data.get("accuracy_level", "low") logger.info("Classes: %s", classes) + logger.info("Accuracy level: %s", accuracy_level) + + # Select the SAM checkpoint path based on the accuracy level + if accuracy_level == "high": + sam_checkpoint_path = SAM_CHECKPOINT_PATH_HIGH + elif accuracy_level == "low": + sam_checkpoint_path = SAM_CHECKPOINT_PATH_LOW + else: # Default to mid if the accuracy level is not recognized + sam_checkpoint_path = SAM_CHECKPOINT_PATH_MID # Initialize the S3 client and the CutoutCreator s3 = Boto3Client() @@ -267,7 +276,7 @@ async def create_cutouts(image_name: str, request: Request): classes=classes, grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - sam_checkpoint_path=SAM_CHECKPOINT_PATH, + sam_checkpoint_path=sam_checkpoint_path, ) # Create the cutouts @@ -288,7 +297,6 @@ async def create_cutouts(image_name: str, request: Request): ) raise - @app.post("/create-cutouts") async def create_all_cutouts( image_names: List[str] = Body(...), classes: List[str] = Body(...) @@ -326,6 +334,7 @@ async def create_all_cutouts( mounts=[local_packages], secret=Secret.from_name("my-aws-secret"), container_idle_timeout=300, + retries=1, ) @asgi_app() def cutout_app(): From 19910052353709c85a6e4fff176417bc7d5d54cc Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Wed, 22 Nov 2023 21:55:47 +0800 Subject: [PATCH 14/23] Update deployment order in ci-cd.yml --- .github/workflows/ci-cd.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index 093369e..cc309f7 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -39,9 +39,9 @@ jobs: - name: Deploy cutout generator job run: | cd app - modal deploy grounded_cutouts.py --env ${{ steps.vars.outputs.environment }} + modal deploy --env ${{ steps.vars.outputs.environment }} grounded_cutouts.py - name: Deploy s3_handler job run: | cd app/s3_handler - modal deploy app.py --env ${{ steps.vars.outputs.environment }} + modal deploy --env ${{ steps.vars.outputs.environment }} app.py \ No newline at end of file From bb26ae0e01639eb9e79168557aaf7d3a6b408ad9 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Wed, 22 Nov 2023 21:58:19 +0800 Subject: [PATCH 15/23] Fix deployment command in ci-cd.yml --- .github/workflows/ci-cd.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index cc309f7..3b4673a 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -39,9 +39,9 @@ jobs: - name: Deploy cutout generator job run: | cd app - modal deploy --env ${{ steps.vars.outputs.environment }} grounded_cutouts.py + modal deploy --env=${{ steps.vars.outputs.environment }} grounded_cutouts.py - name: Deploy s3_handler job run: | cd app/s3_handler - modal deploy --env ${{ steps.vars.outputs.environment }} app.py + modal deploy --env=${{ steps.vars.outputs.environment }} app.py \ No newline at end of file From 7ce526d5493c662e714d77c5ba8c527eed8f5b60 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 17:22:08 +0800 Subject: [PATCH 16/23] Update SAM checkpoint path based on encoder version --- app/grounded_cutouts.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index 445dbce..b15edc3 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -99,12 +99,12 @@ def __init__( classes: str, grounding_dino_config_path: str, grounding_dino_checkpoint_path: str, - sam_checkpoint_path: str, + encoder_version: str = "vit_b", ): self.classes = classes self.grounding_dino_config_path = grounding_dino_config_path self.grounding_dino_checkpoint_path = grounding_dino_checkpoint_path - self.sam_checkpoint_path = sam_checkpoint_path + self.encoder_version = encoder_version self.HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) def __enter__(self): @@ -122,8 +122,17 @@ def __enter__(self): ) self.s3 = Boto3Client() self.mask_annotator = sv.MaskAnnotator() + + if self.encoder_version == "vit_b": + self.sam_checkpoint_path = SAM_CHECKPOINT_PATH_LOW + elif self.encoder_version == "vit_l": + self.sam_checkpoint_path = SAM_CHECKPOINT_PATH_MID + elif self.encoder_version == "vit_h": + self.sam_checkpoint_path = SAM_CHECKPOINT_PATH_HIGH + self.segment = Segmenter( - sam_encoder_version="vit_h", sam_checkpoint_path=self.sam_checkpoint_path + sam_encoder_version=self.encoder_version, + sam_checkpoint_path=self.sam_checkpoint_path, ) @method() @@ -207,13 +216,13 @@ def main( classes: str, grounding_dino_config_path: str, grounding_dino_checkpoint_path: str, - sam_checkpoint_path: str, + encoder_version: str, ): return CutoutCreator( classes, grounding_dino_config_path, grounding_dino_checkpoint_path, - sam_checkpoint_path, + encoder_version, ) @@ -229,7 +238,7 @@ async def warmup(): classes=[], grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - sam_checkpoint_path=SAM_CHECKPOINT_PATH_LOW, + encoder_version="vit_b", ) return "Warmed up!" @@ -264,11 +273,13 @@ async def create_cutouts(image_name: str, request: Request): # Select the SAM checkpoint path based on the accuracy level if accuracy_level == "high": - sam_checkpoint_path = SAM_CHECKPOINT_PATH_HIGH + encoder_version = "vit_h" + elif accuracy_level == "mid": + encoder_version = "vit_l" elif accuracy_level == "low": - sam_checkpoint_path = SAM_CHECKPOINT_PATH_LOW + encoder_version = "vit_b" else: # Default to mid if the accuracy level is not recognized - sam_checkpoint_path = SAM_CHECKPOINT_PATH_MID + encoder_version = "vit_b" # Initialize the S3 client and the CutoutCreator s3 = Boto3Client() @@ -276,7 +287,7 @@ async def create_cutouts(image_name: str, request: Request): classes=classes, grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - sam_checkpoint_path=sam_checkpoint_path, + encoder_version=encoder_version, ) # Create the cutouts @@ -317,7 +328,7 @@ async def create_all_cutouts( classes=classes, grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - sam_checkpoint_path=SAM_CHECKPOINT_PATH, + encoder_version="vit_b", ) result = {} From 97e44ef09a9c7643a2443f8c86fca9464bb7dcea Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 17:53:57 +0800 Subject: [PATCH 17/23] Improve code quality and readability --- app/grounded_cutouts.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index b15edc3..1c60e78 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -123,13 +123,13 @@ def __enter__(self): self.s3 = Boto3Client() self.mask_annotator = sv.MaskAnnotator() - if self.encoder_version == "vit_b": - self.sam_checkpoint_path = SAM_CHECKPOINT_PATH_LOW - elif self.encoder_version == "vit_l": - self.sam_checkpoint_path = SAM_CHECKPOINT_PATH_MID - elif self.encoder_version == "vit_h": - self.sam_checkpoint_path = SAM_CHECKPOINT_PATH_HIGH + encoder_checkpoint_paths = { + "vit_b": SAM_CHECKPOINT_PATH_LOW, + "vit_l": SAM_CHECKPOINT_PATH_MID, + "vit_h": SAM_CHECKPOINT_PATH_HIGH, + } + self.sam_checkpoint_path = encoder_checkpoint_paths.get(self.encoder_version) self.segment = Segmenter( sam_encoder_version=self.encoder_version, sam_checkpoint_path=self.sam_checkpoint_path, @@ -272,14 +272,12 @@ async def create_cutouts(image_name: str, request: Request): logger.info("Accuracy level: %s", accuracy_level) # Select the SAM checkpoint path based on the accuracy level - if accuracy_level == "high": - encoder_version = "vit_h" - elif accuracy_level == "mid": - encoder_version = "vit_l" - elif accuracy_level == "low": - encoder_version = "vit_b" - else: # Default to mid if the accuracy level is not recognized - encoder_version = "vit_b" + accuracy_encoder_versions = { + "high": "vit_h", + "mid": "vit_l", + "low": "vit_b", + } + encoder_version = accuracy_encoder_versions.get(accuracy_level, "vit_b") # Initialize the S3 client and the CutoutCreator s3 = Boto3Client() @@ -308,6 +306,7 @@ async def create_cutouts(image_name: str, request: Request): ) raise + @app.post("/create-cutouts") async def create_all_cutouts( image_names: List[str] = Body(...), classes: List[str] = Body(...) From 16acc83b05c1fd167e12de19f49171337bd6cc18 Mon Sep 17 00:00:00 2001 From: Noah Rijkaard Date: Thu, 21 Dec 2023 17:54:31 +0800 Subject: [PATCH 18/23] Update app/grounded_cutouts.py Co-authored-by: CodiumAI-Agent <137281646+CodiumAI-Agent@users.noreply.github.com> Signed-off-by: Noah Rijkaard --- app/grounded_cutouts.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py index 1c60e78..3a9bed5 100644 --- a/app/grounded_cutouts.py +++ b/app/grounded_cutouts.py @@ -281,6 +281,14 @@ async def create_cutouts(image_name: str, request: Request): # Initialize the S3 client and the CutoutCreator s3 = Boto3Client() + """ + Create cutouts for an image. + + :param classes: The classes for the cutout + :param grounding_dino_config_path: The path to the DINO configuration + :param grounding_dino_checkpoint_path: The path to the DINO checkpoint + :param encoder_version: The version of the encoder based on the accuracy level + """ cutout = CutoutCreator( classes=classes, grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, From c2d65f4223850379f6afc35b47d65c93b6d2886e Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 21:06:11 +0800 Subject: [PATCH 19/23] Refactor position + imports --- .pre-commit-config.yaml | 10 + app/common/__init__.py | 36 +++ app/cutout_handler/__init__.py | 3 + app/{ => cutout_handler}/dino.py | 13 +- app/cutout_handler/grounded_cutouts.py | 151 ++++++++++ app/{ => cutout_handler}/s3_handler.py | 24 +- app/{ => cutout_handler}/segment.py | 13 +- app/cutout_handler/server.py | 191 +++++++++++++ app/grounded_cutouts.py | 364 ------------------------- app/s3_handler/app.py | 42 ++- app/s3_handler/s3_handler.py | 10 +- legacy_code/cutouts.py | 20 +- 12 files changed, 481 insertions(+), 396 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 app/common/__init__.py create mode 100644 app/cutout_handler/__init__.py rename app/{ => cutout_handler}/dino.py (89%) create mode 100644 app/cutout_handler/grounded_cutouts.py rename app/{ => cutout_handler}/s3_handler.py (86%) rename app/{ => cutout_handler}/segment.py (79%) create mode 100644 app/cutout_handler/server.py delete mode 100644 app/grounded_cutouts.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..87d2be8 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +# repos: +# - repo: https://github.com/psf/black +# hooks: +# - id: black +# language_version: python3.8 # Should match the version of Python you're using + +# - repo: https://github.com/pycqa/isort +# hooks: +# - id: isort +# language_version: python3.8 # Should match the version of Python you're using \ No newline at end of file diff --git a/app/common/__init__.py b/app/common/__init__.py new file mode 100644 index 0000000..1f6ba4d --- /dev/null +++ b/app/common/__init__.py @@ -0,0 +1,36 @@ +from modal import Image, Mount, Stub + +s3_handler_image = Image.debian_slim().pip_install("boto3", "botocore") + +cutout_generator_image = ( + Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3") + .pip_install( + "segment-anything", "opencv-python", "botocore", "boto3", "fastapi", "starlette" + ) + .run_commands( + "apt-get update", + "apt-get install -y git wget libgl1-mesa-glx libglib2.0-0", + "git clone https://github.com/IDEA-Research/GroundingDINO.git", + "pip install -q -e GroundingDINO/", + "mkdir -p /weights", + "mkdir -p /data", + "pip uninstall -y supervision", + "pip uninstall -y opencv-python", + "pip install opencv-python==4.8.0.74", + "pip install -q supervision==0.6.0", + "wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/", + "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/", + "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -P weights/", + "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -P weights/", + ) +) + +local_packages = Mount.from_local_python_packages( + "app.cutout_handler.dino", + "app.cutout_handler.segment", + "app.cutout_handler.s3_handler", + "app.cutout_handler.grounded_cutouts", +) + +cutout_handler_stub = Stub(image=cutout_generator_image, name="cutout_generator") +s3_handler_stub = Stub(image=s3_handler_image, name="s3_handler") diff --git a/app/cutout_handler/__init__.py b/app/cutout_handler/__init__.py new file mode 100644 index 0000000..ed7f30f --- /dev/null +++ b/app/cutout_handler/__init__.py @@ -0,0 +1,3 @@ +from app.common import cutout_handler_stub, s3_handler_stub + +from .server import cutout_app diff --git a/app/dino.py b/app/cutout_handler/dino.py similarity index 89% rename from app/dino.py rename to app/cutout_handler/dino.py index 1e1528f..fb5bbb1 100644 --- a/app/dino.py +++ b/app/cutout_handler/dino.py @@ -1,11 +1,13 @@ -import torch -from groundingdino.util.inference import Model from typing import List +from app.common import cutout_handler_stub + +cutout_handler_stub.cls() + class Dino: - """ A class for object detection using GroundingDINO. - """ + """A class for object detection using GroundingDINO.""" + def __init__( self, classes, @@ -14,6 +16,9 @@ def __init__( model_config_path, model_checkpoint_path, ): + import torch + from groundingdino.util.inference import Model + self.classes = classes self.box_threshold = box_threshold self.text_threshold = text_threshold diff --git a/app/cutout_handler/grounded_cutouts.py b/app/cutout_handler/grounded_cutouts.py new file mode 100644 index 0000000..dd1be39 --- /dev/null +++ b/app/cutout_handler/grounded_cutouts.py @@ -0,0 +1,151 @@ +import io +import logging +import os +from typing import Dict, List + +import cv2 +import numpy as np +import supervision as sv +from fastapi import Body, FastAPI +from fastapi.middleware.cors import CORSMiddleware +from modal import Image, Mount, Secret, Stub, asgi_app, method +from starlette.requests import Request + +from .dino import Dino +from .s3_handler import Boto3Client +from .segment import Segmenter + +# ====================== +# Constants +# ====================== +HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) +GROUNDING_DINO_CONFIG_PATH = os.path.join( + HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" +) +GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( + HOME, "weights", "groundingdino_swint_ogc.pth" +) +SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") +SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth") +SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth") + +# ====================== +# Logging +# ====================== +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +c_handler = logging.StreamHandler() +c_handler.setLevel(logging.DEBUG) + +c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") +c_handler.setFormatter(c_format) + +logger.addHandler(c_handler) + + +class CutoutCreator: + def __init__( + self, + classes: str, + grounding_dino_config_path: str, + grounding_dino_checkpoint_path: str, + encoder_version: str = "vit_b", + ): + self.classes = classes + self.grounding_dino_config_path = grounding_dino_config_path + self.grounding_dino_checkpoint_path = grounding_dino_checkpoint_path + self.encoder_version = encoder_version + self.HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + self.s3 = Boto3Client() + self.dino = Dino( + classes=self.classes, + box_threshold=0.35, + text_threshold=0.25, + model_config_path=self.grounding_dino_config_path, + model_checkpoint_path=self.grounding_dino_checkpoint_path, + ) + self.mask_annotator = sv.MaskAnnotator() + + encoder_checkpoint_paths = { + "vit_b": SAM_CHECKPOINT_PATH_LOW, + "vit_l": SAM_CHECKPOINT_PATH_MID, + "vit_h": SAM_CHECKPOINT_PATH_HIGH, + } + + self.sam_checkpoint_path = encoder_checkpoint_paths.get(self.encoder_version) + self.segment = Segmenter( + sam_encoder_version=self.encoder_version, + sam_checkpoint_path=self.sam_checkpoint_path, + ) + + def create_annotated_image(self, image, image_name, detections: Dict[str, list]): + """Create a highlighted annotated image from an image and detections. + + Args: + image (File): Image to be used for creating the annotated image. + image_name (string): name of image + detections (Dict[str, list]): annotations for the image + """ + annotated_image = self.mask_annotator.annotate( + scene=image, detections=detections + ) + # Convert annotated image to bytes + img_bytes = io.BytesIO() + Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG") + img_bytes.seek(0) + # Upload bytes to S3 + self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png") + + def create_cutouts(self, image_name): + """Create cutouts from an image and upload them to S3. + + Args: + image_name (string): name of image + """ + + # Define paths + data_path = os.path.join(HOME, "data") + cutouts_path = os.path.join(HOME, "cutouts") + + # Download image from s3 + image_path = self.s3.download_from_s3(data_path, image_name) + if image_path is None: + print(f"Failed to download image {image_name} from S3") + return + + # Check if image exists + if not os.path.exists(image_path): + print(f"Image {image_name} not found in folder {image_path}") + return + + # Create cutouts directory if it doesn't exist + os.makedirs(cutouts_path, exist_ok=True) + + # Read image + image = cv2.imread(image_path) + + # Predict and segment image + detections = self.dino.predict(image) + masks = self.segment.segment(image, detections.xyxy) + + # Apply each mask to the image + for i, mask in enumerate(masks): + # Ensure the mask is a boolean array + mask = mask.astype(bool) + + # Apply the mask to create a cutout + cutout = np.zeros_like(image) + cutout[mask] = image[mask] + + # Save the cutout + cutout_name = f"{image_name}_cutout_{i}.png" + cutout_path = os.path.join(cutouts_path, cutout_name) + cv2.imwrite(cutout_path, cutout) + + # Upload the cutout to S3 + with open(cutout_path, "rb") as f: + self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}") + + # Create annotated image + # self.create_annotated_image(image, f"{image_name}_{i}", detections) diff --git a/app/s3_handler.py b/app/cutout_handler/s3_handler.py similarity index 86% rename from app/s3_handler.py rename to app/cutout_handler/s3_handler.py index fc88a43..e2359f5 100644 --- a/app/s3_handler.py +++ b/app/cutout_handler/s3_handler.py @@ -1,11 +1,15 @@ -import os -import boto3 import logging -from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError +import os + +from app.common import s3_handler_stub + +s3_handler_stub.cls() class Boto3Client: def __init__(self): + import boto3 + self.s3 = boto3.client( "s3", aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], @@ -14,6 +18,9 @@ def __init__(self): ) def download_from_s3(self, save_path, image_name): + import boto3 + from botocore.exceptions import ClientError + s3_client = boto3.client("s3") file_path = os.path.join(save_path, image_name) try: @@ -33,6 +40,8 @@ def download_from_s3(self, save_path, image_name): return file_path def upload_to_s3(self, file_body, folder, image_name): + from botocore.exceptions import BotoCoreError, NoCredentialsError + try: self.s3.put_object( Body=file_body, @@ -51,6 +60,8 @@ def upload_to_s3(self, file_body, folder, image_name): raise def generate_presigned_urls(self, folder, expiration=3600): + from botocore.exceptions import ClientError + try: response = self.s3.list_objects_v2( Bucket=os.environ["CUTOUT_BUCKET"], Prefix=folder @@ -75,11 +86,16 @@ def generate_presigned_urls(self, folder, expiration=3600): return urls def generate_presigned_url_with_metadata(self, folder, key, expiration=3600): + from botocore.exceptions import ClientError + try: # Generate presigned URL url = self.s3.generate_presigned_url( "get_object", - Params={"Bucket": os.environ["CUTOUT_BUCKET"], "Key": f"{folder}/{key}"}, + Params={ + "Bucket": os.environ["CUTOUT_BUCKET"], + "Key": f"{folder}/{key}", + }, ExpiresIn=expiration, ) # Get object metadata diff --git a/app/segment.py b/app/cutout_handler/segment.py similarity index 79% rename from app/segment.py rename to app/cutout_handler/segment.py index 33e7a88..3eda724 100644 --- a/app/segment.py +++ b/app/cutout_handler/segment.py @@ -1,14 +1,19 @@ -import numpy as np -import torch -from segment_anything import sam_model_registry, SamPredictor +from app.common import cutout_handler_stub + +cutout_handler_stub.cls() class Segmenter: + import numpy as np + def __init__( self, sam_encoder_version: str, sam_checkpoint_path: str, ): + import torch + from segment_anything import SamPredictor, sam_model_registry + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.sam = sam_model_registry[sam_encoder_version]( checkpoint=sam_checkpoint_path @@ -16,6 +21,8 @@ def __init__( self.sam_predictor = SamPredictor(self.sam) def segment(self, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: + import numpy as np + self.sam_predictor.set_image(image) result_masks = [] for box in xyxy: diff --git a/app/cutout_handler/server.py b/app/cutout_handler/server.py new file mode 100644 index 0000000..a982c17 --- /dev/null +++ b/app/cutout_handler/server.py @@ -0,0 +1,191 @@ +import logging +import os +from typing import Dict, List + +from fastapi import Body, FastAPI +from fastapi.middleware.cors import CORSMiddleware +from modal import Secret, asgi_app +from starlette.requests import Request + +from app.common import cutout_handler_stub, local_packages + +from .grounded_cutouts import CutoutCreator + +# ====================== +# Constants +# ====================== +HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) +GROUNDING_DINO_CONFIG_PATH = os.path.join( + HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" +) +GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( + HOME, "weights", "groundingdino_swint_ogc.pth" +) +SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") +SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth") +SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth") + + +# ====================== +# Logging +# ====================== +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +c_handler = logging.StreamHandler() +c_handler.setLevel(logging.DEBUG) + +c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") +c_handler.setFormatter(c_format) + +logger.addHandler(c_handler) + + +# ====================== +# FastAPI Setup +# ====================== +app = FastAPI() + +# stub = Stub(name="cutout_generator") + +origins = [ + "http://localhost:3000", # local development + "https://cutouts.noahrijkaard.com", # main website +] + +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/warmup") +async def warmup(): + """Warmup the container. + + Returns: + _type_: return message + """ + # Spins up the container and loads the models, this will make it easier to create cutouts later + CutoutCreator( + classes=[], + grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, + grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, + encoder_version="vit_b", + ) + + return "Warmed up!" + + +@app.post("/create-cutouts/{image_name}") +async def create_cutouts(image_name: str, request: Request): + """ + Create cutouts from an image and upload them to S3. + + Args: + image_name (str): Name of image to create cutouts from. + classes (List[str], optional): A list of classes for the AI to detect for. Defaults to Body(...). + + Returns: + _type_: _description_ + """ + from s3_handler import Boto3Client + + try: + # Log the start of the process + logger.info("Creating cutouts for image %s ", image_name) + + # Parse the request body as JSON + data = await request.json() + + # Get the classes and accuracy level from the JSON data + classes = data.get("classes", []) + accuracy_level = data.get("accuracy_level", "low") + logger.info("Classes: %s", classes) + logger.info("Accuracy level: %s", accuracy_level) + + # Select the SAM checkpoint path based on the accuracy level + accuracy_encoder_versions = { + "high": "vit_h", + "mid": "vit_l", + "low": "vit_b", + } + encoder_version = accuracy_encoder_versions.get(accuracy_level, "vit_b") + + # Initialize the S3 client and the CutoutCreator + s3 = Boto3Client() + cutout = CutoutCreator( + classes=classes, + grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, + grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, + encoder_version=encoder_version, + ) + + # Create the cutouts + print(f"CREATING CUTOUTS FOR IMAGE {image_name}") + cutout.create_cutouts(image_name) + logger.info("Cutouts created for image %s", image_name) + + # Generate presigned URLs for the cutouts + urls = s3.generate_presigned_urls(f"cutouts/{image_name}") + logger.info("Presigned URLs generated for cutouts of image %s", image_name) + + # Return the URLs + return urls + except Exception as e: + # Log any errors that occur + logger.error( + "An error occurred while creating cutouts for image %s: %s", image_name, e + ) + raise + + +@app.post("/create-cutouts") +async def create_all_cutouts( + image_names: List[str] = Body(...), classes: List[str] = Body(...) +): + """Create cutouts from multiple images and upload them to S3. + + Args: + image_names (List[str]): List of image names to create cutouts from. + classes (List[str], optional): A list of classes for the AI to detect for. Defaults to Body(...). + + Returns: + Dict[str, List[str]]: A dictionary where the keys are the image names and the values are the lists of presigned URLs for the cutouts. + """ + from s3_handler import Boto3Client + + s3 = Boto3Client() + cutout = CutoutCreator( + classes=classes, + grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, + grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, + encoder_version="vit_b", + ) + + result = {} + for image_name in image_names: + cutout.create_cutouts(image_name) + result[image_name] = s3.generate_presigned_urls(f"cutouts/{image_name}") + + return result + + +@cutout_handler_stub.function( + gpu="T4", + mounts=[local_packages], + secret=Secret.from_name("my-aws-secret"), + container_idle_timeout=300, + retries=1, +) +@asgi_app() +def cutout_app(): + """Create a FastAPI app for creating cutouts. + + Returns: + FastAPI: FastAPI app for creating cutouts. + """ + return app diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py deleted file mode 100644 index 3a9bed5..0000000 --- a/app/grounded_cutouts.py +++ /dev/null @@ -1,364 +0,0 @@ -import os -import io -import logging -from typing import List, Dict -from fastapi import FastAPI, Body -from fastapi.middleware.cors import CORSMiddleware -from starlette.requests import Request -from modal import asgi_app, Secret, Stub, Mount, Image, method - -# ====================== -# Constants -# ====================== -HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) -GROUNDING_DINO_CONFIG_PATH = os.path.join( - HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" -) -GROUNDING_DINO_CHECKPOINT_PATH = os.path.join( - HOME, "weights", "groundingdino_swint_ogc.pth" -) -SAM_CHECKPOINT_PATH_HIGH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") -SAM_CHECKPOINT_PATH_MID = os.path.join(HOME, "weights", "sam_vit_l_0b3195.pth") -SAM_CHECKPOINT_PATH_LOW = os.path.join(HOME, "weights", "sam_vit_b_01ec64.pth") - -# ====================== -# Logging -# ====================== -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - -c_handler = logging.StreamHandler() -c_handler.setLevel(logging.DEBUG) - -c_format = logging.Formatter("%(name)s - %(levelname)s - %(message)s") -c_handler.setFormatter(c_format) - -logger.addHandler(c_handler) - -# ====================== -# FastAPI Setup -# ====================== -app = FastAPI() - -stub = Stub(name="cutout_generator") - -origins = [ - "http://localhost:3000", # local development - "https://cutouts.noahrijkaard.com", # main website -] - -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# ====================== -# Modal Image Setup -# ====================== -local_packages = Mount.from_local_python_packages("dino", "segment", "s3_handler") -cutout_generator_image = ( - Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3") - .pip_install( - "segment-anything", "opencv-python", "botocore", "boto3", "fastapi", "starlette" - ) - .run_commands( - "apt-get update", - "apt-get install -y git wget libgl1-mesa-glx libglib2.0-0", - "git clone https://github.com/IDEA-Research/GroundingDINO.git", - "pip install -q -e GroundingDINO/", - "mkdir -p /weights", - "mkdir -p /data", - "pip uninstall -y supervision", - "pip uninstall -y opencv-python", - "pip install opencv-python==4.8.0.74", - "pip install -q supervision==0.6.0", - "wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/", - "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/", - "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth -P weights/", - "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth -P weights/", - ) -) - - -@stub.cls( - image=cutout_generator_image, - gpu="T4", - mounts=[local_packages], - secret=Secret.from_name("my-aws-secret"), - container_idle_timeout=300, -) -class CutoutCreator: - import cv2 - import numpy as np - - def __init__( - self, - classes: str, - grounding_dino_config_path: str, - grounding_dino_checkpoint_path: str, - encoder_version: str = "vit_b", - ): - self.classes = classes - self.grounding_dino_config_path = grounding_dino_config_path - self.grounding_dino_checkpoint_path = grounding_dino_checkpoint_path - self.encoder_version = encoder_version - self.HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) - - def __enter__(self): - from s3_handler import Boto3Client - from dino import Dino - from segment import Segmenter - import supervision as sv - - self.dino = Dino( - classes=self.classes, - box_threshold=0.35, - text_threshold=0.25, - model_config_path=self.grounding_dino_config_path, - model_checkpoint_path=self.grounding_dino_checkpoint_path, - ) - self.s3 = Boto3Client() - self.mask_annotator = sv.MaskAnnotator() - - encoder_checkpoint_paths = { - "vit_b": SAM_CHECKPOINT_PATH_LOW, - "vit_l": SAM_CHECKPOINT_PATH_MID, - "vit_h": SAM_CHECKPOINT_PATH_HIGH, - } - - self.sam_checkpoint_path = encoder_checkpoint_paths.get(self.encoder_version) - self.segment = Segmenter( - sam_encoder_version=self.encoder_version, - sam_checkpoint_path=self.sam_checkpoint_path, - ) - - @method() - def create_annotated_image(self, image, image_name, detections: Dict[str, list]): - """Create a highlighted annotated image from an image and detections. - - Args: - image (File): Image to be used for creating the annotated image. - image_name (string): name of image - detections (Dict[str, list]): annotations for the image - """ - annotated_image = self.mask_annotator.annotate( - scene=image, detections=detections - ) - # Convert annotated image to bytes - img_bytes = io.BytesIO() - Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG") - img_bytes.seek(0) - # Upload bytes to S3 - self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png") - - @method() - def create_cutouts(self, image_name): - """Create cutouts from an image and upload them to S3. - - Args: - image_name (string): name of image - """ - import cv2 - import numpy as np - - # Define paths - data_path = os.path.join(HOME, "data") - cutouts_path = os.path.join(HOME, "cutouts") - - # Download image from s3 - image_path = self.s3.download_from_s3(data_path, image_name) - if image_path is None: - print(f"Failed to download image {image_name} from S3") - return - - # Check if image exists - if not os.path.exists(image_path): - print(f"Image {image_name} not found in folder {image_path}") - return - - # Create cutouts directory if it doesn't exist - os.makedirs(cutouts_path, exist_ok=True) - - # Read image - image = cv2.imread(image_path) - - # Predict and segment image - detections = self.dino.predict(image) - masks = self.segment.segment(image, detections.xyxy) - - # Apply each mask to the image - for i, mask in enumerate(masks): - # Ensure the mask is a boolean array - mask = mask.astype(bool) - - # Apply the mask to create a cutout - cutout = np.zeros_like(image) - cutout[mask] = image[mask] - - # Save the cutout - cutout_name = f"{image_name}_cutout_{i}.png" - cutout_path = os.path.join(cutouts_path, cutout_name) - cv2.imwrite(cutout_path, cutout) - - # Upload the cutout to S3 - with open(cutout_path, "rb") as f: - self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}") - - # Create annotated image - # self.create_annotated_image(image, f"{image_name}_{i}", detections) - - -@stub.local_entrypoint() -def main( - classes: str, - grounding_dino_config_path: str, - grounding_dino_checkpoint_path: str, - encoder_version: str, -): - return CutoutCreator( - classes, - grounding_dino_config_path, - grounding_dino_checkpoint_path, - encoder_version, - ) - - -@app.get("/warmup") -async def warmup(): - """Warmup the container. - - Returns: - _type_: return message - """ - # Spins up the container and loads the models, this will make it easier to create cutouts later - CutoutCreator( - classes=[], - grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, - grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - encoder_version="vit_b", - ) - - return "Warmed up!" - - -@app.post("/create-cutouts/{image_name}") -async def create_cutouts(image_name: str, request: Request): - """ - Create cutouts from an image and upload them to S3. - - Args: - image_name (str): Name of image to create cutouts from. - classes (List[str], optional): A list of classes for the AI to detect for. Defaults to Body(...). - - Returns: - _type_: _description_ - """ - from s3_handler import Boto3Client - - try: - # Log the start of the process - logger.info("Creating cutouts for image %s ", image_name) - - # Parse the request body as JSON - data = await request.json() - - # Get the classes and accuracy level from the JSON data - classes = data.get("classes", []) - accuracy_level = data.get("accuracy_level", "low") - logger.info("Classes: %s", classes) - logger.info("Accuracy level: %s", accuracy_level) - - # Select the SAM checkpoint path based on the accuracy level - accuracy_encoder_versions = { - "high": "vit_h", - "mid": "vit_l", - "low": "vit_b", - } - encoder_version = accuracy_encoder_versions.get(accuracy_level, "vit_b") - - # Initialize the S3 client and the CutoutCreator - s3 = Boto3Client() - """ - Create cutouts for an image. - - :param classes: The classes for the cutout - :param grounding_dino_config_path: The path to the DINO configuration - :param grounding_dino_checkpoint_path: The path to the DINO checkpoint - :param encoder_version: The version of the encoder based on the accuracy level - """ - cutout = CutoutCreator( - classes=classes, - grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, - grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - encoder_version=encoder_version, - ) - - # Create the cutouts - print(f"CREATING CUTOUTS FOR IMAGE {image_name}") - cutout.create_cutouts.remote(image_name) - logger.info("Cutouts created for image %s", image_name) - - # Generate presigned URLs for the cutouts - urls = s3.generate_presigned_urls(f"cutouts/{image_name}") - logger.info("Presigned URLs generated for cutouts of image %s", image_name) - - # Return the URLs - return urls - except Exception as e: - # Log any errors that occur - logger.error( - "An error occurred while creating cutouts for image %s: %s", image_name, e - ) - raise - - -@app.post("/create-cutouts") -async def create_all_cutouts( - image_names: List[str] = Body(...), classes: List[str] = Body(...) -): - """Create cutouts from multiple images and upload them to S3. - - Args: - image_names (List[str]): List of image names to create cutouts from. - classes (List[str], optional): A list of classes for the AI to detect for. Defaults to Body(...). - - Returns: - Dict[str, List[str]]: A dictionary where the keys are the image names and the values are the lists of presigned URLs for the cutouts. - """ - from s3_handler import Boto3Client - - s3 = Boto3Client() - cutout = CutoutCreator( - classes=classes, - grounding_dino_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, - grounding_dino_config_path=GROUNDING_DINO_CONFIG_PATH, - encoder_version="vit_b", - ) - - result = {} - for image_name in image_names: - cutout.create_cutouts(image_name) - result[image_name] = s3.generate_presigned_urls(f"cutouts/{image_name}") - - return result - - -@stub.function( - image=cutout_generator_image, - gpu="T4", - mounts=[local_packages], - secret=Secret.from_name("my-aws-secret"), - container_idle_timeout=300, - retries=1, -) -@asgi_app() -def cutout_app(): - """Create a FastAPI app for creating cutouts. - - Returns: - FastAPI: FastAPI app for creating cutouts. - """ - return app diff --git a/app/s3_handler/app.py b/app/s3_handler/app.py index f4656ce..e3081d9 100644 --- a/app/s3_handler/app.py +++ b/app/s3_handler/app.py @@ -1,9 +1,9 @@ -from modal import Mount, Image, Secret, Stub, asgi_app -import os import logging -from fastapi import FastAPI, File, UploadFile, Body, HTTPException -from fastapi.middleware.cors import CORSMiddleware +import os +from fastapi import Body, FastAPI, File, HTTPException, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from modal import Image, Mount, Secret, Stub, asgi_app stub = Stub(name="s3_handler") @@ -22,6 +22,7 @@ allow_headers=["*"], ) + # ================================================ # API Endpoints # ================================================ @@ -35,9 +36,9 @@ async def upload_image_to_s3(image: UploadFile = File(...)): Returns: str: Message indicating whether the upload was successful. """ - from s3_handler import Boto3Client from botocore.exceptions import BotoCoreError, NoCredentialsError - + from s3_handler import Boto3Client + s3_client = Boto3Client() try: s3_client.upload_to_s3(image.file, "images", image.filename) @@ -46,9 +47,12 @@ async def upload_image_to_s3(image: UploadFile = File(...)): except BotoCoreError as e: raise HTTPException(status_code=500, detail=str(e)) from e except Exception as e: - raise HTTPException(status_code=500, detail="An error occurred while uploading the image") from e + raise HTTPException( + status_code=500, detail="An error occurred while uploading the image" + ) from e return {"message": "Image uploaded successfully", "status_code": 200} + @app.get("/generate-presigned-urls/{image_name}") async def generate_presigned_urls(image_name: str): """Generate presigned urls for the cutouts of an image. @@ -65,7 +69,7 @@ async def generate_presigned_urls(image_name: str): return s3_client.generate_presigned_urls(f"cutouts/{image_name}") -@app.get('/get-image/{image_name}') +@app.get("/get-image/{image_name}") async def get_image(image_name: str): """Get an image from S3. @@ -75,8 +79,8 @@ async def get_image(image_name: str): Returns: FileResponse: FileResponse object containing the image. """ - from s3_handler import Boto3Client from fastapi.responses import FileResponse + from s3_handler import Boto3Client s3_client = Boto3Client() data = s3_client.generate_presigned_url_with_metadata("images", image_name) @@ -86,13 +90,27 @@ async def get_image(image_name: str): @stub.function( - image=Image.debian_slim().pip_install("boto3", "fastapi", "starlette", "uvicorn", "python-multipart", "pydantic", "requests", "httpx", "httpcore", "httpx[http2]", "httpx[http1]"), mounts=[Mount.from_local_python_packages("s3_handler")], secret=Secret.from_name("my-aws-secret") + image=Image.debian_slim().pip_install( + "boto3", + "fastapi", + "starlette", + "uvicorn", + "python-multipart", + "pydantic", + "requests", + "httpx", + "httpcore", + "httpx[http2]", + "httpx[http1]", + ), + mounts=[Mount.from_local_python_packages("s3_handler")], + secret=Secret.from_name("my-aws-secret"), ) - @asgi_app() def main(): return app + # ================================= # Modal s3 functions # ================================= @@ -104,7 +122,7 @@ def main(): # from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError # s3_client = Boto3Client() - + # try: # s3_client.upload_to_s3(file_body, folder, image_name) # logging.info(f"Successfully uploaded {image_name} to {folder}") diff --git a/app/s3_handler/s3_handler.py b/app/s3_handler/s3_handler.py index 2cd5bb5..e989e4f 100644 --- a/app/s3_handler/s3_handler.py +++ b/app/s3_handler/s3_handler.py @@ -1,7 +1,8 @@ +import logging import os + import boto3 -import logging -from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError +from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError class Boto3Client: @@ -76,7 +77,10 @@ def generate_presigned_url_with_metadata(self, folder, key, expiration=3600): # Generate presigned URL url = self.s3.generate_presigned_url( "get_object", - Params={"Bucket": os.environ["CUTOUT_BUCKET"], "Key": f"{folder}/{key}"}, + Params={ + "Bucket": os.environ["CUTOUT_BUCKET"], + "Key": f"{folder}/{key}", + }, ExpiresIn=expiration, ) # Get object metadata diff --git a/legacy_code/cutouts.py b/legacy_code/cutouts.py index 054da01..1e66844 100644 --- a/legacy_code/cutouts.py +++ b/legacy_code/cutouts.py @@ -1,16 +1,18 @@ -from typing import Dict -import os import io +import os +from typing import Dict + import cv2 import numpy as np -from s3_handler import Boto3Client +import supervision as sv from dino import Dino -from segment import Segmenter from PIL import Image -import supervision as sv +from s3_handler import Boto3Client +from segment import Segmenter HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) + class CutoutCreator: """ A class for creating cutouts from an image and uploading them to S3. @@ -20,7 +22,13 @@ class CutoutCreator: s3: A Boto3Client object for uploading to S3. mask_annotator: A MaskAnnotator object for annotating images with masks. """ - def __init__(self, classes: str, grounding_dino_config_path: str, grounding_dino_checkpoint_path: str): + + def __init__( + self, + classes: str, + grounding_dino_config_path: str, + grounding_dino_checkpoint_path: str, + ): self.dino = Dino( classes=classes, box_threshold=0.35, From 4e5226233b85443590a67979b7cac0858865c9b7 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 21:21:00 +0800 Subject: [PATCH 20/23] Clean up imports --- app/cutout_handler/dino.py | 8 +++++--- app/cutout_handler/s3_handler.py | 24 ++++++------------------ app/cutout_handler/segment.py | 14 ++++++++------ 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/app/cutout_handler/dino.py b/app/cutout_handler/dino.py index fb5bbb1..63aa479 100644 --- a/app/cutout_handler/dino.py +++ b/app/cutout_handler/dino.py @@ -1,6 +1,10 @@ from typing import List -from app.common import cutout_handler_stub +from app.common import cutout_handler_stub, cutout_generator_image + +with cutout_generator_image.imports(): + import torch + from groundingdino.util.inference import Model cutout_handler_stub.cls() @@ -16,8 +20,6 @@ def __init__( model_config_path, model_checkpoint_path, ): - import torch - from groundingdino.util.inference import Model self.classes = classes self.box_threshold = box_threshold diff --git a/app/cutout_handler/s3_handler.py b/app/cutout_handler/s3_handler.py index e2359f5..bb239f9 100644 --- a/app/cutout_handler/s3_handler.py +++ b/app/cutout_handler/s3_handler.py @@ -1,14 +1,14 @@ -import logging import os +import logging +from app.common import s3_handler_stub, s3_handler_image -from app.common import s3_handler_stub +with s3_handler_image.imports(): + import boto3 + from botocore.exceptions import ClientError, BotoCoreError, NoCredentialsError s3_handler_stub.cls() - - class Boto3Client: def __init__(self): - import boto3 self.s3 = boto3.client( "s3", @@ -18,9 +18,6 @@ def __init__(self): ) def download_from_s3(self, save_path, image_name): - import boto3 - from botocore.exceptions import ClientError - s3_client = boto3.client("s3") file_path = os.path.join(save_path, image_name) try: @@ -40,8 +37,6 @@ def download_from_s3(self, save_path, image_name): return file_path def upload_to_s3(self, file_body, folder, image_name): - from botocore.exceptions import BotoCoreError, NoCredentialsError - try: self.s3.put_object( Body=file_body, @@ -60,8 +55,6 @@ def upload_to_s3(self, file_body, folder, image_name): raise def generate_presigned_urls(self, folder, expiration=3600): - from botocore.exceptions import ClientError - try: response = self.s3.list_objects_v2( Bucket=os.environ["CUTOUT_BUCKET"], Prefix=folder @@ -86,16 +79,11 @@ def generate_presigned_urls(self, folder, expiration=3600): return urls def generate_presigned_url_with_metadata(self, folder, key, expiration=3600): - from botocore.exceptions import ClientError - try: # Generate presigned URL url = self.s3.generate_presigned_url( "get_object", - Params={ - "Bucket": os.environ["CUTOUT_BUCKET"], - "Key": f"{folder}/{key}", - }, + Params={"Bucket": os.environ["CUTOUT_BUCKET"], "Key": f"{folder}/{key}"}, ExpiresIn=expiration, ) # Get object metadata diff --git a/app/cutout_handler/segment.py b/app/cutout_handler/segment.py index 3eda724..802ab5a 100644 --- a/app/cutout_handler/segment.py +++ b/app/cutout_handler/segment.py @@ -1,4 +1,11 @@ -from app.common import cutout_handler_stub +from app.common import cutout_handler_stub, cutout_generator_image + + +with cutout_generator_image.imports(): + import torch + from segment_anything import SamPredictor, sam_model_registry + import numpy as np + cutout_handler_stub.cls() @@ -11,9 +18,6 @@ def __init__( sam_encoder_version: str, sam_checkpoint_path: str, ): - import torch - from segment_anything import SamPredictor, sam_model_registry - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.sam = sam_model_registry[sam_encoder_version]( checkpoint=sam_checkpoint_path @@ -21,8 +25,6 @@ def __init__( self.sam_predictor = SamPredictor(self.sam) def segment(self, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: - import numpy as np - self.sam_predictor.set_image(image) result_masks = [] for box in xyxy: From e920dc74be33ddfcd41cc9e76aff2ae65e606949 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 21:29:45 +0800 Subject: [PATCH 21/23] Fix logging statements in cutout_handler --- app/cutout_handler/grounded_cutouts.py | 4 ++-- app/cutout_handler/server.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/app/cutout_handler/grounded_cutouts.py b/app/cutout_handler/grounded_cutouts.py index dd1be39..b9286f4 100644 --- a/app/cutout_handler/grounded_cutouts.py +++ b/app/cutout_handler/grounded_cutouts.py @@ -111,12 +111,12 @@ def create_cutouts(self, image_name): # Download image from s3 image_path = self.s3.download_from_s3(data_path, image_name) if image_path is None: - print(f"Failed to download image {image_name} from S3") + logger.error(f"Failed to download image {image_name} from S3") return # Check if image exists if not os.path.exists(image_path): - print(f"Image {image_name} not found in folder {image_path}") + logger.error(f"Image {image_name} not found in folder {image_path}") return # Create cutouts directory if it doesn't exist diff --git a/app/cutout_handler/server.py b/app/cutout_handler/server.py index a982c17..66a605e 100644 --- a/app/cutout_handler/server.py +++ b/app/cutout_handler/server.py @@ -7,7 +7,7 @@ from modal import Secret, asgi_app from starlette.requests import Request -from app.common import cutout_handler_stub, local_packages +from app.common import cutout_handler_stub, local_packages, cutout_generator_image from .grounded_cutouts import CutoutCreator @@ -125,7 +125,7 @@ async def create_cutouts(image_name: str, request: Request): ) # Create the cutouts - print(f"CREATING CUTOUTS FOR IMAGE {image_name}") + logger.info(f"CREATING CUTOUTS FOR IMAGE {image_name}") cutout.create_cutouts(image_name) logger.info("Cutouts created for image %s", image_name) From 19da2948c8ad296e585d2193b7773135ce99dcc4 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 21:29:59 +0800 Subject: [PATCH 22/23] Hopefully fix deployment --- .github/workflows/ci-cd.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index 3b4673a..816e1f6 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -39,9 +39,9 @@ jobs: - name: Deploy cutout generator job run: | cd app - modal deploy --env=${{ steps.vars.outputs.environment }} grounded_cutouts.py + modal deploy --env=${{ steps.vars.outputs.environment }} app.cutout_handler::cutout_handler_stub - name: Deploy s3_handler job run: | cd app/s3_handler - modal deploy --env=${{ steps.vars.outputs.environment }} app.py + modal deploy --env=${{ steps.vars.outputs.environment }} app.s3_handler.app \ No newline at end of file From 3b65f39d41e6eb91779a5af56ccd3647253a38a7 Mon Sep 17 00:00:00 2001 From: OriginalByteMe Date: Thu, 21 Dec 2023 21:32:50 +0800 Subject: [PATCH 23/23] Update deployment configuration --- .github/workflows/ci-cd.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/ci-cd.yml b/.github/workflows/ci-cd.yml index 816e1f6..49d7de2 100644 --- a/.github/workflows/ci-cd.yml +++ b/.github/workflows/ci-cd.yml @@ -32,16 +32,14 @@ jobs: id: vars run: | if [[ "${{ github.ref }}" == "refs/heads/main" ]]; then - echo "::set-output name=environment::prod" + echo "::set-output name=environment::main" else echo "::set-output name=environment::dev" fi - name: Deploy cutout generator job run: | - cd app modal deploy --env=${{ steps.vars.outputs.environment }} app.cutout_handler::cutout_handler_stub - name: Deploy s3_handler job run: | - cd app/s3_handler modal deploy --env=${{ steps.vars.outputs.environment }} app.s3_handler.app \ No newline at end of file