Skip to content

Commit

Permalink
Merge pull request #9 from OriginalByteMe/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
OriginalByteMe authored Nov 17, 2023
2 parents 0599417 + a045aad commit 8a2ed73
Show file tree
Hide file tree
Showing 7 changed files with 592 additions and 146 deletions.
125 changes: 93 additions & 32 deletions app/cutout.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,99 @@
from typing import Dict
import os
import io
import cv2
import numpy as np
import os
from s3_handler import Boto3Client
from dino import Dino
from segment import Segmenter
from PIL import Image
import supervision as sv

HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

class CutoutCreator:
def __init__(self, image_folder):
self.image_folder = image_folder

def create_cutouts(self, image_name, masks, output_folder,bucket_name, s3):
# Load the image
image_path = os.path.join(self.image_folder, image_name)
for item in os.listdir(self.image_folder):
print("Item: ",item)
if not os.path.exists(image_path):
print(f"Image {image_name} not found in folder {self.image_folder}")
return

image = cv2.imread(image_path)

# Apply each mask to the image
for i, mask in enumerate(masks):
# Ensure the mask is a boolean array
mask = mask.astype(bool)

# Apply the mask to create a cutout
cutout = np.zeros_like(image)
cutout[mask] = image[mask]

# Save the cutout
cutout_name = f"{image_name}_cutout_{i}.png"
cutout_path = os.path.join(output_folder, cutout_name)
cv2.imwrite(cutout_path, cutout)

# Upload the cutout to S3
with open(cutout_path, "rb") as f:
s3.upload_to_s3(bucket_name, f.read(), f"cutouts/{image_name}/{cutout_name}")
"""
A class for creating cutouts from an image and uploading them to S3.
Attributes:
dino: A Dino object for object detection.
s3: A Boto3Client object for uploading to S3.
mask_annotator: A MaskAnnotator object for annotating images with masks.
"""
def __init__(self, classes: str, grounding_dino_config_path: str, grounding_dino_checkpoint_path: str):
self.dino = Dino(
classes=classes,
box_threshold=0.35,
text_threshold=0.25,
model_config_path=grounding_dino_config_path,
model_checkpoint_path=grounding_dino_checkpoint_path,
)
self.s3 = Boto3Client()
self.mask_annotator = sv.MaskAnnotator()

def create_annotated_image(self, image, image_name, detections: Dict[str, list]):
"""Create a highlighted annotated image from an image and detections.
Args:
image (File): Image to be used for creating the annotated image.
image_name (string): name of image
detections (Dict[str, list]): annotations for the image
"""
annotated_image = self.mask_annotator.annotate(
scene=image, detections=detections
)
# Convert annotated image to bytes
img_bytes = io.BytesIO()
Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG")
img_bytes.seek(0)
# Upload bytes to S3
self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png")

def create_cutouts(self, image_name, sam_checkpoint_path):
"""Create cutouts from an image and upload them to S3.
Args:
image_name (string): name of image
sam_checkpoint_path (string): path to sam checkpoint
"""
# Download image from s3
image_path = self.s3.download_from_s3(os.path.join(HOME, "data"), image_name)
if not os.path.exists(os.path.join(HOME, "cutouts")):
os.mkdir(os.path.join(HOME, "cutouts"))
image = cv2.imread(image_path)
segment = Segmenter(
sam_encoder_version="vit_h", sam_checkpoint_path=sam_checkpoint_path
)
detections = self.dino.predict(image)

masks = segment.segment(image, detections.xyxy)
# Load the image
# image_path = os.path.join(self.image_folder, image_name)
# for item in os.listdir(self.image_folder):
# print("Item: ",item)
if not os.path.exists(image_path):
print(f"Image {image_name} not found in folder {image_path}")
return

image = cv2.imread(image_path)

# Apply each mask to the image
for i, mask in enumerate(masks):
# Ensure the mask is a boolean array
mask = mask.astype(bool)

# Apply the mask to create a cutout
cutout = np.zeros_like(image)
cutout[mask] = image[mask]

# Save the cutout
cutout_name = f"{image_name}_cutout_{i}.png"
cutout_path = os.path.join(HOME, "cutouts", cutout_name)
cv2.imwrite(cutout_path, cutout)

# Upload the cutout to S3
with open(cutout_path, "rb") as f:
self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}")

# Create annotated image
# self.create_annotated_image(image, f"{image_name}_{i}", detections)
68 changes: 49 additions & 19 deletions app/dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,55 @@


class Dino:
def __init__(self, classes, box_threshold, text_threshold, model_config_path, model_checkpoint_path):
self.classes = classes
self.box_threshold = box_threshold
self.text_threshold = text_threshold
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.grounding_dino_model = Model(model_config_path=model_config_path, model_checkpoint_path=model_checkpoint_path)

def enhance_class_name(self, class_names: List[str]) -> List[str]:
return [
f"all {class_name}s"
for class_name
in class_names
]

def predict(self, image):
detections = self.grounding_dino_model.predict_with_classes(image=image, classes=self.enhance_class_name(class_names=self.classes), box_threshold=self.box_threshold, text_threshold=self.text_threshold)
detections = detections[detections.class_id != None]
return detections

""" A class for object detection using GroundingDINO.
"""
def __init__(
self,
classes,
box_threshold,
text_threshold,
model_config_path,
model_checkpoint_path,
):
self.classes = classes
self.box_threshold = box_threshold
self.text_threshold = text_threshold
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.grounding_dino_model = Model(
model_config_path=model_config_path,
model_checkpoint_path=model_checkpoint_path,
)

def enhance_class_name(self, class_names: List[str]) -> List[str]:
"""Enhance class names for GroundingDINO.
Args:
class_names (List[str]): List of class names.
Returns:
List[str]: List of class names with "all" prepended.
"""
return [f"all {class_name}s" for class_name in class_names]

def predict(self, image):
"""Predict objects in an image.
Args:
image (File): Image to be used for object detection.
Returns:
Dict[str, list]: Dictionary of objects detected in the image.
"""
detections = self.grounding_dino_model.predict_with_classes(
image=image,
classes=self.enhance_class_name(class_names=self.classes),
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
)
detections = detections[detections.class_id != None]
return detections


# Example usage
# dino = Dino(classes=['person', 'nose', 'chair', 'shoe', 'ear', 'hat'],
# box_threshold=0.35,
Expand Down
Loading

0 comments on commit 8a2ed73

Please sign in to comment.