Skip to content

Commit

Permalink
Add model_type to constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
MedericFourmy committed Oct 14, 2024
1 parent 797de37 commit 5c40a8d
Showing 1 changed file with 18 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
- Deprecate this class when possible
"""

from typing import Union
from typing import Tuple, Union

import torch

Expand All @@ -24,6 +24,7 @@
from happypose.pose_estimators.cosypose.cosypose.training.pose_models_cfg import (
load_model_cosypose,
)
from happypose.pose_estimators.megapose.inference.detector import Detector
from happypose.toolbox.datasets.datasets_cfg import make_object_dataset
from happypose.toolbox.datasets.object_dataset import RigidObjectDataset
from happypose.toolbox.inference.utils import load_detector
Expand Down Expand Up @@ -77,34 +78,37 @@ class CosyPoseWrapper:
def __init__(
self,
dataset_name: str,
model_type: str = "pbr",
object_dataset: Union[None, RigidObjectDataset] = None,
n_workers=8,
gpu_renderer=False,
renderer_type: str = "panda3d",
n_workers: int = 8,
) -> None:
"""
inputs:
- dataset_name: hope|tless|ycbv
- object_dataset: None or already existing rigid object dataset. If None, will use dataset_name
to build one
- n_workers: how many processes will be spun in the batch renderer
- renderer_type: 'panda3d'|'bullet'
:param dataset_name: str, name of the dataset on which model was trained, hope|tless|ycbv
:param model_type: str, type of NN model (depends on training data), hope|tless|ycbv
:param object_dataset: None or already existing rigid object dataset. If None, will use dataset_name to build one
:param n_workers: how many processes will be spun in the batch renderer
:param renderer_type: 'panda3d'|'bullet'
"""

self.dataset_name = dataset_name
self.object_dataset = object_dataset
self.detector, self.pose_predictor = self.get_model(
dataset_name, n_workers, renderer_type
dataset_name, model_type, n_workers, renderer_type
)

def get_model(self, dataset_name, n_workers, renderer_type, model_type="pbr"):
def get_model(
self, dataset_name, model_type, n_workers, renderer_type
) -> Tuple[Detector, PoseEstimator]:
"""
Return CosyPose detector and pose estimator objects for a given dataset.
:param dataset_name: str, name of the dataset on which model was trained
:param n_workers: int, number of workers used in the renderer
:param dataset_name: str, name of the dataset on which model was trained, hope|tless|ycbv
:param model_type: str, type of NN model (depends on training data), hope|tless|ycbv
:param renderer_type: str, which renderer to use, "panda3d" and "bullet" supported
:param n_workers: int, number of workers used in the renderer
:param model_type: str, what training data was used, "pbr" and "synth+real" supported
:return: tuple (detector)
:return: tuple (Detector,PoseEstimator)
"""
try:
mids = AVAILABLE_MODELS[dataset_name][model_type]
Expand Down

0 comments on commit 5c40a8d

Please sign in to comment.