Skip to content

Commit 5fdb3fb

Browse files
committed
feat(detections): ✨ paligemma segmentation support added
Signed-off-by: Onuralp SEZER <[email protected]>
1 parent a6e1f03 commit 5fdb3fb

File tree

2 files changed

+65
-17
lines changed

2 files changed

+65
-17
lines changed

supervision/detection/core.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -840,9 +840,10 @@ def from_lmm(
840840

841841
if lmm == LMM.PALIGEMMA:
842842
assert isinstance(result, str)
843-
xyxy, class_id, class_name = from_paligemma(result, **kwargs)
843+
xyxy, class_id, class_name, mask = from_paligemma(result, **kwargs)
844844
data = {CLASS_NAME_DATA_FIELD: class_name}
845-
return cls(xyxy=xyxy, class_id=class_id, data=data)
845+
mask = mask if mask is not None else None
846+
return cls(xyxy=xyxy, class_id=class_id, mask=mask, data=data)
846847

847848
if lmm == LMM.FLORENCE_2:
848849
assert isinstance(result, dict)

supervision/detection/lmm.py

+62-15
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,72 @@ def validate_lmm_parameters(
6969

7070
def from_paligemma(
7171
result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None
72-
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]:
72+
) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray, Optional[np.ndarray]]:
73+
"""
74+
Parse results from Paligemma model which can contain object detection and segmentation.
75+
76+
Args:
77+
result (str): Model output string containing loc and optional seg tokens
78+
resolution_wh (Tuple[int, int]): Target resolution (width, height)
79+
classes (Optional[List[str]]): List of class names to filter results
80+
81+
Returns:
82+
xyxy (np.ndarray): Bounding box coordinates
83+
class_id (Optional[np.ndarray]): Class IDs if classes provided
84+
class_name (np.ndarray): Class names
85+
mask (Optional[np.ndarray]): Segmentation masks if available
86+
""" # noqa: E501
7387
w, h = resolution_wh
74-
pattern = re.compile(
75-
r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s\-]+)"
76-
)
77-
matches = pattern.findall(result)
78-
matches = np.array(matches) if matches else np.empty((0, 5))
7988

80-
xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4]
81-
xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h])
82-
class_name = np.char.strip(class_name.astype(str))
83-
class_id = None
89+
segmentation_pattern = re.compile(
90+
r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s*"
91+
+ "".join(r"<seg(\d{3})>" for _ in range(16))
92+
+ r"\s+([\w\s\-]+)"
93+
)
8494

85-
if classes is not None:
86-
mask = np.array([name in classes for name in class_name]).astype(bool)
87-
xyxy, class_name = xyxy[mask], class_name[mask]
88-
class_id = np.array([classes.index(name) for name in class_name])
95+
detection_pattern = re.compile(
96+
r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s\-]+)"
97+
)
8998

90-
return xyxy, class_id, class_name
99+
segmentation_matches = segmentation_pattern.findall(result)
100+
if segmentation_matches:
101+
matches = np.array(segmentation_matches)
102+
xyxy = matches[:, [1, 0, 3, 2]].astype(int) / 1024 * np.array([w, h, w, h])
103+
class_name = np.char.strip(matches[:, -1].astype(str))
104+
seg_tokens = matches[:, 4:-1].astype(int)
105+
masks = [np.zeros((h, w), dtype=bool) for tokens in seg_tokens]
106+
masks = np.array(masks)
107+
108+
class_id = None
109+
if classes is not None:
110+
mask = np.array([name in classes for name in class_name]).astype(bool)
111+
xyxy = xyxy[mask]
112+
class_name = class_name[mask]
113+
masks = masks[mask]
114+
class_id = np.array([classes.index(name) for name in class_name])
115+
116+
return xyxy, class_id, class_name, masks
117+
118+
detection_matches = detection_pattern.findall(result)
119+
if detection_matches:
120+
matches = np.array(detection_matches)
121+
xyxy = matches[:, [1, 0, 3, 2]].astype(int) / 1024 * np.array([w, h, w, h])
122+
class_name = np.char.strip(matches[:, 4].astype(str))
123+
124+
class_id = None
125+
if classes is not None:
126+
mask = np.array([name in classes for name in class_name]).astype(bool)
127+
xyxy, class_name = xyxy[mask], class_name[mask]
128+
class_id = np.array([classes.index(name) for name in class_name])
129+
130+
return xyxy, class_id, class_name, None
131+
132+
return (
133+
np.empty((0, 4), dtype=float),
134+
None,
135+
np.array([], dtype=str),
136+
None
137+
)
91138

92139

93140
def from_florence_2(

0 commit comments

Comments
 (0)