Skip to content

Commit

Permalink
Fix megapose unitest for GPU only -> torefactor
Browse files Browse the repository at this point in the history
  • Loading branch information
MedericFourmy authored and nim65s committed Jan 12, 2024
1 parent f14e6cb commit 00e8954
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tests/test_megapose_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TestMegaPoseInference(unittest.TestCase):
def test_megapose_pipeline(self):
"""Run detector from CosyPose with coarse and refiner from MegaPose."""
observation = TestCosyPoseInference._load_crackers_example_observation()
observation = observation.cuda()

detector = TestCosyPoseInference._load_detector()
detections = detector.get_detections(observation=observation)
Expand All @@ -30,7 +31,7 @@ def test_megapose_pipeline(self):
)

model_info = NAMED_MODELS["megapose-1.0-RGB"]
pose_estimator = load_named_model("megapose-1.0-RGB", object_dataset).to("cpu")
pose_estimator = load_named_model("megapose-1.0-RGB", object_dataset).to("cuda")
# let's limit the grid, 278 is the most promising one, 477 the least one
pose_estimator._SO3_grid = pose_estimator._SO3_grid[[278, 477]]
preds, data = pose_estimator.run_inference_pipeline(
Expand All @@ -45,7 +46,7 @@ def test_megapose_pipeline(self):
self.assertEqual(len(preds), 1)
self.assertEqual(preds.infos.label[0], "ycbv-obj_000002")

pose = pin.SE3(preds.poses[0].numpy())
pose = pin.SE3(preds.poses[0].cpu().numpy())
exp_pose = pin.SE3(
pin.exp3(np.array([1.44, 1.19, -0.91])),
np.array([0, 0, 0.52]),
Expand Down

0 comments on commit 00e8954

Please sign in to comment.