From 1510650656cb2dd5e30cf6b6c6cc90e43e1c5768 Mon Sep 17 00:00:00 2001 From: MedericFourmy Date: Tue, 17 Oct 2023 13:50:33 -0400 Subject: [PATCH 1/9] In bop formatting, select the coarse estimate with highest score --- .../megapose/evaluation/bop.py | 49 +++++++++++++++++++ .../scripts/run_full_megapose_eval.py | 9 ++-- 2 files changed, 54 insertions(+), 4 deletions(-) diff --git a/happypose/pose_estimators/megapose/evaluation/bop.py b/happypose/pose_estimators/megapose/evaluation/bop.py index 0729e167..c879ad86 100644 --- a/happypose/pose_estimators/megapose/evaluation/bop.py +++ b/happypose/pose_estimators/megapose/evaluation/bop.py @@ -36,6 +36,7 @@ from happypose.pose_estimators.megapose.evaluation.eval_config import BOPEvalConfig from happypose.toolbox.datasets.scene_dataset import ObjectData from happypose.toolbox.inference.utils import make_detections_from_object_data +from happypose.toolbox.utils.tensor_collection import PandasTensorCollection device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -141,12 +142,60 @@ def convert_results_to_coco(results_path, out_json_path, detection_method): return + + +def filter_pose_estimates( + data_TCO: PandasTensorCollection, + top_K: int, + group_cols: list[str], + filter_field: str, + ascending: bool = False +) -> PandasTensorCollection: + """Filter the pose estimates by retaining only the top-K coarse model scores. + + Retain only the top_K estimates corresponding to each hypothesis_id + + Args: + top_K: how many estimates to retain + filter_field: The field to filter estimates by + """ + + # TODO: refactor with definition in pose_estimator.py + + df = data_TCO.infos + + # Logic from https://stackoverflow.com/a/40629420 + df = df.sort_values(filter_field, ascending=ascending).groupby(group_cols).head(top_K) + + data_TCO_filtered = data_TCO[df.index.tolist()] + + return data_TCO_filtered + + +def get_best_coarse_predictions(coarse_preds: PandasTensorCollection): + group_cols = ["scene_id", "view_id", "label", "instance_id"] + coarse_preds = filter_pose_estimates(coarse_preds, top_K=1, group_cols=group_cols, filter_field='coarse_score', ascending=False) + coarse_preds.infos = coarse_preds.infos.rename(columns={'coarse_score': 'pose_score'}) + return coarse_preds + + def convert_results_to_bop( results_path: Path, out_csv_path: Path, method: str, use_pose_score: bool = True ): + """ + results_path: path to file storing a pickled dictionary, + with a "predictions" key storing all results of a given evaluation + out_csv_path: path where bop format csv is saved + method: key to one of the available method predictions + use_pose_score: if true, uses the score obtained from the pose estimator, otherwise from the detector + """ + predictions = torch.load(results_path)["predictions"] predictions = predictions[method] + if method=='coarse': + predictions = get_best_coarse_predictions(predictions) + print("Predictions from:", results_path) print("Method:", method) print("Number of predictions: ", len(predictions)) diff --git a/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py b/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py index 70313440..8a6a4c7f 100644 --- a/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py +++ b/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py @@ -141,12 +141,13 @@ def run_full_eval(cfg: FullEvalConfig) -> None: if not cfg.skip_inference: eval_out = run_eval(eval_cfg) - # If we are skpping the inference mimic the output that run_eval + # If we are skipping the inference, mimic the output that run_eval # would have produced so that we can run the bop_eval else: # Otherwise hack the output so we can run the BOP eval if get_rank() == 0: results_dir = get_save_dir(eval_cfg) - pred_keys = ["refiner/final"] + # pred_keys = ["refiner/final"] + pred_keys = ["coarse"] if eval_cfg.inference.run_depth_refiner: pred_keys.append("depth_refiner") eval_out = { @@ -162,7 +163,7 @@ def run_full_eval(cfg: FullEvalConfig) -> None: # Run the bop eval for each type of prediction if cfg.run_bop_eval and get_rank() == 0: - bop_eval_keys = set(("refiner/final", "depth_refiner")) + bop_eval_keys = set(("coarse", "refiner/final", "depth_refiner")) bop_eval_keys = bop_eval_keys.intersection(set(eval_out["pred_keys"])) for method in bop_eval_keys: @@ -175,7 +176,7 @@ def run_full_eval(cfg: FullEvalConfig) -> None: split="test", eval_dir=eval_out["save_dir"] / "bop_evaluation", method=method, - convert_only=False, + convert_only=eval_cfg.convert_only, ) bop_eval_cfgs.append(bop_eval_cfg) From a13c6b1f0918f8f591b389b1a3983d1688edbe53 Mon Sep 17 00:00:00 2001 From: MedericFourmy Date: Tue, 17 Oct 2023 13:52:17 -0400 Subject: [PATCH 2/9] convert_only option in main run eval script --- happypose/pose_estimators/megapose/evaluation/eval_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/happypose/pose_estimators/megapose/evaluation/eval_config.py b/happypose/pose_estimators/megapose/evaluation/eval_config.py index 18b628e5..db885ea8 100644 --- a/happypose/pose_estimators/megapose/evaluation/eval_config.py +++ b/happypose/pose_estimators/megapose/evaluation/eval_config.py @@ -90,6 +90,7 @@ class FullEvalConfig(EvalConfig): ds_names: Optional[List[str]] = None run_bop_eval: bool = True modelnet_categories: Optional[List[str]] = None + convert_only: bool = False @dataclass From b9f296ba32705a7f510aefa8542dace73175dbe6 Mon Sep 17 00:00:00 2001 From: MedericFourmy Date: Wed, 18 Oct 2023 03:39:50 -0400 Subject: [PATCH 3/9] Remove unused variable --- .../megapose/scripts/run_full_megapose_eval.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py b/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py index 8a6a4c7f..1c88d8c0 100644 --- a/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py +++ b/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py @@ -54,17 +54,6 @@ logger = get_logger(__name__) -BOP_DATASET_NAMES = [ - "lm", - "lmo", - "tless", - "tudl", - "icbin", - "itodd", - "hb", - "ycbv", - # 'hope', -] BOP_TEST_DATASETS = [ "lmo.bop19", From fb18b14285b23e55a314ba777a5cc250011bbf5b Mon Sep 17 00:00:00 2001 From: MedericFourmy Date: Wed, 18 Oct 2023 04:00:30 -0400 Subject: [PATCH 4/9] Add coarse eval choice in main megapose run evaluation script --- .../pose_estimators/megapose/evaluation/eval_config.py | 2 +- .../megapose/scripts/run_full_megapose_eval.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/happypose/pose_estimators/megapose/evaluation/eval_config.py b/happypose/pose_estimators/megapose/evaluation/eval_config.py index db885ea8..0757686a 100644 --- a/happypose/pose_estimators/megapose/evaluation/eval_config.py +++ b/happypose/pose_estimators/megapose/evaluation/eval_config.py @@ -89,7 +89,7 @@ class FullEvalConfig(EvalConfig): detection_coarse_types: Optional[List] = None ds_names: Optional[List[str]] = None run_bop_eval: bool = True - modelnet_categories: Optional[List[str]] = None + eval_coarse_also: bool = False convert_only: bool = False diff --git a/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py b/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py index 1c88d8c0..f33f6b8b 100644 --- a/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py +++ b/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py @@ -118,7 +118,7 @@ def run_full_eval(cfg: FullEvalConfig) -> None: # create the EvalConfig objects that we will call `run_eval` on eval_configs: Dict[str, EvalConfig] = dict() - for (detection_type, coarse_estimation_type) in cfg.detection_coarse_types: + for detection_type, coarse_estimation_type in cfg.detection_coarse_types: name, cfg_ = create_eval_cfg(cfg, detection_type, coarse_estimation_type, ds_name) eval_configs[name] = cfg_ @@ -135,8 +135,7 @@ def run_full_eval(cfg: FullEvalConfig) -> None: else: # Otherwise hack the output so we can run the BOP eval if get_rank() == 0: results_dir = get_save_dir(eval_cfg) - # pred_keys = ["refiner/final"] - pred_keys = ["coarse"] + pred_keys = ["coarse", "refiner/final"] if eval_cfg.inference.run_depth_refiner: pred_keys.append("depth_refiner") eval_out = { @@ -151,8 +150,11 @@ def run_full_eval(cfg: FullEvalConfig) -> None: # Run the bop eval for each type of prediction if cfg.run_bop_eval and get_rank() == 0: + bop_eval_keys = set(("refiner/final", "depth_refiner")) + if cfg.eval_coarse_also: + bop_eval_keys.add('coarse') - bop_eval_keys = set(("coarse", "refiner/final", "depth_refiner")) + # Remove from evaluation predictions that were not produced at inference time bop_eval_keys = bop_eval_keys.intersection(set(eval_out["pred_keys"])) for method in bop_eval_keys: From cec03a1fa96526278754fdea0dec0f77b245929e Mon Sep 17 00:00:00 2001 From: MedericFourmy Date: Mon, 23 Oct 2023 10:13:09 -0400 Subject: [PATCH 5/9] Refactor filter_pose_estimates --- .../megapose/evaluation/bop.py | 34 +-------------- .../megapose/inference/pose_estimator.py | 43 ++++++------------- happypose/toolbox/utils/tensor_collection.py | 29 +++++++++++++ .../megapose_estimator_visualization.ipynb | 18 +++++--- 4 files changed, 55 insertions(+), 69 deletions(-) diff --git a/happypose/pose_estimators/megapose/evaluation/bop.py b/happypose/pose_estimators/megapose/evaluation/bop.py index c879ad86..519e9429 100644 --- a/happypose/pose_estimators/megapose/evaluation/bop.py +++ b/happypose/pose_estimators/megapose/evaluation/bop.py @@ -36,7 +36,7 @@ from happypose.pose_estimators.megapose.evaluation.eval_config import BOPEvalConfig from happypose.toolbox.datasets.scene_dataset import ObjectData from happypose.toolbox.inference.utils import make_detections_from_object_data -from happypose.toolbox.utils.tensor_collection import PandasTensorCollection +from happypose.toolbox.utils.tensor_collection import PandasTensorCollection, filter_top_pose_estimates device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -142,39 +142,9 @@ def convert_results_to_coco(results_path, out_json_path, detection_method): return - - -def filter_pose_estimates( - data_TCO: PandasTensorCollection, - top_K: int, - group_cols: list[str], - filter_field: str, - ascending: bool = False -) -> PandasTensorCollection: - """Filter the pose estimates by retaining only the top-K coarse model scores. - - Retain only the top_K estimates corresponding to each hypothesis_id - - Args: - top_K: how many estimates to retain - filter_field: The field to filter estimates by - """ - - # TODO: refactor with definition in pose_estimator.py - - df = data_TCO.infos - - # Logic from https://stackoverflow.com/a/40629420 - df = df.sort_values(filter_field, ascending=ascending).groupby(group_cols).head(top_K) - - data_TCO_filtered = data_TCO[df.index.tolist()] - - return data_TCO_filtered - - def get_best_coarse_predictions(coarse_preds: PandasTensorCollection): group_cols = ["scene_id", "view_id", "label", "instance_id"] - coarse_preds = filter_pose_estimates(coarse_preds, top_K=1, group_cols=group_cols, filter_field='coarse_score', ascending=False) + coarse_preds = filter_top_pose_estimates(coarse_preds, top_K=1, group_cols=group_cols, filter_field='coarse_score', ascending=False) coarse_preds.infos = coarse_preds.infos.rename(columns={'coarse_score': 'pose_score'}) return coarse_preds diff --git a/happypose/pose_estimators/megapose/inference/pose_estimator.py b/happypose/pose_estimators/megapose/inference/pose_estimator.py index e54af605..8a242c17 100644 --- a/happypose/pose_estimators/megapose/inference/pose_estimator.py +++ b/happypose/pose_estimators/megapose/inference/pose_estimator.py @@ -49,7 +49,7 @@ from happypose.toolbox.lib3d.cosypose_ops import TCO_init_from_boxes_autodepth_with_R from happypose.toolbox.utils import transform_utils from happypose.toolbox.utils.logging import get_logger -from happypose.toolbox.utils.tensor_collection import PandasTensorCollection +from happypose.toolbox.utils.tensor_collection import PandasTensorCollection, filter_top_pose_estimates from happypose.toolbox.utils.timer import Timer logger = get_logger(__name__) @@ -591,9 +591,13 @@ def run_inference_pipeline( timing_str += f"coarse={coarse_extra_data['time']:.2f}, " # Extract top-K coarse hypotheses - data_TCO_filtered = self.filter_pose_estimates( - data_TCO_coarse, top_K=n_pose_hypotheses, filter_field="coarse_logit" + data_TCO_filtered = filter_top_pose_estimates( + data_TCO_coarse, + top_K=n_pose_hypotheses, + group_cols=["batch_im_id", "label", "instance_id"], + filter_field="coarse_logit" ) + else: data_TCO_coarse = coarse_estimates coarse_extra_data = None @@ -618,8 +622,11 @@ def run_inference_pipeline( timing_str += f"scoring={scoring_extra_data['time']:.2f}, " # Extract the highest scoring pose estimate for each instance_id - data_TCO_final_scored = self.filter_pose_estimates( - data_TCO_scored, top_K=1, filter_field="pose_logit" + data_TCO_final_scored = self.filter_top_pose_estimates( + data_TCO_scored, + top_K=1, + group_cols=["batch_im_id", "label", "instance_id"], + filter_field="pose_logit" ) # Optionally run ICP or TEASER++ @@ -649,29 +656,3 @@ def run_inference_pipeline( extra_data["depth_refiner"] = {"preds": data_TCO_depth_refiner} return data_TCO_final, extra_data - - def filter_pose_estimates( - self, - data_TCO: PoseEstimatesType, - top_K: int, - filter_field: str, - ascending: bool = False, - ) -> PoseEstimatesType: - """Filter the pose estimates by retaining only the top-K coarse model scores. - - Retain only the top_K estimates corresponding to each hypothesis_id - - Args: - top_K: how many estimates to retain - filter_field: The field to filter estimates by - """ - - df = data_TCO.infos - - group_cols = ["batch_im_id", "label", "instance_id"] - # Logic from https://stackoverflow.com/a/40629420 - df = df.sort_values(filter_field, ascending=ascending).groupby(group_cols).head(top_K) - - data_TCO_filtered = data_TCO[df.index.tolist()] - - return data_TCO_filtered \ No newline at end of file diff --git a/happypose/toolbox/utils/tensor_collection.py b/happypose/toolbox/utils/tensor_collection.py index f0c35c83..2c005d23 100644 --- a/happypose/toolbox/utils/tensor_collection.py +++ b/happypose/toolbox/utils/tensor_collection.py @@ -194,3 +194,32 @@ def __setstate__(self, state): self.__init__(state["infos"], **state["tensors"]) self.meta = state["meta"] return + + + +def filter_top_pose_estimates( + data_TCO: PandasTensorCollection, + top_K: int, + group_cols: list[str], + filter_field: str, + ascending: bool = False +) -> PandasTensorCollection: + """Filter the pose estimates by retaining only the top-K coarse model scores. + + Retain only the top_K estimates corresponding to each hypothesis_id + + Args: + top_K: how many estimates to retain + group_cols: group of columns among which sorting should be done + filter_field: the field to filter estimates by + ascending: should filter_field + """ + + df = data_TCO.infos + + # Logic from https://stackoverflow.com/a/40629420 + df = df.sort_values(filter_field, ascending=ascending).groupby(group_cols).head(top_K) + + data_TCO_filtered = data_TCO[df.index.tolist()] + + return data_TCO_filtered diff --git a/notebooks/megapose/megapose_estimator_visualization.ipynb b/notebooks/megapose/megapose_estimator_visualization.ipynb index 9a557e5e..be23a8b5 100644 --- a/notebooks/megapose/megapose_estimator_visualization.ipynb +++ b/notebooks/megapose/megapose_estimator_visualization.ipynb @@ -243,6 +243,8 @@ "metadata": {}, "outputs": [], "source": [ + "from happypose.toolbox.utils.tensor_collection import filter_pose_estimates\n", + "\n", "# Options for inference\n", "use_gt_detections = True # Note, if you aren't using gt_detections then this should be false\n", "n_refiner_iterations = 5\n", @@ -294,10 +296,11 @@ " f\"model_time={extra_data['model_time']:.2f}, render_time={extra_data['render_time']:.2f}\")\n", " \n", " # Extract top-K coarse hypotheses\n", - " data_TCO_filtered = pose_estimator.filter_pose_estimates(data_TCO_coarse, \n", - " top_K=n_pose_hypotheses, \n", - " filter_field='coarse_logit')\n", - " \n", + " data_TCO_filtered = filter_pose_estimates(data_TCO_coarse,\n", + " top_K=n_pose_hypotheses, \n", + " group_cols=[\"batch_im_id\", \"label\", \"instance_id\"], \n", + " filter_field='coarse_logit')\n", + "\n", " # Refine the top_K coarse hypotheses\n", " preds, extra_data = pose_estimator.forward_refiner(observation_tensor, data_TCO_filtered, \n", " n_iterations=n_refiner_iterations, keep_all_outputs=True)\n", @@ -311,7 +314,10 @@ " data_TCO_scored, extra_data = pose_estimator.forward_scoring_model(observation_tensor, data_TCO_refined)\n", "\n", " # Extract the highest scoring pose estimate for each instance_id\n", - " data_TCO_final = pose_estimator.filter_pose_estimates(data_TCO_scored, top_K=1, filter_field='pose_logit')\n", + " data_TCO_final = filter_pose_estimates(data_TCO_scored, \n", + " top_K=1, \n", + " group_cols=[\"batch_im_id\", \"label\", \"instance_id\"], \n", + " filter_field='pose_logit')\n", " \n", " \n", " if run_depth_refiner:\n", @@ -969,4 +975,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} From 625c378b6148c604e281b0b8b6795c727b2237ee Mon Sep 17 00:00:00 2001 From: MedericFourmy Date: Thu, 26 Oct 2023 13:34:12 -0400 Subject: [PATCH 6/9] Forgot one filter_top_pose_estimates occurence --- happypose/pose_estimators/megapose/inference/pose_estimator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/happypose/pose_estimators/megapose/inference/pose_estimator.py b/happypose/pose_estimators/megapose/inference/pose_estimator.py index 8a242c17..8f05e703 100644 --- a/happypose/pose_estimators/megapose/inference/pose_estimator.py +++ b/happypose/pose_estimators/megapose/inference/pose_estimator.py @@ -622,7 +622,7 @@ def run_inference_pipeline( timing_str += f"scoring={scoring_extra_data['time']:.2f}, " # Extract the highest scoring pose estimate for each instance_id - data_TCO_final_scored = self.filter_top_pose_estimates( + data_TCO_final_scored = filter_top_pose_estimates( data_TCO_scored, top_K=1, group_cols=["batch_im_id", "label", "instance_id"], From fcf75a9fb9ff679a36df9bc62ff0ba1692728e37 Mon Sep 17 00:00:00 2001 From: MedericFourmy Date: Thu, 26 Oct 2023 13:34:51 -0400 Subject: [PATCH 7/9] Use metadata to obtain different total runtimes for coarse, refiner, depth_refiner --- .../megapose/evaluation/prediction_runner.py | 60 +++++++++++++------ 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/happypose/pose_estimators/megapose/evaluation/prediction_runner.py b/happypose/pose_estimators/megapose/evaluation/prediction_runner.py index 4c4e0ec3..a7166349 100644 --- a/happypose/pose_estimators/megapose/evaluation/prediction_runner.py +++ b/happypose/pose_estimators/megapose/evaluation/prediction_runner.py @@ -135,7 +135,6 @@ def run_inference_pipeline( coarse_estimates.infos["instance_id"] = 0 run_detector = False - t = time.time() preds, extra_data = pose_estimator.run_inference_pipeline( obs_tensor, detections=detections, @@ -147,7 +146,6 @@ def run_inference_pipeline( bsz_images=self.inference_cfg.bsz_images, bsz_objects=self.inference_cfg.bsz_objects, ) - elapsed = time.time() - t # TODO (lmanuelli): Process this into a dict with keys like # - 'refiner/iteration=1` @@ -156,25 +154,32 @@ def run_inference_pipeline( # Note: Since we support multi-hypotheses we need to potentially # go back and extract out the 'refiner/iteration=1`, `refiner/iteration=5` things for the ones that were actually the highest scoring at the end. - all_preds = dict() - data_TCO_refiner = extra_data["refiner"]["preds"] - all_preds = { "final": preds, - f"refiner/iteration={self.inference_cfg.n_refiner_iterations}": data_TCO_refiner, - "refiner/final": data_TCO_refiner, + f"refiner/iteration={self.inference_cfg.n_refiner_iterations}": extra_data["refiner"]["preds"], + "refiner/final": extra_data["refiner"]["preds"], "coarse": extra_data["coarse"]["preds"], + "coarse_filter": extra_data["coarse_filter"]["preds"], + } + + # Only keep necessary metadata + del extra_data['coarse']['data']['TCO'] + all_preds_data = { + 'coarse': extra_data['coarse']['data'], + 'refiner': extra_data['refiner']['data'], + 'scoring': extra_data['scoring'], } if self.inference_cfg.run_depth_refiner: - all_preds[f"depth_refiner"] = extra_data["depth_refiner"]["preds"] + all_preds["depth_refiner"] = extra_data["depth_refiner"]["preds"] + all_preds_data["depth_refiner"] = extra_data["depth_refiner"]["data"] for k, v in all_preds.items(): if "mask" in v.tensors: breakpoint() v.delete_tensor("mask") - return all_preds + return all_preds, all_preds_data def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstimatesType]: """Runs predictions @@ -237,7 +242,7 @@ def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstima cuda_timer = CudaTimer() cuda_timer.start() with torch.no_grad(): - all_preds = self.run_inference_pipeline( + all_preds, all_preds_data = self.run_inference_pipeline( pose_estimator, obs_tensor, gt_detections, sam_detections, initial_estimates=initial_data ) cuda_timer.end() @@ -246,15 +251,34 @@ def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstima total_duration = duration + dt_det # Add metadata to the predictions for later evaluation - for k, v in all_preds.items(): - v.infos['time'] = total_duration - v.infos['scene_id'] = scene_id - v.infos['view_id'] = view_id - predictions_list[k].append(v) - - # Concatenate the lists of PandasTensorCollections + for pred_name, pred in all_preds.items(): + pred.infos['time'] = dt_det + compute_pose_est_total_time(all_preds_data, pred_name) + pred.infos['scene_id'] = scene_id + pred.infos['view_id'] = view_id + predictions_list[pred_name].append(pred) + + # Concatenate the lists of PandasTensorCollections predictions = dict() for k, v in predictions_list.items(): predictions[k] = tc.concatenate(v) - return predictions \ No newline at end of file + return predictions + + +def compute_pose_est_total_time(all_preds_data: dict, pred_name: str): + # all_preds_data: dict_keys(['final', 'refiner/iteration=5', 'refiner/final', 'coarse', 'coarse_filter']) # optionally 'depth_refiner' + dt_coarse = all_preds_data['coarse']['time'] + dt_coarse_refiner = dt_coarse + all_preds_data['refiner']['time'] + if 'depth_refiner' in all_preds_data: + dt_coarse_refiner_depth = dt_coarse_refiner + all_preds_data['depth_refiner']['time'] + + if pred_name.startswith('coarse'): + return dt_coarse + elif pred_name.startswith('refiner'): + return dt_coarse_refiner + elif pred_name == 'depth_refiner': + return dt_coarse_refiner_depth + elif pred_name == 'final': + return dt_coarse_refiner_depth if 'depth_refiner' in all_preds_data else dt_coarse_refiner + else: + raise ValueError(f'{pred_name} extra data not in {all_preds_data.keys()}') \ No newline at end of file From ee460eefc4a6189ea6393c42efce57f48de98718 Mon Sep 17 00:00:00 2001 From: Guilhem Saurel Date: Fri, 27 Oct 2023 16:02:25 +0200 Subject: [PATCH 8/9] apply black on 4 files --- .../megapose/evaluation/eval_config.py | 2 - .../megapose/evaluation/prediction_runner.py | 112 ++++++++++++------ .../megapose/inference/pose_estimator.py | 62 ++++++---- .../scripts/run_full_megapose_eval.py | 27 +++-- 4 files changed, 126 insertions(+), 77 deletions(-) diff --git a/happypose/pose_estimators/megapose/evaluation/eval_config.py b/happypose/pose_estimators/megapose/evaluation/eval_config.py index 0757686a..3a81d58e 100644 --- a/happypose/pose_estimators/megapose/evaluation/eval_config.py +++ b/happypose/pose_estimators/megapose/evaluation/eval_config.py @@ -84,7 +84,6 @@ class EvalConfig: @dataclass class FullEvalConfig(EvalConfig): - # Full eval detection_coarse_types: Optional[List] = None ds_names: Optional[List[str]] = None @@ -95,7 +94,6 @@ class FullEvalConfig(EvalConfig): @dataclass class BOPEvalConfig: - results_path: str dataset: str split: str diff --git a/happypose/pose_estimators/megapose/evaluation/prediction_runner.py b/happypose/pose_estimators/megapose/evaluation/prediction_runner.py index a7166349..cd0a9bf3 100644 --- a/happypose/pose_estimators/megapose/evaluation/prediction_runner.py +++ b/happypose/pose_estimators/megapose/evaluation/prediction_runner.py @@ -40,17 +40,19 @@ ObservationTensor, PoseEstimatesType, ) -from happypose.pose_estimators.megapose.config import ( - BOP_DS_DIR -) +from happypose.pose_estimators.megapose.config import BOP_DS_DIR from happypose.pose_estimators.megapose.evaluation.bop import ( get_sam_detections, - load_sam_predictions + load_sam_predictions, ) from happypose.pose_estimators.megapose.training.utils import CudaTimer from happypose.toolbox.datasets.samplers import DistributedSceneSampler -from happypose.toolbox.datasets.scene_dataset import SceneDataset, SceneObservation, ObjectData +from happypose.toolbox.datasets.scene_dataset import ( + SceneDataset, + SceneObservation, + ObjectData, +) from happypose.toolbox.utils.distributed import get_rank, get_tmp_dir, get_world_size from happypose.toolbox.utils.logging import get_logger @@ -62,7 +64,7 @@ logger = get_logger(__name__) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class PredictionRunner: @@ -73,13 +75,14 @@ def __init__( batch_size: int = 1, n_workers: int = 4, ) -> None: - self.inference_cfg = inference_cfg self.rank = get_rank() self.world_size = get_world_size() self.tmp_dir = get_tmp_dir() - sampler = DistributedSceneSampler(scene_ds, num_replicas=self.world_size, rank=self.rank) + sampler = DistributedSceneSampler( + scene_ds, num_replicas=self.world_size, rank=self.rank + ) self.sampler = sampler self.scene_ds = scene_ds dataloader = DataLoader( @@ -125,13 +128,17 @@ def run_inference_pipeline( run_detector = True else: - raise ValueError(f"Unknown detection type {self.inference_cfg.detection_type}") + raise ValueError( + f"Unknown detection type {self.inference_cfg.detection_type}" + ) coarse_estimates = None if self.inference_cfg.coarse_estimation_type == "external": # TODO (ylabbe): This is hacky, clean this for modelnet eval. coarse_estimates = initial_estimates - coarse_estimates = happypose.toolbox.inference.utils.add_instance_id(coarse_estimates) + coarse_estimates = happypose.toolbox.inference.utils.add_instance_id( + coarse_estimates + ) coarse_estimates.infos["instance_id"] = 0 run_detector = False @@ -156,18 +163,20 @@ def run_inference_pipeline( all_preds = { "final": preds, - f"refiner/iteration={self.inference_cfg.n_refiner_iterations}": extra_data["refiner"]["preds"], + f"refiner/iteration={self.inference_cfg.n_refiner_iterations}": extra_data[ + "refiner" + ]["preds"], "refiner/final": extra_data["refiner"]["preds"], "coarse": extra_data["coarse"]["preds"], "coarse_filter": extra_data["coarse_filter"]["preds"], } # Only keep necessary metadata - del extra_data['coarse']['data']['TCO'] + del extra_data["coarse"]["data"]["TCO"] all_preds_data = { - 'coarse': extra_data['coarse']['data'], - 'refiner': extra_data['refiner']['data'], - 'scoring': extra_data['scoring'], + "coarse": extra_data["coarse"]["data"], + "refiner": extra_data["refiner"]["data"], + "scoring": extra_data["scoring"], } if self.inference_cfg.run_depth_refiner: @@ -181,7 +190,9 @@ def run_inference_pipeline( return all_preds, all_preds_data - def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstimatesType]: + def get_predictions( + self, pose_estimator: PoseEstimator + ) -> Dict[str, PoseEstimatesType]: """Runs predictions Returns: A dict with keys @@ -202,15 +213,17 @@ def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstima ###### # Temporary solution if self.inference_cfg.detection_type == "sam": - df_all_dets, df_targets = load_sam_predictions(self.scene_ds.ds_dir.name, self.scene_ds.ds_dir) + df_all_dets, df_targets = load_sam_predictions( + self.scene_ds.ds_dir.name, self.scene_ds.ds_dir + ) for n, data in enumerate(tqdm(self.dataloader)): # data is a dict rgb = data["rgb"] depth = data["depth"] K = data["cameras"].K - im_info = data['im_infos'][0] - scene_id, view_id = im_info['scene_id'], im_info['view_id'] + im_info = data["im_infos"][0] + scene_id, view_id = im_info["scene_id"], im_info["view_id"] # Dirty but avoids creating error when running with real detector dt_det = 0 @@ -220,8 +233,13 @@ def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstima ###### # Temporary solution if self.inference_cfg.detection_type == "sam": - # We assume a unique image ("view") associated with a unique scene_id is - sam_detections = get_sam_detections(data=data, df_all_dets=df_all_dets, df_targets=df_targets, dt_det=dt_det) + # We assume a unique image ("view") associated with a unique scene_id is + sam_detections = get_sam_detections( + data=data, + df_all_dets=df_all_dets, + df_targets=df_targets, + dt_det=dt_det, + ) else: sam_detections = None gt_detections = data["gt_detections"].cuda() @@ -236,14 +254,22 @@ def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstima if n == 0: with torch.no_grad(): self.run_inference_pipeline( - pose_estimator, obs_tensor, gt_detections, sam_detections, initial_estimates=initial_data + pose_estimator, + obs_tensor, + gt_detections, + sam_detections, + initial_estimates=initial_data, ) cuda_timer = CudaTimer() cuda_timer.start() with torch.no_grad(): all_preds, all_preds_data = self.run_inference_pipeline( - pose_estimator, obs_tensor, gt_detections, sam_detections, initial_estimates=initial_data + pose_estimator, + obs_tensor, + gt_detections, + sam_detections, + initial_estimates=initial_data, ) cuda_timer.end() duration = cuda_timer.elapsed() @@ -252,11 +278,13 @@ def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstima # Add metadata to the predictions for later evaluation for pred_name, pred in all_preds.items(): - pred.infos['time'] = dt_det + compute_pose_est_total_time(all_preds_data, pred_name) - pred.infos['scene_id'] = scene_id - pred.infos['view_id'] = view_id + pred.infos["time"] = dt_det + compute_pose_est_total_time( + all_preds_data, pred_name + ) + pred.infos["scene_id"] = scene_id + pred.infos["view_id"] = view_id predictions_list[pred_name].append(pred) - + # Concatenate the lists of PandasTensorCollections predictions = dict() for k, v in predictions_list.items(): @@ -267,18 +295,24 @@ def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstima def compute_pose_est_total_time(all_preds_data: dict, pred_name: str): # all_preds_data: dict_keys(['final', 'refiner/iteration=5', 'refiner/final', 'coarse', 'coarse_filter']) # optionally 'depth_refiner' - dt_coarse = all_preds_data['coarse']['time'] - dt_coarse_refiner = dt_coarse + all_preds_data['refiner']['time'] - if 'depth_refiner' in all_preds_data: - dt_coarse_refiner_depth = dt_coarse_refiner + all_preds_data['depth_refiner']['time'] - - if pred_name.startswith('coarse'): + dt_coarse = all_preds_data["coarse"]["time"] + dt_coarse_refiner = dt_coarse + all_preds_data["refiner"]["time"] + if "depth_refiner" in all_preds_data: + dt_coarse_refiner_depth = ( + dt_coarse_refiner + all_preds_data["depth_refiner"]["time"] + ) + + if pred_name.startswith("coarse"): return dt_coarse - elif pred_name.startswith('refiner'): + elif pred_name.startswith("refiner"): return dt_coarse_refiner - elif pred_name == 'depth_refiner': - return dt_coarse_refiner_depth - elif pred_name == 'final': - return dt_coarse_refiner_depth if 'depth_refiner' in all_preds_data else dt_coarse_refiner + elif pred_name == "depth_refiner": + return dt_coarse_refiner_depth + elif pred_name == "final": + return ( + dt_coarse_refiner_depth + if "depth_refiner" in all_preds_data + else dt_coarse_refiner + ) else: - raise ValueError(f'{pred_name} extra data not in {all_preds_data.keys()}') \ No newline at end of file + raise ValueError(f"{pred_name} extra data not in {all_preds_data.keys()}") diff --git a/happypose/pose_estimators/megapose/inference/pose_estimator.py b/happypose/pose_estimators/megapose/inference/pose_estimator.py index 8f05e703..5018901a 100644 --- a/happypose/pose_estimators/megapose/inference/pose_estimator.py +++ b/happypose/pose_estimators/megapose/inference/pose_estimator.py @@ -49,12 +49,16 @@ from happypose.toolbox.lib3d.cosypose_ops import TCO_init_from_boxes_autodepth_with_R from happypose.toolbox.utils import transform_utils from happypose.toolbox.utils.logging import get_logger -from happypose.toolbox.utils.tensor_collection import PandasTensorCollection, filter_top_pose_estimates +from happypose.toolbox.utils.tensor_collection import ( + PandasTensorCollection, + filter_top_pose_estimates, +) from happypose.toolbox.utils.timer import Timer logger = get_logger(__name__) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + class PoseEstimator(PoseEstimationModule): """Performs inference for pose estimation.""" @@ -69,7 +73,6 @@ def __init__( bsz_images: int = 256, SO3_grid_size: int = 576, ) -> None: - super().__init__() self.coarse_model = coarse_model self.refiner_model = refiner_model @@ -90,7 +93,9 @@ def __init__( self.cfg = self.coarse_model.cfg self.mesh_db = self.coarse_model.mesh_db else: - raise ValueError("At least one of refiner_model or " " coarse_model must be specified.") + raise ValueError( + "At least one of refiner_model or " " coarse_model must be specified." + ) self.eval() @@ -153,7 +158,7 @@ def forward_refiner( model_time = 0.0 - for (batch_idx, (batch_ids,)) in enumerate(dl): + for batch_idx, (batch_ids,) in enumerate(dl): data_TCO_input_ = data_TCO_input[batch_ids] df_ = data_TCO_input_.infos TCO_input_ = data_TCO_input_.poses @@ -218,7 +223,8 @@ def forward_refiner( } logger.debug( - f"Pose prediction on {B} poses (n_iterations={n_iterations}):" f" {timer.stop()}" + f"Pose prediction on {B} poses (n_iterations={n_iterations}):" + f" {timer.stop()}" ) return preds, extra_data @@ -231,7 +237,6 @@ def forward_scoring_model( cuda_timer: bool = False, return_debug_data: bool = False, ) -> Tuple[PoseEstimatesType, dict]: - """Score the estimates using the coarse model. @@ -312,9 +317,7 @@ def forward_scoring_model( elapsed = time.time() - start_time - timing_str = ( - f"time: {elapsed:.2f}, model_time: {model_time:.2f}, render_time: {render_time:.2f}" - ) + timing_str = f"time: {elapsed:.2f}, model_time: {model_time:.2f}, render_time: {render_time:.2f}" extra_data = { "render_time": render_time, @@ -384,7 +387,6 @@ def forward_coarse_model( TCO_init = [] for (batch_ids,) in dl: - # b = bsz_images df_ = df_hypotheses.iloc[batch_ids.cpu().numpy()] @@ -418,7 +420,7 @@ def forward_coarse_model( ) del points_ - + out_ = coarse_model.forward_coarse( images=images_, K=K_, @@ -472,9 +474,7 @@ def forward_coarse_model( elapsed = time.time() - start_time - timing_str = ( - f"time: {elapsed:.2f}, model_time: {model_time:.2f}, render_time: {render_time:.2f}" - ) + timing_str = f"time: {elapsed:.2f}, model_time: {model_time:.2f}, render_time: {render_time:.2f}" extra_data = { "render_time": render_time, @@ -512,7 +512,9 @@ def run_depth_refiner( depth = observation.depth K = observation.K - refined_preds, extra_data = self.depth_refiner.refine_poses(predictions, depth=depth, K=K) + refined_preds, extra_data = self.depth_refiner.refine_poses( + predictions, depth=depth, K=K + ) return refined_preds, extra_data @@ -592,10 +594,10 @@ def run_inference_pipeline( # Extract top-K coarse hypotheses data_TCO_filtered = filter_top_pose_estimates( - data_TCO_coarse, - top_K=n_pose_hypotheses, - group_cols=["batch_im_id", "label", "instance_id"], - filter_field="coarse_logit" + data_TCO_coarse, + top_K=n_pose_hypotheses, + group_cols=["batch_im_id", "label", "instance_id"], + filter_field="coarse_logit", ) else: @@ -623,16 +625,18 @@ def run_inference_pipeline( # Extract the highest scoring pose estimate for each instance_id data_TCO_final_scored = filter_top_pose_estimates( - data_TCO_scored, - top_K=1, + data_TCO_scored, + top_K=1, group_cols=["batch_im_id", "label", "instance_id"], - filter_field="pose_logit" + filter_field="pose_logit", ) # Optionally run ICP or TEASER++ if run_depth_refiner: depth_refiner_start = time.time() - data_TCO_depth_refiner, _ = self.run_depth_refiner(observation, data_TCO_final_scored) + data_TCO_depth_refiner, _ = self.run_depth_refiner( + observation, data_TCO_final_scored + ) data_TCO_final = data_TCO_depth_refiner depth_refiner_time = time.time() - depth_refiner_start timing_str += f"depth refiner={depth_refiner_time:.2f}" @@ -646,9 +650,15 @@ def run_inference_pipeline( extra_data: dict = dict() extra_data["coarse"] = {"preds": data_TCO_coarse, "data": coarse_extra_data} extra_data["coarse_filter"] = {"preds": data_TCO_filtered} - extra_data["refiner_all_hypotheses"] = {"preds": preds, "data": refiner_extra_data} + extra_data["refiner_all_hypotheses"] = { + "preds": preds, + "data": refiner_extra_data, + } extra_data["scoring"] = {"preds": data_TCO_scored, "data": scoring_extra_data} - extra_data["refiner"] = {"preds": data_TCO_final_scored, "data": refiner_extra_data} + extra_data["refiner"] = { + "preds": data_TCO_final_scored, + "data": refiner_extra_data, + } extra_data["timing_str"] = timing_str extra_data["time"] = timer.elapsed() diff --git a/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py b/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py index f33f6b8b..42f1c7d5 100644 --- a/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py +++ b/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py @@ -15,7 +15,6 @@ """ - # Standard Library import copy import os @@ -47,9 +46,17 @@ HardwareConfig, ) -from happypose.pose_estimators.megapose.evaluation.evaluation import get_save_dir, generate_save_key, run_eval +from happypose.pose_estimators.megapose.evaluation.evaluation import ( + get_save_dir, + generate_save_key, + run_eval, +) from happypose.pose_estimators.megapose.evaluation.bop import run_evaluation -from happypose.toolbox.utils.distributed import get_rank, get_world_size, init_distributed_mode +from happypose.toolbox.utils.distributed import ( + get_rank, + get_world_size, + init_distributed_mode, +) from happypose.toolbox.utils.logging import get_logger, set_logging_level logger = get_logger(__name__) @@ -66,7 +73,9 @@ ] -MODELNET_TEST_DATASETS = [f"modelnet.{category}.test" for category in MODELNET_TEST_CATEGORIES] +MODELNET_TEST_DATASETS = [ + f"modelnet.{category}.test" for category in MODELNET_TEST_CATEGORIES +] def create_eval_cfg( @@ -75,7 +84,6 @@ def create_eval_cfg( coarse_estimation_type: str, ds_name: str, ) -> Tuple[str, EvalConfig]: - cfg = copy.deepcopy(cfg) cfg.inference.detection_type = detection_type @@ -101,7 +109,6 @@ def create_eval_cfg( def run_full_eval(cfg: FullEvalConfig) -> None: - bop_eval_cfgs = [] init_distributed_mode() @@ -115,17 +122,17 @@ def run_full_eval(cfg: FullEvalConfig) -> None: # Iterate over each dataset for ds_name in cfg.ds_names: - # create the EvalConfig objects that we will call `run_eval` on eval_configs: Dict[str, EvalConfig] = dict() for detection_type, coarse_estimation_type in cfg.detection_coarse_types: - name, cfg_ = create_eval_cfg(cfg, detection_type, coarse_estimation_type, ds_name) + name, cfg_ = create_eval_cfg( + cfg, detection_type, coarse_estimation_type, ds_name + ) eval_configs[name] = cfg_ # For each eval_cfg run the evaluation. # Note that the results get saved to disk for save_key, eval_cfg in eval_configs.items(): - # Run the inference if not cfg.skip_inference: eval_out = run_eval(eval_cfg) @@ -152,7 +159,7 @@ def run_full_eval(cfg: FullEvalConfig) -> None: if cfg.run_bop_eval and get_rank() == 0: bop_eval_keys = set(("refiner/final", "depth_refiner")) if cfg.eval_coarse_also: - bop_eval_keys.add('coarse') + bop_eval_keys.add("coarse") # Remove from evaluation predictions that were not produced at inference time bop_eval_keys = bop_eval_keys.intersection(set(eval_out["pred_keys"])) From e5f4c2a0505ec19d7c00e5f5fa48dc7356a878eb Mon Sep 17 00:00:00 2001 From: Guilhem Saurel Date: Fri, 27 Oct 2023 16:07:54 +0200 Subject: [PATCH 9/9] prepare for merge with next-pack --- .../megapose/evaluation/eval_config.py | 9 +- .../megapose/evaluation/prediction_runner.py | 91 ++++++++--------- .../megapose/inference/pose_estimator.py | 98 +++++++++---------- .../scripts/run_full_megapose_eval.py | 45 ++++----- 4 files changed, 111 insertions(+), 132 deletions(-) diff --git a/happypose/pose_estimators/megapose/evaluation/eval_config.py b/happypose/pose_estimators/megapose/evaluation/eval_config.py index 3a81d58e..398e52f3 100644 --- a/happypose/pose_estimators/megapose/evaluation/eval_config.py +++ b/happypose/pose_estimators/megapose/evaluation/eval_config.py @@ -17,7 +17,7 @@ # Standard Library from dataclasses import dataclass -from typing import List, Optional +from typing import Optional # MegaPose from happypose.pose_estimators.megapose.inference.types import InferenceConfig @@ -48,7 +48,8 @@ class EvalConfig: 2. If `run_id` is None, then use `config_id`, `run_comment`and `run_postfix` to create a `run_id` - In 2., the parameters of the config are set-up using the function `update_cfg_with_config_id`. + In 2., the parameters of the config are set-up using the function + `update_cfg_with_config_id`. """ # Network @@ -85,8 +86,8 @@ class EvalConfig: @dataclass class FullEvalConfig(EvalConfig): # Full eval - detection_coarse_types: Optional[List] = None - ds_names: Optional[List[str]] = None + detection_coarse_types: Optional[list] = None + ds_names: Optional[list[str]] = None run_bop_eval: bool = True eval_coarse_also: bool = False convert_only: bool = False diff --git a/happypose/pose_estimators/megapose/evaluation/prediction_runner.py b/happypose/pose_estimators/megapose/evaluation/prediction_runner.py index cd0a9bf3..eba04e3a 100644 --- a/happypose/pose_estimators/megapose/evaluation/prediction_runner.py +++ b/happypose/pose_estimators/megapose/evaluation/prediction_runner.py @@ -16,51 +16,35 @@ # Standard Library -import time from collections import defaultdict -from typing import Dict, Optional -from pathlib import Path - - -# Third Party -import numpy as np -import torch -from torch.utils.data import DataLoader -from tqdm import tqdm +from typing import Optional # MegaPose import happypose.pose_estimators.megapose import happypose.toolbox.utils.tensor_collection as tc -from happypose.pose_estimators.megapose.inference.pose_estimator import ( - PoseEstimator, + +# Third Party +import torch +from happypose.pose_estimators.megapose.evaluation.bop import ( + get_sam_detections, + load_sam_predictions, ) +from happypose.pose_estimators.megapose.inference.pose_estimator import PoseEstimator from happypose.pose_estimators.megapose.inference.types import ( DetectionsType, InferenceConfig, ObservationTensor, PoseEstimatesType, ) -from happypose.pose_estimators.megapose.config import BOP_DS_DIR -from happypose.pose_estimators.megapose.evaluation.bop import ( - get_sam_detections, - load_sam_predictions, -) - from happypose.pose_estimators.megapose.training.utils import CudaTimer from happypose.toolbox.datasets.samplers import DistributedSceneSampler -from happypose.toolbox.datasets.scene_dataset import ( - SceneDataset, - SceneObservation, - ObjectData, -) -from happypose.toolbox.utils.distributed import get_rank, get_tmp_dir, get_world_size -from happypose.toolbox.utils.logging import get_logger - +from happypose.toolbox.datasets.scene_dataset import SceneDataset, SceneObservation # Temporary -from happypose.toolbox.inference.utils import make_detections_from_object_data -import pandas as pd -import json +from happypose.toolbox.utils.distributed import get_rank, get_tmp_dir, get_world_size +from happypose.toolbox.utils.logging import get_logger +from torch.utils.data import DataLoader +from tqdm import tqdm logger = get_logger(__name__) @@ -81,7 +65,9 @@ def __init__( self.tmp_dir = get_tmp_dir() sampler = DistributedSceneSampler( - scene_ds, num_replicas=self.world_size, rank=self.rank + scene_ds, + num_replicas=self.world_size, + rank=self.rank, ) self.sampler = sampler self.scene_ds = scene_ds @@ -104,12 +90,13 @@ def run_inference_pipeline( gt_detections: DetectionsType, sam_detections: DetectionsType, initial_estimates: Optional[PoseEstimatesType] = None, - ) -> Dict[str, PoseEstimatesType]: + ) -> dict[str, PoseEstimatesType]: """Runs inference pipeline, extracts the results. Returns: A dict with keys - 'final': final preds - - 'refiner/final': preds at final refiner iteration (before depth refinement) + - 'refiner/final': preds at final refiner iteration (before depth + refinement) - 'depth_refinement': preds after depth refinement. @@ -128,16 +115,15 @@ def run_inference_pipeline( run_detector = True else: - raise ValueError( - f"Unknown detection type {self.inference_cfg.detection_type}" - ) + msg = f"Unknown detection type {self.inference_cfg.detection_type}" + raise ValueError(msg) coarse_estimates = None if self.inference_cfg.coarse_estimation_type == "external": # TODO (ylabbe): This is hacky, clean this for modelnet eval. coarse_estimates = initial_estimates coarse_estimates = happypose.toolbox.inference.utils.add_instance_id( - coarse_estimates + coarse_estimates, ) coarse_estimates.infos["instance_id"] = 0 run_detector = False @@ -159,13 +145,13 @@ def run_inference_pipeline( # - 'refiner/iteration=5` # - `depth_refiner` # Note: Since we support multi-hypotheses we need to potentially - # go back and extract out the 'refiner/iteration=1`, `refiner/iteration=5` things for the ones that were actually the highest scoring at the end. + # go back and extract out the 'refiner/iteration=1`, `refiner/iteration=5` + # things for the ones that were actually the highest scoring at the end. + ref_str = f"refiner/iteration={self.inference_cfg.n_refiner_iterations}" all_preds = { "final": preds, - f"refiner/iteration={self.inference_cfg.n_refiner_iterations}": extra_data[ - "refiner" - ]["preds"], + ref_str: extra_data["refiner"]["preds"], "refiner/final": extra_data["refiner"]["preds"], "coarse": extra_data["coarse"]["preds"], "coarse_filter": extra_data["coarse_filter"]["preds"], @@ -183,17 +169,18 @@ def run_inference_pipeline( all_preds["depth_refiner"] = extra_data["depth_refiner"]["preds"] all_preds_data["depth_refiner"] = extra_data["depth_refiner"]["data"] - for k, v in all_preds.items(): + for _k, v in all_preds.items(): if "mask" in v.tensors: - breakpoint() + # breakpoint() v.delete_tensor("mask") return all_preds, all_preds_data def get_predictions( - self, pose_estimator: PoseEstimator - ) -> Dict[str, PoseEstimatesType]: - """Runs predictions + self, + pose_estimator: PoseEstimator, + ) -> dict[str, PoseEstimatesType]: + """Runs predictions. Returns: A dict with keys - 'refiner/iteration=1` @@ -204,7 +191,6 @@ def get_predictions( """ - predictions_list = defaultdict(list) ###### @@ -214,7 +200,8 @@ def get_predictions( # Temporary solution if self.inference_cfg.detection_type == "sam": df_all_dets, df_targets = load_sam_predictions( - self.scene_ds.ds_dir.name, self.scene_ds.ds_dir + self.scene_ds.ds_dir.name, + self.scene_ds.ds_dir, ) for n, data in enumerate(tqdm(self.dataloader)): @@ -274,19 +261,20 @@ def get_predictions( cuda_timer.end() duration = cuda_timer.elapsed() - total_duration = duration + dt_det + duration + dt_det # Add metadata to the predictions for later evaluation for pred_name, pred in all_preds.items(): pred.infos["time"] = dt_det + compute_pose_est_total_time( - all_preds_data, pred_name + all_preds_data, + pred_name, ) pred.infos["scene_id"] = scene_id pred.infos["view_id"] = view_id predictions_list[pred_name].append(pred) # Concatenate the lists of PandasTensorCollections - predictions = dict() + predictions = {} for k, v in predictions_list.items(): predictions[k] = tc.concatenate(v) @@ -315,4 +303,5 @@ def compute_pose_est_total_time(all_preds_data: dict, pred_name: str): else dt_coarse_refiner ) else: - raise ValueError(f"{pred_name} extra data not in {all_preds_data.keys()}") + msg = f"{pred_name} extra data not in {all_preds_data.keys()}" + raise ValueError(msg) diff --git a/happypose/pose_estimators/megapose/inference/pose_estimator.py b/happypose/pose_estimators/megapose/inference/pose_estimator.py index 5018901a..853ab633 100644 --- a/happypose/pose_estimators/megapose/inference/pose_estimator.py +++ b/happypose/pose_estimators/megapose/inference/pose_estimator.py @@ -1,5 +1,4 @@ -""" -Copyright (c) 2022 Inria & NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Copyright (c) 2022 Inria & NVIDIA CORPORATION & AFFILIATES. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,8 +19,7 @@ # Standard Library import time from collections import defaultdict -from dataclasses import dataclass -from typing import Any, Optional, Tuple +from typing import Any # Third Party import numpy as np @@ -30,16 +28,10 @@ from torch.utils.data import DataLoader, TensorDataset # MegaPose -import happypose.pose_estimators.megapose as megapose import happypose.toolbox.inference.utils import happypose.toolbox.utils.tensor_collection as tc -from happypose.pose_estimators.megapose.inference.depth_refiner import ( - DepthRefiner, -) -from happypose.pose_estimators.megapose.training.utils import ( - CudaTimer, - SimpleTimer, -) +from happypose.pose_estimators.megapose.inference.depth_refiner import DepthRefiner +from happypose.pose_estimators.megapose.training.utils import CudaTimer, SimpleTimer from happypose.toolbox.inference.pose_estimator import PoseEstimationModule from happypose.toolbox.inference.types import ( DetectionsType, @@ -65,10 +57,10 @@ class PoseEstimator(PoseEstimationModule): def __init__( self, - refiner_model: Optional[torch.nn.Module] = None, - coarse_model: Optional[torch.nn.Module] = None, - detector_model: Optional[torch.nn.Module] = None, - depth_refiner: Optional[DepthRefiner] = None, + refiner_model: torch.nn.Module | None = None, + coarse_model: torch.nn.Module | None = None, + detector_model: torch.nn.Module | None = None, + depth_refiner: DepthRefiner | None = None, bsz_objects: int = 8, bsz_images: int = 256, SO3_grid_size: int = 576, @@ -93,9 +85,8 @@ def __init__( self.cfg = self.coarse_model.cfg self.mesh_db = self.coarse_model.mesh_db else: - raise ValueError( - "At least one of refiner_model or " " coarse_model must be specified." - ) + msg = "At least one of refiner_model or coarse_model must be specified." + raise ValueError(msg) self.eval() @@ -103,7 +94,7 @@ def __init__( self.keep_all_coarse_outputs = False self.refiner_outputs = None self.coarse_outputs = None - self.debug_dict: dict = dict() + self.debug_dict: dict = {} def load_SO3_grid(self, grid_size: int) -> None: """Loads the SO(3) grid.""" @@ -119,14 +110,14 @@ def forward_refiner( keep_all_outputs: bool = False, cuda_timer: bool = False, **refiner_kwargs, - ) -> Tuple[dict, dict]: + ) -> tuple[dict, dict]: """Runs the refiner model for the specified number of iterations. - Will actually use the batched_model_predictions to stay within batch size limit. - Returns: + Returns + ------- (preds, extra_data) preds: @@ -139,7 +130,6 @@ def forward_refiner( A dict containing additional information such as timing """ - timer = Timer() timer.start() @@ -224,7 +214,7 @@ def forward_refiner( logger.debug( f"Pose prediction on {B} poses (n_iterations={n_iterations}):" - f" {timer.stop()}" + f" {timer.stop()}", ) return preds, extra_data @@ -236,15 +226,13 @@ def forward_scoring_model( data_TCO: PoseEstimatesType, cuda_timer: bool = False, return_debug_data: bool = False, - ) -> Tuple[PoseEstimatesType, dict]: + ) -> tuple[PoseEstimatesType, dict]: """Score the estimates using the coarse model. - Adds the 'pose_score' field to data_TCO.infos Modifies PandasTensorCollection in-place. """ - start_time = time.time() assert self.coarse_model is not None @@ -295,7 +283,7 @@ def forward_scoring_model( images_crop_list.append(out_["images_crop"]) renders_list.append(out_["renders"]) - debug_data = dict() + debug_data = {} # Combine together the data from the different batches logits = torch.cat(logits_list) @@ -304,8 +292,8 @@ def forward_scoring_model( images_crop: torch.tensor = torch.cat(images_crop_list) renders: torch.tensor = torch.cat(renders_list) - H = images_crop.shape[2] - W = images_crop.shape[3] + images_crop.shape[2] + images_crop.shape[3] debug_data = { "images_crop": images_crop, @@ -317,7 +305,10 @@ def forward_scoring_model( elapsed = time.time() - start_time - timing_str = f"time: {elapsed:.2f}, model_time: {model_time:.2f}, render_time: {render_time:.2f}" + timing_str = ( + f"time: {elapsed:.2f}, model_time: {model_time:.2f}, " + f"render_time: {render_time:.2f}" + ) extra_data = { "render_time": render_time, @@ -340,13 +331,12 @@ def forward_coarse_model( detections: DetectionsType, cuda_timer: bool = False, return_debug_data: bool = False, - ) -> Tuple[PoseEstimatesType, dict]: + ) -> tuple[PoseEstimatesType, dict]: """Generates pose hypotheses and scores them with the coarse model. - Generates coarse hypotheses using the SO(3) grid. - Scores them using the coarse model. """ - start_time = time.time() happypose.toolbox.inference.types.assert_detections_valid(detections) @@ -455,7 +445,7 @@ def forward_coarse_model( TCO = torch.cat(TCO_init) TCO_reshape = TCO.reshape([B, M, 4, 4]) - debug_data = dict() + debug_data = {} if return_debug_data: images_crop = torch.cat(images_crop_list) @@ -474,7 +464,10 @@ def forward_coarse_model( elapsed = time.time() - start_time - timing_str = f"time: {elapsed:.2f}, model_time: {model_time:.2f}, render_time: {render_time:.2f}" + timing_str = ( + f"time: {elapsed:.2f}, model_time: {model_time:.2f}, " + f"render_time: {render_time:.2f}" + ) extra_data = { "render_time": render_time, @@ -499,21 +492,22 @@ def forward_detection_model( **kwargs: Any, ) -> DetectionsType: """Runs the detector.""" - return self.detector_model.get_detections(observation, *args, **kwargs) def run_depth_refiner( self, observation: ObservationTensor, predictions: PoseEstimatesType, - ) -> Tuple[PoseEstimatesType, dict]: + ) -> tuple[PoseEstimatesType, dict]: """Runs the depth refiner.""" assert self.depth_refiner is not None, "You must specify a depth refiner" depth = observation.depth K = observation.K refined_preds, extra_data = self.depth_refiner.refine_poses( - predictions, depth=depth, K=K + predictions, + depth=depth, + K=K, ) return refined_preds, extra_data @@ -522,18 +516,18 @@ def run_depth_refiner( def run_inference_pipeline( self, observation: ObservationTensor, - detections: Optional[DetectionsType] = None, - run_detector: Optional[bool] = None, + detections: DetectionsType | None = None, + run_detector: bool | None = None, n_refiner_iterations: int = 5, n_pose_hypotheses: int = 1, keep_all_refiner_outputs: bool = False, - detection_filter_kwargs: Optional[dict] = None, + detection_filter_kwargs: dict | None = None, run_depth_refiner: bool = False, - bsz_images: Optional[int] = None, - bsz_objects: Optional[int] = None, + bsz_images: int | None = None, + bsz_objects: int | None = None, cuda_timer: bool = False, - coarse_estimates: Optional[PoseEstimatesType] = None, - ) -> Tuple[PoseEstimatesType, dict]: + coarse_estimates: PoseEstimatesType | None = None, + ) -> tuple[PoseEstimatesType, dict]: """Runs the entire pose estimation pipeline. Performs the following steps @@ -545,13 +539,13 @@ def run_inference_pipeline( 5. Score refined hypotheses 6. Select highest scoring refined hypotheses. - Returns: + Returns + ------- data_TCO_final: final predictions data: Dict containing additional data about the different steps in the pipeline. """ - timing_str = "" timer = SimpleTimer() timer.start() @@ -581,7 +575,8 @@ def run_inference_pipeline( # Filter detections if detection_filter_kwargs is not None: detections = happypose.toolbox.inference.utils.filter_detections( - detections, **detection_filter_kwargs + detections, + **detection_filter_kwargs, ) # Run the coarse estimator using detections @@ -635,7 +630,8 @@ def run_inference_pipeline( if run_depth_refiner: depth_refiner_start = time.time() data_TCO_depth_refiner, _ = self.run_depth_refiner( - observation, data_TCO_final_scored + observation, + data_TCO_final_scored, ) data_TCO_final = data_TCO_depth_refiner depth_refiner_time = time.time() - depth_refiner_start @@ -647,7 +643,7 @@ def run_inference_pipeline( timer.stop() timing_str = f"total={timer.elapsed():.2f}, {timing_str}" - extra_data: dict = dict() + extra_data: dict = {} extra_data["coarse"] = {"preds": data_TCO_coarse, "data": coarse_extra_data} extra_data["coarse_filter"] = {"preds": data_TCO_filtered} extra_data["refiner_all_hypotheses"] = { diff --git a/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py b/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py index 42f1c7d5..e2ce6cb0 100644 --- a/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py +++ b/happypose/pose_estimators/megapose/scripts/run_full_megapose_eval.py @@ -1,5 +1,4 @@ -""" -Copyright (c) 2022 Inria & NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Copyright (c) 2022 Inria & NVIDIA CORPORATION & AFFILIATES. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,39 +18,26 @@ import copy import os from pathlib import Path -from typing import Dict, Optional, Tuple - -# Third Party -from omegaconf import OmegaConf # MegaPose -from happypose.pose_estimators.megapose.bop_config import ( - PBR_COARSE, - PBR_DETECTORS, - PBR_REFINER, - SYNT_REAL_COARSE, - SYNT_REAL_DETECTORS, - SYNT_REAL_REFINER, -) +from happypose.pose_estimators.megapose.bop_config import PBR_DETECTORS from happypose.pose_estimators.megapose.config import ( DEBUG_RESULTS_DIR, - EXP_DIR, MODELNET_TEST_CATEGORIES, RESULTS_DIR, ) +from happypose.pose_estimators.megapose.evaluation.bop import run_evaluation from happypose.pose_estimators.megapose.evaluation.eval_config import ( BOPEvalConfig, EvalConfig, FullEvalConfig, HardwareConfig, ) - from happypose.pose_estimators.megapose.evaluation.evaluation import ( - get_save_dir, generate_save_key, + get_save_dir, run_eval, ) -from happypose.pose_estimators.megapose.evaluation.bop import run_evaluation from happypose.toolbox.utils.distributed import ( get_rank, get_world_size, @@ -59,6 +45,9 @@ ) from happypose.toolbox.utils.logging import get_logger, set_logging_level +# Third Party +from omegaconf import OmegaConf + logger = get_logger(__name__) @@ -83,7 +72,7 @@ def create_eval_cfg( detection_type: str, coarse_estimation_type: str, ds_name: str, -) -> Tuple[str, EvalConfig]: +) -> tuple[str, EvalConfig]: cfg = copy.deepcopy(cfg) cfg.inference.detection_type = detection_type @@ -101,7 +90,8 @@ def create_eval_cfg( elif detection_type == "sam": pass else: - raise ValueError(f"Unknown detector type {cfg.detector_type}") + msg = f"Unknown detector type {cfg.detector_type}" + raise ValueError(msg) name = generate_save_key(detection_type, coarse_estimation_type) @@ -123,16 +113,19 @@ def run_full_eval(cfg: FullEvalConfig) -> None: # Iterate over each dataset for ds_name in cfg.ds_names: # create the EvalConfig objects that we will call `run_eval` on - eval_configs: Dict[str, EvalConfig] = dict() + eval_configs: dict[str, EvalConfig] = {} for detection_type, coarse_estimation_type in cfg.detection_coarse_types: name, cfg_ = create_eval_cfg( - cfg, detection_type, coarse_estimation_type, ds_name + cfg, + detection_type, + coarse_estimation_type, + ds_name, ) eval_configs[name] = cfg_ # For each eval_cfg run the evaluation. # Note that the results get saved to disk - for save_key, eval_cfg in eval_configs.items(): + for _save_key, eval_cfg in eval_configs.items(): # Run the inference if not cfg.skip_inference: eval_out = run_eval(eval_cfg) @@ -152,12 +145,12 @@ def run_full_eval(cfg: FullEvalConfig) -> None: } assert Path( - eval_out["results_path"] + eval_out["results_path"], ).is_file(), f"The file {eval_out['results_path']} doesn't exist" # Run the bop eval for each type of prediction if cfg.run_bop_eval and get_rank() == 0: - bop_eval_keys = set(("refiner/final", "depth_refiner")) + bop_eval_keys = {"refiner/final", "depth_refiner"} if cfg.eval_coarse_also: bop_eval_keys.add("coarse") @@ -165,7 +158,7 @@ def run_full_eval(cfg: FullEvalConfig) -> None: bop_eval_keys = bop_eval_keys.intersection(set(eval_out["pred_keys"])) for method in bop_eval_keys: - if not "bop19" in ds_name: + if "bop19" not in ds_name: continue bop_eval_cfg = BOPEvalConfig(