Skip to content

Commit

Permalink
code improve
Browse files Browse the repository at this point in the history
  • Loading branch information
max-unfinity committed Nov 21, 2024
1 parent 14512c3 commit 666e51e
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions rtdetrv2_pytorch/src/solver/det_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,10 @@ def fit(self, ):
checkpoint_paths = [self.output_dir / 'last.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.checkpoint_freq == 0:
checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth')
checkpoint_paths.append(self.output_dir / f'checkpoint{epoch + 1:04}.pth')
for checkpoint_path in checkpoint_paths:
state_dict = self.state_dict()
if not self.cfg.yaml_cfg['save_optimizer'] and "optimizer" in state_dict:
state_dict.pop("optimizer")
if not self.cfg.yaml_cfg['save_ema'] and "ema" in state_dict:
state_dict.pop("model") # keep ema as a model
self._strip_state_dict(state_dict)
dist_utils.save_on_master(state_dict, checkpoint_path)

module = self.ema.module if self.ema else self.model
Expand All @@ -92,7 +89,9 @@ def fit(self, ):
best_stat[k] = test_stats[k][0]

if best_stat['epoch'] == epoch and self.output_dir:
dist_utils.save_on_master(self.state_dict(), self.output_dir / 'best.pth')
state_dict = self.state_dict()
self._strip_state_dict(state_dict)
dist_utils.save_on_master(state_dict, self.output_dir / 'best.pth')

print(f'best_stat: {best_stat}')

Expand Down Expand Up @@ -134,3 +133,9 @@ def val(self, ):
dist_utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth")

return

def _strip_state_dict(self, state_dict):
if not self.cfg.yaml_cfg['save_optimizer'] and "optimizer" in state_dict:
state_dict.pop("optimizer")
if not self.cfg.yaml_cfg['save_ema'] and "ema" in state_dict:
state_dict.pop("model") # keep ema as a model

0 comments on commit 666e51e

Please sign in to comment.