Skip to content

Commit

Permalink
apply black format
Browse files Browse the repository at this point in the history
  • Loading branch information
TheoMF committed Sep 7, 2023
1 parent 4eeb210 commit c333aad
Showing 1 changed file with 30 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
CosyPoseWrapper,
)

#from happypose.pose_estimators.cosypose.cosypose.rendering.bullet_scene_renderer import BulletSceneRenderer
# from happypose.pose_estimators.cosypose.cosypose.rendering.bullet_scene_renderer import BulletSceneRenderer
from happypose.pose_estimators.cosypose.cosypose.visualization.singleview import (
render_prediction_wrt_camera,
)
Expand All @@ -51,13 +51,10 @@
########################






logger = get_logger(__name__)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def load_observation(
example_dir: Path,
Expand All @@ -70,7 +67,10 @@ def load_observation(

depth = None
if load_depth:
depth = np.array(Image.open(example_dir / "image_depth.png"), dtype=np.float32) / 1000
depth = (
np.array(Image.open(example_dir / "image_depth.png"), dtype=np.float32)
/ 1000
)
assert depth.shape[:2] == camera_data.resolution

return rgb, depth, camera_data
Expand All @@ -86,6 +86,7 @@ def load_observation_tensor(
observation.cuda()
return observation


def make_object_dataset(example_dir: Path) -> RigidObjectDataset:
rigid_objects = []
mesh_units = "mm"
Expand All @@ -98,21 +99,27 @@ def make_object_dataset(example_dir: Path) -> RigidObjectDataset:
assert not mesh_path, f"there multiple meshes in the {label} directory"
mesh_path = fn
assert mesh_path, f"couldnt find a obj or ply mesh for {label}"
rigid_objects.append(RigidObject(label=label, mesh_path=mesh_path, mesh_units=mesh_units))
rigid_objects.append(
RigidObject(label=label, mesh_path=mesh_path, mesh_units=mesh_units)
)
# TODO: fix mesh units
rigid_object_dataset = RigidObjectDataset(rigid_objects)
return rigid_object_dataset


def rendering(predictions, example_dir):
def rendering(predictions, example_dir, example_name):
object_dataset = make_object_dataset(example_dir)
# rendering
rgb, _, camera_data = load_observation(example_dir, load_depth=False)
camera_data.TWC = Transform(np.eye(4))
renderer = Panda3dSceneRenderer(object_dataset)
# Data necessary for image rendering
object_datas = [ObjectData(label="crackers", TWO=Transform(predictions.poses[0].numpy()))]
camera_data, object_datas = convert_scene_observation_to_panda3d(camera_data, object_datas)
object_datas = [
ObjectData(label=example_name, TWO=Transform(predictions.poses[0].numpy()))
]
camera_data, object_datas = convert_scene_observation_to_panda3d(
camera_data, object_datas
)
light_datas = [
Panda3dLightData(
light_type="ambient",
Expand All @@ -138,7 +145,7 @@ def save_predictions(example_dir, renderings):
# BulletSceneRenderer.render_scene: gets a "object list" (prediction like object), a list of camera infos (with Km pose, res) and renders
# a "camera observation" for each camera/viewpoint
# Actually, renders: rgb, mask, depth, near, far
#rgb_render = render_prediction_wrt_camera(renderer, preds, cam)
# rgb_render = render_prediction_wrt_camera(renderer, preds, cam)
mask = ~(rgb_render.sum(axis=-1) == 0)
alpha = 0.1
rgb_n_render = rgb.copy()
Expand All @@ -157,7 +164,9 @@ def save_predictions(example_dir, renderings):
rgb, renderings.rgb, dilate_iterations=1, color=(0, 255, 0)
)["img"]
fig_contour_overlay = plotter.plot_image(contour_overlay)
fig_all = gridplot([[fig_rgb, fig_contour_overlay, fig_mesh_overlay]], toolbar_location=None)
fig_all = gridplot(
[[fig_rgb, fig_contour_overlay, fig_mesh_overlay]], toolbar_location=None
)
vis_dir = example_dir / "visualizations"
vis_dir.mkdir(exist_ok=True)
export_png(fig_mesh_overlay, filename=vis_dir / "mesh_overlay.png")
Expand All @@ -181,23 +190,25 @@ def run_inference(
set_logging_level("info")
parser = argparse.ArgumentParser()
parser.add_argument("example_name")
parser.add_argument("--model", type=str, default="megapose-1.0-RGB-multi-hypothesis")
parser.add_argument(
"--model", type=str, default="megapose-1.0-RGB-multi-hypothesis"
)
parser.add_argument("--dataset", type=str, default="ycbv")
#parser.add_argument("--vis-detections", action="store_true")
# parser.add_argument("--vis-detections", action="store_true")
parser.add_argument("--run-inference", action="store_true", default=True)
#parser.add_argument("--vis-outputs", action="store_true")
# parser.add_argument("--vis-outputs", action="store_true")
args = parser.parse_args()

data_dir = os.getenv("MEGAPOSE_DATA_DIR")
assert data_dir
example_dir = Path(data_dir) / "examples" / args.example_name
dataset_to_use = args.dataset # tless or ycbv

#if args.vis_detections:
# if args.vis_detections:
# make_detections_visualization(example_dir)

if args.run_inference:
run_inference(example_dir, args.model, dataset_to_use)

#if args.vis_outputs:
# make_output_visualization(example_dir)
# if args.vis_outputs:
# make_output_visualization(example_dir)

0 comments on commit c333aad

Please sign in to comment.