diff --git a/TrainingCodes/dncnn_pytorch/main_train.py b/TrainingCodes/dncnn_pytorch/main_train.py index d4599f9b..7593cd1f 100644 --- a/TrainingCodes/dncnn_pytorch/main_train.py +++ b/TrainingCodes/dncnn_pytorch/main_train.py @@ -133,7 +133,8 @@ def log(*args, **kwargs): initial_epoch = findLastCheckpoint(save_dir=save_dir) # load the last model in matconvnet style if initial_epoch > 0: print('resuming by loading epoch %03d' % initial_epoch) - model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))) + # model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))) + model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)) model.train() # criterion = nn.MSELoss(reduction = 'sum') # PyTorch 0.4.1 criterion = sum_squared_error()