Skip to content

Commit

Permalink
Refactor filter_pose_estimates
Browse files Browse the repository at this point in the history
  • Loading branch information
MedericFourmy committed Oct 23, 2023
1 parent fb18b14 commit cec03a1
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 69 deletions.
34 changes: 2 additions & 32 deletions happypose/pose_estimators/megapose/evaluation/bop.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from happypose.pose_estimators.megapose.evaluation.eval_config import BOPEvalConfig
from happypose.toolbox.datasets.scene_dataset import ObjectData
from happypose.toolbox.inference.utils import make_detections_from_object_data
from happypose.toolbox.utils.tensor_collection import PandasTensorCollection
from happypose.toolbox.utils.tensor_collection import PandasTensorCollection, filter_top_pose_estimates

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Expand Down Expand Up @@ -142,39 +142,9 @@ def convert_results_to_coco(results_path, out_json_path, detection_method):
return




def filter_pose_estimates(
data_TCO: PandasTensorCollection,
top_K: int,
group_cols: list[str],
filter_field: str,
ascending: bool = False
) -> PandasTensorCollection:
"""Filter the pose estimates by retaining only the top-K coarse model scores.
Retain only the top_K estimates corresponding to each hypothesis_id
Args:
top_K: how many estimates to retain
filter_field: The field to filter estimates by
"""

# TODO: refactor with definition in pose_estimator.py

df = data_TCO.infos

# Logic from https://stackoverflow.com/a/40629420
df = df.sort_values(filter_field, ascending=ascending).groupby(group_cols).head(top_K)

data_TCO_filtered = data_TCO[df.index.tolist()]

return data_TCO_filtered


def get_best_coarse_predictions(coarse_preds: PandasTensorCollection):
group_cols = ["scene_id", "view_id", "label", "instance_id"]
coarse_preds = filter_pose_estimates(coarse_preds, top_K=1, group_cols=group_cols, filter_field='coarse_score', ascending=False)
coarse_preds = filter_top_pose_estimates(coarse_preds, top_K=1, group_cols=group_cols, filter_field='coarse_score', ascending=False)
coarse_preds.infos = coarse_preds.infos.rename(columns={'coarse_score': 'pose_score'})
return coarse_preds

Expand Down
43 changes: 12 additions & 31 deletions happypose/pose_estimators/megapose/inference/pose_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from happypose.toolbox.lib3d.cosypose_ops import TCO_init_from_boxes_autodepth_with_R
from happypose.toolbox.utils import transform_utils
from happypose.toolbox.utils.logging import get_logger
from happypose.toolbox.utils.tensor_collection import PandasTensorCollection
from happypose.toolbox.utils.tensor_collection import PandasTensorCollection, filter_top_pose_estimates
from happypose.toolbox.utils.timer import Timer

