Skip to content

Commit

Permalink
fix resume checkpoint bug
Browse files Browse the repository at this point in the history
  • Loading branch information
will-jl944 committed Oct 11, 2021
1 parent 3ea22be commit ba6df9c
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions paddlex/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ def load_pretrain_weights(model, pretrain_weights=None, model_name=None):
def load_optimizer(optimizer, state_dict_path):
logging.info("Loading optimizer from {}".format(state_dict_path))
optim_state_dict = paddle.load(state_dict_path)
for key in optimizer.state_dict().keys():
if key not in optim_state_dict.keys():
optim_state_dict[key] = optimizer.state_dict()[key]
if 'last_epoch' in optim_state_dict:
optim_state_dict.pop('last_epoch')
optimizer.set_state_dict(optim_state_dict)
Expand Down

0 comments on commit ba6df9c

Please sign in to comment.