Skip to content

Commit

Permalink
Merge pull request #7 from OriginalByteMe/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
OriginalByteMe authored Nov 15, 2023
2 parents 803dcc4 + 3fe3daa commit 0599417
Show file tree
Hide file tree
Showing 15 changed files with 285 additions and 554 deletions.
9 changes: 0 additions & 9 deletions .beamignore

This file was deleted.

4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ yolo*
__pycache__/
*.jpg
*.png
cutout_generator-*
cutout_generator-*
Terraform/.terraform

76 changes: 76 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# AI Image Cutout Maker

AI Image Cutout Maker is a project that uses artificial intelligence to automatically create cutouts from images. This project is designed to simplify the process of creating cutouts, which can be a time-consuming task if done manually.

This project utilizes the power of Segment Anything and Grounding Dino AI models to detect subjects in an image and cut them out. These models are hosted on Modal, which allows us to leverage GPU acceleration for faster and more efficient processing.

The cutouts are then stored in an Amazon S3 bucket, providing a scalable and secure storage solution. This setup allows us to handle large volumes of images and serve them quickly and efficiently.

## Project Structure

The project is structured as follows:

- `app/`: This directory contains the main application code.
- `cutout.py`: This script handles the process of creating cutouts from images using the Segment Anything and Grounding Dino AI models.
- `dino.py`: This script is responsible for interacting with the Grounding Dino AI model.
- `segment.py`: This script is used for interacting with the Segment Anything AI model.
- `s3_handler.py`: This script handles interactions with Amazon S3, such as uploading and downloading images.
- `grounded_cutouts.py`: This script ...

- `.venv/`: This directory contains the virtual environment for the project.

- `modal_utils/`: This directory contains utility functions used throughout the project.

- `grpc_utils.py`: This script handles the gRPC connections in the project.


## Purpose of the Project

The purpose of this project is to automate the process of creating cutouts from images. By using artificial intelligence, we can create accurate cutouts much faster than would be possible manually. This can be useful in a variety of applications, such as graphic design, image editing, and more. There is also the social media aspect of creating stickers out of items cut out of an image, which is a popular trend on social media platforms such as Instagram and TikTok.

## How to Use