logger = get_logger(__name__)
Expand Down Expand Up @@ -591,9 +591,13 @@ def run_inference_pipeline(
timing_str += f"coarse={coarse_extra_data['time']:.2f}, "

# Extract top-K coarse hypotheses
data_TCO_filtered = self.filter_pose_estimates(
data_TCO_coarse, top_K=n_pose_hypotheses, filter_field="coarse_logit"
data_TCO_filtered = filter_top_pose_estimates(
data_TCO_coarse,
top_K=n_pose_hypotheses,
group_cols=["batch_im_id", "label", "instance_id"],
filter_field="coarse_logit"
)

else:
data_TCO_coarse = coarse_estimates
coarse_extra_data = None
Expand All @@ -618,8 +622,11 @@ def run_inference_pipeline(
timing_str += f"scoring={scoring_extra_data['time']:.2f}, "

# Extract the highest scoring pose estimate for each instance_id
data_TCO_final_scored = self.filter_pose_estimates(
data_TCO_scored, top_K=1, filter_field="pose_logit"
data_TCO_final_scored = self.filter_top_pose_estimates(
data_TCO_scored,
top_K=1,
group_cols=["batch_im_id", "label", "instance_id"],
filter_field="pose_logit"
)

# Optionally run ICP or TEASER++
Expand Down Expand Up @@ -649,29 +656,3 @@ def run_inference_pipeline(
extra_data["depth_refiner"] = {"preds": data_TCO_depth_refiner}

return data_TCO_final, extra_data

def filter_pose_estimates(
self,
data_TCO: PoseEstimatesType,
top_K: int,
filter_field: str,
ascending: bool = False,
) -> PoseEstimatesType:
"""Filter the pose estimates by retaining only the top-K coarse model scores.
Retain only the top_K estimates corresponding to each hypothesis_id
Args:
top_K: how many estimates to retain
filter_field: The field to filter estimates by
"""

df = data_TCO.infos

group_cols = ["batch_im_id", "label", "instance_id"]
# Logic from https://stackoverflow.com/a/40629420
df = df.sort_values(filter_field, ascending=ascending).groupby(group_cols).head(top_K)

data_TCO_filtered = data_TCO[df.index.tolist()]

return data_TCO_filtered
29 changes: 29 additions & 0 deletions happypose/toolbox/utils/tensor_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,32 @@ def __setstate__(self, state):
self.__init__(state["infos"], **state["tensors"])
self.meta = state["meta"]
return



def filter_top_pose_estimates(
data_TCO: PandasTensorCollection,
top_K: int,
group_cols: list[str],
filter_field: str,
ascending: bool = False
) -> PandasTensorCollection:
"""Filter the pose estimates by retaining only the top-K coarse model scores.
Retain only the top_K estimates corresponding to each hypothesis_id
Args:
top_K: how many estimates to retain
group_cols: group of columns among which sorting should be done
filter_field: the field to filter estimates by
ascending: should filter_field
"""

df = data_TCO.infos

# Logic from https://stackoverflow.com/a/40629420
df = df.sort_values(filter_field, ascending=ascending).groupby(group_cols).head(top_K)

data_TCO_filtered = data_TCO[df.index.tolist()]

return data_TCO_filtered
18 changes: 12 additions & 6 deletions notebooks/megapose/megapose_estimator_visualization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@
"metadata": {},
"outputs": [],
"source": [
"from happypose.toolbox.utils.tensor_collection import filter_pose_estimates\n",
"\n",
"# Options for inference\n",
"use_gt_detections = True # Note, if you aren't using gt_detections then this should be false\n",
"n_refiner_iterations = 5\n",
Expand Down Expand Up @@ -294,10 +296,11 @@
" f\"model_time={extra_data['model_time']:.2f}, render_time={extra_data['render_time']:.2f}\")\n",
" \n",
" # Extract top-K coarse hypotheses\n",
" data_TCO_filtered = pose_estimator.filter_pose_estimates(data_TCO_coarse, \n",
" top_K=n_pose_hypotheses, \n",
" filter_field='coarse_logit')\n",
" \n",
" data_TCO_filtered = filter_pose_estimates(data_TCO_coarse,\n",
" top_K=n_pose_hypotheses, \n",
" group_cols=[\"batch_im_id\", \"label\", \"instance_id\"], \n",
" filter_field='coarse_logit')\n",
"\n",
" # Refine the top_K coarse hypotheses\n",
" preds, extra_data = pose_estimator.forward_refiner(observation_tensor, data_TCO_filtered, \n",
" n_iterations=n_refiner_iterations, keep_all_outputs=True)\n",
Expand All @@ -311,7 +314,10 @@
" data_TCO_scored, extra_data = pose_estimator.forward_scoring_model(observation_tensor, data_TCO_refined)\n",
"\n",
" # Extract the highest scoring pose estimate for each instance_id\n",
" data_TCO_final = pose_estimator.filter_pose_estimates(data_TCO_scored, top_K=1, filter_field='pose_logit')\n",
" data_TCO_final = filter_pose_estimates(data_TCO_scored, \n",
" top_K=1, \n",
" group_cols=[\"batch_im_id\", \"label\", \"instance_id\"], \n",
" filter_field='pose_logit')\n",
" \n",
" \n",
" if run_depth_refiner:\n",
Expand Down Expand Up @@ -969,4 +975,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

0 comments on commit cec03a1

Please sign in to comment.