diff --git a/happypose/pose_estimators/cosypose/cosypose/utils/cosypose_wrapper.py b/happypose/pose_estimators/cosypose/cosypose/utils/cosypose_wrapper.py index 2e6f447e..13688c1b 100644 --- a/happypose/pose_estimators/cosypose/cosypose/utils/cosypose_wrapper.py +++ b/happypose/pose_estimators/cosypose/cosypose/utils/cosypose_wrapper.py @@ -11,7 +11,7 @@ - Deprecate this class when possible """ -from typing import Union +from typing import Tuple, Union import torch @@ -24,6 +24,7 @@ from happypose.pose_estimators.cosypose.cosypose.training.pose_models_cfg import ( load_model_cosypose, ) +from happypose.pose_estimators.megapose.inference.detector import Detector from happypose.toolbox.datasets.datasets_cfg import make_object_dataset from happypose.toolbox.datasets.object_dataset import RigidObjectDataset from happypose.toolbox.inference.utils import load_detector @@ -77,34 +78,37 @@ class CosyPoseWrapper: def __init__( self, dataset_name: str, + model_type: str = "pbr", object_dataset: Union[None, RigidObjectDataset] = None, - n_workers=8, - gpu_renderer=False, renderer_type: str = "panda3d", + n_workers: int = 8, ) -> None: """ - inputs: - - dataset_name: hope|tless|ycbv - - object_dataset: None or already existing rigid object dataset. If None, will use dataset_name - to build one - - n_workers: how many processes will be spun in the batch renderer - - renderer_type: 'panda3d'|'bullet' + :param dataset_name: str, name of the dataset on which model was trained, hope|tless|ycbv + :param model_type: str, type of NN model (depends on training data), hope|tless|ycbv + :param object_dataset: None or already existing rigid object dataset. If None, will use dataset_name to build one + :param n_workers: how many processes will be spun in the batch renderer + :param renderer_type: 'panda3d'|'bullet' """ self.dataset_name = dataset_name self.object_dataset = object_dataset self.detector, self.pose_predictor = self.get_model( - dataset_name, n_workers, renderer_type + dataset_name, model_type, n_workers, renderer_type ) - def get_model(self, dataset_name, n_workers, renderer_type, model_type="pbr"): + def get_model( + self, dataset_name, model_type, n_workers, renderer_type + ) -> Tuple[Detector, PoseEstimator]: """ + Return CosyPose detector and pose estimator objects for a given dataset. - :param dataset_name: str, name of the dataset on which model was trained - :param n_workers: int, number of workers used in the renderer + :param dataset_name: str, name of the dataset on which model was trained, hope|tless|ycbv + :param model_type: str, type of NN model (depends on training data), hope|tless|ycbv :param renderer_type: str, which renderer to use, "panda3d" and "bullet" supported + :param n_workers: int, number of workers used in the renderer :param model_type: str, what training data was used, "pbr" and "synth+real" supported - :return: tuple (detector) + :return: tuple (Detector,PoseEstimator) """ try: mids = AVAILABLE_MODELS[dataset_name][model_type]