-
Notifications
You must be signed in to change notification settings - Fork 395
/
train_mvdiffusion_mixed.py
859 lines (736 loc) · 38.1 KB
/
train_mvdiffusion_mixed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
from comet_ml import Experiment
from comet_ml.integration.pytorch import log_model
import argparse
import datetime
import logging
import inspect
import math
import os
from typing import Dict, Optional, Tuple, List
from omegaconf import OmegaConf
from PIL import Image
import cv2
import numpy as np
from dataclasses import dataclass
from packaging import version
import shutil
from collections import defaultdict
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
from torchvision.utils import make_grid, save_image
import transformers
import accelerate
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, StableDiffusionPipeline
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from mv_diffusion_30.models.unet_mv2d_condition import UNetMV2DConditionModel
from mv_diffusion_30.data.objaverse_dataset import ObjaverseDataset as MVDiffusionDataset
from mv_diffusion_30.pipelines.pipeline_mvdiffusion_image import MVDiffusionImagePipeline
from einops import rearrange
import time
logger = get_logger(__name__, log_level="INFO")
@dataclass
class TrainingConfig:
pretrained_model_name_or_path: str
pretrained_unet_path: Optional[str]
revision: Optional[str]
train_dataset: Dict
validation_dataset: Dict
validation_train_dataset: Dict
output_dir: str
seed: Optional[int]
train_batch_size: int
validation_batch_size: int
validation_train_batch_size: int
max_train_steps: int
gradient_accumulation_steps: int
gradient_checkpointing: bool
learning_rate: float
scale_lr: bool
lr_scheduler: str
lr_warmup_steps: int
snr_gamma: Optional[float]
use_8bit_adam: bool
allow_tf32: bool
use_ema: bool
dataloader_num_workers: int
adam_beta1: float
adam_beta2: float
adam_weight_decay: float
adam_epsilon: float
max_grad_norm: Optional[float]
prediction_type: Optional[str]
logging_dir: str
vis_dir: str
mixed_precision: Optional[str]
report_to: Optional[str]
local_rank: int
checkpointing_steps: int
checkpoints_total_limit: Optional[int]
resume_from_checkpoint: Optional[str]
enable_xformers_memory_efficient_attention: bool
validation_steps: int
validation_sanity_check: bool
tracker_project_name: str
trainable_modules: Optional[list]
use_classifier_free_guidance: bool
condition_drop_rate: float
scale_input_latents: bool
pipe_kwargs: Dict
pipe_validation_kwargs: Dict
unet_from_pretrained_kwargs: Dict
validation_guidance_scales: List[float]
validation_grid_nrow: int
camera_embedding_lr_mult: float
num_views: int
camera_embedding_type: str
pred_type: str
drop_type: str
def log_validation(dataloader, vae, feature_extractor, image_encoder, unet, cfg: TrainingConfig, accelerator,
weight_dtype, global_step, name, save_dir):
logger.info(f"Running {name} ... ")
pipeline = MVDiffusionImagePipeline(
image_encoder=image_encoder, feature_extractor=feature_extractor, vae=vae, unet=accelerator.unwrap_model(unet),
safety_checker=None,
scheduler=DDIMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler"),
**cfg.pipe_kwargs
)
pipeline.set_progress_bar_config(disable=True)
# if cfg.enable_xformers_memory_efficient_attention:
# pipeline.enable_xformers_memory_efficient_attention()
# pass
if cfg.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(cfg.seed)
images_cond, images_gt, images_pred = [], [], defaultdict(list)
for i, batch in enumerate(dataloader):
# (B, Nv, 3, H, W)
if cfg.pred_type == 'color' or cfg.pred_type == 'mixed_rgb_normal_depth' or cfg.pred_type == 'mixed_color_normal':
imgs_in, imgs_out = batch['imgs_in'], batch['imgs_out']
elif cfg.pred_type == 'normal':
imgs_in, imgs_out = batch['imgs_in'], batch['normals_out']
else:
imgs_in, imgs_out = batch['imgs_in'], batch['imgs_out']
# (B, Nv, Nce)
camera_embeddings = batch['camera_embeddings']
if cfg.pred_type == 'mixed_rgb_normal_depth' or cfg.pred_type == 'color' or cfg.pred_type == 'mixed_color_normal' or cfg.pred_type == 'mixed_rgb_noraml_mask':
task_embeddings = batch['task_embeddings']
camera_embeddings = torch.cat([camera_embeddings, task_embeddings], dim=-1)
# (B*Nv, 3, H, W)
imgs_in, imgs_out = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W"), rearrange(imgs_out,
"B Nv C H W -> (B Nv) C H W")
# (B*Nv, Nce)
camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce")
images_cond.append(imgs_in)
images_gt.append(imgs_out)
with torch.autocast("cuda"):
# B*Nv images
for guidance_scale in cfg.validation_guidance_scales:
out = pipeline(
imgs_in, camera_embeddings, generator=generator, guidance_scale=guidance_scale, output_type='pt',
num_images_per_prompt=1, **cfg.pipe_validation_kwargs
).images
images_pred[f"{name}-sample_cfg{guidance_scale:.1f}"].append(out)
nrow = cfg.validation_grid_nrow
images_gt_all = torch.cat(images_gt, dim=0)
images_gt_grid = make_grid(images_gt_all, nrow=nrow, padding=0, value_range=(0, 1))
save_image(images_gt_grid, os.path.join(save_dir, f"{global_step}-{name}-gt.jpg"))
images_cond_all = torch.cat(images_cond, dim=0)
images_pred_all = {}
for k, v in images_pred.items():
images_pred_all[k] = torch.cat(v, dim=0)
images_cond_grid = make_grid(images_cond_all, nrow=nrow, padding=0, value_range=(0, 1))
images_pred_grid = {}
for k, v in images_pred_all.items():
images_pred_grid[k] = make_grid(v, nrow=nrow, padding=0, value_range=(0, 1))
save_image(images_cond_grid, os.path.join(save_dir, f"{global_step}-{name}-cond.jpg"))
for k, v in images_pred_grid.items():
save_image(v, os.path.join(save_dir, f"{global_step}-{k}.jpg"))
torch.cuda.empty_cache()
def main(
cfg: TrainingConfig
):
# override local_rank with envvar
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != cfg.local_rank:
cfg.local_rank = env_local_rank
vis_dir = os.path.join(cfg.output_dir, cfg.vis_dir)
logging_dir = os.path.join(cfg.output_dir, cfg.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
mixed_precision=cfg.mixed_precision,
log_with=cfg.report_to,
project_config=accelerator_project_config,
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if cfg.seed is not None:
set_seed(cfg.seed)
generator = torch.Generator(device=accelerator.device).manual_seed(cfg.seed)
# Handle the repository creation
if accelerator.is_main_process:
os.makedirs(cfg.output_dir, exist_ok=True)
os.makedirs(vis_dir, exist_ok=True)
OmegaConf.save(cfg, os.path.join(cfg.output_dir, 'config.yaml'))
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(cfg.pretrained_model_name_or_path,
subfolder="image_encoder", revision=cfg.revision)
feature_extractor = CLIPImageProcessor.from_pretrained(cfg.pretrained_model_name_or_path,
subfolder="feature_extractor", revision=cfg.revision)
vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="vae", revision=cfg.revision)
if cfg.pretrained_unet_path is None:
unet = UNetMV2DConditionModel.from_pretrained_2d(cfg.pretrained_model_name_or_path, subfolder="unet",
revision=cfg.revision, **cfg.unet_from_pretrained_kwargs)
else:
print("load pre-trained unet from ", cfg.pretrained_unet_path)
unet = UNetMV2DConditionModel.from_pretrained(cfg.pretrained_unet_path, revision=cfg.revision,
**cfg.unet_from_pretrained_kwargs)
if cfg.use_ema:
ema_unet = EMAModel(unet.parameters(), model_cls=UNetMV2DConditionModel, model_config=unet.config)
def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod ** 0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
# Freeze vae and text_encoder
vae.requires_grad_(False)
image_encoder.requires_grad_(False)
if cfg.trainable_modules is None:
unet.requires_grad_(True)
else:
unet.requires_grad_(False)
for name, module in unet.named_modules():
if name.endswith(tuple(cfg.trainable_modules)):
for params in module.parameters():
params.requires_grad = True
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if cfg.use_ema:
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
for i, model in enumerate(models):
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
if cfg.use_ema:
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNetMV2DConditionModel)
ema_unet.load_state_dict(load_model.state_dict())
ema_unet.to(accelerator.device)
del load_model
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
load_model = UNetMV2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if cfg.gradient_checkpointing:
unet.enable_gradient_checkpointing()
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if cfg.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if cfg.scale_lr:
cfg.learning_rate = (
cfg.learning_rate * cfg.gradient_accumulation_steps * cfg.train_batch_size * accelerator.num_processes
)
# Initialize the optimizer
if cfg.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
params, params_class_embedding = [], []
for name, param in unet.named_parameters():
if 'class_embedding' in name:
params_class_embedding.append(param)
else:
params.append(param)
# 打印参数的名字和数据类型
# print(f"Parameter: {name}, Type: {param.dtype}")
optimizer = optimizer_cls(
[
{"params": params, "lr": cfg.learning_rate},
{"params": params_class_embedding, "lr": cfg.learning_rate * cfg.camera_embedding_lr_mult}
],
betas=(cfg.adam_beta1, cfg.adam_beta2),
weight_decay=cfg.adam_weight_decay,
eps=cfg.adam_epsilon,
)
lr_scheduler = get_scheduler(
cfg.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.lr_warmup_steps * accelerator.num_processes,
num_training_steps=cfg.max_train_steps * accelerator.num_processes,
)
# Get the training dataset
train_dataset = MVDiffusionDataset(
**cfg.train_dataset
)
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=cfg.train_batch_size, shuffle=True, num_workers=cfg.dataloader_num_workers,
)
cam_condition = (cfg.train_dataset.pred_ortho and cfg.train_dataset.pred_persp)
if cam_condition:
cfg.validation_dataset.pred_ortho, cfg.validation_dataset.pred_persp = True, False
cfg.validation_train_dataset.pred_ortho, cfg.validation_train_dataset.pred_persp = True, False
validation_dataset_ortho = MVDiffusionDataset(
**cfg.validation_dataset
)
validation_train_dataset_ortho = MVDiffusionDataset(
**cfg.validation_train_dataset
)
cfg.validation_dataset.pred_ortho, cfg.validation_dataset.pred_persp = False, True
cfg.validation_train_dataset.pred_ortho, cfg.validation_train_dataset.pred_persp = False, True
validation_dataset_persp = MVDiffusionDataset(
**cfg.validation_dataset
)
validation_train_dataset_persp = MVDiffusionDataset(
**cfg.validation_train_dataset
)
# print(validation_dataset_ortho.pred_ortho, validation_dataset_ortho.pred_persp, validation_train_dataset_persp.pred_ortho, validation_train_dataset_persp.pred_persp, validation_dataset_persp.root_dir_persp)
validation_dataloader_ortho = torch.utils.data.DataLoader(
validation_dataset_ortho, batch_size=cfg.validation_batch_size, shuffle=False,
num_workers=cfg.dataloader_num_workers
)
validation_dataloader_persp = torch.utils.data.DataLoader(
validation_dataset_persp, batch_size=cfg.validation_batch_size, shuffle=False,
num_workers=cfg.dataloader_num_workers
)
validation_train_dataloader_ortho = torch.utils.data.DataLoader(
validation_train_dataset_ortho, batch_size=cfg.validation_train_batch_size, shuffle=False,
num_workers=cfg.dataloader_num_workers
)
validation_train_dataloader_persp = torch.utils.data.DataLoader(
validation_train_dataset_persp, batch_size=cfg.validation_train_batch_size, shuffle=False,
num_workers=cfg.dataloader_num_workers
)
else:
validation_dataset = MVDiffusionDataset(
**cfg.validation_dataset
)
validation_train_dataset = MVDiffusionDataset(
**cfg.validation_train_dataset
)
validation_dataloader = torch.utils.data.DataLoader(
validation_dataset, batch_size=cfg.validation_batch_size, shuffle=False,
num_workers=cfg.dataloader_num_workers
)
validation_train_dataloader = torch.utils.data.DataLoader(
validation_train_dataset, batch_size=cfg.validation_train_batch_size, shuffle=False,
num_workers=cfg.dataloader_num_workers
)
# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
if cfg.use_ema:
ema_unet.to(accelerator.device)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
cfg.mixed_precision = accelerator.mixed_precision
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
cfg.mixed_precision = accelerator.mixed_precision
# Move text_encode and vae to gpu and cast to weight_dtype
image_encoder.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
clip_image_mean = torch.as_tensor(feature_extractor.image_mean)[:, None, None].to(accelerator.device,
dtype=torch.float32)
clip_image_std = torch.as_tensor(feature_extractor.image_std)[:, None, None].to(accelerator.device,
dtype=torch.float32)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
# tracker_config = dict(vars(cfg))
tracker_config = {}
accelerator.init_trackers(cfg.tracker_project_name, tracker_config)
# Train!
total_batch_size = cfg.train_batch_size * accelerator.num_processes * cfg.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {cfg.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {cfg.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {cfg.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if cfg.resume_from_checkpoint:
if cfg.resume_from_checkpoint != "latest":
path = os.path.basename(cfg.resume_from_checkpoint)
else:
# Get the most recent checkpoint
if os.path.exists(os.path.join(cfg.output_dir, "checkpoint")):
path = "checkpoint"
else:
dirs = os.listdir(cfg.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{cfg.resume_from_checkpoint}' does not exist. Starting a new training run."
)
cfg.resume_from_checkpoint = None
else:
accelerator.print(f"Resuming from checkpoint {path}")
global_step = 0 # 0 or just change this
resume_global_step = global_step * cfg.gradient_accumulation_steps
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * cfg.gradient_accumulation_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, cfg.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
if accelerator.is_main_process:
experiment = Experiment(
api_key="your_api_key",
project_name="your_project_name",
workspace="your_workspace"
)
hyper_params = {
"learning_rate": 0.0001,
"steps": 80000,
"batch_size": 256
}
experiment.log_parameters(hyper_params)
for epoch in range(first_epoch, num_train_epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if cfg.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % cfg.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet):
# (B, Nv, 3, H, W)
if cfg.pred_type == 'color' or cfg.pred_type == 'mixed_rgb_normal_depth' or cfg.pred_type == 'mixed_color_normal':
imgs_in, imgs_out = batch['imgs_in'], batch['imgs_out']
elif cfg.pred_type == 'normal':
imgs_in, imgs_out = batch['imgs_in'], batch['normals_out']
else:
imgs_in, imgs_out = batch['imgs_in'], batch['imgs_out']
bnm, Nv = imgs_in.shape[0], imgs_in.shape[1]
# (B, Nv, Nce)
camera_embeddings = batch['camera_embeddings']
if cfg.pred_type == 'mixed_rgb_normal_depth' or cfg.pred_type == 'color' or cfg.pred_type == 'mixed_color_normal' or cfg.pred_type == 'mixed_rgb_noraml_mask':
task_embeddings = batch['task_embeddings']
camera_embeddings = torch.cat([camera_embeddings, task_embeddings], dim=-1)
# (B*Nv, 3, H, W)
imgs_in, imgs_out = rearrange(imgs_in, "B Nv C H W -> (B Nv) C H W"), rearrange(imgs_out,
"B Nv C H W -> (B Nv) C H W")
# (B*Nv, Nce)
camera_embeddings = rearrange(camera_embeddings, "B Nv Nce -> (B Nv) Nce")
# (B*Nv, Nce')
if cfg.camera_embedding_type == 'e_de_da_sincos':
camera_embeddings = torch.cat([
torch.sin(camera_embeddings),
torch.cos(camera_embeddings)
], dim=-1)
else:
raise NotImplementedError
imgs_in, imgs_out, camera_embeddings = imgs_in.to(weight_dtype), imgs_out.to(
weight_dtype), camera_embeddings.to(weight_dtype)
# (B*Nv, 4, Hl, Wl)
cond_vae_embeddings = vae.encode(imgs_in * 2.0 - 1.0).latent_dist.mode()
if cfg.scale_input_latents:
cond_vae_embeddings = cond_vae_embeddings * vae.config.scaling_factor
latents = vae.encode(imgs_out * 2.0 - 1.0).latent_dist.sample() * vae.config.scaling_factor
# DO NOT use this! Very slow!
# imgs_in_pil = [TF.to_pil_image(img) for img in imgs_in]
# imgs_in_proc = feature_extractor(images=imgs_in_pil, return_tensors='pt').pixel_values.to(dtype=latents.dtype, device=latents.device)
imgs_in_proc = TF.resize(imgs_in,
(feature_extractor.crop_size['height'], feature_extractor.crop_size['width']),
interpolation=InterpolationMode.BICUBIC)
# do the normalization in float32 to preserve precision
imgs_in_proc = ((imgs_in_proc.float() - clip_image_mean) / clip_image_std).to(weight_dtype)
# (B*Nv, 1, 768)
image_embeddings = image_encoder(imgs_in_proc).image_embeds.unsqueeze(1)
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# same noise for different views of the same object
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz // cfg.num_views,),
device=latents.device).repeat_interleave(cfg.num_views)
timesteps = timesteps.long()
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Conditioning dropout to support classifier-free guidance during inference. For more details
# check out the section 3.2.1 of the original paper https://arxiv.org/abs/2211.09800.
if cfg.use_classifier_free_guidance and cfg.condition_drop_rate > 0.:
if cfg.drop_type == 'drop_as_a_whole':
# drop a group of normals and colors as a whole
random_p = torch.rand(bnm, device=latents.device, generator=generator)
# Sample masks for the conditioning images.
image_mask_dtype = cond_vae_embeddings.dtype
image_mask = 1 - (
(random_p >= cfg.condition_drop_rate).to(image_mask_dtype)
* (random_p < 3 * cfg.condition_drop_rate).to(image_mask_dtype)
)
image_mask = image_mask.reshape(bnm, 1, 1, 1, 1).repeat(1, Nv, 1, 1, 1)
image_mask = rearrange(image_mask, "B Nv C H W -> (B Nv) C H W")
# Final image conditioning.
cond_vae_embeddings = image_mask * cond_vae_embeddings
# Sample masks for the conditioning images.
clip_mask_dtype = image_embeddings.dtype
clip_mask = 1 - (
(random_p < 2 * cfg.condition_drop_rate).to(clip_mask_dtype)
)
clip_mask = clip_mask.reshape(bnm, 1, 1, 1).repeat(1, Nv, 1, 1)
clip_mask = rearrange(clip_mask, "B Nv M C -> (B Nv) M C")
# Final image conditioning.
image_embeddings = clip_mask * image_embeddings
elif cfg.drop_type == 'drop_independent':
random_p = torch.rand(bsz, device=latents.device, generator=generator)
# Sample masks for the conditioning images.
image_mask_dtype = cond_vae_embeddings.dtype
image_mask = 1 - (
(random_p >= cfg.condition_drop_rate).to(image_mask_dtype)
* (random_p < 3 * cfg.condition_drop_rate).to(image_mask_dtype)
)
image_mask = image_mask.reshape(bsz, 1, 1, 1)
# Final image conditioning.
cond_vae_embeddings = image_mask * cond_vae_embeddings
# Sample masks for the conditioning images.
clip_mask_dtype = image_embeddings.dtype
clip_mask = 1 - (
(random_p < 2 * cfg.condition_drop_rate).to(clip_mask_dtype)
)
clip_mask = clip_mask.reshape(bsz, 1, 1)
# Final image conditioning.
image_embeddings = clip_mask * image_embeddings
# (B*Nv, 8, Hl, Wl)
latent_model_input = torch.cat([noisy_latents, cond_vae_embeddings], dim=1)
model_pred = unet(
latent_model_input,
timesteps,
encoder_hidden_states=image_embeddings,
class_labels=camera_embeddings
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if cfg.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
torch.stack([snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(cfg.train_batch_size)).mean()
train_loss += avg_loss.item() / cfg.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
# print(loss.dtype)
if accelerator.sync_gradients and cfg.max_grad_norm is not None:
accelerator.clip_grad_norm_(unet.parameters(), cfg.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if cfg.use_ema:
ema_unet.step(unet.parameters())
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % cfg.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(cfg.output_dir, f"checkpoint")
accelerator.save_state(save_path)
try:
unet.module.save_pretrained(os.path.join(cfg.output_dir, f"unet-{global_step}"))
except:
unet.save_pretrained(os.path.join(cfg.output_dir, f"unet-{global_step}"))
logger.info(f"Saved state to {save_path}")
if global_step % cfg.validation_steps == 0 or (cfg.validation_sanity_check and global_step == 1):
if accelerator.is_main_process:
if cfg.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
if cam_condition:
log_validation(
validation_dataloader_ortho,
vae,
feature_extractor,
image_encoder,
unet,
cfg,
accelerator,
weight_dtype,
global_step,
'validation_ortho',
vis_dir
)
log_validation(
validation_dataloader_persp,
vae,
feature_extractor,
image_encoder,
unet,
cfg,
accelerator,
weight_dtype,
global_step,
'validation_persp',
vis_dir
)
log_validation(
validation_train_dataloader_ortho,
vae,
feature_extractor,
image_encoder,
unet,
cfg,
accelerator,
weight_dtype,
global_step,
'validation_train_ortho',
vis_dir
)
log_validation(
validation_train_dataloader_persp,
vae,
feature_extractor,
image_encoder,
unet,
cfg,
accelerator,
weight_dtype,
global_step,
'validation_train_persp',
vis_dir
)
else:
log_validation(
validation_dataloader,
vae,
feature_extractor,
image_encoder,
unet,
cfg,
accelerator,
weight_dtype,
global_step,
'validation',
vis_dir
)
log_validation(
validation_train_dataloader,
vae,
feature_extractor,
image_encoder,
unet,
cfg,
accelerator,
weight_dtype,
global_step,
'validation_train',
vis_dir
)
if cfg.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= cfg.max_train_steps:
break
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
if cfg.use_ema:
ema_unet.copy_to(unet.parameters())
pipeline = MVDiffusionImagePipeline(
image_encoder=image_encoder, feature_extractor=feature_extractor, vae=vae, unet=unet, safety_checker=None,
scheduler=DDIMScheduler.from_pretrained(cfg.pretrained_model_name_or_path, subfolder="scheduler"),
**cfg.pipe_kwargs
)
os.makedirs(os.path.join(cfg.output_dir, "pipeckpts"), exist_ok=True)
pipeline.save_pretrained(os.path.join(cfg.output_dir, "pipeckpts"))
accelerator.end_training()
if accelerator.is_main_process:
log_model(experiment, model=unet, model_name="mv_depth_normal")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True)
args = parser.parse_args()
schema = OmegaConf.structured(TrainingConfig)
cfg = OmegaConf.load(args.config)
cfg = OmegaConf.merge(schema, cfg)
main(cfg)