From 0cf2213d7df2955dfe8fcc20f4858e7ced46d6da Mon Sep 17 00:00:00 2001 From: Vladimir Petrik Date: Mon, 18 Sep 2023 14:20:03 +0200 Subject: [PATCH] evaluate coarse on two viewpoints only and fix typos --- happypose/toolbox/utils/transform_utils.py | 4 ++-- tests/test_megapose_inference.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/happypose/toolbox/utils/transform_utils.py b/happypose/toolbox/utils/transform_utils.py index 411a46d8..19d82c67 100644 --- a/happypose/toolbox/utils/transform_utils.py +++ b/happypose/toolbox/utils/transform_utils.py @@ -34,8 +34,8 @@ def load_SO3_grid(resolution): Returns: rotmats: [N,3,3] """ - meagpose_dir = PROJECT_DIR / "happypose" / "pose_estimators" / "megapose" - data_fname = meagpose_dir / f"src/megapose/data/data_{resolution}.qua" + megapose_dir = PROJECT_DIR / "happypose" / "pose_estimators" / "megapose" + data_fname = megapose_dir / f"src/megapose/data/data_{resolution}.qua" assert data_fname.is_file(), f"File {data_fname} not found" diff --git a/tests/test_megapose_inference.py b/tests/test_megapose_inference.py index 1a1b0ecd..2138e988 100644 --- a/tests/test_megapose_inference.py +++ b/tests/test_megapose_inference.py @@ -13,7 +13,7 @@ class TestMegaPoseInference(unittest.TestCase): """Unit tests for MegaPose inference example.""" - def test_meagpose_pipeline(self): + def test_megapose_pipeline(self): """Run detector from CosyPose with coarse and refiner from MegaPose""" observation = TestCosyPoseInference._load_crackers_example_observation() @@ -30,10 +30,15 @@ def test_meagpose_pipeline(self): model_info = NAMED_MODELS["megapose-1.0-RGB"] pose_estimator = load_named_model("megapose-1.0-RGB", object_dataset).to("cpu") - preds, _ = pose_estimator.run_inference_pipeline( + # let's limit the grid, 278 is the most promising one, 477 the least one + pose_estimator._SO3_grid = pose_estimator._SO3_grid[[278, 477]] + preds, data = pose_estimator.run_inference_pipeline( observation, detections=detections, **model_info["inference_parameters"] ) + scores = data["coarse"]["data"]["logits"] + self.assertGreater(scores[0], scores[1]) # 278 is better than 477 + self.assertEqual(len(preds), 1) self.assertEqual(preds.infos.label[0], "ycbv-obj_000002")