Skip to content

Commit

Permalink
224-224 not 244-224, cleaner code for frontier selecting, filling depth
Browse files Browse the repository at this point in the history
  • Loading branch information
naokiyokoyamabd committed Sep 1, 2023
1 parent 7f1de6c commit 611e4c9
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 39 deletions.
38 changes: 38 additions & 0 deletions scripts/parse_jsons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import argparse
import json
import os
from collections import Counter


def read_json_files(directory):
failure_causes = []
for filename in os.listdir(directory):
if filename.endswith(".json"):
with open(os.path.join(directory, filename), "r") as f:
data = json.load(f)
if "failure_cause" in data:
failure_causes.append(data["failure_cause"])
return failure_causes


def calculate_frequencies(failure_causes):
counter = Counter(failure_causes)
total = sum(counter.values())
for cause, count in counter.items():
percentage = (count / total) * 100
print(
f"Failure cause: {cause}, Frequency: {count}, Percentage: {percentage:.2f}%"
)


def main():
parser = argparse.ArgumentParser(description="Process some integers.")
parser.add_argument("directory", type=str, help="Directory to process")
args = parser.parse_args()

failure_causes = read_json_files(args.directory)
calculate_frequencies(failure_causes)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions zsos/mapping/obstacle_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from zsos.mapping.base_map import BaseMap
from zsos.mapping.value_map import JSON_PATH, KWARGS_JSON
from zsos.utils.geometry_utils import extract_yaw, get_point_cloud, transform_points
from zsos.utils.img_utils import fill_small_holes


class ObstacleMap(BaseMap):
Expand Down Expand Up @@ -74,7 +73,8 @@ def update_map(
topdown_fov (float): The field of view of the depth camera projected onto
the topdown map.
"""
filled_depth = fill_small_holes(depth, 10000)
filled_depth = depth.copy()
filled_depth[depth == 0] = 1.0
scaled_depth = filled_depth * (max_depth - min_depth) + min_depth
mask = scaled_depth < max_depth
point_cloud_camera_frame = get_point_cloud(scaled_depth, mask, fx, fy)
Expand Down
4 changes: 2 additions & 2 deletions zsos/policy/base_objectnav_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ def _infer_depth(
class ZSOSConfig:
name: str = "HabitatITMPolicy"
pointnav_policy_path: str = "data/pointnav_weights.pth"
depth_image_shape: Tuple[int, int] = (244, 224)
det_conf_threshold: float = 0.6
depth_image_shape: Tuple[int, int] = (224, 224)
det_conf_threshold: float = 0.8
pointnav_stop_radius: float = 0.9
use_max_confidence: bool = False
object_map_erosion_size: int = 5
Expand Down
81 changes: 46 additions & 35 deletions zsos/policy/itm_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class BaseITMPolicy(BaseObjectNavPolicy):
_acyclic_enforcer: AcyclicEnforcer = None # must be set by ._reset()
_last_value: float = float("-inf")
_last_frontier: np.ndarray = np.zeros(2)
_second_best_thresh: float = 0.9

@staticmethod
def _vis_reduce_fn(i):
Expand Down Expand Up @@ -83,49 +82,61 @@ def _get_best_frontier(
Returns:
Tuple[np.ndarray, float]: The best frontier and its value.
"""
# The points and values will be sorted in descending order
sorted_pts, sorted_values = self._sort_frontiers_by_value(
observations, frontiers
)
best_frontier, best_value = None, None

robot_xy = self._observations_cache["robot_xy"]

if self._last_value > 0.0:
closest_index = closest_point_within_threshold(
sorted_pts, self._last_frontier, threshold=0.5
)
else:
closest_index = -1

if (
closest_index != -1
and self._last_value
> sorted_values[closest_index] * self._second_best_thresh
):
best_frontier, best_value = (
sorted_pts[closest_index],
sorted_values[closest_index],
)
else:
for frontier, value in zip(sorted_pts, sorted_values):
cyclic = self._acyclic_enforcer.check_cyclic(robot_xy, frontier)
if not cyclic:
best_frontier, best_value = frontier, value
best_frontier_idx = None

# If there is a last point pursued, then we consider sticking to pursuing it
# if it is still in the list of frontiers and its current value is not much
# worse than self._last_value.
if not np.array_equal(self._last_frontier, np.zeros(2)):
curr_index = None

for idx, p in enumerate(sorted_pts):
if np.array_equal(p, self._last_frontier):
# Last point is still in the list of frontiers
curr_index = idx
break
print("Suppressed cyclic frontier.")

if best_frontier is None:
print("All frontiers are cyclic. Choosing the closest one.")
best_idx = max(
range(len(frontiers)),
key=lambda i: np.linalg.norm(frontiers[i] - robot_xy),
if curr_index is None:
closest_index = closest_point_within_threshold(
sorted_pts, self._last_frontier, threshold=0.5
)

best_frontier, best_value = (
frontiers[best_idx],
sorted_values[best_idx],
)
if closest_index != -1:
# There is a point close to the last point pursued
curr_index = closest_index

if curr_index is not None:
curr_value = sorted_values[curr_index]
if curr_value + 0.01 > self._last_value:
# The last point pursued is still in the list of frontiers and its
# value is not much worse than self._last_value
best_frontier_idx = curr_index

# If there is no last point pursued, then just take the best point, given that
# it is not cyclic.
if best_frontier_idx is None:
for idx, frontier in enumerate(sorted_pts):
cyclic = self._acyclic_enforcer.check_cyclic(robot_xy, frontier)
if cyclic:
print("Suppressed cyclic frontier.")
continue
best_frontier_idx = idx
break

if best_frontier_idx is None:
print("All frontiers are cyclic. Just choosing the closest one.")
best_frontier_idx = max(
range(len(frontiers)),
key=lambda i: np.linalg.norm(frontiers[i] - robot_xy),
)

best_frontier = sorted_pts[best_frontier_idx]
best_value = sorted_values[best_frontier_idx]
self._acyclic_enforcer.add_state_action(robot_xy, best_frontier)
self._last_value = best_value
self._last_frontier = best_frontier
Expand Down

0 comments on commit 611e4c9

Please sign in to comment.