From 00e8954988e1bb8641a2286b842947742c5ee531 Mon Sep 17 00:00:00 2001 From: Mederic Fourmy Date: Tue, 19 Dec 2023 12:06:07 +0100 Subject: [PATCH] Fix megapose unitest for GPU only -> torefactor --- tests/test_megapose_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_megapose_inference.py b/tests/test_megapose_inference.py index bc5dae4a..78e963a5 100644 --- a/tests/test_megapose_inference.py +++ b/tests/test_megapose_inference.py @@ -17,6 +17,7 @@ class TestMegaPoseInference(unittest.TestCase): def test_megapose_pipeline(self): """Run detector from CosyPose with coarse and refiner from MegaPose.""" observation = TestCosyPoseInference._load_crackers_example_observation() + observation = observation.cuda() detector = TestCosyPoseInference._load_detector() detections = detector.get_detections(observation=observation) @@ -30,7 +31,7 @@ def test_megapose_pipeline(self): ) model_info = NAMED_MODELS["megapose-1.0-RGB"] - pose_estimator = load_named_model("megapose-1.0-RGB", object_dataset).to("cpu") + pose_estimator = load_named_model("megapose-1.0-RGB", object_dataset).to("cuda") # let's limit the grid, 278 is the most promising one, 477 the least one pose_estimator._SO3_grid = pose_estimator._SO3_grid[[278, 477]] preds, data = pose_estimator.run_inference_pipeline( @@ -45,7 +46,7 @@ def test_megapose_pipeline(self): self.assertEqual(len(preds), 1) self.assertEqual(preds.infos.label[0], "ycbv-obj_000002") - pose = pin.SE3(preds.poses[0].numpy()) + pose = pin.SE3(preds.poses[0].cpu().numpy()) exp_pose = pin.SE3( pin.exp3(np.array([1.44, 1.19, -0.91])), np.array([0, 0, 0.52]),