@@ -69,25 +69,72 @@ def validate_lmm_parameters(
69
69
70
70
def from_paligemma (
71
71
result : str , resolution_wh : Tuple [int , int ], classes : Optional [List [str ]] = None
72
- ) -> Tuple [np .ndarray , Optional [np .ndarray ], np .ndarray ]:
72
+ ) -> Tuple [np .ndarray , Optional [np .ndarray ], np .ndarray , Optional [np .ndarray ]]:
73
+ """
74
+ Parse results from Paligemma model which can contain object detection and segmentation.
75
+
76
+ Args:
77
+ result (str): Model output string containing loc and optional seg tokens
78
+ resolution_wh (Tuple[int, int]): Target resolution (width, height)
79
+ classes (Optional[List[str]]): List of class names to filter results
80
+
81
+ Returns:
82
+ xyxy (np.ndarray): Bounding box coordinates
83
+ class_id (Optional[np.ndarray]): Class IDs if classes provided
84
+ class_name (np.ndarray): Class names
85
+ mask (Optional[np.ndarray]): Segmentation masks if available
86
+ """ # noqa: E501
73
87
w , h = resolution_wh
74
- pattern = re .compile (
75
- r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s\-]+)"
76
- )
77
- matches = pattern .findall (result )
78
- matches = np .array (matches ) if matches else np .empty ((0 , 5 ))
79
88
80
- xyxy , class_name = matches [:, [1 , 0 , 3 , 2 ]], matches [:, 4 ]
81
- xyxy = xyxy .astype (int ) / 1024 * np .array ([w , h , w , h ])
82
- class_name = np .char .strip (class_name .astype (str ))
83
- class_id = None
89
+ segmentation_pattern = re .compile (
90
+ r"<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})>\s*"
91
+ + "" .join (r"<seg(\d{3})>" for _ in range (16 ))
92
+ + r"\s+([\w\s\-]+)"
93
+ )
84
94
85
- if classes is not None :
86
- mask = np .array ([name in classes for name in class_name ]).astype (bool )
87
- xyxy , class_name = xyxy [mask ], class_name [mask ]
88
- class_id = np .array ([classes .index (name ) for name in class_name ])
95
+ detection_pattern = re .compile (
96
+ r"(?<!<loc\d{4}>)<loc(\d{4})><loc(\d{4})><loc(\d{4})><loc(\d{4})> ([\w\s\-]+)"
97
+ )
89
98
90
- return xyxy , class_id , class_name
99
+ segmentation_matches = segmentation_pattern .findall (result )
100
+ if segmentation_matches :
101
+ matches = np .array (segmentation_matches )
102
+ xyxy = matches [:, [1 , 0 , 3 , 2 ]].astype (int ) / 1024 * np .array ([w , h , w , h ])
103
+ class_name = np .char .strip (matches [:, - 1 ].astype (str ))
104
+ seg_tokens = matches [:, 4 :- 1 ].astype (int )
105
+ masks = [np .zeros ((h , w ), dtype = bool ) for tokens in seg_tokens ]
106
+ masks = np .array (masks )
107
+
108
+ class_id = None
109
+ if classes is not None :
110
+ mask = np .array ([name in classes for name in class_name ]).astype (bool )
111
+ xyxy = xyxy [mask ]
112
+ class_name = class_name [mask ]
113
+ masks = masks [mask ]
114
+ class_id = np .array ([classes .index (name ) for name in class_name ])
115
+
116
+ return xyxy , class_id , class_name , masks
117
+
118
+ detection_matches = detection_pattern .findall (result )
119
+ if detection_matches :
120
+ matches = np .array (detection_matches )
121
+ xyxy = matches [:, [1 , 0 , 3 , 2 ]].astype (int ) / 1024 * np .array ([w , h , w , h ])
122
+ class_name = np .char .strip (matches [:, 4 ].astype (str ))
123
+
124
+ class_id = None
125
+ if classes is not None :
126
+ mask = np .array ([name in classes for name in class_name ]).astype (bool )
127
+ xyxy , class_name = xyxy [mask ], class_name [mask ]
128
+ class_id = np .array ([classes .index (name ) for name in class_name ])
129
+
130
+ return xyxy , class_id , class_name , None
131
+
132
+ return (
133
+ np .empty ((0 , 4 ), dtype = float ),
134
+ None ,
135
+ np .array ([], dtype = str ),
136
+ None
137
+ )
91
138
92
139
93
140
def from_florence_2 (
0 commit comments