Skip to content

Commit

Permalink
refactor available models to allow synth+real
Browse files Browse the repository at this point in the history
  • Loading branch information
MedericFourmy committed Oct 10, 2024
1 parent 606c145 commit 797de37
Showing 1 changed file with 59 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,47 @@
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# To download model weights for one of the datasets, run these commands:
# python -m happypose.toolbox.utils.download --cosypose_models=<detector_run_id>
# python -m happypose.toolbox.utils.download --cosypose_models=<coarse_run_id>
# python -m happypose.toolbox.utils.download --cosypose_models=<refiner_run_id>

AVAILABLE_MODELS = {
"hope": {
"pbr": {
"detector_run_id": "detector-bop-hope-pbr--15246",
"coarse_run_id": "coarse-bop-hope-pbr--225203",
"refiner_run_id": "refiner-bop-hope-pbr--955392",
},
"synth+real": {}, # no such model
},
"tless": {
"pbr": {
"detector_run_id": "detector-bop-tless-pbr--873074",
"coarse_run_id": "coarse-bop-tless-pbr--506801",
"refiner_run_id": "refiner-bop-tless-pbr--233420",
},
"synth+real": {
"detector_run_id": "detector-bop-tless-synt+real--452847",
"coarse_run_id": "coarse-bop-tless-synt+real--160982",
"refiner_run_id": "refiner-bop-tless-synt+real--881314",
},
},
"ycbv": {
"ycbv": {
"detector_run_id": "detector-bop-ycbv-pbr--970850",
"coarse_run_id": "coarse-bop-ycbv-pbr--724183",
"refiner_run_id": "refiner-bop-ycbv-pbr--604090",
},
"synth+real": {
"detector_run_id": "detector-bop-ycbv-synt+real--292971",
"coarse_run_id": "coarse-bop-ycbv-synt+real--822463",
"refiner_run_id": "refiner-bop-ycbv-synt+real--631598",
},
},
}


class CosyPoseWrapper:
def __init__(
self,
Expand All @@ -56,38 +97,25 @@ def __init__(
dataset_name, n_workers, renderer_type
)

def get_model(self, dataset_name, n_workers, renderer_type):
# load models
if dataset_name == "hope":
# HOPE setup
# python -m happypose.toolbox.utils.download --cosypose_models=detector-bop-hope-pbr--15246
# python -m happypose.toolbox.utils.download --cosypose_models=coarse-bop-hope-pbr--225203
# python -m happypose.toolbox.utils.download --cosypose_models=refiner-bop-hope-pbr--955392
detector_run_id = "detector-bop-hope-pbr--15246"
coarse_run_id = "coarse-bop-hope-pbr--225203"
refiner_run_id = "refiner-bop-hope-pbr--955392"
elif dataset_name == "tless":
# TLESS setup
# python -m happypose.toolbox.utils.download --cosypose_models=detector-bop-tless-pbr--873074
# python -m happypose.toolbox.utils.download --cosypose_models=coarse-bop-tless-pbr--506801
# python -m happypose.toolbox.utils.download --cosypose_models=refiner-bop-tless-pbr--233420
detector_run_id = "detector-bop-tless-pbr--873074"
coarse_run_id = "coarse-bop-tless-pbr--506801"
refiner_run_id = "refiner-bop-tless-pbr--233420"
elif dataset_name == "ycbv":
# YCBV setup
# python -m happypose.toolbox.utils.download --cosypose_models=detector-bop-ycbv-pbr--970850
# python -m happypose.toolbox.utils.download --cosypose_models=coarse-bop-ycbv-pbr--724183
# python -m happypose.toolbox.utils.download --cosypose_models=refiner-bop-ycbv-pbr--604090
detector_run_id = "detector-bop-ycbv-pbr--970850"
coarse_run_id = "coarse-bop-ycbv-pbr--724183"
refiner_run_id = "refiner-bop-ycbv-pbr--604090"
else:
msg = f"Not prepared for {dataset_name} dataset"
raise ValueError(msg)
detector = load_detector(detector_run_id, device)
def get_model(self, dataset_name, n_workers, renderer_type, model_type="pbr"):
"""
:param dataset_name: str, name of the dataset on which model was trained
:param n_workers: int, number of workers used in the renderer
:param renderer_type: str, which renderer to use, "panda3d" and "bullet" supported
:param model_type: str, what training data was used, "pbr" and "synth+real" supported
:return: tuple (detector)
"""
try:
mids = AVAILABLE_MODELS[dataset_name][model_type]
except KeyError:
raise KeyError(
f"{dataset_name}, {model_type} combination not supported when loading cosypose models"
)

detector = load_detector(mids["detector_run_id"], device)
coarse_model, refiner_model = self.load_pose_models(
coarse_run_id, refiner_run_id, n_workers, renderer_type
mids["coarse_run_id"], mids["refiner_run_id"], n_workers, renderer_type
)

pose_estimator = PoseEstimator(
Expand Down

0 comments on commit 797de37

Please sign in to comment.