Skip to content

Commit

Permalink
use new load_image primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
capjamesg committed Dec 5, 2023
1 parent 008ea5d commit c899164
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 17 deletions.
2 changes: 1 addition & 1 deletion autodistill_grounding_dino/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from autodistill_grounding_dino.grounding_dino_model import GroundingDINO

__version__ = "0.1.1"
__version__ = "0.1.2"
24 changes: 11 additions & 13 deletions autodistill_grounding_dino/grounding_dino_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,42 @@
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import cv2

torch.use_deterministic_algorithms(False)

import supervision as sv
from groundingdino.util.inference import Model
from autodistill.detection import CaptionOntology, DetectionBaseModel
from autodistill.helpers import load_image
from groundingdino.util.inference import Model

from autodistill_grounding_dino.helpers import (
combine_detections,
load_grounding_dino,
)
from autodistill_grounding_dino.helpers import (combine_detections,
load_grounding_dino)

HOME = os.path.expanduser("~")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


@dataclass
class GroundingDINO(DetectionBaseModel):
ontology: CaptionOntology
grounding_dino_model: Model
box_threshold: float
text_threshold: float

def __init__(self, ontology: CaptionOntology, box_threshold=0.35, text_threshold=0.25):
def __init__(
self, ontology: CaptionOntology, box_threshold=0.35, text_threshold=0.25
):
self.ontology = ontology
self.grounding_dino_model = load_grounding_dino()
self.box_threshold = box_threshold
self.text_threshold = text_threshold

def predict(self, input: str) -> sv.Detections:
image = cv2.imread(input)
image = load_image(input, return_format="cv2")

# GroundingDINO predictions
detections_list = []

for i, description in enumerate(self.ontology.prompts()):
# detect objects
for _, description in enumerate(self.ontology.prompts()):
detections = self.grounding_dino_model.predict_with_classes(
image=image,
classes=[description],
Expand All @@ -55,5 +54,4 @@ def predict(self, input: str) -> sv.Detections:
detections_list, overwrite_class_ids=range(len(detections_list))
)

# separate in supervision to combine detections and override class_ids
return detections
return detections
4 changes: 1 addition & 3 deletions autodistill_grounding_dino/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
import urllib.request

import numpy as np
import supervision as sv
import torch
from groundingdino.util.inference import Model

import supervision as sv

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if not torch.cuda.is_available():
Expand Down Expand Up @@ -104,4 +103,3 @@ def load_grounding_dino():
)

return grounding_dino_model

0 comments on commit c899164

Please sign in to comment.