From 08e85ca041364602e7e1ea2b2b6a3a50a9909a8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Elliot=20Ma=C3=AEtre?= Date: Mon, 11 Sep 2023 16:33:23 +0200 Subject: [PATCH] training pose cosypose --- environment.yml | 1 + .../cosypose/evaluation/runner_utils.py | 3 +- .../cosypose/cosypose/models/pose.py | 1 + .../cosypose/scripts/run_pose_training.py | 6 +- .../cosypose/cosypose/training/train_pose.py | 88 +++++++++++-------- .../megapose/evaluation/prediction_runner.py | 10 +-- .../src/megapose/models/pose_rigid.py | 4 + 7 files changed, 68 insertions(+), 45 deletions(-) diff --git a/environment.yml b/environment.yml index dd8fec7f..b0256281 100644 --- a/environment.yml +++ b/environment.yml @@ -13,6 +13,7 @@ dependencies: - pytorch==1.11.0 - torchvision==0.12.0 - cudatoolkit==11.3.1 + - pytorch3d - ipython - ipykernel - jupyterlab diff --git a/happypose/pose_estimators/cosypose/cosypose/evaluation/runner_utils.py b/happypose/pose_estimators/cosypose/cosypose/evaluation/runner_utils.py index 8ecebce9..5a128b25 100644 --- a/happypose/pose_estimators/cosypose/cosypose/evaluation/runner_utils.py +++ b/happypose/pose_estimators/cosypose/cosypose/evaluation/runner_utils.py @@ -11,8 +11,7 @@ def run_pred_eval(pred_runner, pred_kwargs, eval_runner, eval_preds=None): all_predictions = dict() for pred_prefix, pred_kwargs_n in pred_kwargs.items(): - print("Prediction :", pred_prefix) - preds = pred_runner.get_predictions(**pred_kwargs_n) + preds = pred_runner.get_predictions(pred_kwargs_n['pose_predictor']) for preds_name, preds_n in preds.items(): all_predictions[f'{pred_prefix}/{preds_name}'] = preds_n diff --git a/happypose/pose_estimators/cosypose/cosypose/models/pose.py b/happypose/pose_estimators/cosypose/cosypose/models/pose.py index 37df4a73..c15ff40b 100644 --- a/happypose/pose_estimators/cosypose/cosypose/models/pose.py +++ b/happypose/pose_estimators/cosypose/cosypose/models/pose.py @@ -134,6 +134,7 @@ def forward(self, images, K, labels, TCO, n_iterations=1): K_crop=K_crop, boxes_rend=boxes_rend, boxes_crop=boxes_crop, + model_outputs=model_outputs ) TCO_input = TCO_output diff --git a/happypose/pose_estimators/cosypose/cosypose/scripts/run_pose_training.py b/happypose/pose_estimators/cosypose/cosypose/scripts/run_pose_training.py index 0d7b7d43..5b2ba6bb 100644 --- a/happypose/pose_estimators/cosypose/cosypose/scripts/run_pose_training.py +++ b/happypose/pose_estimators/cosypose/cosypose/scripts/run_pose_training.py @@ -48,7 +48,7 @@ def make_cfg(args): cfg.n_pose_dims = 9 cfg.n_rendering_workers = N_WORKERS cfg.refiner_run_id_for_test = None - cfg.coarse_run_id_for_test = None + cfg.coarse_run_id_for_test = 'coarse-bop-ycbv-pbr--724183' # Optimizer cfg.lr = 3e-4 @@ -58,9 +58,9 @@ def make_cfg(args): cfg.clip_grad_norm = 0.5 # Training - cfg.batch_size = 32 + cfg.batch_size = 16 cfg.epoch_size = 115200 - cfg.n_epochs = 700 + cfg.n_epochs = 2 cfg.n_dataloader_workers = N_WORKERS # Method diff --git a/happypose/pose_estimators/cosypose/cosypose/training/train_pose.py b/happypose/pose_estimators/cosypose/cosypose/training/train_pose.py index a71dffb0..458d85b2 100644 --- a/happypose/pose_estimators/cosypose/cosypose/training/train_pose.py +++ b/happypose/pose_estimators/cosypose/cosypose/training/train_pose.py @@ -18,6 +18,9 @@ from happypose.toolbox.datasets.datasets_cfg import make_object_dataset, make_scene_dataset from happypose.pose_estimators.cosypose.cosypose.datasets.pose_dataset import PoseDataset from happypose.pose_estimators.cosypose.cosypose.datasets.samplers import PartialSampler, ListSampler +from happypose.pose_estimators.megapose.src.megapose.inference.types import ( + InferenceConfig +) # Evaluation from happypose.pose_estimators.cosypose.cosypose.integrated.pose_estimator import ( @@ -48,6 +51,9 @@ from happypose.pose_estimators.cosypose.cosypose.utils.distributed import get_world_size, get_rank, sync_model, init_distributed_mode, reduce_dict from torch.backends import cudnn +# temp +from torch.utils.data import Subset + cudnn.benchmark = True logger = get_logger(__name__) @@ -90,7 +96,7 @@ def load_model(run_id): if run_id is None: return None run_dir = EXP_DIR / run_id - cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader) + cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.Loader) cfg = check_update_config(cfg) model = create_model_pose(cfg, renderer=model_training.renderer, mesh_db=model_training.mesh_db).cuda().eval() @@ -99,7 +105,7 @@ def load_model(run_id): model.eval() model.cfg = cfg return model - + if args.train_refiner: refiner_model = model_training coarse_model = load_model(args.coarse_run_id_for_test) @@ -108,7 +114,7 @@ def load_model(run_id): refiner_model = load_model(args.refiner_run_id_for_test) else: raise ValueError - + pose_estimator = PoseEstimator( refiner_model=refiner_model, coarse_model=coarse_model, @@ -129,9 +135,19 @@ def load_model(run_id): #pred_runner = MultiviewPredictionRunner(scene_ds_pred, batch_size=1, # n_workers=args.n_dataloader_workers, cache_data=False) - inference = {'detection_type': 'gt', 'coarse_estimation_type': 'S03_grid', 'SO3_grid_size': 576, - 'n_refiner_iterations': 5, 'n_pose_hypotheses': 5, 'run_depth_refiner': False, - 'depth_refiner': None, 'bsz_objects': 16, 'bsz_images': 288} + #inference = {'detection_type': 'gt', 'coarse_estimation_type': 'S03_grid', 'SO3_grid_size': 576, + # 'n_refiner_iterations': 5, 'n_pose_hypotheses': 5, 'run_depth_refiner': False, + # 'depth_refiner': None, 'bsz_objects': 16, 'bsz_images': 288} + + inference = InferenceConfig(detection_type= 'gt', + coarse_estimation_type='S03_grid', + SO3_grid_size= 576, + n_refiner_iterations=5, + n_pose_hypotheses= 5, + run_depth_refiner= False, + depth_refiner= None, + bsz_objects= 16, + bsz_images= 288) pred_runner = PredictionRunner( scene_ds=scene_ds, @@ -255,7 +271,7 @@ def make_datasets(dataset_names): scene_ds_train = make_datasets(args.train_ds_names) scene_ds_val = make_datasets(args.val_ds_names) - + ds_kwargs = dict( resize=args.input_resize, rgb_augmentation=args.rgb_augmentation, @@ -266,6 +282,7 @@ def make_datasets(dataset_names): ds_train = PoseDataset(scene_ds_train, **ds_kwargs) ds_val = PoseDataset(scene_ds_val, **ds_kwargs) + train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size) ds_iter_train = DataLoader(ds_train, sampler=train_sampler, batch_size=args.batch_size, num_workers=args.n_dataloader_workers, collate_fn=ds_train.collate_fn, @@ -341,30 +358,31 @@ def train_epoch(): iterator = tqdm(ds_iter_train, ncols=80) t = time.time() for n, sample in enumerate(iterator): - if n > 0: - meters_time['data'].add(time.time() - t) - - optimizer.zero_grad() - - t = time.time() - loss = h(data=sample, meters=meters_train) - meters_time['forward'].add(time.time() - t) - iterator.set_postfix(loss=loss.item()) - meters_train['loss_total'].add(loss.item()) - - t = time.time() - loss.backward() - total_grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), max_norm=args.clip_grad_norm, norm_type=2) - meters_train['grad_norm'].add(torch.as_tensor(total_grad_norm).item()) - - optimizer.step() - meters_time['backward'].add(time.time() - t) - meters_time['memory'].add(torch.cuda.max_memory_allocated() / 1024. ** 2) - - if epoch < args.n_epochs_warmup: - lr_scheduler_warmup.step() - t = time.time() + if n<5: + if n > 0: + meters_time['data'].add(time.time() - t) + + optimizer.zero_grad() + + t = time.time() + loss = h(data=sample, meters=meters_train) + meters_time['forward'].add(time.time() - t) + iterator.set_postfix(loss=loss.item()) + meters_train['loss_total'].add(loss.item()) + + t = time.time() + loss.backward() + total_grad_norm = torch.nn.utils.clip_grad_norm_( + model.parameters(), max_norm=args.clip_grad_norm, norm_type=2) + meters_train['grad_norm'].add(torch.as_tensor(total_grad_norm).item()) + + optimizer.step() + meters_time['backward'].add(time.time() - t) + meters_time['memory'].add(torch.cuda.max_memory_allocated() / 1024. ** 2) + + if epoch < args.n_epochs_warmup: + lr_scheduler_warmup.step() + t = time.time() if epoch >= args.n_epochs_warmup: lr_scheduler.step() @@ -374,7 +392,7 @@ def validation(): for sample in tqdm(ds_iter_val, ncols=80): loss = h(data=sample, meters=meters_val) meters_val['loss_total'].add(loss.item()) - + @torch.no_grad() def test(): model.eval() @@ -383,11 +401,11 @@ def test(): train_epoch() if epoch % args.val_epoch_interval == 0: validation() - + test_dict = None if epoch % args.test_epoch_interval == 0: test_dict = test() - + log_dict = dict() log_dict.update({ 'grad_norm': meters_train['grad_norm'].mean, @@ -405,7 +423,7 @@ def test(): for string, meters in zip(('train', 'val'), (meters_train, meters_val)): for k in dict(meters).keys(): log_dict[f'{string}_{k}'] = meters[k].mean - + test_dict = {} log_dict = reduce_dict(log_dict) if get_rank() == 0: log(config=args, model=model, epoch=epoch, diff --git a/happypose/pose_estimators/megapose/src/megapose/evaluation/prediction_runner.py b/happypose/pose_estimators/megapose/src/megapose/evaluation/prediction_runner.py index 2fc853ef..ed1f6bc1 100644 --- a/happypose/pose_estimators/megapose/src/megapose/evaluation/prediction_runner.py +++ b/happypose/pose_estimators/megapose/src/megapose/evaluation/prediction_runner.py @@ -292,12 +292,12 @@ def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstima # ############ RUN ONLY BEGINNING OF DATASET # # if n > 0: - # if n < 298: + if n > 10: # # if n != 582: - # print('################') - # print('Prediction runner SKIP') - # print('################') - # continue + print('################') + print('Prediction runner SKIP') + print('################') + continue # ############ RUN ONLY BEGINNING OF DATASET # Dirty but avoids creating error when running with real detector diff --git a/happypose/pose_estimators/megapose/src/megapose/models/pose_rigid.py b/happypose/pose_estimators/megapose/src/megapose/models/pose_rigid.py index 75a025f8..4beeba86 100644 --- a/happypose/pose_estimators/megapose/src/megapose/models/pose_rigid.py +++ b/happypose/pose_estimators/megapose/src/megapose/models/pose_rigid.py @@ -63,6 +63,10 @@ class PosePredictorOutputCosypose: K_crop: torch.Tensor boxes_rend: torch.Tensor boxes_crop: torch.Tensor + model_outputs: torch.Tensor + + def __getitem__(self, item): + return getattr(self, item) @dataclass class PosePredictorOutput: