From c333aad915b12756b6a7ffaa09ec4d9dee8921cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Martinez?= Date: Thu, 7 Sep 2023 11:19:46 +0200 Subject: [PATCH] apply black format --- .../scripts/run_inference_on_example.py | 49 ++++++++++++------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/happypose/pose_estimators/cosypose/cosypose/scripts/run_inference_on_example.py b/happypose/pose_estimators/cosypose/cosypose/scripts/run_inference_on_example.py index 641b5be9..5e95ef7d 100644 --- a/happypose/pose_estimators/cosypose/cosypose/scripts/run_inference_on_example.py +++ b/happypose/pose_estimators/cosypose/cosypose/scripts/run_inference_on_example.py @@ -24,7 +24,7 @@ CosyPoseWrapper, ) -#from happypose.pose_estimators.cosypose.cosypose.rendering.bullet_scene_renderer import BulletSceneRenderer +# from happypose.pose_estimators.cosypose.cosypose.rendering.bullet_scene_renderer import BulletSceneRenderer from happypose.pose_estimators.cosypose.cosypose.visualization.singleview import ( render_prediction_wrt_camera, ) @@ -51,13 +51,10 @@ ######################## - - - - logger = get_logger(__name__) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + def load_observation( example_dir: Path, @@ -70,7 +67,10 @@ def load_observation( depth = None if load_depth: - depth = np.array(Image.open(example_dir / "image_depth.png"), dtype=np.float32) / 1000 + depth = ( + np.array(Image.open(example_dir / "image_depth.png"), dtype=np.float32) + / 1000 + ) assert depth.shape[:2] == camera_data.resolution return rgb, depth, camera_data @@ -86,6 +86,7 @@ def load_observation_tensor( observation.cuda() return observation + def make_object_dataset(example_dir: Path) -> RigidObjectDataset: rigid_objects = [] mesh_units = "mm" @@ -98,21 +99,27 @@ def make_object_dataset(example_dir: Path) -> RigidObjectDataset: assert not mesh_path, f"there multiple meshes in the {label} directory" mesh_path = fn assert mesh_path, f"couldnt find a obj or ply mesh for {label}" - rigid_objects.append(RigidObject(label=label, mesh_path=mesh_path, mesh_units=mesh_units)) + rigid_objects.append( + RigidObject(label=label, mesh_path=mesh_path, mesh_units=mesh_units) + ) # TODO: fix mesh units rigid_object_dataset = RigidObjectDataset(rigid_objects) return rigid_object_dataset -def rendering(predictions, example_dir): +def rendering(predictions, example_dir, example_name): object_dataset = make_object_dataset(example_dir) # rendering rgb, _, camera_data = load_observation(example_dir, load_depth=False) camera_data.TWC = Transform(np.eye(4)) renderer = Panda3dSceneRenderer(object_dataset) # Data necessary for image rendering - object_datas = [ObjectData(label="crackers", TWO=Transform(predictions.poses[0].numpy()))] - camera_data, object_datas = convert_scene_observation_to_panda3d(camera_data, object_datas) + object_datas = [ + ObjectData(label=example_name, TWO=Transform(predictions.poses[0].numpy())) + ] + camera_data, object_datas = convert_scene_observation_to_panda3d( + camera_data, object_datas + ) light_datas = [ Panda3dLightData( light_type="ambient", @@ -138,7 +145,7 @@ def save_predictions(example_dir, renderings): # BulletSceneRenderer.render_scene: gets a "object list" (prediction like object), a list of camera infos (with Km pose, res) and renders # a "camera observation" for each camera/viewpoint # Actually, renders: rgb, mask, depth, near, far - #rgb_render = render_prediction_wrt_camera(renderer, preds, cam) + # rgb_render = render_prediction_wrt_camera(renderer, preds, cam) mask = ~(rgb_render.sum(axis=-1) == 0) alpha = 0.1 rgb_n_render = rgb.copy() @@ -157,7 +164,9 @@ def save_predictions(example_dir, renderings): rgb, renderings.rgb, dilate_iterations=1, color=(0, 255, 0) )["img"] fig_contour_overlay = plotter.plot_image(contour_overlay) - fig_all = gridplot([[fig_rgb, fig_contour_overlay, fig_mesh_overlay]], toolbar_location=None) + fig_all = gridplot( + [[fig_rgb, fig_contour_overlay, fig_mesh_overlay]], toolbar_location=None + ) vis_dir = example_dir / "visualizations" vis_dir.mkdir(exist_ok=True) export_png(fig_mesh_overlay, filename=vis_dir / "mesh_overlay.png") @@ -181,11 +190,13 @@ def run_inference( set_logging_level("info") parser = argparse.ArgumentParser() parser.add_argument("example_name") - parser.add_argument("--model", type=str, default="megapose-1.0-RGB-multi-hypothesis") + parser.add_argument( + "--model", type=str, default="megapose-1.0-RGB-multi-hypothesis" + ) parser.add_argument("--dataset", type=str, default="ycbv") - #parser.add_argument("--vis-detections", action="store_true") + # parser.add_argument("--vis-detections", action="store_true") parser.add_argument("--run-inference", action="store_true", default=True) - #parser.add_argument("--vis-outputs", action="store_true") + # parser.add_argument("--vis-outputs", action="store_true") args = parser.parse_args() data_dir = os.getenv("MEGAPOSE_DATA_DIR") @@ -193,11 +204,11 @@ def run_inference( example_dir = Path(data_dir) / "examples" / args.example_name dataset_to_use = args.dataset # tless or ycbv - #if args.vis_detections: + # if args.vis_detections: # make_detections_visualization(example_dir) if args.run_inference: run_inference(example_dir, args.model, dataset_to_use) - #if args.vis_outputs: - # make_output_visualization(example_dir) \ No newline at end of file + # if args.vis_outputs: + # make_output_visualization(example_dir)