Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict: #159

Open
goongzi-leean opened this issue Aug 2, 2022 · 0 comments

Comments

@goongzi-leean
Copy link

When I load AGGAN-Mod, I get this error:

RuntimeError: Error(s) in loading state_dict for Generator:
Missing key(s) in state_dict: "blocks.0.0.bn1.gain.weight", "blocks.0.0.bn1.bias.weight", "blocks.0.0.bn2.gain.weight", "blocks.0.0.bn2.bias.weight", "blocks.1.0.bn1.gain.weight", "blocks.1.0.bn1.bias.weight", "blocks.1.0.bn2.gain.weight", "blocks.1.0.bn2.bias.weight", "blocks.2.0.bn1.gain.weight", "blocks.2.0.bn1.bias.weight", "blocks.2.0.bn2.gain.weight", "blocks.2.0.bn2.bias.weight".
Unexpected key(s) in state_dict: "blocks.0.0.bn1.embed0.weight", "blocks.0.0.bn1.embed1.weight", "blocks.0.0.bn2.embed0.weight", "blocks.0.0.bn2.embed1.weight", "blocks.1.0.bn1.embed0.weight", "blocks.1.0.bn1.embed1.weight", "blocks.1.0.bn2.embed0.weight", "blocks.1.0.bn2.embed1.weight", "blocks.2.0.bn1.embed0.weight", "blocks.2.0.bn1.embed1.weight", "blocks.2.0.bn2.embed0.weight", "blocks.2.0.bn2.embed1.weight".

So I went to find out why.
The network structure in which my generator was found looks like this:

Generator(
(linear0): Linear(in_features=128, out_features=4096, bias=True)
(blocks): ModuleList(
(0): ModuleList(
(0): GenBlock(
(bn1): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(gain): Linear(in_features=10, out_features=256, bias=False)
(bias): Linear(in_features=10, out_features=256, bias=False)

)
(bn2): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(gain): Linear(in_features=10, out_features=256, bias=False)
(bias): Linear(in_features=10, out_features=256, bias=False)
)
(activation): ReLU(inplace=True)
(conv2d0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2d1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2d2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)

And the log that the author trained looks like this:

Generator(
(linear0): Linear(in_features=128, out_features=4096, bias=True)
(blocks): ModuleList(
(0): ModuleList(
(0): GenBlock(
(bn1): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(embed0): Embedding(10, 256)
(embed1): Embedding(10, 256)

)
(bn2): ConditionalBatchNorm2d(
(bn): BatchNorm2d(256, eps=0.0001, momentum=0.1, affine=False, track_running_stats=True)
(embed0): Embedding(10, 256)
(embed1): Embedding(10, 256)
)
(activation): ReLU(inplace=True)
(conv2d0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(conv2d1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(conv2d2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)

I found ConditionalBatchNorm2d (in ops.py) in the latest code and found:

self.gain = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False)
self.bias = MODULES.g_linear(in_features=in_features, out_features=out_features, bias=False)

but g_linear= ops.linear(in config.py)

This is where the above error comes in.

ConditionalBatchNorm2d will need to be modified if a load author pre-trained generator is required. Or you can choose to retrain. This is true for all conditions GAN.

Of course, I hope the author can pay attention to this problem.

Best!

Leean

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant