Skip to content

Commit

Permalink
Added support for choosing the model size from tiny or base. Changed …
Browse files Browse the repository at this point in the history
…download url to be huggingface rather than github.
  • Loading branch information
djwessel committed Jan 11, 2024
1 parent c899164 commit 1eb1f71
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
8 changes: 6 additions & 2 deletions autodistill_grounding_dino/grounding_dino_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ class GroundingDINO(DetectionBaseModel):
text_threshold: float

def __init__(
self, ontology: CaptionOntology, box_threshold=0.35, text_threshold=0.25
self,
ontology: CaptionOntology,
box_threshold: float = 0.35,
text_threshold: float = 0.25,
model: str = "tiny",
):
self.ontology = ontology
self.grounding_dino_model = load_grounding_dino()
self.grounding_dino_model = load_grounding_dino(model)
self.box_threshold = box_threshold
self.text_threshold = text_threshold

Expand Down
29 changes: 20 additions & 9 deletions autodistill_grounding_dino/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODELS = {
"tiny": {
"config": "GroundingDINO_SwinT_OGC.cfg.py",
"checkpoint": "groundingdino_swint_ogc.pth",
},
"base": {
"config": "GroundingDINO_SwinB.cfg.py",
"checkpoint": "groundingdino_swinb_cogcoor.pth",
},
}

if not torch.cuda.is_available():
print("WARNING: CUDA not available. GroundingDINO will run very slowly.")

Expand Down Expand Up @@ -63,17 +74,15 @@ def combine_detections(detections_list, overwrite_class_ids):
)


def load_grounding_dino():
def load_grounding_dino(model: str = "tiny"):
config = MODELS[model]["config"]
checkpoint = MODELS[model]["checkpoint"]
AUTODISTILL_CACHE_DIR = os.path.expanduser("~/.cache/autodistill")

GROUDNING_DINO_CACHE_DIR = os.path.join(AUTODISTILL_CACHE_DIR, "groundingdino")

GROUNDING_DINO_CONFIG_PATH = os.path.join(
GROUDNING_DINO_CACHE_DIR, "GroundingDINO_SwinT_OGC.py"
)
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(
GROUDNING_DINO_CACHE_DIR, "groundingdino_swint_ogc.pth"
)
GROUNDING_DINO_CONFIG_PATH = os.path.join(GROUDNING_DINO_CACHE_DIR, config)
GROUNDING_DINO_CHECKPOINT_PATH = os.path.join(GROUDNING_DINO_CACHE_DIR, checkpoint)

try:
print("trying to load grounding dino directly")
Expand All @@ -89,11 +98,13 @@ def load_grounding_dino():
os.makedirs(GROUDNING_DINO_CACHE_DIR)

if not os.path.exists(GROUNDING_DINO_CHECKPOINT_PATH):
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
url = f"https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/{checkpoint}"
urllib.request.urlretrieve(url, GROUNDING_DINO_CHECKPOINT_PATH)

if not os.path.exists(GROUNDING_DINO_CONFIG_PATH):
url = "https://raw.githubusercontent.com/roboflow/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
url = (
f"https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/{config}"
)
urllib.request.urlretrieve(url, GROUNDING_DINO_CONFIG_PATH)

grounding_dino_model = Model(
Expand Down

0 comments on commit 1eb1f71

Please sign in to comment.