Skip to content

Commit

Permalink
evaluate coarse on two viewpoints only and fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
petrikvladimir committed Sep 18, 2023
1 parent de16a78 commit 0cf2213
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
4 changes: 2 additions & 2 deletions happypose/toolbox/utils/transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def load_SO3_grid(resolution):
Returns:
rotmats: [N,3,3]
"""
meagpose_dir = PROJECT_DIR / "happypose" / "pose_estimators" / "megapose"
data_fname = meagpose_dir / f"src/megapose/data/data_{resolution}.qua"
megapose_dir = PROJECT_DIR / "happypose" / "pose_estimators" / "megapose"
data_fname = megapose_dir / f"src/megapose/data/data_{resolution}.qua"

assert data_fname.is_file(), f"File {data_fname} not found"

Expand Down
9 changes: 7 additions & 2 deletions tests/test_megapose_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class TestMegaPoseInference(unittest.TestCase):
"""Unit tests for MegaPose inference example."""

def test_meagpose_pipeline(self):
def test_megapose_pipeline(self):
"""Run detector from CosyPose with coarse and refiner from MegaPose"""
observation = TestCosyPoseInference._load_crackers_example_observation()

Expand All @@ -30,10 +30,15 @@ def test_meagpose_pipeline(self):

model_info = NAMED_MODELS["megapose-1.0-RGB"]
pose_estimator = load_named_model("megapose-1.0-RGB", object_dataset).to("cpu")
preds, _ = pose_estimator.run_inference_pipeline(
# let's limit the grid, 278 is the most promising one, 477 the least one
pose_estimator._SO3_grid = pose_estimator._SO3_grid[[278, 477]]
preds, data = pose_estimator.run_inference_pipeline(
observation, detections=detections, **model_info["inference_parameters"]
)

scores = data["coarse"]["data"]["logits"]
self.assertGreater(scores[0], scores[1]) # 278 is better than 477

self.assertEqual(len(preds), 1)
self.assertEqual(preds.infos.label[0], "ycbv-obj_000002")

Expand Down

0 comments on commit 0cf2213

Please sign in to comment.