Skip to content

Commit

Permalink
adding blip2 confirmation support (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
naokiyokoyamabd authored Sep 7, 2023
1 parent e38b145 commit e48f1e2
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 26 deletions.
38 changes: 36 additions & 2 deletions zsos/policy/base_objectnav_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from zsos.obs_transformers.utils import image_resize
from zsos.policy.utils.pointnav_policy import WrappedPointNavResNetPolicy
from zsos.utils.geometry_utils import rho_theta
from zsos.vlm.blip2 import BLIP2Client
from zsos.vlm.coco_classes import COCO_CLASSES
from zsos.vlm.grounding_dino import GroundingDINOClient, ObjectDetections
from zsos.vlm.sam import MobileSAMClient
Expand Down Expand Up @@ -49,6 +50,10 @@ def __init__(
max_obstacle_height: float = 0.88,
agent_radius: float = 0.18,
obstacle_map_area_threshold: float = 1.5,
use_vqa: bool = True,
vqa_prompt: str = "Is this ",
coco_threshold: float = 0.6,
non_coco_threshold: float = 0.4,
*args,
**kwargs,
):
Expand All @@ -60,6 +65,10 @@ def __init__(
port=os.environ.get("YOLOV7_PORT", 12184)
)
self._mobile_sam = MobileSAMClient(port=os.environ.get("SAM_PORT", 12183))
if use_vqa:
self._vqa = BLIP2Client(port=os.environ.get("BLIP2_PORT", 12185))
else:
self._vqa = None
self._pointnav_policy = WrappedPointNavResNetPolicy(pointnav_policy_path)
self._object_map: ObjectPointCloudMap = ObjectPointCloudMap(
erosion_size=object_map_erosion_size
Expand All @@ -68,6 +77,9 @@ def __init__(
self._det_conf_threshold = det_conf_threshold
self._pointnav_stop_radius = pointnav_stop_radius
self._visualize = visualize
self._vqa_prompt = vqa_prompt
self._coco_threshold = coco_threshold
self._non_coco_threshold = non_coco_threshold

self._num_steps = 0
self._did_reset = False
Expand Down Expand Up @@ -221,7 +233,7 @@ def _get_policy_info(self, detections: ObjectDetections) -> Dict[str, Any]:
def _get_object_detections(self, img: np.ndarray) -> ObjectDetections:
if self._target_object in COCO_CLASSES:
detections = self._coco_object_detector.predict(img)
self._det_conf_threshold = 0.6
self._det_conf_threshold = self._coco_threshold
else:
detections = self._object_detector.predict(img)
detections.phrases = [
Expand All @@ -232,7 +244,7 @@ def _get_object_detections(self, img: np.ndarray) -> ObjectDetections:
detections.phrases = [
p.replace("dining table", "table") for p in detections.phrases
]
self._det_conf_threshold = 0.4
self._det_conf_threshold = self._non_coco_threshold
if self._detect_target_only:
detections.filter_by_class([self._target_object])
detections.filter_by_conf(self._det_conf_threshold)
Expand Down Expand Up @@ -321,6 +333,24 @@ def _update_object_map(
[width, height, width, height]
)
object_mask = self._mobile_sam.segment_bbox(rgb, bbox_denorm.tolist())

# If we are using vqa, then use the BLIP2 model to visually confirm whether
# the contours are actually correct.
if self._vqa is not None:
contours, _ = cv2.findContours(
object_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
annotated_rgb = cv2.drawContours(
rgb.copy(), contours, -1, (255, 0, 0), 2
)
question = f"Question: {self._vqa_prompt}"
if not detections.phrases[idx].endswith("ing"):
question += "a "
question += detections.phrases[idx] + "? Answer:"
answer = self._vqa.ask(annotated_rgb, question)
if not answer.lower().startswith("yes"):
continue

self._object_masks[object_mask > 0] = 1
self._object_map.update_map(
detections.phrases[idx],
Expand Down Expand Up @@ -373,6 +403,10 @@ class ZSOSConfig:
text_prompt: str = "Seems like there is a target_object ahead."
min_obstacle_height: float = 0.61
max_obstacle_height: float = 0.88
use_vqa: bool = True
vqa_prompt: str = "Is this "
coco_threshold: float = 0.6
non_coco_threshold: float = 0.4

@classmethod
@property
Expand Down
2 changes: 1 addition & 1 deletion zsos/policy/habitat_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
MP3D_ID_TO_NAME = [
"chair",
"table",
"framed photo", # "picture",
"framed photograph", # "picture",
"cabinet",
"pillow", # "cushion",
"couch", # "sofa",
Expand Down
22 changes: 12 additions & 10 deletions zsos/vlm/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,23 @@ def ask(self, image, prompt=None) -> str:
"""
pil_img = Image.fromarray(image)
processed_image = (
self.vis_processors["eval"](pil_img).unsqueeze(0).to(self.device)
)

if prompt is None or prompt == "":
out = self.model.generate({"image": processed_image})[0]
else:
out = self.model.generate({"image": processed_image, "prompt": prompt})[0]
with torch.inference_mode():
processed_image = (
self.vis_processors["eval"](pil_img).unsqueeze(0).to(self.device)
)
if prompt is None or prompt == "":
out = self.model.generate({"image": processed_image})[0]
else:
out = self.model.generate({"image": processed_image, "prompt": prompt})[
0
]

return out


class BLIP2Client:
def __init__(self, url: str = "http://localhost:8070/blip2"):
self.url = url
def __init__(self, port: int = 12185):
self.url = f"http://localhost:{port}/blip2"

def ask(self, image: np.ndarray, prompt: Optional[str] = None) -> str:
if prompt is None:
Expand Down
6 changes: 4 additions & 2 deletions zsos/vlm/blip2itm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ def cosine(self, image: np.ndarray, txt: str) -> float:
pil_img = Image.fromarray(image)
img = self.vis_processors["eval"](pil_img).unsqueeze(0).to(self.device)
txt = self.text_processors["eval"](txt)

cosine = self.model({"image": img, "text_input": txt}, match_head="itc").item()
with torch.inference_mode():
cosine = self.model(
{"image": img, "text_input": txt}, match_head="itc"
).item()

return cosine

Expand Down
4 changes: 2 additions & 2 deletions zsos/vlm/classes.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
picture
framed photograph
cabinet
pillow
dresser
nightstand
sink
stool
towel
Expand Down
15 changes: 8 additions & 7 deletions zsos/vlm/grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ def predict(self, image: np.ndarray, visualize: bool = False) -> ObjectDetection
image_transformed = F.normalize(
image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
boxes, logits, phrases = predict(
model=self.model,
image=image_transformed,
caption=self.classes,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
)
with torch.inference_mode():
boxes, logits, phrases = predict(
model=self.model,
image=image_transformed,
caption=self.classes,
box_threshold=self.box_threshold,
text_threshold=self.text_threshold,
)
detections = ObjectDetections(boxes, logits, phrases, image_source=image)

classes = self.classes.split(" . ")
Expand Down
7 changes: 5 additions & 2 deletions zsos/vlm/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ def segment_bbox(self, image: np.ndarray, bbox: List[int]) -> np.ndarray:
is the same size as the bbox, cropped out of the image.
"""
self.predictor.set_image(image)
masks, _, _ = self.predictor.predict(box=np.array(bbox), multimask_output=False)
with torch.inference_mode():
self.predictor.set_image(image)
masks, _, _ = self.predictor.predict(
box=np.array(bbox), multimask_output=False
)

return masks[0]

Expand Down

0 comments on commit e48f1e2

Please sign in to comment.