Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pin_memory to DataLoader and update ImageInfo to support #1894

Draft
wants to merge 3 commits into
base: sd3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
6 changes: 6 additions & 0 deletions finetune/make_captions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def run_batch(path_imgs):
batch_size=args.batch_size,
shuffle=False,
num_workers=args.max_data_loader_n_workers,
pin_memory=args.pin_memory,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
)
Expand Down Expand Up @@ -187,6 +188,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
Expand Down
6 changes: 6 additions & 0 deletions finetune/make_captions_by_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def run_batch(path_imgs):
dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
Expand Down Expand Up @@ -164,6 +165,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--max_length", type=int, default=50, help="max length of caption / captionの最大長")
parser.add_argument(
"--remove_words",
Expand Down
6 changes: 6 additions & 0 deletions finetune/prepare_buckets_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def process_batch(is_last):
dataset,
batch_size=1,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
Expand Down Expand Up @@ -223,6 +224,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument(
"--max_resolution",
type=str,
Expand Down
6 changes: 6 additions & 0 deletions finetune/tag_images_by_wd14_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def run_batch(path_imgs):
dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=args.pin_memory,
num_workers=args.max_data_loader_n_workers,
collate_fn=collate_fn_remove_corrupted,
drop_last=False,
Expand Down Expand Up @@ -409,6 +410,11 @@ def setup_parser() -> argparse.ArgumentParser:
default=None,
help="enable image reading by DataLoader with this number of workers (faster) / DataLoaderによる画像読み込みを有効にしてこのワーカー数を適用する(読み込みを高速化)",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument(
"--caption_extention",
type=str,
Expand Down
1 change: 1 addition & 0 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
1 change: 1 addition & 0 deletions flux_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
32 changes: 31 additions & 1 deletion library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Tuple,
Union
)
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState
from accelerate import Accelerator, InitProcessGroupKwargs, DistributedDataParallelKwargs, PartialState, DataLoaderConfiguration
import glob
import math
import os
Expand Down Expand Up @@ -176,6 +176,19 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,

self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime

@staticmethod
def _pin_tensor(tensor):
return tensor.pin_memory() if tensor is not None else tensor

def pin_memory(self):
self.latents = self._pin_tensor(self.latents)
self.latents_flipped = self._pin_tensor(self.latents_flipped)
self.text_encoder_outputs1 = self._pin_tensor(self.text_encoder_outputs1)
self.text_encoder_outputs2 = self._pin_tensor(self.text_encoder_outputs2)
self.text_encoder_pool2 = self._pin_tensor(self.text_encoder_pool2)
self.alpha_mask = self._pin_tensor(self.alpha_mask)
return self


class BucketManager:
def __init__(self, no_upscale, max_reso, min_size, max_size, reso_steps) -> None:
Expand Down Expand Up @@ -2036,6 +2049,11 @@ def load_dreambooth_dir(subset: DreamBoothSubset):

self.num_reg_images = num_reg_images

def pin_memory(self):
for key in self.image_data.keys():
if hasattr(self.image_data[key], 'pin_memory') and callable(self.image_data[key].pin_memory):
self.image_data[key].pin_memory()


class FineTuningDataset(BaseDataset):
def __init__(
Expand Down Expand Up @@ -3734,6 +3752,11 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
action="store_true",
help="persistent DataLoader workers (useful for reduce time gap between epoch, but may use more memory) / DataLoader のワーカーを持続させる (エポック間の時間差を少なくするのに有効だが、より多くのメモリを消費する可能性がある)",
)
parser.add_argument(
"--pin_memory",
action="store_true",
help="Pin memory for faster GPU loading / GPU の読み込みを高速化するためのピンメモリ",
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
Expand Down Expand Up @@ -5276,6 +5299,8 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers = [i for i in kwargs_handlers if i is not None]
deepspeed_plugin = deepspeed_utils.prepare_deepspeed_plugin(args)

dataloader_config = DataLoaderConfiguration(non_blocking=args.pin_memory)

accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
Expand All @@ -5284,6 +5309,7 @@ def prepare_accelerator(args: argparse.Namespace):
kwargs_handlers=kwargs_handlers,
dynamo_backend=dynamo_backend,
deepspeed_plugin=deepspeed_plugin,
dataloader_config=dataloader_config
)
print("accelerator device:", accelerator.device)
return accelerator
Expand Down Expand Up @@ -6379,6 +6405,10 @@ def __call__(self, examples):
dataset.set_current_step(self.current_step.value)
return examples[0]

def pin_memory(self):
if hasattr(self, 'pin_memory') and callable(self.pin_memory):
self.dataset.pin_memory()


class LossRecorder:
def __init__(self):
Expand Down
1 change: 1 addition & 0 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
1 change: 1 addition & 0 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
1 change: 1 addition & 0 deletions sdxl_train_control_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def unwrap_model(model):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
1 change: 1 addition & 0 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
1 change: 1 addition & 0 deletions sdxl_train_control_net_lllite_old.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
1 change: 1 addition & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
1 change: 1 addition & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ def train(self, args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
1 change: 1 addition & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def train(self, args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
1 change: 1 addition & 0 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def train(args):
shuffle=True,
collate_fn=collator,
num_workers=n_workers,
pin_memory=args.pin_memory,
persistent_workers=args.persistent_data_loader_workers,
)

Expand Down
Loading