Skip to content

Commit

Permalink
debugging cosypose training
Browse files Browse the repository at this point in the history
  • Loading branch information
centos Cloud User committed Jan 24, 2024
1 parent f5f1faf commit 31be268
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 33 deletions.
70 changes: 38 additions & 32 deletions happypose/pose_estimators/cosypose/cosypose/training/train_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,44 +422,50 @@ def train_epoch():
iterator = tqdm(ds_iter_train, ncols=80)
t = time.time()
for n, sample in enumerate(iterator):
if n > 0:
meters_time["data"].add(time.time() - t)

optimizer.zero_grad()

t = time.time()
loss = h(data=sample, meters=meters_train)
meters_time["forward"].add(time.time() - t)
iterator.set_postfix(loss=loss.item())
meters_train["loss_total"].add(loss.item())

t = time.time()
loss.backward()
total_grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=args.clip_grad_norm,
norm_type=2,
)
meters_train["grad_norm"].add(torch.as_tensor(total_grad_norm).item())

optimizer.step()
meters_time["backward"].add(time.time() - t)
meters_time["memory"].add(
torch.cuda.max_memory_allocated() / 1024.0**2,
)

if epoch < args.n_epochs_warmup:
lr_scheduler_warmup.step()
t = time.time()
if n < 5:
if n > 0:
meters_time["data"].add(time.time() - t)

optimizer.zero_grad()

t = time.time()
loss = h(data=sample, meters=meters_train)
meters_time["forward"].add(time.time() - t)
iterator.set_postfix(loss=loss.item())
meters_train["loss_total"].add(loss.item())

t = time.time()
loss.backward()
total_grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
max_norm=args.clip_grad_norm,
norm_type=2,
)
meters_train["grad_norm"].add(torch.as_tensor(total_grad_norm).item())

optimizer.step()
meters_time["backward"].add(time.time() - t)
meters_time["memory"].add(
torch.cuda.max_memory_allocated() / 1024.0**2,
)

if epoch < args.n_epochs_warmup:
lr_scheduler_warmup.step()
t = time.time()
else:
continue
if epoch >= args.n_epochs_warmup:
lr_scheduler.step()

@torch.no_grad()
def validation():
model.eval()
for sample in tqdm(ds_iter_val, ncols=80):
loss = h(data=sample, meters=meters_val)
meters_val["loss_total"].add(loss.item())
for n, sample in enumerate(tqdm(ds_iter_val, ncols=80)):
if n < 5:
loss = h(data=sample, meters=meters_val)
meters_val["loss_total"].add(loss.item())
else:
continue

@torch.no_grad()
def test():
Expand Down
5 changes: 4 additions & 1 deletion happypose/toolbox/lib3d/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def __init__(
else:
rotation_np = rotation
else:
raise ValueError
if isinstance(rotation, list):
rotation_np = np.array(rotation)
else:
raise ValueError

if rotation_np.size == 4:
quaternion_xyzw = rotation_np.flatten().tolist()
Expand Down

0 comments on commit 31be268

Please sign in to comment.