-
Notifications
You must be signed in to change notification settings - Fork 222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
model.load_state_dict(checkpoint['state_dict']) error with pytorch 0.4.0 #26
Comments
i hv the same problem too. @alexandrecc , do u hv any solution so far? |
You should be able to make something like this work. import re
# Code modified from torchvision densenet source for loading from pre .4 densenet weights.
checkpoint = torch.load('./model.pth.tar')
state_dict = checkpoint['state_dict']
remove_data_parallel = False # Change if you don't want to use nn.DataParallel(model)
pattern = re.compile(
r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
for key in list(state_dict.keys()):
match = pattern.match(key)
new_key = match.group(1) + match.group(2) if match else key
new_key = new_key[7:] if remove_data_parallel else new_key
state_dict[new_key] = state_dict[key]
# Delete old key only if modified.
if match or remove_data_parallel:
del state_dict[key] |
Thanks JasperJenkins... This worked but I received another error in the form: |
Same Here |
+1 Seeing this error as well, has torch implemented some way to ensure backwards compatibility when parsing older models? I can't seem to find anything and I would rather not change the keys themselves since that seems quite error prone. |
Set num_workers=0 and try again to find the real issue. |
It's late to mentioned here, but testing started you can just put with checkpoint loading |
This Worked for me: state_dict = checkpoint['state_dict'] for k, v in state_dict.items(): model.load_state_dict(new_state_dict) |
did it work? |
yes in my case it's worked |
I was running the code without any problem on pytorch 0.3.0.
I upgraded yesterday to pytorch 0.4.0 and can't load the checkpoint file. I am on Ubuntu and python 3.6 in conda env.
I get this error:
RuntimeError Traceback (most recent call last)
in ()
181 if name == 'main':
--> 182 main()
in main()
39 print("=> loading checkpoint")
40 checkpoint = torch.load(CKPT_PATH)
---> 41 model.load_state_dict(checkpoint['state_dict'])
42 print("=> loaded checkpoint")
43 else:
~/anaconda3/envs/fastai/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
719 if len(error_msgs) > 0:
720 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 721 self.class.name, "\n\t".join(error_msgs)))
722
723 def parameters(self):
RuntimeError: Error(s) in loading state_dict for DenseNet121:
Missing key(s) in state_dict: "densenet121.features.conv0.weight", "densenet121.features.norm0.weight", "densenet121.features.norm0.bias", "densenet121.features.norm0.running_mean", "densenet121.features.norm0.running_var", "densenet121.features.denseblock1.denselayer1.norm1.weight", "densenet121.features.denseblock1.denselayer1.norm1.bias", "densenet121.features.denseblock1.denselayer1.norm1.running_mean",
(entire network ...)
"module.densenet121.features.denseblock4.denselayer16.conv.2.weight", "module.densenet121.features.norm5.weight", "module.densenet121.features.norm5.bias", "module.densenet121.features.norm5.running_mean", "module.densenet121.features.norm5.running_var", "module.densenet121.classifier.0.weight", "module.densenet121.classifier.0.bias".
It is likely related to this information about pytorch 0.4.0:
https://pytorch.org/2018/04/22/0_4_0-migration-guide.html
New edge-case constraints on names of submodules, parameters, and buffers in nn.Module
name that is an empty string or contains "." is no longer permitted in module.add_module(name, value), module.add_parameter(name, value) or module.add_buffer(name, value) because such names may cause lost data in the state_dict. If you are loading a checkpoint for modules containing such names, please update the module definition and patch the state_dict before loading it.
The text was updated successfully, but these errors were encountered: