From 05c0e2edc3a99e09707c68ade9657ffcd8c60fa4 Mon Sep 17 00:00:00 2001 From: centos Cloud User Date: Mon, 29 Jan 2024 15:29:31 +0100 Subject: [PATCH] first working version of pose training --- .../cosypose/scripts/run_cosypose_eval.py | 7 ++-- .../cosypose/scripts/run_pose_training.py | 2 +- .../cosypose/cosypose/training/train_pose.py | 35 +++++++++---------- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/happypose/pose_estimators/cosypose/cosypose/scripts/run_cosypose_eval.py b/happypose/pose_estimators/cosypose/cosypose/scripts/run_cosypose_eval.py index 848bf112..53b706ef 100644 --- a/happypose/pose_estimators/cosypose/cosypose/scripts/run_cosypose_eval.py +++ b/happypose/pose_estimators/cosypose/cosypose/scripts/run_cosypose_eval.py @@ -23,7 +23,7 @@ RESULTS_DIR, ) from happypose.pose_estimators.cosypose.cosypose.datasets.bop import remap_bop_targets -from happypose.pose_estimators.cosypose.cosypose.datasets.datasets_cfg import ( +from happypose.toolbox.datasets.datasets_cfg import ( make_object_dataset, make_scene_dataset, ) @@ -50,9 +50,8 @@ from happypose.pose_estimators.cosypose.cosypose.integrated.pose_predictor import ( CoarseRefinePosePredictor, ) -from happypose.pose_estimators.cosypose.cosypose.lib3d.rigid_mesh_database import ( - MeshDataBase, -) +from happypose.toolbox.lib3d.rigid_mesh_database import MeshDataBase + from happypose.pose_estimators.cosypose.cosypose.rendering.bullet_batch_renderer import ( # noqa: E501 BulletBatchRenderer, ) diff --git a/happypose/pose_estimators/cosypose/cosypose/scripts/run_pose_training.py b/happypose/pose_estimators/cosypose/cosypose/scripts/run_pose_training.py index b9f52d2a..937ca2e7 100644 --- a/happypose/pose_estimators/cosypose/cosypose/scripts/run_pose_training.py +++ b/happypose/pose_estimators/cosypose/cosypose/scripts/run_pose_training.py @@ -71,7 +71,7 @@ def make_cfg(args): # Training cfg.batch_size = 16 cfg.epoch_size = 115200 - cfg.n_epochs = 700 + cfg.n_epochs = 3 cfg.n_dataloader_workers = N_WORKERS # Method diff --git a/happypose/pose_estimators/cosypose/cosypose/training/train_pose.py b/happypose/pose_estimators/cosypose/cosypose/training/train_pose.py index 9d57094c..fe0e84cb 100644 --- a/happypose/pose_estimators/cosypose/cosypose/training/train_pose.py +++ b/happypose/pose_estimators/cosypose/cosypose/training/train_pose.py @@ -28,6 +28,10 @@ from happypose.pose_estimators.cosypose.cosypose.integrated.pose_estimator import ( PoseEstimator, ) +from happypose.pose_estimators.megapose.evaluation.meters.modelnet_meters import ( + ModelNetErrorMeter, +) + from happypose.pose_estimators.cosypose.cosypose.scripts.run_cosypose_eval import ( get_pose_meters, load_pix2pose_results, @@ -89,7 +93,7 @@ def save_checkpoint(model): logger.info(test_dict) -def make_eval_bundle(args, model_training): +def make_eval_bundle(args, model_training, mesh_db): eval_bundle = {} model_training.cfg = args @@ -217,10 +221,9 @@ def load_model(run_id): ) # Evaluation - meters = get_pose_meters(scene_ds, ds_name) - meters = {k.split("_")[0]: v for k, v in meters.items()} - list(iter(pred_runner.sampler)) - print(scene_ds.frame_index) + meters = { + "modelnet": ModelNetErrorMeter(mesh_db, sample_n_points=None), + } # scene_ds_ids = np.concatenate( # scene_ds.frame_index.loc[mv_group_ids, "scene_ds_ids"].values # ) @@ -335,16 +338,12 @@ def make_datasets(dataset_names): n_workers=args.n_rendering_workers, preload_cache=False, ) - mesh_db = ( - MeshDataBase.from_object_ds(object_ds) - .batched(n_sym=args.n_symmetries_batch) - .cuda() - .float() - ) + mesh_db = MeshDataBase.from_object_ds(object_ds) + mesh_db_batched = mesh_db.batched(n_sym=args.n_symmetries_batch).cuda().float() - model = create_model_pose(cfg=args, renderer=renderer, mesh_db=mesh_db).cuda() + model = create_model_pose(cfg=args, renderer=renderer, mesh_db=mesh_db_batched).cuda() - eval_bundle = make_eval_bundle(args, model) + eval_bundle = make_eval_bundle(args, model, mesh_db) if args.resume_run_id: resume_dir = EXP_DIR / args.resume_run_id @@ -413,7 +412,7 @@ def lambd(batch): model=model, cfg=args, n_iterations=args.n_iterations, - mesh_db=mesh_db, + mesh_db=mesh_db_batched, input_generator=args.TCO_input_generator, ) @@ -476,9 +475,9 @@ def test(): if epoch % args.val_epoch_interval == 0: validation() - test_dict = None - if epoch % args.test_epoch_interval == 0: - test_dict = test() + #test_dict = None + #if epoch % args.test_epoch_interval == 0: + # test_dict = test() log_dict = {} log_dict.update( @@ -507,6 +506,6 @@ def test(): model=model, epoch=epoch, log_dict=log_dict, - test_dict=test_dict, + test_dict=None, ) dist.barrier()