Skip to content

Commit

Permalink
few fixes, run_inference return predictions attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
TheoMF committed Sep 7, 2023
1 parent c333aad commit 8a8caea
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,20 @@ def run_inference(
example_dir: Path,
model_name: str,
dataset_to_use: str,
example_name: str,
render: bool = True,
) -> None:
observation = load_observation_tensor(example_dir)
CosyPose = CosyPoseWrapper(dataset_name=dataset_to_use, n_workers=8)
predictions = CosyPose.inference(observation)
renderings = rendering(predictions, example_dir)
save_predictions(example_dir, renderings)
try:
observation = load_observation_tensor(example_dir)
CosyPose = CosyPoseWrapper(dataset_name=dataset_to_use, n_workers=8)
predictions = CosyPose.inference(observation)
if render:
renderings = rendering(predictions, example_dir, example_name)
save_predictions(example_dir, renderings)
return predictions.poses, predictions.infos
except AttributeError as err:
print("rien trouvé")
return []


if __name__ == "__main__":
Expand All @@ -193,10 +201,11 @@ def run_inference(
parser.add_argument(
"--model", type=str, default="megapose-1.0-RGB-multi-hypothesis"
)
parser.add_argument("--dataset", type=str, default="ycbv")
parser.add_argument("--dataset", type=str, default="tless")
# parser.add_argument("--vis-detections", action="store_true")
parser.add_argument("--run-inference", action="store_true", default=True)
# parser.add_argument("--vis-outputs", action="store_true")
parser.add_argument("--render", action="store_true", default=False)
args = parser.parse_args()

data_dir = os.getenv("MEGAPOSE_DATA_DIR")
Expand All @@ -208,7 +217,9 @@ def run_inference(
# make_detections_visualization(example_dir)

if args.run_inference:
run_inference(example_dir, args.model, dataset_to_use)
run_inference(
example_dir, args.model, dataset_to_use, args.example_name, args.render
)

# if args.vis_outputs:
# make_output_visualization(example_dir)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class BaseScene:
_connected = False
_simulation_step = 1/240.

def connect(self, gpu_renderer=True, gui=False):
def connect(self, gpu_renderer=False, gui=False):
assert not self._connected, 'Already connected'
if gui:
self._client_id = pb.connect(pb.GUI, '--width=640 --height=480')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def load_pose_models(coarse_run_id, refiner_run_id, n_workers):
#renderer = BulletBatchRenderer(object_set=cfg.urdf_ds_name, n_workers=n_workers, gpu_renderer=gpu_renderer)
#

object_dataset = make_object_dataset("ycbv")
object_dataset = make_object_dataset("tless.cad")
mesh_db = MeshDataBase.from_object_ds(object_dataset)
renderer = Panda3dBatchRenderer(object_dataset, n_workers=n_workers, preload_cache=False)
mesh_db_batched = mesh_db.batched().to(device)
Expand Down Expand Up @@ -189,4 +189,4 @@ def inference(self, observation, coarse_guess=None):
n_coarse_iterations=0, n_refiner_iterations=4)
print("inference successfully.")
# result: this_batch_detections, final_preds
return final_preds.cpu()
return final_preds.cpu()

0 comments on commit 8a8caea

Please sign in to comment.