From c4b0bb6fce70c12fd63d706154783d56fe3ed9ab Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 23 Jan 2025 10:39:01 -0500 Subject: [PATCH 1/3] Add pin_memory to DataLoader and update ImageInfo to support --- fine_tune.py | 1 + finetune/make_captions.py | 1 + finetune/make_captions_by_git.py | 1 + finetune/prepare_buckets_latents.py | 1 + finetune/tag_images_by_wd14_tagger.py | 1 + flux_train.py | 1 + flux_train_control_net.py | 1 + library/train_util.py | 27 +++++++++++++++++++++++++++ sd3_train.py | 1 + sdxl_train.py | 1 + sdxl_train_control_net.py | 1 + sdxl_train_control_net_lllite.py | 1 + sdxl_train_control_net_lllite_old.py | 1 + train_db.py | 1 + train_network.py | 1 + train_textual_inversion.py | 1 + train_textual_inversion_XTI.py | 1 + 17 files changed, 43 insertions(+) diff --git a/fine_tune.py b/fine_tune.py index 176087065..7fbc58778 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -242,6 +242,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index 489bdbcce..ded9a7476 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -126,6 +126,7 @@ def run_batch(path_imgs): batch_size=args.batch_size, shuffle=False, num_workers=args.max_data_loader_n_workers, + pin_memory=args.pin_memory, collate_fn=collate_fn_remove_corrupted, drop_last=False, ) diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index edeebadf3..babdaea52 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -113,6 +113,7 @@ def run_batch(path_imgs): dataset, batch_size=args.batch_size, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 019c737a6..77b829716 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -122,6 +122,7 @@ def process_batch(is_last): dataset, batch_size=1, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index cbc3d2d6b..c786e8a6d 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -335,6 +335,7 @@ def run_batch(path_imgs): dataset, batch_size=args.batch_size, shuffle=False, + pin_memory=args.pin_memory, num_workers=args.max_data_loader_n_workers, collate_fn=collate_fn_remove_corrupted, drop_last=False, diff --git a/flux_train.py b/flux_train.py index fced3bef9..4e1b0b4a3 100644 --- a/flux_train.py +++ b/flux_train.py @@ -397,6 +397,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/flux_train_control_net.py b/flux_train_control_net.py index 9d36a41d3..6a5151545 100644 --- a/flux_train_control_net.py +++ b/flux_train_control_net.py @@ -398,6 +398,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/library/train_util.py b/library/train_util.py index 72b5b24db..1e6fe3b82 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -176,6 +176,19 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime + @staticmethod + def _pin_tensor(tensor): + return tensor.pin_memory() if tensor is not None else tensor + + def pin_memory(self): + self.latents = self._pin_tensor(self.latents) + self.latents_flipped = self._pin_tensor(self.latents_flipped) + self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1) + self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2) + self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2) + self.alpha_mask = self._pin_tensor(self.alpha_mask) + return self + class BucketManager: def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None: @@ -2036,6 +2049,11 @@ def load_dreambooth_dir(subset: DreamBoothSubset): self.num_reg_images = num_reg_images + def pin_memory(self): + for key in self.image_data.keys(): + if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory): + self.image_data[key].pin_memory() + class FineTuningDataset(BaseDataset): def __init__( @@ -3734,6 +3752,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: action="store_true", help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" @@ -6379,6 +6402,10 @@ def __call__(self, examples): dataset.set_current_step(self.current_step.value) return examples[0] + def pin_memory(self): + if hasattr(self, 'pin_memory') and callable(self.pin_memory): + self.dataset.pin_memory() + class LossRecorder: def __init__(self): diff --git a/sd3_train.py b/sd3_train.py index 120455e7b..116e49886 100644 --- a/sd3_train.py +++ b/sd3_train.py @@ -498,6 +498,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train.py b/sdxl_train.py index b9d529243..2b60ebba8 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -430,6 +430,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net.py b/sdxl_train_control_net.py index ffbf03cab..32c9996a5 100644 --- a/sdxl_train_control_net.py +++ b/sdxl_train_control_net.py @@ -281,6 +281,7 @@ def unwrap_model(model): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 365059b75..d74ed99f9 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -272,6 +272,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 5b372befc..098f7f561 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -220,6 +220,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_db.py b/train_db.py index ad21f8d1b..1b5ec198d 100644 --- a/train_db.py +++ b/train_db.py @@ -210,6 +210,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_network.py b/train_network.py index 5e82b307c..7e1665d51 100644 --- a/train_network.py +++ b/train_network.py @@ -577,6 +577,7 @@ def train(self, args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 65da4859b..14a548a02 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -408,6 +408,7 @@ def train(self, args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 2a2b42310..f63dac862 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -316,6 +316,7 @@ def train(args): shuffle=True, collate_fn=collator, num_workers=n_workers, + pin_memory=args.pin_memory, persistent_workers=args.persistent_data_loader_workers, ) From 50d8daa7d8711b6a181909246aaca88b8411e080 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 23 Jan 2025 11:02:29 -0500 Subject: [PATCH 2/3] Accelerate dataloader_config to non_blocking if pin_memory is enabled --- library/train_util.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 1e6fe3b82..9711dd56d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -23,7 +23,7 @@ Tuple, Union ) -from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState +from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration import glob import math import os @@ -5299,6 +5299,8 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers = [i for i in kwargs_handlers if i is not None] deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args) + dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -5307,6 +5309,7 @@ def prepare_accelerator(args: argparse.Namespace): kwargs_handlers=kwargs_handlers, dynamo_backend=dynamo_backend, deepspeed_plugin=deepspeed_plugin, + dataloader_config=dataloader_config ) print("accelerator device:", accelerator.device) return accelerator From 03b35be3876eb8eece1858be5b855fefcec4179d Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 23 Jan 2025 12:45:37 -0500 Subject: [PATCH 3/3] Add pin_memory to finetune scripts --- finetune/make_captions.py | 5 +++++ finetune/make_captions_by_git.py | 5 +++++ finetune/prepare_buckets_latents.py | 5 +++++ finetune/tag_images_by_wd14_tagger.py | 5 +++++ 4 files changed, 20 insertions(+) diff --git a/finetune/make_captions.py b/finetune/make_captions.py index ded9a7476..cc9a1444f 100644 --- a/finetune/make_captions.py +++ b/finetune/make_captions.py @@ -188,6 +188,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)") parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p") parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長") diff --git a/finetune/make_captions_by_git.py b/finetune/make_captions_by_git.py index babdaea52..c4c61257d 100644 --- a/finetune/make_captions_by_git.py +++ b/finetune/make_captions_by_git.py @@ -165,6 +165,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長") parser.add_argument( "--remove_words", diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 77b829716..ef536db0a 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -224,6 +224,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument( "--max_resolution", type=str, diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index c786e8a6d..6ed595dec 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -410,6 +410,11 @@ def setup_parser() -> argparse.ArgumentParser: default=None, help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)", ) + parser.add_argument( + "--pin_memory", + action="store_true", + help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ", + ) parser.add_argument( "--caption_extention", type=str,