-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathfinetune.py
858 lines (755 loc) · 33.6 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
A minimal training script for Lumina-T2I using PyTorch FSDP.
"""
import argparse
from collections import OrderedDict, defaultdict
import contextlib
from copy import deepcopy
from datetime import datetime
import functools
from functools import partial
import json
import logging
import os
import random
import socket
from time import time
import warnings
import torch.nn.functional as F
from PIL import Image
# import cairosvg
from diffusers import AutoencoderKL
import fairscale.nn.model_parallel.initialize as fs_init
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
from torch.distributed.fsdp import (
FullStateDictConfig,
FullyShardedDataParallel as FSDP,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from transformers import AutoModel, AutoTokenizer
from data import DataNoReportException, ItemProcessor, MyDataset, read_general
from imgproc import generate_crop_size_list, to_rgb_if_rgba, var_center_crop
import models
from parallel import distributed_init, get_intra_node_process_group
from transport import create_transport
from util.misc import SmoothedValue
#############################################################################
# Data item Processor #
#############################################################################
class NonRGBError(DataNoReportException):
pass
class T2IItemProcessor(ItemProcessor):
def __init__(self, transform):
self.image_transform = transform
self.special_format_set = set()
def process_item(self, data_item, training_mode=False):
if "super_high_quality_caption" in data_item:
url = data_item["image_path"]
image = Image.open(read_general(url))
text = data_item["super_high_quality_caption"]
system_prompt = "You are an assistant designed to generate high-quality images with the highest degree of image-text alignment based on textual prompts. <Prompt Start> " # noqa
elif "path" in data_item:
url = data_item["path"]
image = Image.open(read_general(url))
text = data_item["prompt"]
system_prompt = "You are an assistant designed to generate high-quality images based on user prompts. <Prompt Start> " # noqa
else:
raise ValueError(f"Unrecognized item: {data_item}")
if image.mode.upper() != "RGB":
mode = image.mode.upper()
if mode not in self.special_format_set:
self.special_format_set.add(mode)
print(mode, url)
if mode == "RGBA":
image = to_rgb_if_rgba(image)
elif mode == "P" or mode == "L":
image = image.convert("RGB")
else:
raise NonRGBError()
image = self.image_transform(image)
if text is None or text.strip() == "":
text = ""
text = system_prompt + text
return image, text
#############################################################################
# Training Helper Functions #
#############################################################################
def apply_average_pool(latent, factor):
"""
Apply average pooling to downsample the latent.
Args:
latent (torch.Tensor): Latent tensor with shape (1, C, H, W).
factor (int): Downsampling factor.
Returns:
torch.Tensor: Downsampled latent tensor.
"""
return F.avg_pool2d(latent, kernel_size=factor, stride=factor)
def dataloader_collate_fn(samples):
image = [x[0] for x in samples]
caps = [x[1] for x in samples]
return image, caps
def get_train_sampler(dataset, rank, world_size, global_batch_size, max_steps, resume_step, seed):
sample_indices = torch.empty([max_steps * global_batch_size // world_size], dtype=torch.long)
epoch_id, fill_ptr, offs = 0, 0, 0
while fill_ptr < sample_indices.size(0):
g = torch.Generator()
g.manual_seed(seed + epoch_id)
epoch_sample_indices = torch.randperm(len(dataset), generator=g)
epoch_id += 1
epoch_sample_indices = epoch_sample_indices[(rank + offs) % world_size :: world_size]
offs = (offs + world_size - len(dataset) % world_size) % world_size
epoch_sample_indices = epoch_sample_indices[: sample_indices.size(0) - fill_ptr]
sample_indices[fill_ptr : fill_ptr + epoch_sample_indices.size(0)] = epoch_sample_indices
fill_ptr += epoch_sample_indices.size(0)
return sample_indices[resume_step * global_batch_size // world_size :].tolist()
@torch.no_grad()
def update_ema(ema_model, model, decay=0.95):
"""
Step the EMA model towards the current model.
"""
ema_params = OrderedDict(ema_model.named_parameters())
model_params = OrderedDict(model.named_parameters())
assert set(ema_params.keys()) == set(model_params.keys())
for name, param in model_params.items():
# TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
def cleanup():
"""
End DDP training.
"""
dist.destroy_process_group()
def create_logger(logging_dir):
"""
Create a logger that writes to a log file and stdout.
"""
if dist.get_rank() == 0: # real logger
logging.basicConfig(
level=logging.INFO,
format="[\033[34m%(asctime)s\033[0m] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=[
logging.StreamHandler(),
logging.FileHandler(f"{logging_dir}/log.txt"),
],
)
logger = logging.getLogger(__name__)
else: # dummy logger (does nothing)
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
def setup_lm_fsdp_sync(model: nn.Module) -> FSDP:
# LM FSDP always use FULL_SHARD among the node.
model = FSDP(
model,
auto_wrap_policy=functools.partial(
lambda_auto_wrap_policy,
lambda_fn=lambda m: m in list(model.layers),
),
process_group=get_intra_node_process_group(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=MixedPrecision(
param_dtype=next(model.parameters()).dtype,
),
device_id=torch.cuda.current_device(),
sync_module_states=True,
limit_all_gathers=True,
use_orig_params=True,
)
torch.cuda.synchronize()
return model
def setup_fsdp_sync(model: nn.Module, args: argparse.Namespace) -> FSDP:
model = FSDP(
model,
auto_wrap_policy=functools.partial(
lambda_auto_wrap_policy,
lambda_fn=lambda m: m in model.get_fsdp_wrap_module_list(),
),
process_group=fs_init.get_data_parallel_group(),
sharding_strategy={
"fsdp": ShardingStrategy.FULL_SHARD,
"sdp": ShardingStrategy.SHARD_GRAD_OP,
}[args.data_parallel],
mixed_precision=MixedPrecision(
param_dtype={
"fp32": torch.float,
"tf32": torch.float,
"bf16": torch.bfloat16,
"fp16": torch.float16,
}[args.precision],
reduce_dtype={
"fp32": torch.float,
"tf32": torch.float,
"bf16": torch.bfloat16,
"fp16": torch.float16,
}[args.grad_precision or args.precision],
),
device_id=torch.cuda.current_device(),
sync_module_states=True,
limit_all_gathers=True,
use_orig_params=True,
)
torch.cuda.synchronize()
return model
def setup_mixed_precision(args):
if args.precision == "tf32":
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
elif args.precision in ["bf16", "fp16", "fp32"]:
pass
else:
raise NotImplementedError(f"Unknown precision: {args.precision}")
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
captions = []
for caption in prompt_batch:
if random.random() < proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
with torch.no_grad():
text_inputs = tokenizer(
captions,
padding=True,
pad_to_multiple_of=8,
max_length=256,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_masks = text_inputs.attention_mask
prompt_embeds = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
return prompt_embeds, prompt_masks
#############################################################################
# Training Loop #
#############################################################################
def main(args):
"""
Trains a new DiT model.
"""
assert torch.cuda.is_available(), "Training currently requires at least one GPU."
distributed_init(args)
dp_world_size = fs_init.get_data_parallel_world_size()
dp_rank = fs_init.get_data_parallel_rank()
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * dist.get_world_size() + rank
torch.manual_seed(seed)
torch.cuda.set_device(device)
setup_mixed_precision(args)
print(f"Starting rank={rank}, seed={seed}, "
f"world_size={dist.get_world_size()}.")
# Setup an experiment folder:
os.makedirs(args.results_dir, exist_ok=True)
checkpoint_dir = os.path.join(args.results_dir, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
if rank == 0:
logger = create_logger(args.results_dir)
logger.info(f"Experiment directory: {args.results_dir}")
tb_logger = SummaryWriter(
os.path.join(
args.results_dir, "tensorboard", datetime.now().strftime("%Y%m%d_%H%M%S_") + socket.gethostname()
)
)
else:
logger = create_logger(None)
tb_logger = None
logger.info("Training arguments: " + json.dumps(args.__dict__, indent=2))
logger.info(f"Setting-up language model: google/gemma-2-2b")
# create tokenizers
# Load the tokenizers
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
tokenizer.padding_side = "right"
# create text encoders
# text_encoder
text_encoder = AutoModel.from_pretrained(
"google/gemma-2-2b",
torch_dtype=torch.bfloat16,
).cuda()
text_encoder = setup_lm_fsdp_sync(text_encoder)
logger.info(f"text encoder: {type(text_encoder)}")
cap_feat_dim = text_encoder.config.hidden_size
# Create model:
model = models.__dict__[args.model](
in_channels=16,
qk_norm=args.qk_norm,
cap_feat_dim=cap_feat_dim,
)
logger.info(f"DiT Parameters: {model.parameter_count():,}")
model_patch_size = model.patch_size
if args.auto_resume and args.resume is None:
try:
existing_checkpoints = os.listdir(checkpoint_dir)
if len(existing_checkpoints) > 0:
existing_checkpoints.sort()
args.resume = os.path.join(checkpoint_dir, existing_checkpoints[-1])
except Exception:
pass
if args.resume is not None:
logger.info(f"Auto resuming from: {args.resume}")
# Note that parameter initialization is done within the DiT constructor
model_ema = deepcopy(model)
if args.resume:
if dp_rank == 0: # other ranks receive weights in setup_fsdp_sync
logger.info(f"Resuming model weights from: {args.resume}")
model.load_state_dict(
torch.load(
os.path.join(
args.resume,
f"consolidated.{0:02d}-of-{1:02d}.pth",
),
map_location="cpu",
),
strict=True,
)
logger.info(f"Resuming ema weights from: {args.resume}")
model_ema.load_state_dict(
torch.load(
os.path.join(
args.resume,
f"consolidated_ema.{0:02d}-of-{1:02d}.pth",
),
map_location="cpu",
),
strict=True,
)
elif args.init_from:
if dp_rank == 0:
logger.info(f"Initializing model weights from: {args.init_from}")
state_dict = torch.load(
os.path.join(
args.init_from,
f"consolidated.{0:02d}-of-{1:02d}.pth",
),
map_location="cpu",
)
size_mismatch_keys = []
model_state_dict = model.state_dict()
for k, v in state_dict.items():
if k in model_state_dict and model_state_dict[k].shape != v.shape:
size_mismatch_keys.append(k)
for k in size_mismatch_keys:
del state_dict[k]
del model_state_dict
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
missing_keys_ema, unexpected_keys_ema = model_ema.load_state_dict(state_dict, strict=False)
del state_dict
assert set(missing_keys) == set(missing_keys_ema)
assert set(unexpected_keys) == set(unexpected_keys_ema)
logger.info("Model initialization result:")
logger.info(f" Size mismatch keys: {size_mismatch_keys}")
logger.info(f" Missing keys: {missing_keys}")
logger.info(f" Unexpeected keys: {unexpected_keys}")
dist.barrier()
# checkpointing (part1, should be called before FSDP wrapping)
if args.checkpointing:
checkpointing_list = list(model.get_checkpointing_wrap_module_list())
checkpointing_list_ema = list(model_ema.get_checkpointing_wrap_module_list())
else:
checkpointing_list = []
checkpointing_list_ema = []
model = setup_fsdp_sync(model, args)
model_ema = setup_fsdp_sync(model_ema, args)
# checkpointing (part2, after FSDP wrapping)
if args.checkpointing:
logger.info("apply gradient checkpointing")
non_reentrant_wrapper = partial(
checkpoint_wrapper,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda submodule: submodule in checkpointing_list,
)
apply_activation_checkpointing(
model_ema,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda submodule: submodule in checkpointing_list_ema,
)
logger.info(f"model:\n{model}\n")
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16).to(
device
)
logger.info("AdamW eps 1e-15 betas (0.9, 0.95)")
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd, eps=1e-15, betas=(0.9, 0.95))
if args.resume:
opt_state_world_size = len(
[x for x in os.listdir(args.resume) if x.startswith("optimizer.") and x.endswith(".pth")]
)
assert opt_state_world_size == dist.get_world_size(), (
f"Resuming from a checkpoint with unmatched world size "
f"({dist.get_world_size()} vs. {opt_state_world_size}) "
f"is currently not supported."
)
logger.info(f"Resuming optimizer states from: {args.resume}")
opt.load_state_dict(
torch.load(
os.path.join(
args.resume,
f"optimizer.{dist.get_rank():05d}-of-" f"{dist.get_world_size():05d}.pth",
),
map_location="cpu",
)
)
for param_group in opt.param_groups:
param_group["lr"] = args.lr
param_group["weight_decay"] = args.wd
with open(os.path.join(args.resume, "resume_step.txt")) as f:
resume_step = int(f.read().strip())
else:
resume_step = 0
# Setup data:
data_collection = {}
for train_res in [1024]:
logger.info(f"Creating data for resolution {train_res}")
global_bsz = getattr(args, f"global_bsz_{train_res}")
local_bsz = global_bsz // dp_world_size # todo caution for sequence parallel
micro_bsz = getattr(args, f"micro_bsz_{train_res}")
assert global_bsz % dp_world_size == 0, "Batch size must be divisible by data parallel world size."
logger.info(f"Global bsz: {global_bsz} Local bsz: {local_bsz} Micro bsz: {micro_bsz}")
patch_size = 8 * model_patch_size
logger.info(f"patch size: {patch_size}")
max_num_patches = round((train_res / patch_size) ** 2)
logger.info(f"Limiting number of patches to {max_num_patches}.")
crop_size_list = generate_crop_size_list(max_num_patches, patch_size)
logger.info("List of crop sizes:")
for i in range(0, len(crop_size_list), 6):
logger.info(" " + "".join([f"{f'{w} x {h}':14s}" for w, h in crop_size_list[i : i + 6]]))
image_transform = transforms.Compose(
[
transforms.Lambda(functools.partial(var_center_crop, crop_size_list=crop_size_list, random_top_k=1)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
]
)
dataset = MyDataset(
args.data_path,
item_processor=T2IItemProcessor(image_transform),
cache_on_disk=args.cache_data_on_disk,
)
num_samples = global_bsz * args.max_steps
logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")
logger.info(f"Total # samples to consume: {num_samples:,} " f"({num_samples / len(dataset):.2f} epochs)")
sampler = get_train_sampler(
dataset,
dp_rank,
dp_world_size,
global_bsz,
args.max_steps,
resume_step,
args.global_seed + train_res * 100, # avoid same sampling for different resolutions
)
loader = DataLoader(
dataset,
batch_size=local_bsz,
sampler=sampler,
num_workers=args.num_workers,
pin_memory=True,
collate_fn=dataloader_collate_fn,
)
# default: 1000 steps, linear noise schedule
transport = create_transport(
"Linear",
"velocity",
None,
None,
None,
snr_type=args.snr_type,
do_shift=not args.no_shift,
seq_len=(train_res // 16) ** 2,
) # default: velocity;
data_collection[train_res] = {
"loader": loader,
"loader_iter": iter(loader),
"global_bsz": global_bsz,
"local_bsz": local_bsz,
"micro_bsz": micro_bsz,
"metrics": defaultdict(lambda: SmoothedValue(args.log_every)),
"transport": transport,
}
# Prepare models for training:
model.train()
# Variables for monitoring/logging purposes:
logger.info(f"Training for {args.max_steps:,} steps...")
for step in range(resume_step, args.max_steps):
start_time = time()
for train_res, data_pack in data_collection.items():
x, caps = next(data_pack["loader_iter"])
x = [img.to(device, non_blocking=True) for img in x]
with torch.no_grad():
vae_scale = {
"sdxl": 0.13025,
"sd3": 1.5305,
"ema": 0.18215,
"mse": 0.18215,
"cogvideox": 1.15258426,
"flux": 0.3611,
}["flux"]
vae_shift = {
"sdxl": 0.0,
"sd3": 0.0609,
"ema": 0.0,
"mse": 0.0,
"cogvideox": 0.0,
"flux": 0.1159,
}["flux"]
if step == resume_step:
warnings.warn(f"vae scale: {vae_scale} vae shift: {vae_shift}")
# Map input images to latent space + normalize latents:
for i, img in enumerate(x):
x[i] = (vae.encode(img[None].bfloat16()).latent_dist.mode()[0] - vae_shift) * vae_scale
x[i] = x[i].float()
with torch.no_grad():
cap_feats, cap_mask = encode_prompt(caps, text_encoder, tokenizer, args.caption_dropout_prob)
loss_item = 0.0
loss_256_item = 0.0
loss_1024_item = 0.0
opt.zero_grad()
# Number of bins, for loss recording
n_loss_bins = 20
# Create bins for t
loss_bins = torch.linspace(0.0, 1.0, n_loss_bins + 1, device="cuda")
loss_bins_256 = torch.linspace(0.0, 1.0, n_loss_bins + 1, device="cuda")
# Initialize occurrence and sum tensors
bin_occurrence = torch.zeros(n_loss_bins, device="cuda")
bin_occurrence_256 = torch.zeros(n_loss_bins, device="cuda")
bin_sum_loss = torch.zeros(n_loss_bins, device="cuda")
bin_sum_loss_256 = torch.zeros(n_loss_bins, device="cuda")
for mb_idx in range((data_pack["local_bsz"] - 1) // data_pack["micro_bsz"] + 1):
mb_st = mb_idx * data_pack["micro_bsz"]
mb_ed = min((mb_idx + 1) * data_pack["micro_bsz"], data_pack["local_bsz"])
last_mb = mb_ed == data_pack["local_bsz"]
x_mb = x[mb_st:mb_ed]
### muti resolution
x_mb_256 = [apply_average_pool(x, 4) for x in x_mb]
cap_feats_mb = cap_feats[mb_st:mb_ed]
cap_mask_mb = cap_mask[mb_st:mb_ed]
model_kwargs = dict(cap_feats=cap_feats_mb, cap_mask=cap_mask_mb)
with {
"bf16": torch.cuda.amp.autocast(dtype=torch.bfloat16),
"fp16": torch.cuda.amp.autocast(dtype=torch.float16),
"fp32": contextlib.nullcontext(),
"tf32": contextlib.nullcontext(),
}[args.precision]:
loss_dict = data_pack["transport"].training_losses(model, x_mb, model_kwargs)
loss_dict_256 = data_pack["transport"].training_losses(model, x_mb_256, model_kwargs)
loss_1024 = loss_dict["loss"].sum() / data_pack["local_bsz"]
loss_256 = loss_dict_256["loss"].sum() / data_pack["local_bsz"]
loss = loss_1024 + loss_256
loss_item += loss.item()
loss_1024_item += loss_1024.item()
loss_256_item += loss_256.item()
with model.no_sync() if args.data_parallel in ["sdp"] and not last_mb else contextlib.nullcontext():
loss.backward()
# for bin-wise loss recording
# Digitize t values to find which bin they belong to
bin_indices = torch.bucketize(loss_dict["t"].cuda(), loss_bins, right=True) - 1
detached_loss = loss_dict["loss"].detach()
bin_indices_256 = torch.bucketize(loss_dict_256["t"].cuda(), loss_bins, right=True) - 1
detached_loss_256 = loss_dict_256["loss"].detach()
# Iterate through each bin index to update occurrence and sum
for i in range(n_loss_bins):
mask = bin_indices == i # Mask for elements in the i-th bin
bin_occurrence[i] = bin_occurrence[i] + mask.sum() # Count occurrences in the i-th bin
bin_sum_loss[i] = bin_sum_loss[i] + detached_loss[mask].sum() # Sum loss values in the i-th bin
for i in range(n_loss_bins):
mask = bin_indices_256 == i
bin_occurrence_256[i] = bin_occurrence_256[i] + mask.sum()
bin_sum_loss_256[i] = bin_sum_loss_256[i] + detached_loss_256[mask].sum()
grad_norm = model.clip_grad_norm_(max_norm=args.grad_clip)
dist.all_reduce(bin_occurrence)
dist.all_reduce(bin_sum_loss)
dist.all_reduce(bin_occurrence_256)
dist.all_reduce(bin_sum_loss_256)
if tb_logger is not None:
tb_logger.add_scalar(f"{train_res}/loss", loss_item, step)
tb_logger.add_scalar(f"{train_res}/loss_256", loss_256_item, step)
tb_logger.add_scalar(f"{train_res}/loss_1024", loss_1024_item, step)
tb_logger.add_scalar(f"{train_res}/grad_norm", grad_norm, step)
tb_logger.add_scalar(f"{train_res}/lr", opt.param_groups[0]["lr"], step)
for i in range(n_loss_bins):
if bin_occurrence[i] > 0:
bin_avg_loss = (bin_sum_loss[i] / bin_occurrence[i]).item()
tb_logger.add_scalar(f"{train_res}/loss-bin{i+1}-{n_loss_bins}", bin_avg_loss, step)
for i in range(n_loss_bins):
if bin_occurrence_256[i] > 0:
bin_avg_loss = (bin_sum_loss_256[i] / bin_occurrence_256[i]).item()
tb_logger.add_scalar(f"{train_res}/loss_256-bin{i+1}-{n_loss_bins}", bin_avg_loss, step)
opt.step()
end_time = time()
# Log loss values:
metrics = data_pack["metrics"]
metrics["loss"].update(loss_item)
metrics["loss_1024"].update(loss_1024_item)
metrics["loss_256"].update(loss_256_item)
metrics["grad_norm"].update(grad_norm)
metrics["Secs/Step"].update(end_time - start_time)
metrics["Imgs/Sec"].update(data_pack["global_bsz"] / (end_time - start_time))
metrics["grad_norm"].update(grad_norm)
for i in range(n_loss_bins):
if bin_occurrence[i] > 0:
bin_avg_loss = (bin_sum_loss[i] / bin_occurrence[i]).item()
metrics[f"bin_1024_{i + 1:02}-{n_loss_bins}"].update(bin_avg_loss, int(bin_occurrence[i].item()))
for i in range(n_loss_bins):
if bin_occurrence_256[i] > 0:
bin_avg_loss = (bin_sum_loss_256[i] / bin_occurrence_256[i]).item()
metrics[f"bin_256_{i + 1:02}-{n_loss_bins}"].update(bin_avg_loss, int(bin_occurrence_256[i].item()))
if (step + 1) % args.log_every == 0:
# Measure training speed:
torch.cuda.synchronize()
logger.info(
f"Res{train_res}_{train_res//4}: (step{step + 1:07d}) "
+ f"lr{opt.param_groups[0]['lr']:.6f} "
+ " ".join([f"{key}:{str(metrics[key])}" for key in sorted(metrics.keys())])
)
start_time = time()
update_ema(model_ema, model)
# Save DiT checkpoint:
if step == 0 or (step + 1) % args.ckpt_every == 0 or (step + 1) == args.max_steps:
checkpoint_path = f"{checkpoint_dir}/{step + 1:07d}"
os.makedirs(checkpoint_path, exist_ok=True)
with FSDP.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
):
consolidated_model_state_dict = model.state_dict()
if fs_init.get_data_parallel_rank() == 0:
consolidated_fn = (
"consolidated."
f"{fs_init.get_model_parallel_rank():02d}-of-"
f"{fs_init.get_model_parallel_world_size():02d}"
".pth"
)
torch.save(
consolidated_model_state_dict,
os.path.join(checkpoint_path, consolidated_fn),
)
dist.barrier()
del consolidated_model_state_dict
logger.info(f"Saved consolidated to {checkpoint_path}.")
with FSDP.state_dict_type(
model_ema,
StateDictType.FULL_STATE_DICT,
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
):
consolidated_ema_state_dict = model_ema.state_dict()
if fs_init.get_data_parallel_rank() == 0:
consolidated_ema_fn = (
"consolidated_ema."
f"{fs_init.get_model_parallel_rank():02d}-of-"
f"{fs_init.get_model_parallel_world_size():02d}"
".pth"
)
torch.save(
consolidated_ema_state_dict,
os.path.join(checkpoint_path, consolidated_ema_fn),
)
dist.barrier()
del consolidated_ema_state_dict
logger.info(f"Saved consolidated_ema to {checkpoint_path}.")
with FSDP.state_dict_type(
model,
StateDictType.LOCAL_STATE_DICT,
):
opt_state_fn = f"optimizer.{dist.get_rank():05d}-of-" f"{dist.get_world_size():05d}.pth"
torch.save(opt.state_dict(), os.path.join(checkpoint_path, opt_state_fn))
dist.barrier()
logger.info(f"Saved optimizer to {checkpoint_path}.")
if dist.get_rank() == 0:
torch.save(args, os.path.join(checkpoint_path, "model_args.pth"))
with open(os.path.join(checkpoint_path, "resume_step.txt"), "w") as f:
print(step + 1, file=f)
dist.barrier()
logger.info(f"Saved training arguments to {checkpoint_path}.")
model.eval() # important! This disables randomized embedding dropout
# do any sampling/FID calculation/etc. with ema (or model) in eval mode ...
logger.info("Done!")
cleanup()
if __name__ == "__main__":
# Default args here will train DiT_Llama2_7B_patch2 with the
# hyperparameters we used in our paper (except training iters).
parser = argparse.ArgumentParser()
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--cache_data_on_disk", default=False, action="store_true")
parser.add_argument("--results_dir", type=str, required=True)
parser.add_argument("--model", type=str, default="DiT_Llama2_7B_patch2")
parser.add_argument("--max_steps", type=int, default=100_000, help="Number of training steps.")
parser.add_argument("--global_bsz_1024", type=int, default=256)
parser.add_argument("--micro_bsz_1024", type=int, default=1)
parser.add_argument("--global_seed", type=int, default=0)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--log_every", type=int, default=100)
parser.add_argument("--ckpt_every", type=int, default=50_000)
parser.add_argument("--master_port", type=int, default=18181)
parser.add_argument("--model_parallel_size", type=int, default=1)
parser.add_argument("--data_parallel", type=str, choices=["sdp", "fsdp"], default="fsdp")
parser.add_argument("--checkpointing", action="store_true")
parser.add_argument("--precision", choices=["fp32", "tf32", "fp16", "bf16"], default="bf16")
parser.add_argument("--grad_precision", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate.")
parser.add_argument(
"--no_auto_resume",
action="store_false",
dest="auto_resume",
help="Do NOT auto resume from the last checkpoint in --results_dir.",
)
parser.add_argument("--resume", type=str, help="Resume training from a checkpoint folder.")
parser.add_argument(
"--init_from",
type=str,
help="Initialize the model weights from a checkpoint folder. "
"Compared to --resume, this loads neither the optimizer states "
"nor the data loader states.",
)
parser.add_argument(
"--grad_clip", type=float, default=2.0, help="Clip the L2 norm of the gradients to the given value."
)
parser.add_argument(
"--wd",
type=float,
default=0.0,
help="Weight decay for the optimizer.",
)
parser.add_argument(
"--qk_norm",
action="store_true",
)
parser.add_argument(
"--caption_dropout_prob",
type=float,
default=0.1,
help="Randomly change the caption of a sample to a blank string with the given probability.",
)
parser.add_argument("--snr_type", type=str, default="uniform")
parser.add_argument(
"--no_shift",
action="store_true",
)
args = parser.parse_args()
main(args)