Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stage 2 training BUG: following keys are missing #175

Open
RenieWell opened this issue May 31, 2024 · 8 comments
Open

Stage 2 training BUG: following keys are missing #175

RenieWell opened this issue May 31, 2024 · 8 comments

Comments

@RenieWell
Copy link

Thanks for sharing this work with us!

I have trained the stage 1 model with expected performance with the Unet weight from the pretrained model pretrained_model_name_or_path: 'lambdalabs/sd-image-variations-diffusers'

But after the training, the model trained in stage 1 cann't be loaded by the stage 2 code, and I got the following errors:
load pre-trained unet from ./outputs/wonder3D-mix-vanila/checkpoint/ Traceback (most recent call last): File "/data/.conda/envs/marigold/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/data/.conda/envs/marigold/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 39, in <module> cli.main() File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 430, in main run() File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 284, in run_file runpy.run_path(target, run_name="__main__") File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 321, in run_path return _run_module_code(code, init_globals, run_name, File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 135, in _run_module_code _run_code(code, mod_globals, init_globals, File "/data/.vscode-server/extensions/ms-python.debugpy-2024.6.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 124, in _run_code exec(code, run_globals) File "/data/Wonder3D/train_mvdiffusion_joint.py", line 773, in <module> main(cfg) File "/data/Wonder3D/train_mvdiffusion_joint.py", line 251, in main unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs) File "/data/.conda/envs/marigold/lib/python3.10/site-packages/diffusers/models/modeling_utils.py", line 604, in from_pretrained raise ValueError( ValueError: Cannot load <class 'mvdiffusion.models.unet_mv2d_condition.UNetMV2DConditionModel'> from ./outputs/wonder3D-mix-vanila/checkpoint/ because the following keys are missing: up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.1.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.0.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.1.transformer_blocks.0.norm_joint_mid.weight, down_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.1.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.0.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.2.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.1.transformer_blocks.0.norm_joint_mid.bias, down_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.1.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.2.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.3.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.0.transformer_blocks.0.norm_joint_mid.bias, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.3.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, down_blocks.0.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.0.attentions.0.transformer_blocks.0.norm_joint_mid.bias, up_blocks.2.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.bias, up_blocks.2.attentions.1.transformer_blocks.0.norm_joint_mid.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.1.transformer_blocks.0.attn_joint_mid.to_v.weight, down_blocks.1.attentions.1.transformer_blocks.0.norm_joint_mid.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, up_blocks.3.attentions.2.transformer_blocks.0.norm_joint_mid.weight, up_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_k.weight, down_blocks.2.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_out.0.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, down_blocks.0.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.3.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias, down_blocks.1.attentions.1.transformer_blocks.0.attn_joint_mid.to_q.weight, up_blocks.1.attentions.2.transformer_blocks.0.attn_joint_mid.to_v.weight. Please make sure to pass low_cpu_mem_usage=Falseanddevice_map=None if you want to randomly initialize those weights or else make sure your checkpoint file is correct.

According to my understanding, the missing weights should be trained in stage 2, so they cann't be loaded in before the stage 2's training. Do you have any ideas about this?

@SunzeY
Copy link

SunzeY commented Jun 8, 2024

similar problem!

@bbbbubble
Copy link

same problem. @RenieWell @SunzeY @xxlong0 Have you solved this?

@bbbbubble
Copy link

bbbbubble commented Jun 27, 2024

If add low_cpu_mem_usage=False to from_pretrained() function, or use from_pretrained_2d() function instead, it will start running successfully, but will be hard to converge...

image

@yyuezhi
Copy link

yyuezhi commented Aug 22, 2024

Hi, I am runing in to the same problem. Have you solved this? @bbbbubble @SunzeY @RenieWell

@mengxuyiGit
Copy link

maybe try replacing "from_pretrained" to "from_pretrained_2d" in

unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_unet_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs)
?

@liuyifan22
Copy link

@bbbbubble @mengxuyiGit @RenieWell @SunzeY

Hello, if the problem persists, maybe you can try my solution:
change the cd_attention_mid attribute in File "./configs/train/stage1-mix-6views-lvis.yaml " from false to true,

and change line 237 in ./train_mvdiffusion_image.py to

unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs).

It works for me. Seems that it is not a problem with the second stage but the first. Pls inform me if anything goes up

@bbbbubble
Copy link

@bbbbubble @mengxuyiGit @RenieWell @SunzeY

Hello, if the problem persists, maybe you can try my solution: change the cd_attention_mid attribute in File "./configs/train/stage1-mix-6views-lvis.yaml " from false to true,

and change line 237 in ./train_mvdiffusion_image.py to

unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="unet", revision=cfg.revision, **cfg.unet_from_pretrained_kwargs).

It works for me. Seems that it is not a problem with the second stage but the first. Pls inform me if anything goes up

Seems not work. "Missing Keys" error will show up in the first stage:

[rank0]: ValueError: Cannot load <class 'mvdiffusion.models.unet_mv2d_condition.UNetMV2DConditionModel'> from /home/azhe.cp/avatar_utils/sd-image-variations-diffusers because the following keys are missing:
[rank0]: mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_q.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_v.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.bias, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_k.weight, mid_block.attentions.0.transformer_blocks.0.norm_joint_mid.weight, mid_block.attentions.0.transformer_blocks.0.attn_joint_mid.to_out.0.bias.
[rank0]: Please make sure to pass low_cpu_mem_usage=False and device_map=None if you want to randomly initialize those weights or else make sure your checkpoint file is correct.

@liuyifan22
Copy link

Sorry, I was using Wonder3D's checkpoint as input for my 1st stage training, and that is fit.
If you want to train from lambdalabs/sd-image-variations-diffusers, my method will not be working. Sorry for the confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants