Skip to content

Commit

Permalink
adding support for dynamically providing list of candidate categories…
Browse files Browse the repository at this point in the history
… to GDINO server
  • Loading branch information
naokiyokoyamabd committed Sep 10, 2023
1 parent 81458ed commit 6297fae
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions zsos/vlm/grounding_dino.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import Optional

import numpy as np
import torch
Expand Down Expand Up @@ -43,15 +44,17 @@ def __init__(
self.box_threshold = box_threshold
self.text_threshold = text_threshold

def predict(self, image: np.ndarray, visualize: bool = False) -> ObjectDetections:
def predict(
self, image: np.ndarray, caption: Optional[str] = ""
) -> ObjectDetections:
"""
This function makes predictions on an input image tensor or numpy array using a
pretrained model.
Arguments:
image (np.ndarray): An image in the form of a numpy array.
visualize (bool, optional): A flag indicating whether to visualize the
output data. Defaults to False.
caption (Optional[str]): A string containing the possible classes
separated by periods. If not provided, the default classes will be used.
Returns:
ObjectDetections: An instance of the ObjectDetections class containing the
Expand All @@ -62,17 +65,20 @@ def predict(self, image: np.ndarray, visualize: bool = False) -> ObjectDetection
image_transformed = F.normalize(
image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
if caption == "":
caption = self.classes
print("Caption:", caption)
with torch.inference_mode():
boxes, logits, phrases = predict(
model=self.model,
image=image_transformed,
caption=self.classes,
caption=caption,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
)
detections = ObjectDetections(boxes, logits, phrases, image_source=image)

classes = self.classes.split(" . ")
classes = caption[:-2].split(" . ")
keep = torch.tensor(
[p in classes for p in detections.phrases], dtype=torch.bool
)
Expand All @@ -85,12 +91,13 @@ def predict(self, image: np.ndarray, visualize: bool = False) -> ObjectDetection


class GroundingDINOClient:
def __init__(self, port: int = 12181, classes: str = CLASSES):
def __init__(self, port: int = 12181):
self.url = f"http://localhost:{port}/gdino"
self.classes = classes

def predict(self, image_numpy: np.ndarray) -> ObjectDetections:
response = send_request(self.url, image=image_numpy)
def predict(
self, image_numpy: np.ndarray, caption: Optional[str] = ""
) -> ObjectDetections:
response = send_request(self.url, image=image_numpy, caption=caption)
detections = ObjectDetections.from_json(response, image_source=image_numpy)

return detections
Expand All @@ -108,7 +115,7 @@ def predict(self, image_numpy: np.ndarray) -> ObjectDetections:
class GroundingDINOServer(ServerMixin, GroundingDINO):
def process_payload(self, payload: dict) -> dict:
image = str_to_image(payload["image"])
return self.predict(image).to_json()
return self.predict(image, caption=payload["caption"]).to_json()

gdino = GroundingDINOServer()
print("Model loaded!")
Expand Down

0 comments on commit 6297fae

Please sign in to comment.