-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #7 from OriginalByteMe/develop
Develop
- Loading branch information
Showing
15 changed files
with
285 additions
and
554 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,4 +14,6 @@ yolo* | |
__pycache__/ | ||
*.jpg | ||
*.png | ||
cutout_generator-* | ||
cutout_generator-* | ||
Terraform/.terraform | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# AI Image Cutout Maker | ||
|
||
AI Image Cutout Maker is a project that uses artificial intelligence to automatically create cutouts from images. This project is designed to simplify the process of creating cutouts, which can be a time-consuming task if done manually. | ||
|
||
This project utilizes the power of Segment Anything and Grounding Dino AI models to detect subjects in an image and cut them out. These models are hosted on Modal, which allows us to leverage GPU acceleration for faster and more efficient processing. | ||
|
||
The cutouts are then stored in an Amazon S3 bucket, providing a scalable and secure storage solution. This setup allows us to handle large volumes of images and serve them quickly and efficiently. | ||
|
||
## Project Structure | ||
|
||
The project is structured as follows: | ||
|
||
- `app/`: This directory contains the main application code. | ||
- `cutout.py`: This script handles the process of creating cutouts from images using the Segment Anything and Grounding Dino AI models. | ||
- `dino.py`: This script is responsible for interacting with the Grounding Dino AI model. | ||
- `segment.py`: This script is used for interacting with the Segment Anything AI model. | ||
- `s3_handler.py`: This script handles interactions with Amazon S3, such as uploading and downloading images. | ||
- `grounded_cutouts.py`: This script ... | ||
|
||
- `.venv/`: This directory contains the virtual environment for the project. | ||
|
||
- `modal_utils/`: This directory contains utility functions used throughout the project. | ||
|
||
- `grpc_utils.py`: This script handles the gRPC connections in the project. | ||
|
||
|
||
## Purpose of the Project | ||
|
||
The purpose of this project is to automate the process of creating cutouts from images. By using artificial intelligence, we can create accurate cutouts much faster than would be possible manually. This can be useful in a variety of applications, such as graphic design, image editing, and more. There is also the social media aspect of creating stickers out of items cut out of an image, which is a popular trend on social media platforms such as Instagram and TikTok. | ||
|
||
## How to Use | ||
|
||
As mentioned above, this project is being hosted through Modal, this also does mean that its using the modal API. This means that you will need to have a modal account and have the modal CLI installed. You can find instructions on how to do this [here](https://docs.modal.ai/docs/getting-started). I'am planning on dockerizing this project as well down the line so that it can be used without the modal CLI. | ||
|
||
|
||
## Workflow diagriam for how BE processes and returns values | ||
```mermaid | ||
sequenceDiagram | ||
actor User | ||
User ->> CutoutFE: Upload Image | ||
CutoutFE ->> S3: Call to push image to bucket | ||
CutoutFE ->> CutoutBE: Call to create cutout | ||
CutoutBE ->>+ Modal: Spin up instance with container | ||
Modal ->>- CutoutBE: Allow that container to be used by CutoutBE URL | ||
CutoutBE ->>+ S3: Download image | ||
S3 -->>- CutoutBE: Send Image back | ||
Note over CutoutBE: Processes cutouts + other diagrams | ||
CutoutBE -->> S3: Upload cutouts to bucket | ||
Note over CutoutBE: Generates Pre-signed URL's of processed files | ||
Note over CutoutFE: If need be FE can also create list of presigned urls via s3 | ||
CutoutBE -->>+ CutoutFE: Return list of presigned urls | ||
CutoutFE -->>+ User: Display processed images | ||
User ->> S3: Download images via url | ||
``` | ||
|
||
|
||
## To-Do | ||
|
||
Here are some tasks that are on our roadmap: | ||
|
||
- Dockerize the project: We plan to create a Dockerfile for this project to make it easier to set up and run in any environment. | ||
- API Documentation: We will be writing comprehensive API documentation to make it easier for developers to understand and use our API. | ||
- Improve error handling: We aim to improve our error handling to make our API more robust and reliable. | ||
- Add more AI models: We are planning to integrate more AI models to improve the accuracy and versatility of our image cutout creation. | ||
- Optimize performance: We will be working on optimizing the performance of our API, particularly in terms of processing speed and resource usage. | ||
|
||
Please note that this is not an exhaustive list and the roadmap may change based on project needs and priorities. | ||
|
||
## Contributing | ||
|
||
This is a personal project so it won't really be geared to any contributions, but feel free to fork the repository and make any changes you want. If you have any questions, feel free to reach out to me at my [email](mailto:[email protected]) | ||
|
||
## License | ||
|
||
This project is licensed under the terms of the [MIT License](LICENSE). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
terraform { | ||
required_providers { | ||
aws = { | ||
source = "hashicorp/aws" | ||
} | ||
} | ||
backend "s3" { | ||
bucket = "noah-terraform-remote-state" | ||
key = "modal/cutout-gen-config" | ||
region = "ap-southeast-1" | ||
profile = "noahTest" | ||
} | ||
} | ||
|
||
provider "aws" { | ||
region = "ap-southeast-1" | ||
profile = "noahTest" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
module "s3_bucket" { | ||
source = "terraform-aws-modules/s3-bucket/aws" | ||
|
||
bucket = "cutout-image-store" | ||
acl = "private" | ||
|
||
control_object_ownership = true | ||
object_ownership = "ObjectWriter" | ||
|
||
lifecycle_rule = [ | ||
{ | ||
id = "expire" | ||
status = "Enabled" | ||
enabled = true | ||
|
||
expiration = { | ||
days = 1 | ||
} | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import cv2 | ||
import numpy as np | ||
import os | ||
|
||
|
||
class CutoutCreator: | ||
def __init__(self, image_folder): | ||
self.image_folder = image_folder | ||
|
||
def create_cutouts(self, image_name, masks, output_folder,bucket_name, s3): | ||
# Load the image | ||
image_path = os.path.join(self.image_folder, image_name) | ||
for item in os.listdir(self.image_folder): | ||
print("Item: ",item) | ||
if not os.path.exists(image_path): | ||
print(f"Image {image_name} not found in folder {self.image_folder}") | ||
return | ||
|
||
image = cv2.imread(image_path) | ||
|
||
# Apply each mask to the image | ||
for i, mask in enumerate(masks): | ||
# Ensure the mask is a boolean array | ||
mask = mask.astype(bool) | ||
|
||
# Apply the mask to create a cutout | ||
cutout = np.zeros_like(image) | ||
cutout[mask] = image[mask] | ||
|
||
# Save the cutout | ||
cutout_name = f"{image_name}_cutout_{i}.png" | ||
cutout_path = os.path.join(output_folder, cutout_name) | ||
cv2.imwrite(cutout_path, cutout) | ||
|
||
# Upload the cutout to S3 | ||
with open(cutout_path, "rb") as f: | ||
s3.upload_to_s3(bucket_name, f.read(), f"cutouts/{image_name}/{cutout_name}") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
from groundingdino.util.inference import Model | ||
from typing import List | ||
|
||
|
||
class Dino: | ||
def __init__(self, classes, box_threshold, text_threshold, model_config_path, model_checkpoint_path): | ||
self.classes = classes | ||
self.box_threshold = box_threshold | ||
self.text_threshold = text_threshold | ||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
self.grounding_dino_model = Model(model_config_path=model_config_path, model_checkpoint_path=model_checkpoint_path) | ||
|
||
def enhance_class_name(self, class_names: List[str]) -> List[str]: | ||
return [ | ||
f"all {class_name}s" | ||
for class_name | ||
in class_names | ||
] | ||
|
||
def predict(self, image): | ||
detections = self.grounding_dino_model.predict_with_classes(image=image, classes=self.enhance_class_name(class_names=self.classes), box_threshold=self.box_threshold, text_threshold=self.text_threshold) | ||
detections = detections[detections.class_id != None] | ||
return detections | ||
|
||
# Example usage | ||
# dino = Dino(classes=['person', 'nose', 'chair', 'shoe', 'ear', 'hat'], | ||
# box_threshold=0.35, | ||
# text_threshold=0.25, | ||
# model_config_path=GROUNDING_DINO_CONFIG_PATH, | ||
# model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
import modal | ||
import cv2 | ||
|
||
stub = modal.Stub(name="cutout_generator") | ||
|
||
img_volume = modal.NetworkFileSystem.persisted("image-storage-vol") | ||
cutout_volume = modal.NetworkFileSystem.persisted("cutout-storage-vol") | ||
local_packages = modal.Mount.from_local_python_packages("cutout", "dino", "segment", "s3_handler") | ||
cutout_generator_image = modal.Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3").pip_install( "segment-anything", "opencv-python", "botocore", "boto3").run_commands( | ||
"apt-get update", | ||
"apt-get install -y git wget libgl1-mesa-glx libglib2.0-0", | ||
"echo $CUDA_HOME", | ||
"git clone https://github.com/IDEA-Research/GroundingDINO.git", | ||
"pip install -q -e GroundingDINO/", | ||
"mkdir -p /weights", | ||
"mkdir -p /data", | ||
"pip uninstall -y supervision", | ||
"pip uninstall -y opencv-python", | ||
"pip install opencv-python==4.8.0.74", | ||
"pip install -q supervision==0.6.0", | ||
"wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/", | ||
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/", | ||
"wget -q https://media.roboflow.com/notebooks/examples/dog.jpeg -P data/", | ||
"ls -F", | ||
"ls -F GroundingDINO/groundingdino/config", | ||
"ls -F GroundingDINO/groundingdino/models/GroundingDINO/" | ||
) | ||
# Setup paths | ||
HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) | ||
GROUNDING_DINO_CONFIG_PATH = os.path.join(HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py") | ||
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(HOME, "weights", "groundingdino_swint_ogc.pth") | ||
SAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth") | ||
|
||
@stub.function(image=cutout_generator_image, mounts=[local_packages], gpu="T4", secret=modal.Secret.from_name("my-aws-secret"), network_file_systems={"/images": img_volume, "/cutouts": cutout_volume}) | ||
@modal.web_endpoint() | ||
def main(image_name): | ||
# Import relevant classes | ||
from dino import Dino | ||
from segment import Segmenter | ||
from cutout import CutoutCreator | ||
from s3_handler import Boto3Client | ||
SOURCE_IMAGE_PATH = os.path.join(HOME, "data", image_name) | ||
print(f"SOURCE_IMAGE_PATH: {SOURCE_IMAGE_PATH}") | ||
SAVE_IMAGE_PATH = os.path.join(HOME, "data") | ||
OUTPUT_CUTOUT_PATH = os.path.join(HOME, "cutouts") | ||
dino = Dino(classes=['person', 'nose', 'chair', 'shoe', 'ear', 'hat'], | ||
box_threshold=0.35, | ||
text_threshold=0.25, | ||
model_config_path=GROUNDING_DINO_CONFIG_PATH, | ||
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH) | ||
|
||
segment = Segmenter(sam_encoder_version="vit_h", sam_checkpoint_path=SAM_CHECKPOINT_PATH) | ||
|
||
cutout = CutoutCreator(image_folder=SAVE_IMAGE_PATH) | ||
|
||
s3 = Boto3Client() | ||
|
||
s3.download_from_s3(SAVE_IMAGE_PATH, "cutout-image-store", f"images/{image_name}") | ||
|
||
image = cv2.imread(SOURCE_IMAGE_PATH) | ||
|
||
# Run the DINO model on the image | ||
detections = dino.predict(image) | ||
|
||
detections.mask = segment.segment(image, detections.xyxy) | ||
|
||
cutout.create_cutouts(image_name, detections.mask, OUTPUT_CUTOUT_PATH, "cutout-image-store", s3) | ||
|
||
return "Success" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import numpy as np | ||
import torch | ||
from segment_anything import sam_model_registry, SamPredictor | ||
|
||
|
||
class Segmenter: | ||
def __init__(self, sam_encoder_version: str, sam_checkpoint_path: str, ): | ||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
self.sam = sam_model_registry[sam_encoder_version](checkpoint=sam_checkpoint_path).to(device=self.device) | ||
self.sam_predictor = SamPredictor(self.sam) | ||
|
||
def segment(self, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray: | ||
self.sam_predictor.set_image(image) | ||
result_masks = [] | ||
for box in xyxy: | ||
masks, scores, logits = self.sam_predictor.predict( | ||
box=box, | ||
multimask_output=True | ||
) | ||
index = np.argmax(scores) | ||
result_masks.append(masks[index]) | ||
return np.array(result_masks) |
Oops, something went wrong.