diff --git a/happypose/pose_estimators/cosypose/cosypose/utils/cosypose_wrapper.py b/happypose/pose_estimators/cosypose/cosypose/utils/cosypose_wrapper.py index 6f87a3e6..2e6f447e 100644 --- a/happypose/pose_estimators/cosypose/cosypose/utils/cosypose_wrapper.py +++ b/happypose/pose_estimators/cosypose/cosypose/utils/cosypose_wrapper.py @@ -32,6 +32,47 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# To download model weights for one of the datasets, run these commands: +# python -m happypose.toolbox.utils.download --cosypose_models= +# python -m happypose.toolbox.utils.download --cosypose_models= +# python -m happypose.toolbox.utils.download --cosypose_models= + +AVAILABLE_MODELS = { + "hope": { + "pbr": { + "detector_run_id": "detector-bop-hope-pbr--15246", + "coarse_run_id": "coarse-bop-hope-pbr--225203", + "refiner_run_id": "refiner-bop-hope-pbr--955392", + }, + "synth+real": {}, # no such model + }, + "tless": { + "pbr": { + "detector_run_id": "detector-bop-tless-pbr--873074", + "coarse_run_id": "coarse-bop-tless-pbr--506801", + "refiner_run_id": "refiner-bop-tless-pbr--233420", + }, + "synth+real": { + "detector_run_id": "detector-bop-tless-synt+real--452847", + "coarse_run_id": "coarse-bop-tless-synt+real--160982", + "refiner_run_id": "refiner-bop-tless-synt+real--881314", + }, + }, + "ycbv": { + "ycbv": { + "detector_run_id": "detector-bop-ycbv-pbr--970850", + "coarse_run_id": "coarse-bop-ycbv-pbr--724183", + "refiner_run_id": "refiner-bop-ycbv-pbr--604090", + }, + "synth+real": { + "detector_run_id": "detector-bop-ycbv-synt+real--292971", + "coarse_run_id": "coarse-bop-ycbv-synt+real--822463", + "refiner_run_id": "refiner-bop-ycbv-synt+real--631598", + }, + }, +} + + class CosyPoseWrapper: def __init__( self, @@ -56,38 +97,25 @@ def __init__( dataset_name, n_workers, renderer_type ) - def get_model(self, dataset_name, n_workers, renderer_type): - # load models - if dataset_name == "hope": - # HOPE setup - # python -m happypose.toolbox.utils.download --cosypose_models=detector-bop-hope-pbr--15246 - # python -m happypose.toolbox.utils.download --cosypose_models=coarse-bop-hope-pbr--225203 - # python -m happypose.toolbox.utils.download --cosypose_models=refiner-bop-hope-pbr--955392 - detector_run_id = "detector-bop-hope-pbr--15246" - coarse_run_id = "coarse-bop-hope-pbr--225203" - refiner_run_id = "refiner-bop-hope-pbr--955392" - elif dataset_name == "tless": - # TLESS setup - # python -m happypose.toolbox.utils.download --cosypose_models=detector-bop-tless-pbr--873074 - # python -m happypose.toolbox.utils.download --cosypose_models=coarse-bop-tless-pbr--506801 - # python -m happypose.toolbox.utils.download --cosypose_models=refiner-bop-tless-pbr--233420 - detector_run_id = "detector-bop-tless-pbr--873074" - coarse_run_id = "coarse-bop-tless-pbr--506801" - refiner_run_id = "refiner-bop-tless-pbr--233420" - elif dataset_name == "ycbv": - # YCBV setup - # python -m happypose.toolbox.utils.download --cosypose_models=detector-bop-ycbv-pbr--970850 - # python -m happypose.toolbox.utils.download --cosypose_models=coarse-bop-ycbv-pbr--724183 - # python -m happypose.toolbox.utils.download --cosypose_models=refiner-bop-ycbv-pbr--604090 - detector_run_id = "detector-bop-ycbv-pbr--970850" - coarse_run_id = "coarse-bop-ycbv-pbr--724183" - refiner_run_id = "refiner-bop-ycbv-pbr--604090" - else: - msg = f"Not prepared for {dataset_name} dataset" - raise ValueError(msg) - detector = load_detector(detector_run_id, device) + def get_model(self, dataset_name, n_workers, renderer_type, model_type="pbr"): + """ + + :param dataset_name: str, name of the dataset on which model was trained + :param n_workers: int, number of workers used in the renderer + :param renderer_type: str, which renderer to use, "panda3d" and "bullet" supported + :param model_type: str, what training data was used, "pbr" and "synth+real" supported + :return: tuple (detector) + """ + try: + mids = AVAILABLE_MODELS[dataset_name][model_type] + except KeyError: + raise KeyError( + f"{dataset_name}, {model_type} combination not supported when loading cosypose models" + ) + + detector = load_detector(mids["detector_run_id"], device) coarse_model, refiner_model = self.load_pose_models( - coarse_run_id, refiner_run_id, n_workers, renderer_type + mids["coarse_run_id"], mids["refiner_run_id"], n_workers, renderer_type ) pose_estimator = PoseEstimator(