Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTroch-Lightning Version Update #104

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Do not track log outputs if they appear here.
logs/**

# Let's not track GBs of binary data.
*.ckpt

# Do not track Python compilation artifacts.
**/__pycache__/**

# Do not track local installation artifacts.
latent_diffusion.egg-info/**
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ python main.py --base configs/latent-diffusion/txt2img-1p4B-finetune.yaml
-t
--actual_resume /path/to/pretrained/model.ckpt
-n <run_name>
--gpus 0,
--accelerator gpu
--data_root /path/to/directory/with/images
--init_word <initialization_word>
```

where the initialization word should be a single-token rough description of the object (e.g., 'toy', 'painting', 'sculpture'). If the input is comprised of more than a single token, you will be prompted to replace it.

Please note that `init_word` is *not* the placeholder string that will later represent the concept. It is only used as a beggining point for the optimization scheme.
Please note that `init_word` is *not* the placeholder string that will later represent the concept. It is only used as a beginning point for the optimization scheme.

In the paper, we use 5k training iterations. However, some concepts (particularly styles) can converge much faster.

To run on multiple GPUs, provide a comma-delimited list of GPU indices to the --gpus argument (e.g., ``--gpus 0,3,7,8``)
With the above arguments, this will run on all available GPUs. You can provide a comma-delimited list of GPU indices to the --devices argument (e.g., ``--devices 0,`` or ``--devices 0,3,7,8``) to specify specific GPUs or a single integer (e.g., ``--devices 2``) to indicate how many GPUs to use and let Pytorch-Lightning decide which ones to allocate to the task.

Embeddings and output images will be saved in the log directory.

Expand Down
2 changes: 1 addition & 1 deletion environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies:
- pudb==2019.2
- imageio==2.14.1
- imageio-ffmpeg==0.4.7
- pytorch-lightning==1.5.9
- pytorch-lightning==1.7.7
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit>=0.73.1
Expand Down
8 changes: 4 additions & 4 deletions ldm/models/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ def training_step(self, batch, batch_idx, optimizer_idx):
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)

self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True, sync_dist=True)
return aeloss

if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True, sync_dist=True)
return discloss

def validation_step(self, batch, batch_idx):
Expand Down Expand Up @@ -356,7 +356,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss

Expand All @@ -365,7 +365,7 @@ def training_step(self, batch, batch_idx, optimizer_idx):
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")

self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss

Expand Down
4 changes: 2 additions & 2 deletions ldm/models/diffusion/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ def write_logs(self, loss, logits, targets):
logits, targets, k=5, reduction="mean"
)

self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True, sync_dist=True)
self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
self.log('global_step', torch.tensor(self.global_step, dtype=torch.float32), logger=False, on_epoch=False, prog_bar=True)
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)

Expand Down
31 changes: 16 additions & 15 deletions ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,12 @@ def shared_step(self, batch):
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)

self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)

self.log("global_step", self.global_step,
self.log("global_step", torch.tensor(self.global_step, dtype=torch.float32),
prog_bar=True, logger=True, on_step=True, on_epoch=False)

self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True, sync_dist=True)

if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
Expand All @@ -369,8 +369,8 @@ def validation_step(self, batch, batch_idx):
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)

def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
Expand Down Expand Up @@ -505,7 +505,7 @@ def make_cond_schedule(self, ):

@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
Expand Down Expand Up @@ -921,8 +921,8 @@ def forward(self, x, c, *args, **kwargs):

def _rescale_annotations(self, bboxes, crop_coordinates): # TODO: move to dataset
def rescale_bbox(bbox):
x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
x0 = torch.clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
y0 = torch.clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
w = min(bbox[2] / crop_coordinates[2], 1 - x0)
h = min(bbox[3] / crop_coordinates[3], 1 - y0)
return x0, y0, w, h
Expand Down Expand Up @@ -1068,12 +1068,13 @@ def p_losses(self, x_start, cond, t, noise=None):
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})

logvar_t = self.logvar[t].to(self.device)
logvar = self.logvar.to(self.device)
logvar_t = logvar[t]
loss = loss_simple / torch.exp(logvar_t) + logvar_t
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
# loss = loss_simple / torch.exp(logvar) + logvar
if self.learn_logvar:
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
loss_dict.update({'logvar': self.logvar.data.mean()})
loss_dict.update({'logvar': logvar.data.mean()})

loss = self.l_simple_weight * loss.mean()

Expand Down Expand Up @@ -1136,7 +1137,7 @@ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
if return_codebook_ids:
raise DeprecationWarning("Support dropped.")
model_mean, _, model_log_variance, logits = outputs
# model_mean, _, model_log_variance, logits = outputs
elif return_x0:
model_mean, _, model_log_variance, x0 = outputs
else:
Expand All @@ -1148,8 +1149,8 @@ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))

if return_codebook_ids:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
# if return_codebook_ids:
# return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
if return_x0:
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
else:
Expand Down
5 changes: 2 additions & 3 deletions ldm/modules/embedding_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch import nn
from torch import Tensor, nn

from ldm.data.personalized import per_img_token_list
from transformers import CLIPTokenizer
Expand Down Expand Up @@ -148,8 +148,7 @@ def embedding_parameters(self):
return self.string_to_param_dict.parameters()

def embedding_to_coarse_loss(self):

loss = 0.
loss = torch.zeros(1, requires_grad=True)
num_embeddings = len(self.initial_embeddings)

for key in self.initial_embeddings:
Expand Down
Loading