-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from OriginalByteMe/develop
Develop
- Loading branch information
Showing
7 changed files
with
592 additions
and
146 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,99 @@ | ||
from typing import Dict | ||
import os | ||
import io | ||
import cv2 | ||
import numpy as np | ||
import os | ||
from s3_handler import Boto3Client | ||
from dino import Dino | ||
from segment import Segmenter | ||
from PIL import Image | ||
import supervision as sv | ||
|
||
HOME = os.path.abspath(os.path.join(os.getcwd(), os.pardir)) | ||
|
||
class CutoutCreator: | ||
def __init__(self, image_folder): | ||
self.image_folder = image_folder | ||
|
||
def create_cutouts(self, image_name, masks, output_folder,bucket_name, s3): | ||
# Load the image | ||
image_path = os.path.join(self.image_folder, image_name) | ||
for item in os.listdir(self.image_folder): | ||
print("Item: ",item) | ||
if not os.path.exists(image_path): | ||
print(f"Image {image_name} not found in folder {self.image_folder}") | ||
return | ||
|
||
image = cv2.imread(image_path) | ||
|
||
# Apply each mask to the image | ||
for i, mask in enumerate(masks): | ||
# Ensure the mask is a boolean array | ||
mask = mask.astype(bool) | ||
|
||
# Apply the mask to create a cutout | ||
cutout = np.zeros_like(image) | ||
cutout[mask] = image[mask] | ||
|
||
# Save the cutout | ||
cutout_name = f"{image_name}_cutout_{i}.png" | ||
cutout_path = os.path.join(output_folder, cutout_name) | ||
cv2.imwrite(cutout_path, cutout) | ||
|
||
# Upload the cutout to S3 | ||
with open(cutout_path, "rb") as f: | ||
s3.upload_to_s3(bucket_name, f.read(), f"cutouts/{image_name}/{cutout_name}") | ||
""" | ||
A class for creating cutouts from an image and uploading them to S3. | ||
Attributes: | ||
dino: A Dino object for object detection. | ||
s3: A Boto3Client object for uploading to S3. | ||
mask_annotator: A MaskAnnotator object for annotating images with masks. | ||
""" | ||
def __init__(self, classes: str, grounding_dino_config_path: str, grounding_dino_checkpoint_path: str): | ||
self.dino = Dino( | ||
classes=classes, | ||
box_threshold=0.35, | ||
text_threshold=0.25, | ||
model_config_path=grounding_dino_config_path, | ||
model_checkpoint_path=grounding_dino_checkpoint_path, | ||
) | ||
self.s3 = Boto3Client() | ||
self.mask_annotator = sv.MaskAnnotator() | ||
|
||
def create_annotated_image(self, image, image_name, detections: Dict[str, list]): | ||
"""Create a highlighted annotated image from an image and detections. | ||
Args: | ||
image (File): Image to be used for creating the annotated image. | ||
image_name (string): name of image | ||
detections (Dict[str, list]): annotations for the image | ||
""" | ||
annotated_image = self.mask_annotator.annotate( | ||
scene=image, detections=detections | ||
) | ||
# Convert annotated image to bytes | ||
img_bytes = io.BytesIO() | ||
Image.fromarray(np.uint8(annotated_image)).save(img_bytes, format="PNG") | ||
img_bytes.seek(0) | ||
# Upload bytes to S3 | ||
self.s3.upload_to_s3(img_bytes.read(), "cutouts", f"{image_name}_annotated.png") | ||
|
||
def create_cutouts(self, image_name, sam_checkpoint_path): | ||
"""Create cutouts from an image and upload them to S3. | ||
Args: | ||
image_name (string): name of image | ||
sam_checkpoint_path (string): path to sam checkpoint | ||
""" | ||
# Download image from s3 | ||
image_path = self.s3.download_from_s3(os.path.join(HOME, "data"), image_name) | ||
if not os.path.exists(os.path.join(HOME, "cutouts")): | ||
os.mkdir(os.path.join(HOME, "cutouts")) | ||
image = cv2.imread(image_path) | ||
segment = Segmenter( | ||
sam_encoder_version="vit_h", sam_checkpoint_path=sam_checkpoint_path | ||
) | ||
detections = self.dino.predict(image) | ||
|
||
masks = segment.segment(image, detections.xyxy) | ||
# Load the image | ||
# image_path = os.path.join(self.image_folder, image_name) | ||
# for item in os.listdir(self.image_folder): | ||
# print("Item: ",item) | ||
if not os.path.exists(image_path): | ||
print(f"Image {image_name} not found in folder {image_path}") | ||
return | ||
|
||
image = cv2.imread(image_path) | ||
|
||
# Apply each mask to the image | ||
for i, mask in enumerate(masks): | ||
# Ensure the mask is a boolean array | ||
mask = mask.astype(bool) | ||
|
||
# Apply the mask to create a cutout | ||
cutout = np.zeros_like(image) | ||
cutout[mask] = image[mask] | ||
|
||
# Save the cutout | ||
cutout_name = f"{image_name}_cutout_{i}.png" | ||
cutout_path = os.path.join(HOME, "cutouts", cutout_name) | ||
cv2.imwrite(cutout_path, cutout) | ||
|
||
# Upload the cutout to S3 | ||
with open(cutout_path, "rb") as f: | ||
self.s3.upload_to_s3(f.read(), "cutouts", f"{image_name}/{cutout_name}") | ||
|
||
# Create annotated image | ||
# self.create_annotated_image(image, f"{image_name}_{i}", detections) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.