Skip to content

Commit

Permalink
fixing/refactoring cosy eval
Browse files Browse the repository at this point in the history
  • Loading branch information
MedericFourmy committed Apr 18, 2024
1 parent 35a7655 commit 1db2bd6
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 40 deletions.
2 changes: 1 addition & 1 deletion docs/book/cosypose/evaluate.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Please make sure you followed the steps relative to the evaluation in the main r
Please run the following command to evaluate on YCBV dataset

```
python -m happypose.pose_estimators.cosypose.cosypose.scripts.run_full_cosypose_eval_new detector_run_id=bop_pbr coarse_run_id=coarse-bop-ycbv-pbr--724183 refiner_run_id=refiner-bop-ycbv-pbr--604090 ds_names=["ycbv.bop19"] result_id=ycbv-debug detection_coarse_types=[["detector","S03_grid"]]
python -m happypose.pose_estimators.cosypose.cosypose.scripts.run_full_cosypose_eval_new detector_run_id=bop_pbr coarse_run_id=coarse-bop-ycbv-pbr--724183 refiner_run_id=refiner-bop-ycbv-pbr--604090 ds_names=["ycbv.bop19"] result_id=ycbv-debug detection_coarse_types=[["detector"]]
```

The other BOP datasets are supported as long as you download the correspond models.
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,11 @@
from happypose.pose_estimators.cosypose.cosypose.evaluation.prediction_runner import (
PredictionRunner,
)
from happypose.pose_estimators.cosypose.cosypose.integrated.detector import Detector
from happypose.pose_estimators.cosypose.cosypose.integrated.pose_estimator import (
PoseEstimator,
)

# Detection
from happypose.pose_estimators.cosypose.cosypose.training.detector_models_cfg import (
check_update_config as check_update_config_detector,
)
from happypose.pose_estimators.cosypose.cosypose.training.detector_models_cfg import (
create_model_detector,
)
from happypose.pose_estimators.cosypose.cosypose.training.pose_models_cfg import (
check_update_config as check_update_config_pose,
)
Expand All @@ -45,8 +38,8 @@

# Pose estimator
from happypose.toolbox.datasets.datasets_cfg import make_object_dataset
from happypose.toolbox.inference import load_detector
from happypose.toolbox.lib3d.rigid_mesh_database import MeshDataBase
from happypose.toolbox.renderer.panda3d_batch_renderer import Panda3dBatchRenderer
from happypose.toolbox.utils.distributed import get_rank, get_tmp_dir
from happypose.toolbox.utils.logging import get_logger

Expand All @@ -58,35 +51,50 @@
logger = get_logger(__name__)


def load_detector(run_id, ds_name):
run_dir = EXP_DIR / run_id
# cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader)
cfg = yaml.load((run_dir / "config.yaml").read_text(), Loader=yaml.UnsafeLoader)
cfg = check_update_config_detector(cfg)
label_to_category_id = cfg.label_to_category_id
model = create_model_detector(cfg, len(label_to_category_id))
ckpt = torch.load(run_dir / "checkpoint.pth.tar", map_location=device)
ckpt = ckpt["state_dict"]
model.load_state_dict(ckpt)
model = model.to(device).eval()
model.cfg = cfg
model.config = cfg
model = Detector(model, ds_name)
return model


def load_pose_models(coarse_run_id, refiner_run_id, n_workers):
# def load_detector(run_id, ds_name):
# run_dir = EXP_DIR / run_id
# # cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader)
# cfg = yaml.load((run_dir / "config.yaml").read_text(), Loader=yaml.UnsafeLoader)
# cfg = check_update_config_detector(cfg)
# label_to_category_id = cfg.label_to_category_id
# model = create_model_detector(cfg, len(label_to_category_id))
# ckpt = torch.load(run_dir / "checkpoint.pth.tar", map_location=device)
# ckpt = ckpt["state_dict"]
# model.load_state_dict(ckpt)
# model = model.to(device).eval()
# model.cfg = cfg
# model.config = cfg
# model = Detector(model, ds_name)
# return model


def load_pose_models(coarse_run_id, refiner_run_id, n_workers, renderer_type="panda3d"):
run_dir = EXP_DIR / coarse_run_id
# cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader)
cfg = yaml.load((run_dir / "config.yaml").read_text(), Loader=yaml.UnsafeLoader)
cfg = check_update_config_pose(cfg)

object_dataset = make_object_dataset("ycbv")
renderer = Panda3dBatchRenderer(
object_dataset,
n_workers=n_workers,
preload_cache=False,
)
if renderer_type == "panda3d":
from happypose.toolbox.renderer.panda3d_batch_renderer import (
Panda3dBatchRenderer,
)

renderer = Panda3dBatchRenderer(
object_dataset,
n_workers=n_workers,
preload_cache=True,
)
elif renderer_type == "bullet":
from happypose.toolbox.renderer.bullet_batch_renderer import BulletBatchRenderer

renderer = BulletBatchRenderer(
object_dataset,
n_workers=n_workers,
preload_cache=True,
)
else:
raise ValueError(f"Renderer {renderer_type} not supported")

mesh_db = MeshDataBase.from_object_ds(object_dataset)
mesh_db_batched = mesh_db.batched().to(device)

Expand Down Expand Up @@ -173,7 +181,8 @@ def run_eval(
# Load detector model
if cfg.inference.detection_type == "detector":
assert cfg.detector_run_id is not None
detector_model = load_detector(cfg.detector_run_id, cfg.ds_name)
# detector_model = load_detector(cfg.detector_run_id, cfg.ds_name)
detector_model = load_detector(cfg.detector_run_id)
elif cfg.inference.detection_type == "gt":
detector_model = None
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,11 @@
def create_eval_cfg(
cfg: EvalConfig,
detection_type: str,
coarse_estimation_type: str,
ds_name: str,
) -> Tuple[str, EvalConfig]:
cfg = copy.deepcopy(cfg)

cfg.inference.detection_type = detection_type
cfg.inference.coarse_estimation_type = coarse_estimation_type
cfg.ds_name = ds_name

if detection_type == "detector":
Expand All @@ -88,7 +86,7 @@ def create_eval_cfg(
msg = f"Unknown detector type {cfg.detector_type}"
raise ValueError(msg)

name = generate_save_key(detection_type, coarse_estimation_type)
name = generate_save_key(detection_type, "cosycoarse")

return name, cfg

Expand All @@ -109,11 +107,10 @@ def run_full_eval(cfg: FullEvalConfig) -> None:
for ds_name in cfg.ds_names:
# create the EvalConfig objects that we will call `run_eval` on
eval_configs: Dict[str, EvalConfig] = {}
for detection_type, coarse_estimation_type in cfg.detection_coarse_types:
for detection_type in cfg.detection_coarse_types:
name, cfg_ = create_eval_cfg(
cfg,
detection_type,
coarse_estimation_type,
ds_name,
)
eval_configs[name] = cfg_
Expand Down
2 changes: 1 addition & 1 deletion happypose/pose_estimators/megapose/inference/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def assert_detections_valid(detections: DetectionsType) -> None:
@dataclass
class InferenceConfig:
# TODO: move detection_type outside of here
detection_type: str = "detector" # ['detector', 'gt']
detection_type: str = "detector" # ['detector', 'gt', 'exte']
coarse_estimation_type: str = "SO3_grid"
SO3_grid_size: int = 576
n_refiner_iterations: int = 5
Expand Down

0 comments on commit 1db2bd6

Please sign in to comment.