Skip to content

Commit

Permalink
training pose cosypose
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliotMaitre committed Sep 11, 2023
1 parent 693d9f2 commit 08e85ca
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 45 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies:
- pytorch==1.11.0
- torchvision==0.12.0
- cudatoolkit==11.3.1
- pytorch3d
- ipython
- ipykernel
- jupyterlab
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
def run_pred_eval(pred_runner, pred_kwargs, eval_runner, eval_preds=None):
all_predictions = dict()
for pred_prefix, pred_kwargs_n in pred_kwargs.items():
print("Prediction :", pred_prefix)
preds = pred_runner.get_predictions(**pred_kwargs_n)
preds = pred_runner.get_predictions(pred_kwargs_n['pose_predictor'])
for preds_name, preds_n in preds.items():
all_predictions[f'{pred_prefix}/{preds_name}'] = preds_n

Expand Down
1 change: 1 addition & 0 deletions happypose/pose_estimators/cosypose/cosypose/models/pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def forward(self, images, K, labels, TCO, n_iterations=1):
K_crop=K_crop,
boxes_rend=boxes_rend,
boxes_crop=boxes_crop,
model_outputs=model_outputs
)

TCO_input = TCO_output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def make_cfg(args):
cfg.n_pose_dims = 9
cfg.n_rendering_workers = N_WORKERS
cfg.refiner_run_id_for_test = None
cfg.coarse_run_id_for_test = None
cfg.coarse_run_id_for_test = 'coarse-bop-ycbv-pbr--724183'

# Optimizer
cfg.lr = 3e-4
Expand All @@ -58,9 +58,9 @@ def make_cfg(args):
cfg.clip_grad_norm = 0.5

# Training
cfg.batch_size = 32
cfg.batch_size = 16
cfg.epoch_size = 115200
cfg.n_epochs = 700
cfg.n_epochs = 2
cfg.n_dataloader_workers = N_WORKERS

# Method
Expand Down
88 changes: 53 additions & 35 deletions happypose/pose_estimators/cosypose/cosypose/training/train_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from happypose.toolbox.datasets.datasets_cfg import make_object_dataset, make_scene_dataset
from happypose.pose_estimators.cosypose.cosypose.datasets.pose_dataset import PoseDataset
from happypose.pose_estimators.cosypose.cosypose.datasets.samplers import PartialSampler, ListSampler
from happypose.pose_estimators.megapose.src.megapose.inference.types import (
InferenceConfig
)

# Evaluation
from happypose.pose_estimators.cosypose.cosypose.integrated.pose_estimator import (
Expand Down Expand Up @@ -48,6 +51,9 @@
from happypose.pose_estimators.cosypose.cosypose.utils.distributed import get_world_size, get_rank, sync_model, init_distributed_mode, reduce_dict
from torch.backends import cudnn

# temp
from torch.utils.data import Subset

cudnn.benchmark = True
logger = get_logger(__name__)

Expand Down Expand Up @@ -90,7 +96,7 @@ def load_model(run_id):
if run_id is None:
return None
run_dir = EXP_DIR / run_id
cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.FullLoader)
cfg = yaml.load((run_dir / 'config.yaml').read_text(), Loader=yaml.Loader)
cfg = check_update_config(cfg)
model = create_model_pose(cfg, renderer=model_training.renderer,
mesh_db=model_training.mesh_db).cuda().eval()
Expand All @@ -99,7 +105,7 @@ def load_model(run_id):
model.eval()
model.cfg = cfg
return model

if args.train_refiner:
refiner_model = model_training
coarse_model = load_model(args.coarse_run_id_for_test)
Expand All @@ -108,7 +114,7 @@ def load_model(run_id):
refiner_model = load_model(args.refiner_run_id_for_test)
else:
raise ValueError

