-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlang_sam.py
84 lines (73 loc) · 2.8 KB
/
lang_sam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import numpy as np
from PIL import Image
from lang_sam.models.gdino import GDINO
from lang_sam.models.sam import SAM
# from models.gdino import GDINO
# from models.sam import SAM
class LangSAM:
def __init__(self, sam_type="sam2.1_hiera_small", ckpt_path: str | None = None):
self.sam_type = sam_type
self.sam = SAM()
self.sam.build_model(sam_type, ckpt_path)
self.gdino = GDINO()
self.gdino.build_model()
def predict(
self,
images_pil: list[Image.Image],
texts_prompt: list[str],
box_threshold: float = 0.3,
text_threshold: float = 0.25,
):
"""Predicts masks for given images and text prompts using GDINO and SAM models.
Parameters:
images_pil (list[Image.Image]): List of input images.
texts_prompt (list[str]): List of text prompts corresponding to the images.
box_threshold (float): Threshold for box predictions.
text_threshold (float): Threshold for text predictions.
Returns:
list[dict]: List of results containing masks and other outputs for each image.
Output format:
[{
"boxes": np.ndarray,
"scores": np.ndarray,
"masks": np.ndarray,
"mask_scores": np.ndarray,
}, ...]
"""
gdino_results = self.gdino.predict(images_pil, texts_prompt, box_threshold, text_threshold)
all_results = []
sam_images = []
sam_boxes = []
sam_indices = []
for idx, result in enumerate(gdino_results):
processed_result = {
**result,
"masks": [],
"mask_scores": [],
}
if result["labels"]:
processed_result["boxes"] = result["boxes"].cpu().numpy()
processed_result["scores"] = result["scores"].cpu().numpy()
sam_images.append(np.asarray(images_pil[idx]))
sam_boxes.append(processed_result["boxes"])
sam_indices.append(idx)
all_results.append(processed_result)
if sam_images:
print(f"Predicting {len(sam_boxes)} masks")
masks, mask_scores, _ = self.sam.predict_batch(sam_images, xyxy=sam_boxes)
for idx, mask, score in zip(sam_indices, masks, mask_scores):
all_results[idx].update(
{
"masks": mask,
"mask_scores": score,
}
)
print(f"Predicted {len(all_results)} masks")
return all_results
if __name__ == "__main__":
model = LangSAM()
out = model.predict(
[Image.open("./assets/food.jpg"), Image.open("./assets/car.jpeg")],
["food", "car"],
)
print(out)