diff --git a/.beamignore b/.beamignore deleted file mode 100644 index 02f00e4..0000000 --- a/.beamignore +++ /dev/null @@ -1,9 +0,0 @@ -Images/* -masks/* -cutouts/* -Good_Cutouts/* -.venv/* -.conda/* -yolo* -sam_vit* -cutouts_* \ No newline at end of file diff --git a/.gitignore b/.gitignore index eb4d2c1..9088a5d 100644 --- a/.gitignore +++ b/.gitignore @@ -14,4 +14,6 @@ yolo* __pycache__/ *.jpg *.png -cutout_generator-* \ No newline at end of file +cutout_generator-* +Terraform/.terraform + diff --git a/README.md b/README.md new file mode 100644 index 0000000..1db8c90 --- /dev/null +++ b/README.md @@ -0,0 +1,76 @@ +# AI Image Cutout Maker + +AI Image Cutout Maker is a project that uses artificial intelligence to automatically create cutouts from images. This project is designed to simplify the process of creating cutouts, which can be a time-consuming task if done manually. + +This project utilizes the power of Segment Anything and Grounding Dino AI models to detect subjects in an image and cut them out. These models are hosted on Modal, which allows us to leverage GPU acceleration for faster and more efficient processing. + +The cutouts are then stored in an Amazon S3 bucket, providing a scalable and secure storage solution. This setup allows us to handle large volumes of images and serve them quickly and efficiently. + +## Project Structure + +The project is structured as follows: + +- `app/`: This directory contains the main application code. + - `cutout.py`: This script handles the process of creating cutouts from images using the Segment Anything and Grounding Dino AI models. + - `dino.py`: This script is responsible for interacting with the Grounding Dino AI model. + - `segment.py`: This script is used for interacting with the Segment Anything AI model. + - `s3_handler.py`: This script handles interactions with Amazon S3, such as uploading and downloading images. + - `grounded_cutouts.py`: This script ... + +- `.venv/`: This directory contains the virtual environment for the project. + +- `modal_utils/`: This directory contains utility functions used throughout the project. + +- `grpc_utils.py`: This script handles the gRPC connections in the project. + + +## Purpose of the Project + +The purpose of this project is to automate the process of creating cutouts from images. By using artificial intelligence, we can create accurate cutouts much faster than would be possible manually. This can be useful in a variety of applications, such as graphic design, image editing, and more. There is also the social media aspect of creating stickers out of items cut out of an image, which is a popular trend on social media platforms such as Instagram and TikTok. + +## How to Use + +As mentioned above, this project is being hosted through Modal, this also does mean that its using the modal API. This means that you will need to have a modal account and have the modal CLI installed. You can find instructions on how to do this [here](https://docs.modal.ai/docs/getting-started). I'am planning on dockerizing this project as well down the line so that it can be used without the modal CLI. + + +## Workflow diagriam for how BE processes and returns values +```mermaid +sequenceDiagram + actor User + User ->> CutoutFE: Upload Image + CutoutFE ->> S3: Call to push image to bucket + CutoutFE ->> CutoutBE: Call to create cutout + CutoutBE ->>+ Modal: Spin up instance with container + Modal ->>- CutoutBE: Allow that container to be used by CutoutBE URL + CutoutBE ->>+ S3: Download image + S3 -->>- CutoutBE: Send Image back + Note over CutoutBE: Processes cutouts + other diagrams + CutoutBE -->> S3: Upload cutouts to bucket + Note over CutoutBE: Generates Pre-signed URL's of processed files + Note over CutoutFE: If need be FE can also create list of presigned urls via s3 + CutoutBE -->>+ CutoutFE: Return list of presigned urls + CutoutFE -->>+ User: Display processed images + User ->> S3: Download images via url +``` + + +## To-Do + +Here are some tasks that are on our roadmap: + +- Dockerize the project: We plan to create a Dockerfile for this project to make it easier to set up and run in any environment. +- API Documentation: We will be writing comprehensive API documentation to make it easier for developers to understand and use our API. +- Improve error handling: We aim to improve our error handling to make our API more robust and reliable. +- Add more AI models: We are planning to integrate more AI models to improve the accuracy and versatility of our image cutout creation. +- Optimize performance: We will be working on optimizing the performance of our API, particularly in terms of processing speed and resource usage. + +Please note that this is not an exhaustive list and the roadmap may change based on project needs and priorities. + +## Contributing + +This is a personal project so it won't really be geared to any contributions, but feel free to fork the repository and make any changes you want. If you have any questions, feel free to reach out to me at my [email](mailto:noahrijkaard@gmail.com) + +## License + +This project is licensed under the terms of the [MIT License](LICENSE). + diff --git a/Terraform/backend.tf b/Terraform/backend.tf new file mode 100644 index 0000000..f8702a0 --- /dev/null +++ b/Terraform/backend.tf @@ -0,0 +1,18 @@ +terraform { + required_providers { + aws = { + source = "hashicorp/aws" + } + } + backend "s3" { + bucket = "noah-terraform-remote-state" + key = "modal/cutout-gen-config" + region = "ap-southeast-1" + profile = "noahTest" + } +} + +provider "aws" { + region = "ap-southeast-1" + profile = "noahTest" +} diff --git a/Terraform/s3.tf b/Terraform/s3.tf new file mode 100644 index 0000000..560c13d --- /dev/null +++ b/Terraform/s3.tf @@ -0,0 +1,21 @@ +module "s3_bucket" { + source = "terraform-aws-modules/s3-bucket/aws" + + bucket = "cutout-image-store" + acl = "private" + + control_object_ownership = true + object_ownership = "ObjectWriter" + + lifecycle_rule = [ + { + id = "expire" + status = "Enabled" + enabled = true + + expiration = { + days = 1 + } + } + ] +} diff --git a/app.py b/app.py deleted file mode 100644 index a6f19a8..0000000 --- a/app.py +++ /dev/null @@ -1,22 +0,0 @@ -import beam - -app = beam.App( - name="cutout_generator", - cpu=1, - gpu="T4", - memory="16Gi", - python_version="python3.8", - python_packages="requirements.txt", - commands=["apt-get update && apt-get install -y ffmpeg"], -) - -app.Trigger.TaskQueue( - inputs={ - "image": beam.Types.Image(raw=False), - "name": beam.Types.String(), - "prompt": beam.Types.String(), - }, - handler="run.py:generate_cutout", -) -app.Output.Dir(path="generated_images", name="images") -app.Mount.PersistentVolume(path="./models", name="models") diff --git a/app/cutout.py b/app/cutout.py new file mode 100644 index 0000000..1847352 --- /dev/null +++ b/app/cutout.py @@ -0,0 +1,38 @@ +import cv2 +import numpy as np +import os + + +class CutoutCreator: + def __init__(self, image_folder): + self.image_folder = image_folder + + def create_cutouts(self, image_name, masks, output_folder,bucket_name, s3): + # Load the image + image_path = os.path.join(self.image_folder, image_name) + for item in os.listdir(self.image_folder): + print("Item: ",item) + if not os.path.exists(image_path): + print(f"Image {image_name} not found in folder {self.image_folder}") + return + + image = cv2.imread(image_path) + + # Apply each mask to the image + for i, mask in enumerate(masks): + # Ensure the mask is a boolean array + mask = mask.astype(bool) + + # Apply the mask to create a cutout + cutout = np.zeros_like(image) + cutout[mask] = image[mask] + + # Save the cutout + cutout_name = f"{image_name}_cutout_{i}.png" + cutout_path = os.path.join(output_folder, cutout_name) + cv2.imwrite(cutout_path, cutout) + + # Upload the cutout to S3 + with open(cutout_path, "rb") as f: + s3.upload_to_s3(bucket_name, f.read(), f"cutouts/{image_name}/{cutout_name}") + diff --git a/app/dino.py b/app/dino.py new file mode 100644 index 0000000..b46199b --- /dev/null +++ b/app/dino.py @@ -0,0 +1,31 @@ +import torch +from groundingdino.util.inference import Model +from typing import List + + +class Dino: + def __init__(self, classes, box_threshold, text_threshold, model_config_path, model_checkpoint_path): + self.classes = classes + self.box_threshold = box_threshold + self.text_threshold = text_threshold + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.grounding_dino_model = Model(model_config_path=model_config_path, model_checkpoint_path=model_checkpoint_path) + + def enhance_class_name(self, class_names: List[str]) -> List[str]: + return [ + f"all {class_name}s" + for class_name + in class_names + ] + + def predict(self, image): + detections = self.grounding_dino_model.predict_with_classes(image=image, classes=self.enhance_class_name(class_names=self.classes), box_threshold=self.box_threshold, text_threshold=self.text_threshold) + detections = detections[detections.class_id != None] + return detections + +# Example usage +# dino = Dino(classes=['person', 'nose', 'chair', 'shoe', 'ear', 'hat'], +# box_threshold=0.35, +# text_threshold=0.25, +# model_config_path=GROUNDING_DINO_CONFIG_PATH, +# model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH) diff --git a/app/grounded_cutouts.py b/app/grounded_cutouts.py new file mode 100644 index 0000000..79faeb9 --- /dev/null +++ b/app/grounded_cutouts.py @@ -0,0 +1,71 @@ +import os +import modal +import cv2 + +stub = modal.Stub(name="cutout_generator") + +img_volume = modal.NetworkFileSystem.persisted("image-storage-vol") +cutout_volume = modal.NetworkFileSystem.persisted("cutout-storage-vol") +local_packages = modal.Mount.from_local_python_packages("cutout", "dino", "segment", "s3_handler") +cutout_generator_image = modal.Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3").pip_install( "segment-anything", "opencv-python", "botocore", "boto3").run_commands( + "apt-get update", + "apt-get install -y git wget libgl1-mesa-glx libglib2.0-0", + "echo $CUDA_HOME", + "git clone https://github.com/IDEA-Research/GroundingDINO.git", + "pip install -q -e GroundingDINO/", + "mkdir -p /weights", + "mkdir -p /data", + "pip uninstall -y supervision", + "pip uninstall -y opencv-python", +"pip install opencv-python==4.8.0.74", + "pip install -q supervision==0.6.0", + "wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/", + "wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/", + "wget -q https://media.roboflow.com/notebooks/examples/dog.jpeg -P data/", + "ls -F", + "ls -F GroundingDINO/groundingdino/config", + "ls -F GroundingDINO/groundingdino/models/GroundingDINO/" +) +# Setup paths +HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) +GROUNDING_DINO_CONFIG_PATH = os.path.join(HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py") +GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(HOME, "weights", "groundingdino_swint_ogc.pth") +SAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") + +@stub.function(image=cutout_generator_image, mounts=[local_packages], gpu="T4", secret=modal.Secret.from_name("my-aws-secret"), network_file_systems={"/images": img_volume, "/cutouts": cutout_volume}) +@modal.web_endpoint() +def main(image_name): + # Import relevant classes + from dino import Dino + from segment import Segmenter + from cutout import CutoutCreator + from s3_handler import Boto3Client + SOURCE_IMAGE_PATH = os.path.join(HOME, "data", image_name) + print(f"SOURCE_IMAGE_PATH: {SOURCE_IMAGE_PATH}") + SAVE_IMAGE_PATH = os.path.join(HOME, "data") + OUTPUT_CUTOUT_PATH = os.path.join(HOME, "cutouts") + dino = Dino(classes=['person', 'nose', 'chair', 'shoe', 'ear', 'hat'], + box_threshold=0.35, + text_threshold=0.25, + model_config_path=GROUNDING_DINO_CONFIG_PATH, + model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH) + + segment = Segmenter(sam_encoder_version="vit_h", sam_checkpoint_path=SAM_CHECKPOINT_PATH) + + cutout = CutoutCreator(image_folder=SAVE_IMAGE_PATH) + + s3 = Boto3Client() + + s3.download_from_s3(SAVE_IMAGE_PATH, "cutout-image-store", f"images/{image_name}") + + image = cv2.imread(SOURCE_IMAGE_PATH) + + # Run the DINO model on the image + detections = dino.predict(image) + + detections.mask = segment.segment(image, detections.xyxy) + + cutout.create_cutouts(image_name, detections.mask, OUTPUT_CUTOUT_PATH, "cutout-image-store", s3) + + return "Success" + diff --git a/s3FileHandler.py b/app/s3_handler.py similarity index 80% rename from s3FileHandler.py rename to app/s3_handler.py index d0742e6..f92f97b 100644 --- a/s3FileHandler.py +++ b/app/s3_handler.py @@ -7,20 +7,20 @@ class Boto3Client: def __init__(self): self.s3 = boto3.client( "s3", - endpoint_url="https://13583f5ff84f5693a4a859a769743849.r2.cloudflarestorage.com", aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], - region_name="auto", + region_name=os.environ["AWS_REGION"], ) - def download_from_s3(bucket_name, key): + def download_from_s3(self, save_path, bucket_name, key): s3_client = boto3.client("s3") - file_path = os.path.join(os.getcwd(), key) + file_name = key.split("/")[-1] + file_path = os.path.join(save_path, file_name) try: s3_client.download_file(bucket_name, key, file_path) except ClientError as e: - print(e) + print("BOTO error: ",e) return None return file_path diff --git a/app/segment.py b/app/segment.py new file mode 100644 index 0000000..3d61359 --- /dev/null +++ b/app/segment.py @@ -0,0 +1,22 @@ +import numpy as np +import torch +from segment_anything import sam_model_registry, SamPredictor + + +class Segmenter: + def __init__(self, sam_encoder_version: str, sam_checkpoint_path: str, ): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.sam = sam_model_registry[sam_encoder_version](checkpoint=sam_checkpoint_path).to(device=self.device) + self.sam_predictor = SamPredictor(self.sam) + + def segment(self, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: + self.sam_predictor.set_image(image) + result_masks = [] + for box in xyxy: + masks, scores, logits = self.sam_predictor.predict( + box=box, + multimask_output=True + ) + index = np.argmax(scores) + result_masks.append(masks[index]) + return np.array(result_masks) \ No newline at end of file diff --git a/cutoutHandler.py b/cutoutHandler.py deleted file mode 100644 index 88e56f6..0000000 --- a/cutoutHandler.py +++ /dev/null @@ -1,177 +0,0 @@ -import logging -import os - -import cv2 -import numpy as np -from PIL import Image -from pycocotools import mask as maskUtils - - -class CutoutHandler: - # Define the logging format - log_format = "%(asctime)s [%(levelname)s] %(message)s" - - # Define the color codes for each log level - color_codes = { - "DEBUG": "\033[32m", # Green - "INFO": "\033[34m", # Blue - "WARNING": "\033[33m", # Yellow - "ERROR": "\033[31m", # Red - "CRITICAL": "\033[35m", # Magenta - } - - # Define a custom logging formatter that adds color to the log messages - class ColoredFormatter(logging.Formatter): - def format(self, record): - levelname = record.levelname - if levelname in CutoutHandler.color_codes: - levelname_color = ( - f"{CutoutHandler.color_codes[levelname]}{levelname}\033[0m" - ) - record.levelname = levelname_color - return super().format(record) - - def __init__(self, model, predictor): - self.model = model - self.predictor = predictor - - # Create a logger object - self.logger = logging.getLogger(__name__) - - # Set the logging level to DEBUG - self.logger.setLevel(logging.DEBUG) - - # Create a console handler and set its formatter - console_handler = logging.StreamHandler() - console_handler.setFormatter( - CutoutHandler.ColoredFormatter(CutoutHandler.log_format) - ) - - # Add the console handler to the logger - self.logger.addHandler(console_handler) - - self.cutout_folder = "generated_images" - if not os.path.exists(self.cutout_folder): - os.makedirs(self.cutout_folder) - - def process_image(self, image: Image, name: str, prompt: str) -> list: - self.logger.debug(f"Processing image '{name}' with prompt '{prompt}'") - - # Convert image to rgb color space - image = image.convert("RGB") - # Convert the image to a numpy array - image_np = np.array(image) - print("array_shape: ", image_np.shape) - # Convert the image to RGB format - image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) - - # Run the model on the image - results = self.model(image_np) - # print(results) - # Set the image for the SegmentAnything predictor - self.predictor.set_image(image_np) - - masks_list = [] - - for result in results: - boxes = result.boxes.cpu().numpy() - for i, box in enumerate(boxes): - if result.names[int(box.cls[0])] == prompt: - self.logger.debug( - f"Found bounding box for prompt '{prompt}' in image '{name}'" - ) - # Get the mask from the SegmentAnything predictor - masks, scores, _ = self.predictor.predict( - point_coords=None, - point_labels=None, - box=box.xyxy[0].astype(int), - multimask_output=False, - ) - if scores[0] < 0.6: - self.logger.debug( - f"Skipping mask with score {scores[0]} for prompt '{prompt}' in image '{name}'" - ) - continue - for i, mask in enumerate(masks): - self.logger.debug( - f"Processing mask {i+1} of {len(masks)} for prompt '{prompt}' in image '{name}'" - ) - # Convert the mask numpy array to a binary mask - mask_binary = np.zeros( - (mask.shape[0], mask.shape[1]), dtype=np.uint8 - ) - mask_binary[mask > 0] = 1 - - # Convert the binary mask to a COCO RLE format - mask_rle = maskUtils.encode(np.asfortranarray(mask_binary)) - - # Extract the counts key from the mask RLE dictionary - counts = mask_rle["counts"] - - # Create a new dictionary with the required keys for COCO RLE format - mask_coco_rle = { - "size": [mask.shape[0], mask.shape[1]], - "counts": counts.decode("utf-8"), - } - - # Add the mask to the list of masks - masks_list.append(mask_coco_rle) - - self.logger.debug( - f"Processed {len(masks_list)} masks for prompt '{prompt}' in image '{name}'" - ) - - # Once finished, create cutouts from the masks - return self.create_cutout(image, name, masks_list) - - def create_cutout(self, image: np.ndarray, name: str, masks_list: list) -> list: - cutouts = [] - # Convert the image to a numpy array - image_np = np.array(image) - for i, mask in enumerate(masks_list): - self.logger.debug( - f"Generating cutout {i+1} of {len(masks_list)} for image '{name}'" - ) - - size = mask["size"] - counts = mask["counts"] - mask_decoded = maskUtils.decode({"size": size, "counts": counts}) - mask_binary = np.zeros((size[0], size[1]), dtype=np.uint8) - mask_binary[mask_decoded > 0] = 1 - - # Resize the mask to match the shape of the image - mask_resized = cv2.resize(mask_decoded, (image_np.shape[1], image_np.shape[0])) - - # Extract the cutout from the image using the mask - cutout = image * mask_resized[..., np.newaxis] - - # Create an alpha channel for the cutout image - alpha = np.zeros(cutout.shape[:2], dtype=np.uint8) - alpha[mask_resized > 0] = 255 - cutout = cv2.merge((cutout, alpha)) - - # Crop the cutout image to the bounding rectangle - x, y, w, h = cv2.boundingRect(mask_resized) - cutout = cutout[y : y + h, x : x + w] - - # Create a PIL Image from the cutout numpy array - cutout_pil = Image.fromarray(cutout) - - # Save the cutout to a file - cutout_filename = f"{name}_{i+1}.png" - cutout_path = os.path.join(self.cutout_folder, cutout_filename) - cutout_pil.save(cutout_path) - - self.logger.debug( - f"Saved cutout {i+1} of {len(masks_list)} to file '{cutout_path}' for image '{name}'" - ) - - # Add the cutout to the list of cutouts - cutouts.append(cutout_path) - self.logger.debug( - f"Added cutout {i+1} of {len(masks_list)} to the list of cutouts for image '{name}'" - ) - - self.logger.info(f"Generated {len(cutouts)} cutouts for image '{name}'") - - return cutouts diff --git a/get_started.py b/get_started.py deleted file mode 100644 index c897b77..0000000 --- a/get_started.py +++ /dev/null @@ -1,14 +0,0 @@ -import modal - -stub = modal.Stub("example-get-started") - - -@stub.function() -def square(x): - print("This code is running on a remote worker!") - return x**2 - - -@stub.local_entrypoint() -def main(): - print("the square is", square.call(42)) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 1d1d97b..0000000 --- a/requirements.txt +++ /dev/null @@ -1,110 +0,0 @@ -asttokens==2.2.1 -backcall==0.2.0 -boto3==1.26.165 -botocore==1.29.165 -cairocffi==1.6.0 -CairoSVG==2.7.0 -certifi==2023.5.7 -cffi==1.15.1 -charset-normalizer==3.1.0 -cmake==3.26.4 -comm==0.1.3 -contourpy==1.1.0 -croniter==1.4.1 -cssselect2==0.7.0 -cycler==0.11.0 -debugpy==1.6.7 -decorator==5.1.1 -defusedxml==0.7.1 -executing==1.2.0 -filelock==3.12.2 -fonttools==4.40.0 -gitdb==4.0.10 -GitPython==3.1.31 -idna==3.4 -imageai==3.0.3 -jedi==0.18.2 -Jinja2==3.1.2 -jmespath==1.0.1 -jupyter_client==8.2.0 -jupyter_core==5.3.1 -kiwisolver==1.4.4 -lit==16.0.6 -markdown-it-py==3.0.0 -MarkupSafe==2.1.3 -marshmallow==3.18.0 -marshmallow-dataclass==8.5.14 -matplotlib==3.7.1 -matplotlib-inline==0.1.6 -mdurl==0.1.2 -meshio==5.3.4 -mpmath==1.3.0 -mypy==0.981 -mypy-extensions==1.0.0 -nest-asyncio==1.5.6 -networkx==3.1 -numpy==1.24.3 -nvidia-cublas-cu11==11.10.3.66 -nvidia-cuda-cupti-cu11==11.7.101 -nvidia-cuda-nvrtc-cu11==11.7.99 -nvidia-cuda-runtime-cu11==11.7.99 -nvidia-cudnn-cu11==8.5.0.96 -nvidia-cufft-cu11==10.9.0.58 -nvidia-curand-cu11==10.2.10.91 -nvidia-cusolver-cu11==11.4.0.1 -nvidia-cusparse-cu11==11.7.4.91 -nvidia-nccl-cu11==2.14.3 -nvidia-nvtx-cu11==11.7.91 -opencv-python==4.7.0.72 -packaging==23.1 -pandas==2.0.2 -parso==0.8.3 -pexpect==4.8.0 -pickleshare==0.7.5 -Pillow==9.5.0 -platformdirs==3.5.3 -pooch==1.7.0 -prompt-toolkit==3.0.38 -psutil==5.9.5 -ptyprocess==0.7.0 -pure-eval==0.2.2 -pycocotools==2.0.6 -pycparser==2.21 -Pygments==2.15.1 -pyparsing==3.0.9 -python-dateutil==2.8.2 -pytz==2023.3 -pyvista==0.39.1 -PyYAML==6.0 -pyzmq==25.1.0 -requests==2.31.0 -rich==13.4.2 -s3transfer==0.6.1 -scooby==0.7.2 -seaborn==0.12.2 -segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 -six==1.16.0 -smmap==5.0.0 -stack-data==0.6.2 -svgwrite==1.4.3 -sympy==1.12 -tinycss2==1.2.1 -tomli==2.0.1 -torch==2.0.1 -torchaudio==2.0.2 -torchvision==0.15.2 -tornado==6.3.3 -tqdm==4.65.0 -traitlets==5.9.0 -trimesh==3.22.1 -triton==2.0.0 -typeguard==2.13.3 -typing-inspect==0.9.0 -typing_extensions==4.6.3 -tzdata==2023.3 -ultralytics==8.0.124 -urllib3==1.26.16 -validators==0.20.0 -vtk==9.2.6 -wcwidth==0.2.6 -webencodings==0.5.1 diff --git a/run.py b/run.py deleted file mode 100644 index bb956dc..0000000 --- a/run.py +++ /dev/null @@ -1,216 +0,0 @@ -import os -import json -import numpy as np -from pycocotools import mask as maskUtils -from segment_anything import sam_model_registry, SamPredictor -from cutoutHandler import CutoutHandler -from ultralytics import YOLO -from io import BytesIO -import base64 -import cv2 -from PIL import Image -from s3FileHandler import Boto3Client -import logging - -# Define the logging format -log_format = "%(asctime)s [%(levelname)s] %(message)s" -logging.basicConfig(format=log_format, level=logging.DEBUG) - -# Define the color codes for each log level -color_codes = { - "DEBUG": "\033[32m", # Green - "INFO": "\033[34m", # Blue - "WARNING": "\033[33m", # Yellow - "ERROR": "\033[31m", # Red - "CRITICAL": "\033[35m", # Magenta -} - - -# Define a custom logging formatter that adds color to the log messages -class ColoredFormatter(logging.Formatter): - def format(self, record): - levelname = record.levelname - if levelname in color_codes: - levelname_color = f"{color_codes[levelname]}{levelname}\033[0m" - record.levelname = levelname_color - return super().format(record) - - -# Create a logger object -logger = logging.getLogger(__name__) - -# Set the logging level to DEBUG -logger.setLevel(logging.DEBUG) - -# Create a console handler and set its formatter -console_handler = logging.StreamHandler() -console_handler.setFormatter(ColoredFormatter(log_format)) - -# Add the console handler to the logger -logger.addHandler(console_handler) - -client = Boto3Client() - -# LOAD SEGMENT ANYTHING MODEL -sam_checkpoint = "./models/sam_vit_h_4b8939.pth" -model_type = "vit_h" - -device = "cuda" - -sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) -sam.to(device=device) - -predictor = SamPredictor(sam) - -# LOAD YOLO MODEL -model = YOLO("./models/yolov8x.pt") - -handler = CutoutHandler(model, predictor) - - -def generate_cutout(**inputs) -> None: - prompt = inputs["prompt"] - image = inputs["image"] - name = inputs["name"] - - logger.info( - f"Generating cutout for prompt '{prompt}' and image '{image}' with name '{name}'" - ) - - - # Process the image and append the cutouts to the list - # cutouts = process_image(image, name, prompt, model, predictor) - cutouts = handler.process_image(image, name, prompt) - - logger.info( - f"Generated {len(cutouts)} cutouts for prompt '{prompt}' and image '{image}' with name '{name}'" - ) - - -def process_image(image: Image, name: str, prompt: str, model, predictor) -> list: - logger.debug(f"Processing image '{name}' with prompt '{prompt}'") - - # Convert image to rgb color space - image = image.convert("RGB") - # Convert the image to a numpy array - image_np = np.array(image) - - # Convert the image to RGB format - image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) - - # Run the model on the image - results = model(image_np) - - # Set the image for the SegmentAnything predictor - predictor.set_image(image_np) - - masks_list = [] - - for result in results: - boxes = result.boxes.cpu().numpy() - for i, box in enumerate(boxes): - if result.names[int(box.cls[0])] == prompt: - logger.debug( - f"Found bounding box for prompt '{prompt}' in image '{name}'" - ) - # Get the mask from the SegmentAnything predictor - masks, scores, _ = predictor.predict( - point_coords=None, - point_labels=None, - box=box.xyxy[0].astype(int), - multimask_output=False, - ) - if scores[0] < 0.6: - logger.debug( - f"Skipping mask with score {scores[0]} for prompt '{prompt}' in image '{name}'" - ) - continue - for i, mask in enumerate(masks): - logger.debug( - f"Processing mask {i+1} of {len(masks)} for prompt '{prompt}' in image '{name}'" - ) - # Convert the mask numpy array to a binary mask - mask_binary = np.zeros( - (mask.shape[0], mask.shape[1]), dtype=np.uint8 - ) - mask_binary[mask > 0] = 1 - - # Convert the binary mask to a COCO RLE format - mask_rle = maskUtils.encode(np.asfortranarray(mask_binary)) - - # Extract the counts key from the mask RLE dictionary - counts = mask_rle["counts"] - - # Create a new dictionary with the required keys for COCO RLE format - mask_coco_rle = { - "size": [mask.shape[0], mask.shape[1]], - "counts": counts.decode("utf-8"), - } - - # Add the mask to the list of masks - masks_list.append(mask_coco_rle) - - logger.debug( - f"Processed {len(masks_list)} masks for prompt '{prompt}' in image '{name}'" - ) - - # Once finished, create cutouts from the masks - return create_cutout(image, name, masks_list) - - -def create_cutout(image: Image, name: str, masks_list: list) -> list: - cutouts = [] - - for i, mask in enumerate(masks_list): - logger.debug(f"Generating cutout {i+1} of {len(masks_list)} for image '{name}'") - - - size = mask["size"] - counts = mask["counts"] - mask_decoded = maskUtils.decode({"size": size, "counts": counts}) - mask_binary = np.zeros((size[0], size[1]), dtype=np.uint8) - mask_binary[mask_decoded > 0] = 1 - - # Resize the mask to match the shape of the image - mask_resized = cv2.resize(mask_decoded, (image_np.shape[1], image_np.shape[0])) - - # Extract the cutout from the image using the mask - cutout = image_np * mask_resized[..., np.newaxis] - - # Create an alpha channel for the cutout image - alpha = np.zeros(cutout.shape[:2], dtype=np.uint8) - alpha[mask_resized > 0] = 255 - cutout = cv2.merge((cutout, alpha)) - - # Crop the cutout image to the bounding rectangle - x, y, w, h = cv2.boundingRect(mask_resized) - cutout = cutout[y : y + h, x : x + w] - - # Create a PIL Image from the cutout numpy array - cutout_pil = Image.fromarray(cutout) - - # Save the cutout to a file - cutout_filename = f"{name}_{i+1}.png" - cutout_path = os.path.join("cutouts", cutout_filename) - cutout_pil.save(cutout_path) - - logger.debug( - f"Saved cutout {i+1} of {len(masks_list)} to file '{cutout_path}' for image '{name}'" - ) - - # Add the cutout to the list of cutouts - cutouts.append(cutout_path) - - return cutouts - - -def list_directories(path): - for root, dirs, files in os.walk(path): - for dir in dirs: - print(os.path.join(root, dir)) - - -def list_files(directory): - for filename in os.listdir(directory): - if os.path.isfile(os.path.join(directory, filename)): - print(filename)