pose_estimator = PoseEstimator(
refiner_model=refiner_model,
coarse_model=coarse_model,
Expand All @@ -129,9 +135,19 @@ def load_model(run_id):
#pred_runner = MultiviewPredictionRunner(scene_ds_pred, batch_size=1,
# n_workers=args.n_dataloader_workers, cache_data=False)

inference = {'detection_type': 'gt', 'coarse_estimation_type': 'S03_grid', 'SO3_grid_size': 576,
'n_refiner_iterations': 5, 'n_pose_hypotheses': 5, 'run_depth_refiner': False,
'depth_refiner': None, 'bsz_objects': 16, 'bsz_images': 288}
#inference = {'detection_type': 'gt', 'coarse_estimation_type': 'S03_grid', 'SO3_grid_size': 576,
# 'n_refiner_iterations': 5, 'n_pose_hypotheses': 5, 'run_depth_refiner': False,
# 'depth_refiner': None, 'bsz_objects': 16, 'bsz_images': 288}

inference = InferenceConfig(detection_type= 'gt',
coarse_estimation_type='S03_grid',
SO3_grid_size= 576,
n_refiner_iterations=5,
n_pose_hypotheses= 5,
run_depth_refiner= False,
depth_refiner= None,
bsz_objects= 16,
bsz_images= 288)

pred_runner = PredictionRunner(
scene_ds=scene_ds,
Expand Down Expand Up @@ -255,7 +271,7 @@ def make_datasets(dataset_names):

scene_ds_train = make_datasets(args.train_ds_names)
scene_ds_val = make_datasets(args.val_ds_names)

ds_kwargs = dict(
resize=args.input_resize,
rgb_augmentation=args.rgb_augmentation,
Expand All @@ -266,6 +282,7 @@ def make_datasets(dataset_names):
ds_train = PoseDataset(scene_ds_train, **ds_kwargs)
ds_val = PoseDataset(scene_ds_val, **ds_kwargs)


train_sampler = PartialSampler(ds_train, epoch_size=args.epoch_size)
ds_iter_train = DataLoader(ds_train, sampler=train_sampler, batch_size=args.batch_size,
num_workers=args.n_dataloader_workers, collate_fn=ds_train.collate_fn,
Expand Down Expand Up @@ -341,30 +358,31 @@ def train_epoch():
iterator = tqdm(ds_iter_train, ncols=80)
t = time.time()
for n, sample in enumerate(iterator):
if n > 0:
meters_time['data'].add(time.time() - t)

optimizer.zero_grad()

t = time.time()
loss = h(data=sample, meters=meters_train)
meters_time['forward'].add(time.time() - t)
iterator.set_postfix(loss=loss.item())
meters_train['loss_total'].add(loss.item())

t = time.time()
loss.backward()
total_grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.clip_grad_norm, norm_type=2)
meters_train['grad_norm'].add(torch.as_tensor(total_grad_norm).item())

optimizer.step()
meters_time['backward'].add(time.time() - t)
meters_time['memory'].add(torch.cuda.max_memory_allocated() / 1024. ** 2)

if epoch < args.n_epochs_warmup:
lr_scheduler_warmup.step()
t = time.time()
if n<5:
if n > 0:
meters_time['data'].add(time.time() - t)

optimizer.zero_grad()

t = time.time()
loss = h(data=sample, meters=meters_train)
meters_time['forward'].add(time.time() - t)
iterator.set_postfix(loss=loss.item())
meters_train['loss_total'].add(loss.item())

t = time.time()
loss.backward()
total_grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=args.clip_grad_norm, norm_type=2)
meters_train['grad_norm'].add(torch.as_tensor(total_grad_norm).item())

optimizer.step()
meters_time['backward'].add(time.time() - t)
meters_time['memory'].add(torch.cuda.max_memory_allocated() / 1024. ** 2)

if epoch < args.n_epochs_warmup:
lr_scheduler_warmup.step()
t = time.time()
if epoch >= args.n_epochs_warmup:
lr_scheduler.step()

Expand All @@ -374,7 +392,7 @@ def validation():
for sample in tqdm(ds_iter_val, ncols=80):
loss = h(data=sample, meters=meters_val)
meters_val['loss_total'].add(loss.item())

@torch.no_grad()
def test():
model.eval()
Expand All @@ -383,11 +401,11 @@ def test():
train_epoch()
if epoch % args.val_epoch_interval == 0:
validation()

test_dict = None
if epoch % args.test_epoch_interval == 0:
test_dict = test()

log_dict = dict()
log_dict.update({
'grad_norm': meters_train['grad_norm'].mean,
Expand All @@ -405,7 +423,7 @@ def test():
for string, meters in zip(('train', 'val'), (meters_train, meters_val)):
for k in dict(meters).keys():
log_dict[f'{string}_{k}'] = meters[k].mean

test_dict = {}
log_dict = reduce_dict(log_dict)
if get_rank() == 0:
log(config=args, model=model, epoch=epoch,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,12 @@ def get_predictions(self, pose_estimator: PoseEstimator) -> Dict[str, PoseEstima

# ############ RUN ONLY BEGINNING OF DATASET
# # if n > 0:
# if n < 298:
if n > 10:
# # if n != 582:
# print('################')
# print('Prediction runner SKIP')
# print('################')
# continue
print('################')
print('Prediction runner SKIP')
print('################')
continue
# ############ RUN ONLY BEGINNING OF DATASET

# Dirty but avoids creating error when running with real detector
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ class PosePredictorOutputCosypose:
K_crop: torch.Tensor
boxes_rend: torch.Tensor
boxes_crop: torch.Tensor
model_outputs: torch.Tensor

def __getitem__(self, item):
return getattr(self, item)

@dataclass
class PosePredictorOutput:
Expand Down

0 comments on commit 08e85ca

Please sign in to comment.