Skip to content

Commit

Permalink
Update main_train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cszn authored Aug 30, 2018
1 parent 683a871 commit 5222842
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion TrainingCodes/dncnn_pytorch/main_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def log(*args, **kwargs):
initial_epoch = findLastCheckpoint(save_dir=save_dir) # load the last model in matconvnet style
if initial_epoch > 0:
print('resuming by loading epoch %03d' % initial_epoch)
model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)))
# model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)))
model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))
model.train()
# criterion = nn.MSELoss(reduction = 'sum') # PyTorch 0.4.1
criterion = sum_squared_error()
Expand Down

0 comments on commit 5222842

Please sign in to comment.