Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'DiffusionTransformer' object has no attribute 'cf_predict_start' #16

Open
songxueXS opened this issue Jun 22, 2022 · 3 comments
Open

Comments

@songxueXS
Copy link

songxueXS commented Jun 22, 2022

Hi author, thanks for sharing your inspiring work!
I'm trying to make re-implementation of Improved VQ-Diffusion. I have the problem.

Traceback (most recent call last):
File "/env/anaconda3-pytorch1.7-python3.7-cuda11/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
fn(i, *args)
File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/distributed/launch.py", line 93, in distributed_worker
fn(local_rank, *args)
File "/code/VQ-Diffusion-Improved_VQ-Diffusion/train.py", line 168, in main_worker
solver.train()
File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/engine/solver.py", line 546, in train
self.train_epoch()
File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/engine/solver.py", line 463, in train_epoch
self.sample(batch, phase='train', step_type='iteration')
File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/engine/solver.py", line 198, in sample
samples = model.sample(batch=batch, step=self.last_iter)
File "/env/anaconda3-pytorch1.7-python3.7-cuda11/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, **kwargs)
File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/modeling/models/dalle.py", line 303, in sample
**kwargs)
File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/modeling/transformers/diffusion_transformer.py", line 592, in sample
log_z, sampled = self.p_sample(log_z, cond_emb, t, sampled, self.n_sample[diffusion_index]) # log_z is log_onehot
File "/env/anaconda3-pytorch1.7-python3.7-cuda11/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
return func(*args, **kwargs)
File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/modeling/transformers/diffusion_transformer.py", line 278, in p_sample
model_log_prob, log_x_recon = self.p_pred(log_x, cond_emb, t)
File "/code/VQ-Diffusion-Improved_VQ-Diffusion/image_synthesis/modeling/transformers/diffusion_transformer.py", line 267, in p_pred
log_x_recon = self.cf_predict_start(log_x, cond_emb, t)
File "/env/anaconda3-pytorch1.7-python3.7-cuda11/lib/python3.7/site-packages/torch/nn/modules/module.py", line 779, in getattr
type(self).name, name))
torch.nn.modules.module.ModuleAttributeError: 'DiffusionTransformer' object has no attribute 'cf_predict_start'

@koutilya-pnvr
Copy link

log_x_recon = self.cf_predict_start(log_x, cond_emb, t)

More precisely, this is the code line from the Improved_VQ-Diffusion branch. The cf_predict_start function is not defined in the DiffusionTransformer class. Is it the same as the one from the dalle class?

def cf_predict_start(log_x_t, cond_emb, t):

@songxueXS
Copy link
Author

log_x_recon = self.cf_predict_start(log_x, cond_emb, t)

More precisely, this is the code line from the Improved_VQ-Diffusion branch. The cf_predict_start function is not defined in the DiffusionTransformer class. Is it the same as the one from the dalle class?

def cf_predict_start(log_x_t, cond_emb, t):

Thank you very much for your reply, but I find they are not exactly the same and 'def cf_predict_start(log_x_t, cond_emb, t)' depends on many parameters. Can you give a full version?

@tzco
Copy link
Collaborator

tzco commented Jun 30, 2022

Sorry for late reply. The cf_predict_start function is defined in generate_content in dalle.py for classifier-free sampling, but while training we don't need classifier-free sampling and it is not executed. By adding a placeholder in diffusion_transformer.py it should be solved:

def cf_predict_start(self, log_x_t, cond_emb, t):
    return self.predict_start(log_x_t, cond_emb, t)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants