Skip to content

Commit

Permalink
first working version of pose training
Browse files Browse the repository at this point in the history
  • Loading branch information
centos Cloud User committed Jan 29, 2024
1 parent d42fcca commit 05c0e2e
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
RESULTS_DIR,
)
from happypose.pose_estimators.cosypose.cosypose.datasets.bop import remap_bop_targets
from happypose.pose_estimators.cosypose.cosypose.datasets.datasets_cfg import (
from happypose.toolbox.datasets.datasets_cfg import (
make_object_dataset,
make_scene_dataset,
)
Expand All @@ -50,9 +50,8 @@
from happypose.pose_estimators.cosypose.cosypose.integrated.pose_predictor import (
CoarseRefinePosePredictor,
)
from happypose.pose_estimators.cosypose.cosypose.lib3d.rigid_mesh_database import (
MeshDataBase,
)
from happypose.toolbox.lib3d.rigid_mesh_database import MeshDataBase

from happypose.pose_estimators.cosypose.cosypose.rendering.bullet_batch_renderer import ( # noqa: E501
BulletBatchRenderer,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def make_cfg(args):
# Training
cfg.batch_size = 16
cfg.epoch_size = 115200
cfg.n_epochs = 700
cfg.n_epochs = 3
cfg.n_dataloader_workers = N_WORKERS

# Method
Expand Down
35 changes: 17 additions & 18 deletions happypose/pose_estimators/cosypose/cosypose/training/train_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from happypose.pose_estimators.cosypose.cosypose.integrated.pose_estimator import (
PoseEstimator,
)
from happypose.pose_estimators.megapose.evaluation.meters.modelnet_meters import (
ModelNetErrorMeter,
)

from happypose.pose_estimators.cosypose.cosypose.scripts.run_cosypose_eval import (
get_pose_meters,
load_pix2pose_results,
Expand Down Expand Up @@ -89,7 +93,7 @@ def save_checkpoint(model):
logger.info(test_dict)


def make_eval_bundle(args, model_training):
def make_eval_bundle(args, model_training, mesh_db):
eval_bundle = {}
model_training.cfg = args

Expand Down Expand Up @@ -217,10 +221,9 @@ def load_model(run_id):
)

# Evaluation
meters = get_pose_meters(scene_ds, ds_name)
meters = {k.split("_")[0]: v for k, v in meters.items()}
list(iter(pred_runner.sampler))
print(scene_ds.frame_index)
meters = {
"modelnet": ModelNetErrorMeter(mesh_db, sample_n_points=None),
}
# scene_ds_ids = np.concatenate(
# scene_ds.frame_index.loc[mv_group_ids, "scene_ds_ids"].values
# )
Expand Down Expand Up @@ -335,16 +338,12 @@ def make_datasets(dataset_names):
n_workers=args.n_rendering_workers,
preload_cache=False,
)
mesh_db = (
MeshDataBase.from_object_ds(object_ds)
.batched(n_sym=args.n_symmetries_batch)
.cuda()
.float()
)
mesh_db = MeshDataBase.from_object_ds(object_ds)
mesh_db_batched = mesh_db.batched(n_sym=args.n_symmetries_batch).cuda().float()

model = create_model_pose(cfg=args, renderer=renderer, mesh_db=mesh_db).cuda()
model = create_model_pose(cfg=args, renderer=renderer, mesh_db=mesh_db_batched).cuda()

eval_bundle = make_eval_bundle(args, model)
eval_bundle = make_eval_bundle(args, model, mesh_db)

if args.resume_run_id:
resume_dir = EXP_DIR / args.resume_run_id
Expand Down Expand Up @@ -413,7 +412,7 @@ def lambd(batch):
model=model,
cfg=args,
n_iterations=args.n_iterations,
mesh_db=mesh_db,
mesh_db=mesh_db_batched,
input_generator=args.TCO_input_generator,
)

Expand Down Expand Up @@ -476,9 +475,9 @@ def test():
if epoch % args.val_epoch_interval == 0:
validation()

test_dict = None
if epoch % args.test_epoch_interval == 0:
test_dict = test()
#test_dict = None
#if epoch % args.test_epoch_interval == 0:
# test_dict = test()

log_dict = {}
log_dict.update(
Expand Down Expand Up @@ -507,6 +506,6 @@ def test():
model=model,
epoch=epoch,
log_dict=log_dict,
test_dict=test_dict,
test_dict=None,
)
dist.barrier()

0 comments on commit 05c0e2e

Please sign in to comment.