As mentioned above, this project is being hosted through Modal, this also does mean that its using the modal API. This means that you will need to have a modal account and have the modal CLI installed. You can find instructions on how to do this [here](https://docs.modal.ai/docs/getting-started). I'am planning on dockerizing this project as well down the line so that it can be used without the modal CLI.


## Workflow diagriam for how BE processes and returns values
```mermaid
sequenceDiagram
actor User
User ->> CutoutFE: Upload Image
CutoutFE ->> S3: Call to push image to bucket
CutoutFE ->> CutoutBE: Call to create cutout
CutoutBE ->>+ Modal: Spin up instance with container
Modal ->>- CutoutBE: Allow that container to be used by CutoutBE URL
CutoutBE ->>+ S3: Download image
S3 -->>- CutoutBE: Send Image back
Note over CutoutBE: Processes cutouts + other diagrams
CutoutBE -->> S3: Upload cutouts to bucket
Note over CutoutBE: Generates Pre-signed URL's of processed files
Note over CutoutFE: If need be FE can also create list of presigned urls via s3
CutoutBE -->>+ CutoutFE: Return list of presigned urls
CutoutFE -->>+ User: Display processed images
User ->> S3: Download images via url
```


## To-Do

Here are some tasks that are on our roadmap:

- Dockerize the project: We plan to create a Dockerfile for this project to make it easier to set up and run in any environment.
- API Documentation: We will be writing comprehensive API documentation to make it easier for developers to understand and use our API.
- Improve error handling: We aim to improve our error handling to make our API more robust and reliable.
- Add more AI models: We are planning to integrate more AI models to improve the accuracy and versatility of our image cutout creation.
- Optimize performance: We will be working on optimizing the performance of our API, particularly in terms of processing speed and resource usage.

Please note that this is not an exhaustive list and the roadmap may change based on project needs and priorities.

## Contributing

This is a personal project so it won't really be geared to any contributions, but feel free to fork the repository and make any changes you want. If you have any questions, feel free to reach out to me at my [email](mailto:[email protected])

## License

This project is licensed under the terms of the [MIT License](LICENSE).

18 changes: 18 additions & 0 deletions Terraform/backend.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
terraform {
required_providers {
aws = {
source = "hashicorp/aws"
}
}
backend "s3" {
bucket = "noah-terraform-remote-state"
key = "modal/cutout-gen-config"
region = "ap-southeast-1"
profile = "noahTest"
}
}

provider "aws" {
region = "ap-southeast-1"
profile = "noahTest"
}
21 changes: 21 additions & 0 deletions Terraform/s3.tf
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module "s3_bucket" {
source = "terraform-aws-modules/s3-bucket/aws"

bucket = "cutout-image-store"
acl = "private"

control_object_ownership = true
object_ownership = "ObjectWriter"

lifecycle_rule = [
{
id = "expire"
status = "Enabled"
enabled = true

expiration = {
days = 1
}
}
]
}
22 changes: 0 additions & 22 deletions app.py

This file was deleted.

38 changes: 38 additions & 0 deletions app/cutout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import cv2
import numpy as np
import os


class CutoutCreator:
def __init__(self, image_folder):
self.image_folder = image_folder

def create_cutouts(self, image_name, masks, output_folder,bucket_name, s3):
# Load the image
image_path = os.path.join(self.image_folder, image_name)
for item in os.listdir(self.image_folder):
print("Item: ",item)
if not os.path.exists(image_path):
print(f"Image {image_name} not found in folder {self.image_folder}")
return

image = cv2.imread(image_path)

# Apply each mask to the image
for i, mask in enumerate(masks):
# Ensure the mask is a boolean array
mask = mask.astype(bool)

# Apply the mask to create a cutout
cutout = np.zeros_like(image)
cutout[mask] = image[mask]

# Save the cutout
cutout_name = f"{image_name}_cutout_{i}.png"
cutout_path = os.path.join(output_folder, cutout_name)
cv2.imwrite(cutout_path, cutout)

# Upload the cutout to S3
with open(cutout_path, "rb") as f:
s3.upload_to_s3(bucket_name, f.read(), f"cutouts/{image_name}/{cutout_name}")

31 changes: 31 additions & 0 deletions app/dino.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
from groundingdino.util.inference import Model
from typing import List


class Dino:
def __init__(self, classes, box_threshold, text_threshold, model_config_path, model_checkpoint_path):
self.classes = classes
self.box_threshold = box_threshold
self.text_threshold = text_threshold
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.grounding_dino_model = Model(model_config_path=model_config_path, model_checkpoint_path=model_checkpoint_path)

def enhance_class_name(self, class_names: List[str]) -> List[str]:
return [
f"all {class_name}s"
for class_name
in class_names
]

def predict(self, image):
detections = self.grounding_dino_model.predict_with_classes(image=image, classes=self.enhance_class_name(class_names=self.classes), box_threshold=self.box_threshold, text_threshold=self.text_threshold)
detections = detections[detections.class_id != None]
return detections

# Example usage
# dino = Dino(classes=['person', 'nose', 'chair', 'shoe', 'ear', 'hat'],
# box_threshold=0.35,
# text_threshold=0.25,
# model_config_path=GROUNDING_DINO_CONFIG_PATH,
# model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)
71 changes: 71 additions & 0 deletions app/grounded_cutouts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import modal
import cv2

stub = modal.Stub(name="cutout_generator")

img_volume = modal.NetworkFileSystem.persisted("image-storage-vol")
cutout_volume = modal.NetworkFileSystem.persisted("cutout-storage-vol")
local_packages = modal.Mount.from_local_python_packages("cutout", "dino", "segment", "s3_handler")
cutout_generator_image = modal.Image.from_registry("nvcr.io/nvidia/pytorch:22.12-py3").pip_install( "segment-anything", "opencv-python", "botocore", "boto3").run_commands(
"apt-get update",
"apt-get install -y git wget libgl1-mesa-glx libglib2.0-0",
"echo $CUDA_HOME",
"git clone https://github.com/IDEA-Research/GroundingDINO.git",
"pip install -q -e GroundingDINO/",
"mkdir -p /weights",
"mkdir -p /data",
"pip uninstall -y supervision",
"pip uninstall -y opencv-python",
"pip install opencv-python==4.8.0.74",
"pip install -q supervision==0.6.0",
"wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth -P weights/",
"wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P weights/",
"wget -q https://media.roboflow.com/notebooks/examples/dog.jpeg -P data/",
"ls -F",
"ls -F GroundingDINO/groundingdino/config",
"ls -F GroundingDINO/groundingdino/models/GroundingDINO/"
)
# Setup paths
HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
GROUNDING_DINO_CONFIG_PATH = os.path.join(HOME, "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py")
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(HOME, "weights", "groundingdino_swint_ogc.pth")
SAM_CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")

@stub.function(image=cutout_generator_image, mounts=[local_packages], gpu="T4", secret=modal.Secret.from_name("my-aws-secret"), network_file_systems={"/images": img_volume, "/cutouts": cutout_volume})
@modal.web_endpoint()
def main(image_name):
# Import relevant classes
from dino import Dino
from segment import Segmenter
from cutout import CutoutCreator
from s3_handler import Boto3Client
SOURCE_IMAGE_PATH = os.path.join(HOME, "data", image_name)
print(f"SOURCE_IMAGE_PATH: {SOURCE_IMAGE_PATH}")
SAVE_IMAGE_PATH = os.path.join(HOME, "data")
OUTPUT_CUTOUT_PATH = os.path.join(HOME, "cutouts")
dino = Dino(classes=['person', 'nose', 'chair', 'shoe', 'ear', 'hat'],
box_threshold=0.35,
text_threshold=0.25,
model_config_path=GROUNDING_DINO_CONFIG_PATH,
model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH)

segment = Segmenter(sam_encoder_version="vit_h", sam_checkpoint_path=SAM_CHECKPOINT_PATH)

cutout = CutoutCreator(image_folder=SAVE_IMAGE_PATH)

s3 = Boto3Client()

s3.download_from_s3(SAVE_IMAGE_PATH, "cutout-image-store", f"images/{image_name}")

image = cv2.imread(SOURCE_IMAGE_PATH)

# Run the DINO model on the image
detections = dino.predict(image)

detections.mask = segment.segment(image, detections.xyxy)

cutout.create_cutouts(image_name, detections.mask, OUTPUT_CUTOUT_PATH, "cutout-image-store", s3)

return "Success"

10 changes: 5 additions & 5 deletions s3FileHandler.py → app/s3_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@ class Boto3Client:
def __init__(self):
self.s3 = boto3.client(
"s3",
endpoint_url="https://13583f5ff84f5693a4a859a769743849.r2.cloudflarestorage.com",
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
region_name="auto",
region_name=os.environ["AWS_REGION"],
)

def download_from_s3(bucket_name, key):
def download_from_s3(self, save_path, bucket_name, key):
s3_client = boto3.client("s3")

file_path = os.path.join(os.getcwd(), key)
file_name = key.split("/")[-1]
file_path = os.path.join(save_path, file_name)
try:
s3_client.download_file(bucket_name, key, file_path)
except ClientError as e:
print(e)
print("BOTO error: ",e)
return None

return file_path
Expand Down
22 changes: 22 additions & 0 deletions app/segment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import numpy as np
import torch
from segment_anything import sam_model_registry, SamPredictor


class Segmenter:
def __init__(self, sam_encoder_version: str, sam_checkpoint_path: str, ):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.sam = sam_model_registry[sam_encoder_version](checkpoint=sam_checkpoint_path).to(device=self.device)
self.sam_predictor = SamPredictor(self.sam)

def segment(self, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
self.sam_predictor.set_image(image)
result_masks = []
for box in xyxy:
masks, scores, logits = self.sam_predictor.predict(
box=box,
multimask_output=True
)
index = np.argmax(scores)
result_masks.append(masks[index])
return np.array(result_masks)
Loading

0 comments on commit 0599417

Please sign in to comment.