From f99fe281cbb6519b7b5f1199c570d496ad4df474 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 1 Apr 2024 15:38:26 -0400 Subject: [PATCH 01/97] Add LoRA+ support --- library/train_util.py | 2 ++ networks/dylora.py | 45 ++++++++++++++++++++++++++---------- networks/lora.py | 54 ++++++++++++++++++++++++++++--------------- train_network.py | 2 +- 4 files changed, 71 insertions(+), 32 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d2b69edb5..4e5ab7370 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2789,6 +2789,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): diff --git a/networks/dylora.py b/networks/dylora.py index 637f33450..a73ade8bd 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -406,27 +406,48 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): logger.info(f"weights are merged") """ - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index 948b30b0e..8d7619777 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1035,21 +1035,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras): - params = [] + def assemble_params(loras, lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - params.extend(lora.parameters()) + for name, param in lora.named_parameters(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1063,21 +1085,15 @@ def enumerate_params(loras): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params diff --git a/train_network.py b/train_network.py index e0fa69458..ba0c124d1 100644 --- a/train_network.py +++ b/train_network.py @@ -339,7 +339,7 @@ def train(self, args): # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" From c7691607ea1647864b5149c98434a27f23386c65 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 1 Apr 2024 15:43:04 -0400 Subject: [PATCH 02/97] Add LoRA-FA for LoRA+ --- networks/lora_fa.py | 58 +++++++++++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 20 deletions(-) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 919222ce8..fcc503e89 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1033,22 +1033,43 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): self.requires_grad_(True) all_params = [] - def enumerate_params(loras: List[LoRAModule]): - params = [] + def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): + param_groups = {"lora": {}, "plus": {}} for lora in loras: - # params.extend(lora.parameters()) - params.extend(lora.get_trainable_params()) + for name, param in lora.get_trainable_named_params(): + if lora_plus_ratio is not None and "lora_up" in name: + param_groups["plus"][f"{lora.lora_name}.{name}"] = param + else: + param_groups["lora"][f"{lora.lora_name}.{name}"] = param + + # assigned_param_groups = "" + # for group in param_groups: + # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" + # logger.info(assigned_param_groups) + + params = [] + for key in param_groups.keys(): + param_data = {"params": param_groups[key].values()} + if lr is not None: + if key == "plus": + param_data["lr"] = lr * lora_plus_ratio + else: + param_data["lr"] = lr + + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + + params.append(param_data) + return params if self.text_encoder_loras: - param_data = {"params": enumerate_params(self.text_encoder_loras)} - if text_encoder_lr is not None: - param_data["lr"] = text_encoder_lr - all_params.append(param_data) + params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + all_params.extend(params) if self.unet_loras: if self.block_lr: @@ -1062,21 +1083,15 @@ def enumerate_params(loras: List[LoRAModule]): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - param_data = {"params": enumerate_params(block_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) elif default_lr is not None: - param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) - if ("lr" in param_data) and (param_data["lr"] == 0): - continue - all_params.append(param_data) + params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + all_params.extend(params) else: - param_data = {"params": enumerate_params(self.unet_loras)} - if unet_lr is not None: - param_data["lr"] = unet_lr - all_params.append(param_data) + params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + all_params.extend(params) return all_params @@ -1093,6 +1108,9 @@ def on_epoch_start(self, text_encoder, unet): def get_trainable_params(self): return self.parameters() + def get_trainable_named_params(self): + return self.named_parameters() + def save_weights(self, file, dtype, metadata): if metadata is not None and len(metadata) == 0: metadata = None From 1933ab4b4848b1f8b578c10f25bd050f5e246ac0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Wed, 3 Apr 2024 12:46:34 -0400 Subject: [PATCH 03/97] Fix default_lr being applied --- networks/dylora.py | 21 ++++++++++++++++++--- networks/lora.py | 30 +++++++++++++++++++++++------- networks/lora_fa.py | 30 +++++++++++++++++++++++------- 3 files changed, 64 insertions(+), 17 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index a73ade8bd..edc3e2229 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -407,7 +407,14 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): """ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -442,11 +449,19 @@ def assemble_params(loras, lr, lora_plus_ratio): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params diff --git a/networks/lora.py b/networks/lora.py index 8d7619777..e082941e5 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1035,7 +1035,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -1070,7 +1077,11 @@ def assemble_params(loras, lr, lora_plus_ratio): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: @@ -1085,14 +1096,19 @@ def assemble_params(loras, lr, lora_plus_ratio): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - if unet_lr is not None: - params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) - elif default_lr is not None: - params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_lora_plus_ratio + ) all_params.extend(params) else: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params diff --git a/networks/lora_fa.py b/networks/lora_fa.py index fcc503e89..3f6774dd8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1033,7 +1033,14 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr, , unet_lora_plus_ratio=None, text_encoder_lora_plus_ratio=None): + def prepare_optimizer_params( + self, + text_encoder_lr, + unet_lr, + default_lr, + unet_lora_plus_ratio=None, + text_encoder_lora_plus_ratio=None + ): self.requires_grad_(True) all_params = [] @@ -1068,7 +1075,11 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): return params if self.text_encoder_loras: - params = assemble_params(self.text_encoder_loras, text_encoder_lr, text_encoder_lora_plus_ratio) + params = assemble_params( + self.text_encoder_loras, + text_encoder_lr if text_encoder_lr is not None else default_lr, + text_encoder_lora_plus_ratio + ) all_params.extend(params) if self.unet_loras: @@ -1083,14 +1094,19 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - if unet_lr is not None: - params = assemble_params(block_loras, unet_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) - elif default_lr is not None: - params = assemble_params(block_loras, default_lr * self.get_lr_weight(block_loras[0]), unet_lora_plus_ratio) + params = assemble_params( + block_loras, + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + unet_lora_plus_ratio + ) all_params.extend(params) else: - params = assemble_params(self.unet_loras, unet_lr, unet_lora_plus_ratio) + params = assemble_params( + self.unet_loras, + default_lr if unet_lr is None else unet_lr, + unet_lora_plus_ratio + ) all_params.extend(params) return all_params From 75833e84a1c7e3c2fb0a9e3ce0fe3d8c1758a012 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Mon, 8 Apr 2024 19:23:02 -0400 Subject: [PATCH 04/97] Fix default LR, Add overall LoRA+ ratio, Add log `--loraplus_ratio` added for both TE and UNet Add log for lora+ --- library/train_util.py | 1 + networks/dylora.py | 24 ++++++------- networks/lora.py | 28 ++++++++-------- networks/lora_fa.py | 30 ++++++++--------- train_network.py | 78 ++++++++++++++++++++++++++++++++----------- 5 files changed, 101 insertions(+), 60 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4e5ab7370..7c2bf6935 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2789,6 +2789,7 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") diff --git a/networks/dylora.py b/networks/dylora.py index edc3e2229..dc5c7cb35 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -412,32 +412,32 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if lora_plus_ratio is not None and "lora_up" in name: + if ratio is not None and "lora_B" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -452,7 +452,7 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -460,7 +460,7 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( self.unet_loras, default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora.py b/networks/lora.py index e082941e5..6cb05bcb0 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1040,32 +1040,32 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: for name, param in lora.named_parameters(): - if lora_plus_ratio is not None and "lora_up" in name: + if ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -1080,7 +1080,7 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -1099,15 +1099,15 @@ def assemble_params(loras, lr, lora_plus_ratio): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) else: params = assemble_params( self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 3f6774dd8..2eff86d6c 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1038,32 +1038,32 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_lora_plus_ratio=None, - text_encoder_lora_plus_ratio=None + unet_loraplus_ratio=None, + text_encoder_loraplus_ratio=None, + loraplus_ratio=None ): self.requires_grad_(True) all_params = [] - def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): + def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} for lora in loras: - for name, param in lora.get_trainable_named_params(): - if lora_plus_ratio is not None and "lora_up" in name: + for name, param in lora.named_parameters(): + if ratio is not None and "lora_up" in name: param_groups["plus"][f"{lora.lora_name}.{name}"] = param else: param_groups["lora"][f"{lora.lora_name}.{name}"] = param - # assigned_param_groups = "" - # for group in param_groups: - # assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n" - # logger.info(assigned_param_groups) - params = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} + + if len(param_data["params"]) == 0: + continue + if lr is not None: if key == "plus": - param_data["lr"] = lr * lora_plus_ratio + param_data["lr"] = lr * ratio else: param_data["lr"] = lr @@ -1078,7 +1078,7 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_lora_plus_ratio + text_encoder_loraplus_ratio or loraplus_ratio ) all_params.extend(params) @@ -1097,15 +1097,15 @@ def assemble_params(loras: List[LoRAModule], lr, lora_plus_ratio): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_lora_plus_ratio + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) else: params = assemble_params( self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_lora_plus_ratio + unet_lr if unet_lr is not None else default_lr, + unet_loraplus_ratio or loraplus_ratio ) all_params.extend(params) diff --git a/train_network.py b/train_network.py index ba0c124d1..43226fc47 100644 --- a/train_network.py +++ b/train_network.py @@ -66,34 +66,69 @@ def generate_step_logs( lrs = lr_scheduler.get_last_lr() - if args.network_train_text_encoder_only or len(lrs) <= 2: # not block lr (or single block) - if args.network_train_unet_only: - logs["lr/unet"] = float(lrs[0]) - elif args.network_train_text_encoder_only: - logs["lr/textencoder"] = float(lrs[0]) - else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) # may be same to textencoder - - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] - ) - else: + if len(lrs) > 4: idx = 0 if not args.network_train_unet_only: logs["lr/textencoder"] = float(lrs[0]) idx = 1 for i in range(idx, len(lrs)): - logs[f"lr/group{i}"] = float(lrs[i]) + lora_plus = "" + group_id = i + + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + lora_plus = '_lora+' if i % 2 == 1 else '' + group_id = int((i / 2) + (i % 2 + 0.5)) + + logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i]) if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{i}"] = ( + logs[f"lr/d*lr/group{group_id}{lora_plus}"] = ( lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) + else: + if args.network_train_text_encoder_only: + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + + elif args.network_train_unet_only: + if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + else: + logs["lr/unet"] = float(lrs[0]) + else: + if len(lrs) == 2: + if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None: + logs["lr/unet"] = float(lrs[0]) + logs["lr/unet_lora+"] = float(lrs[1]) + elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None: + logs["lr/all"] = float(lrs[0]) + logs["lr/all_lora+"] = float(lrs[1]) + else: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/unet"] = float(lrs[-1]) + elif len(lrs) == 4: + logs["lr/textencoder"] = float(lrs[0]) + logs["lr/textencoder_lora+"] = float(lrs[1]) + logs["lr/unet"] = float(lrs[2]) + logs["lr/unet_lora+"] = float(lrs[3]) + else: + logs["lr/all"] = float(lrs[0]) + + if ( + args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() + ): # tracking d*lr value of unet. + logs["lr/d*lr"] = ( + lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + ) + return logs def assert_extra_args(self, args, train_dataset_group): @@ -339,7 +374,7 @@ def train(self, args): # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio) + trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio) except TypeError: accelerator.print( "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" @@ -348,6 +383,11 @@ def train(self, args): optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) + if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: + assert ( + (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name) + ), "LoRA+ and Prodigy/DAdaptation is not supported" + # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers From 68467bdf4d76ba2c57289209b0ffd6ba599e2080 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 11 Apr 2024 17:33:19 -0400 Subject: [PATCH 05/97] Fix unset or invalid LR from making a param_group --- networks/dylora.py | 4 ++-- networks/lora.py | 5 +++-- networks/lora_fa.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index dc5c7cb35..0546fc7ae 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -412,8 +412,8 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -441,7 +441,7 @@ def assemble_params(loras, lr, ratio): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: continue params.append(param_data) diff --git a/networks/lora.py b/networks/lora.py index 6cb05bcb0..d74608fea 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1040,8 +1040,8 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -1069,7 +1069,8 @@ def assemble_params(loras, lr, ratio): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: + print("NO LR skipping!") continue params.append(param_data) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 2eff86d6c..9a608118a 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -1038,8 +1038,8 @@ def prepare_optimizer_params( text_encoder_lr, unet_lr, default_lr, - unet_loraplus_ratio=None, text_encoder_loraplus_ratio=None, + unet_loraplus_ratio=None, loraplus_ratio=None ): self.requires_grad_(True) @@ -1067,7 +1067,7 @@ def assemble_params(loras, lr, ratio): else: param_data["lr"] = lr - if ("lr" in param_data) and (param_data["lr"] == 0): + if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: continue params.append(param_data) From 4f203ce40d3a4647d52a2570a228e279dd04b321 Mon Sep 17 00:00:00 2001 From: 2kpr <96332338+2kpr@users.noreply.github.com> Date: Sun, 14 Apr 2024 09:56:58 -0500 Subject: [PATCH 06/97] Fused backward pass --- library/adafactor_fused.py | 106 +++++++++++++++++++++++++++++++++++++ library/train_util.py | 13 +++++ sdxl_train.py | 29 +++++++--- 3 files changed, 142 insertions(+), 6 deletions(-) create mode 100644 library/adafactor_fused.py diff --git a/library/adafactor_fused.py b/library/adafactor_fused.py new file mode 100644 index 000000000..bdfc32ced --- /dev/null +++ b/library/adafactor_fused.py @@ -0,0 +1,106 @@ +import math +import torch +from transformers import Adafactor + +@torch.no_grad() +def adafactor_step_param(self, p, group): + if p.grad is None: + return + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError("Adafactor does not support sparse gradients.") + + state = self.state[p] + grad_shape = grad.shape + + factored, use_first_moment = Adafactor._get_options(group, grad_shape) + # State Initialization + if len(state) == 0: + state["step"] = 0 + + if use_first_moment: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(grad) + if factored: + state["exp_avg_sq_row"] = torch.zeros(grad_shape[:-1]).to(grad) + state["exp_avg_sq_col"] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) + else: + state["exp_avg_sq"] = torch.zeros_like(grad) + + state["RMS"] = 0 + else: + if use_first_moment: + state["exp_avg"] = state["exp_avg"].to(grad) + if factored: + state["exp_avg_sq_row"] = state["exp_avg_sq_row"].to(grad) + state["exp_avg_sq_col"] = state["exp_avg_sq_col"].to(grad) + else: + state["exp_avg_sq"] = state["exp_avg_sq"].to(grad) + + p_data_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_data_fp32 = p_data_fp32.float() + + state["step"] += 1 + state["RMS"] = Adafactor._rms(p_data_fp32) + lr = Adafactor._get_lr(group, state) + + beta2t = 1.0 - math.pow(state["step"], group["decay_rate"]) + update = (grad ** 2) + group["eps"][0] + if factored: + exp_avg_sq_row = state["exp_avg_sq_row"] + exp_avg_sq_col = state["exp_avg_sq_col"] + + exp_avg_sq_row.mul_(beta2t).add_(update.mean(dim=-1), alpha=(1.0 - beta2t)) + exp_avg_sq_col.mul_(beta2t).add_(update.mean(dim=-2), alpha=(1.0 - beta2t)) + + # Approximation of exponential moving average of square of gradient + update = Adafactor._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) + update.mul_(grad) + else: + exp_avg_sq = state["exp_avg_sq"] + + exp_avg_sq.mul_(beta2t).add_(update, alpha=(1.0 - beta2t)) + update = exp_avg_sq.rsqrt().mul_(grad) + + update.div_((Adafactor._rms(update) / group["clip_threshold"]).clamp_(min=1.0)) + update.mul_(lr) + + if use_first_moment: + exp_avg = state["exp_avg"] + exp_avg.mul_(group["beta1"]).add_(update, alpha=(1 - group["beta1"])) + update = exp_avg + + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=(-group["weight_decay"] * lr)) + + p_data_fp32.add_(-update) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_data_fp32) + + +@torch.no_grad() +def adafactor_step(self, closure=None): + """ + Performs a single optimization step + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + adafactor_step_param(self, p, group) + + return loss + +def patch_adafactor_fused(optimizer: Adafactor): + optimizer.step_param = adafactor_step_param.__get__(optimizer) + optimizer.step = adafactor_step.__get__(optimizer) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..46b55c03e 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2920,6 +2920,11 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) + parser.add_argument( + "--fused_backward_pass", + action="store_true", + help="Combines backward pass and optimizer step to reduce VRAM usage / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -3846,6 +3851,14 @@ def get_optimizer(args, trainable_params): optimizer_type = "AdamW" optimizer_type = optimizer_type.lower() + if args.fused_backward_pass: + assert ( + optimizer_type == "Adafactor".lower() + ), "fused_backward_pass currently only works with optimizer_type Adafactor / fused_backward_passは現在optimizer_type Adafactorでのみ機能します" + assert ( + args.gradient_accumulation_steps == 1 + ), "fused_backward_pass does not work with gradient_accumulation_steps > 1 / fused_backward_passはgradient_accumulation_steps>1では機能しません" + # 引数を分解する optimizer_kwargs = {} if args.optimizer_args is not None and len(args.optimizer_args) > 0: diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..3b28575ed 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -430,6 +430,20 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + if args.fused_backward_pass: + import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(tensor, args.max_grad_norm) + optimizer.step_param(tensor, param_group) + tensor.grad = None + + parameter.register_post_accumulate_grad_hook(__grad_hook) + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 @@ -619,13 +633,16 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) accelerator.backward(loss) - if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = [] - for m in training_models: - params_to_clip.extend(m.parameters()) - accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) - optimizer.step() + if not args.fused_backward_pass: + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + params_to_clip = [] + for m in training_models: + params_to_clip.extend(m.parameters()) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) From 64916a35b2378c4a8cdf3e9efeef8b8ab7ccb41c Mon Sep 17 00:00:00 2001 From: Zovjsra <4703michael@gmail.com> Date: Tue, 16 Apr 2024 16:40:08 +0800 Subject: [PATCH 07/97] add disable_mmap to args --- library/sdxl_model_util.py | 14 +++++++++----- library/sdxl_train_util.py | 9 +++++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index f03f1bae5..e6fcb1f9c 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -1,4 +1,5 @@ import torch +import safetensors from accelerate import init_empty_weights from accelerate.utils.modeling import set_module_tensor_to_device from safetensors.torch import load_file, save_file @@ -163,17 +164,20 @@ def _load_state_dict_on_device(model, state_dict, device, dtype=None): raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) -def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None): +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, disable_mmap=False): # model_version is reserved for future use # dtype is used for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None - try: - state_dict = load_file(ckpt_path, device=map_location) - except: - state_dict = load_file(ckpt_path) # prevent device invalid Error + if(disable_mmap): + state_dict = safetensors.torch.load(open(ckpt_path, 'rb').read()) + else: + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error epoch = None global_step = None else: diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index a29013e34..106c5b455 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -44,6 +44,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): weight_dtype, accelerator.device if args.lowram else "cpu", model_dtype, + args.disable_mmap_load_safetensors ) # work on low-ram device @@ -60,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): def _load_target_model( - name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None + name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu", model_dtype=None, disable_mmap=False ): # model_dtype only work with full fp16/bf16 name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path @@ -75,7 +76,7 @@ def _load_target_model( unet, logit_scale, ckpt_info, - ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype) + ) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device, model_dtype, disable_mmap) else: # Diffusers model is loaded to CPU from diffusers import StableDiffusionXLPipeline @@ -332,6 +333,10 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): action="store_true", help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする", ) + parser.add_argument( + "--disable_mmap_load_safetensors", + action="store_true", + ) def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True): From feefcf256e78a5f8d60c3a940f2be3b5c3ca335d Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Thu, 18 Apr 2024 23:15:36 -0400 Subject: [PATCH 08/97] Display name of error latent file When trying to load stored latents, if an error occurs, this change will tell you what file failed to load Currently it will just tell you that something failed without telling you which file --- library/train_util.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..58527fa00 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2123,18 +2123,21 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): if not os.path.exists(npz_path): return False - npz = np.load(npz_path) - if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? - return False - if npz["latents"].shape[1:3] != expected_latents_size: - return False - - if flip_aug: - if "latents_flipped" not in npz: + try: + npz = np.load(npz_path) + if "latents" not in npz or "original_size" not in npz or "crop_ltrb" not in npz: # old ver? return False - if npz["latents_flipped"].shape[1:3] != expected_latents_size: + if npz["latents"].shape[1:3] != expected_latents_size: return False + if flip_aug: + if "latents_flipped" not in npz: + return False + if npz["latents_flipped"].shape[1:3] != expected_latents_size: + return False + except: + raise RuntimeError(f"Error loading file: {npz_path}") + return True From fc374375de4fc9efd10eb598fdc166a4b6d0ad17 Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Thu, 18 Apr 2024 23:29:01 -0400 Subject: [PATCH 09/97] Allow negative learning rate This can be used to train away from a group of images you don't want As this moves the model away from a point instead of towards it, the change in the model is unbounded So, don't set it too low. -4e-7 seemed to work well. --- sdxl_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..1e6cec1a4 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -272,7 +272,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習を準備する:モデルを適切な状態にする if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - train_unet = args.learning_rate > 0 + train_unet = args.learning_rate != 0 train_text_encoder1 = False train_text_encoder2 = False @@ -284,8 +284,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2.gradient_checkpointing_enable() lr_te1 = args.learning_rate_te1 if args.learning_rate_te1 is not None else args.learning_rate # 0 means not train lr_te2 = args.learning_rate_te2 if args.learning_rate_te2 is not None else args.learning_rate # 0 means not train - train_text_encoder1 = lr_te1 > 0 - train_text_encoder2 = lr_te2 > 0 + train_text_encoder1 = lr_te1 != 0 + train_text_encoder2 = lr_te2 != 0 # caching one text encoder output is not supported if not train_text_encoder1: From 2c9db5d9f2f6b57f15b9312139d0410ae8ae4f3c Mon Sep 17 00:00:00 2001 From: Maatra Date: Sat, 20 Apr 2024 14:11:43 +0100 Subject: [PATCH 10/97] passing filtered hyperparameters to accelerate --- fine_tune.py | 2 +- library/train_util.py | 14 ++++++++++++++ sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- train_controlnet.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 10 files changed, 23 insertions(+), 9 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index c7e6bbd2e..77a1a4f30 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -310,7 +310,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..40be2b05b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3378,6 +3378,20 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", ) +def filter_sensitive_args(args: argparse.Namespace): + sensitive_args = ["wandb_api_key", "huggingface_token"] + sensitive_path_args = [ + "pretrained_model_name_or_path", + "vae", + "tokenizer_cache_dir", + "train_data_dir", + "conditioning_data_dir", + "reg_data_dir", + "output_dir", + "logging_dir", + ] + filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args} + return filtered_args # verify command line args for training def verify_command_line_training_args(args: argparse.Namespace): diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..5a9aa214e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -487,7 +487,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) # For --sample_at_first sdxl_train_util.sample_images( diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f89c3628f..770a1f3df 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -353,7 +353,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index e85e978c1..9490cf6f2 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -324,7 +324,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_controlnet.py b/train_controlnet.py index f4c94e8d9..793f79c7d 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -344,7 +344,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_db.py b/train_db.py index 1de504ed8..4f9018293 100644 --- a/train_db.py +++ b/train_db.py @@ -290,7 +290,7 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs) + accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) diff --git a/train_network.py b/train_network.py index c99d37247..1dca437cf 100644 --- a/train_network.py +++ b/train_network.py @@ -753,7 +753,7 @@ def load_model_hook(models, input_dir): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10fce2677..56a387391 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -510,7 +510,7 @@ def train(self, args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) # function for saving/removing diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index ddd03d532..691785239 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs ) # function for saving/removing From 4477116a64bb6c363d0fd9fbf3e21bb813548dfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sat, 20 Apr 2024 21:26:09 +0800 Subject: [PATCH 11/97] fix train controlnet --- library/train_util.py | 4 ++-- requirements.txt | 1 + train_controlnet.py | 8 ++++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..ecf3345fb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1982,8 +1982,8 @@ def make_buckets(self): self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): - return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8): + return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor) def __len__(self): return self.dreambooth_dataset_delegate.__len__() diff --git a/requirements.txt b/requirements.txt index e99775b8a..9495dab2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,6 +17,7 @@ easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.20.1 +omegaconf==2.3.0 # for Image utils imagesize==1.4.1 # for BLIP captioning diff --git a/train_controlnet.py b/train_controlnet.py index f4c94e8d9..763041aa6 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -5,7 +5,7 @@ import random import time from multiprocessing import Value -from types import SimpleNamespace +from omegaconf import OmegaConf import toml from tqdm import tqdm @@ -148,8 +148,10 @@ def train(args): "in_channels": 4, "layers_per_block": 2, "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_attention_heads": [5, 10, 20, 20], "num_class_embeds": None, "only_cross_attention": False, "out_channels": 4, @@ -179,8 +181,10 @@ def train(args): "in_channels": 4, "layers_per_block": 2, "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_attention_heads": 8, "out_channels": 4, "sample_size": 64, "up_block_types": ["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], @@ -193,7 +197,7 @@ def train(args): "resnet_time_scale_shift": "default", "projection_class_embeddings_input_dim": None, } - unet.config = SimpleNamespace(**unet.config) + unet.config = OmegaConf.create(unet.config) controlnet = ControlNetModel.from_unet(unet) From b886d0a359526f5715f3ced05697d406a169055b Mon Sep 17 00:00:00 2001 From: Maatra Date: Sat, 20 Apr 2024 14:36:47 +0100 Subject: [PATCH 12/97] Cleaned typing to be in line with accelerate hyperparameters type resctrictions --- library/train_util.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 40be2b05b..75b3420d9 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3390,7 +3390,20 @@ def filter_sensitive_args(args: argparse.Namespace): "output_dir", "logging_dir", ] - filtered_args = {k: v for k, v in vars(args).items() if k not in sensitive_args + sensitive_path_args} + filtered_args = {} + for k, v in vars(args).items(): + # filter out sensitive values + if k not in sensitive_args + sensitive_path_args: + #Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): + filtered_args[k] = v + # accelerate does not support lists + elif isinstance(v, list): + filtered_args[k] = f"{v}" + # accelerate does not support objects + elif isinstance(v, object): + filtered_args[k] = f"{v}" + return filtered_args # verify command line args for training From 5cb145d13bd9fae307a8766f4088b95f01492580 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= Date: Sat, 20 Apr 2024 21:56:24 +0800 Subject: [PATCH 13/97] Update train_util.py --- library/train_util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ecf3345fb..15c23f3cc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1982,8 +1982,8 @@ def make_buckets(self): self.bucket_manager = self.dreambooth_dataset_delegate.bucket_manager self.buckets_indices = self.dreambooth_dataset_delegate.buckets_indices - def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True, cache_file_suffix=".npz", divisor=8): - return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process, cache_file_suffix, divisor) + def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True): + return self.dreambooth_dataset_delegate.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process) def __len__(self): return self.dreambooth_dataset_delegate.__len__() From 52652cba1a419cd72851c3882f1f877670d889c5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 21 Apr 2024 17:41:32 +0900 Subject: [PATCH 14/97] disable main process check for deepspeed #1247 --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index c99d37247..3a5255160 100644 --- a/train_network.py +++ b/train_network.py @@ -474,7 +474,8 @@ def train(self, args): # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): From 0540c33acac223b672da05e40edcfb3b6a35c0da Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 21 Apr 2024 17:45:29 +0900 Subject: [PATCH 15/97] pop weights if available #1247 --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 3a5255160..aad5a7194 100644 --- a/train_network.py +++ b/train_network.py @@ -481,7 +481,8 @@ def save_model_hook(models, weights, output_dir): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - weights.pop(i) + if len(weights) > i: + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") def load_model_hook(models, input_dir): From 040e26ff1d8f855f52cdfb62781e06284c5e9e34 Mon Sep 17 00:00:00 2001 From: Cauldrath Date: Sun, 21 Apr 2024 13:46:31 -0400 Subject: [PATCH 16/97] Regenerate failed file If a latent file fails to load, print out the path and the error, then return false to regenerate it --- library/train_util.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 58527fa00..4168a41fb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2135,8 +2135,10 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): return False if npz["latents_flipped"].shape[1:3] != expected_latents_size: return False - except: - raise RuntimeError(f"Error loading file: {npz_path}") + except Exception as e: + print(npz_path) + print(e) + return False return True From fdbb03c360777562e91ab1884ed7cf2c3d65611b Mon Sep 17 00:00:00 2001 From: frodo821 Date: Tue, 23 Apr 2024 14:29:05 +0900 Subject: [PATCH 17/97] removed unnecessary `torch` import on line 115 as per #1290 --- finetune/tag_images_by_wd14_tagger.py | 1 - 1 file changed, 1 deletion(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index a327bbd61..b3f9cdd26 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -112,7 +112,6 @@ def main(args): # モデルを読み込む if args.onnx: - import torch import onnx import onnxruntime as ort From 969f82ab474024865d292afd96768e817c9374c1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 29 Apr 2024 20:04:25 +0900 Subject: [PATCH 18/97] move loraplus args from args to network_args, simplify log lr desc --- library/train_util.py | 3 -- networks/lora.py | 58 ++++++++++++++------- train_network.py | 114 ++++++++++++++++-------------------------- 3 files changed, 84 insertions(+), 91 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 048ed2ce3..15c23f3cc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2920,9 +2920,6 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): default=1, help="Polynomial power for polynomial scheduler / polynomialスケジューラでのpolynomial power", ) - parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") - parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") - parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): diff --git a/networks/lora.py b/networks/lora.py index edbbdc0d8..b67c59bd5 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -490,6 +490,14 @@ def create_network( varbose=True, ) + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) @@ -1033,18 +1041,27 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params( - self, - text_encoder_lr, - unet_lr, - default_lr, - text_encoder_loraplus_ratio=None, - unet_loraplus_ratio=None, - loraplus_ratio=None - ): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): + # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) + # if ( + # self.loraplus_lr_ratio is not None + # or self.loraplus_text_encoder_lr_ratio is not None + # or self.loraplus_unet_lr_ratio is not None + # ): + # assert ( + # optimizer_type.lower() != "prodigy" and "dadapt" not in optimizer_type.lower() + # ), "LoRA+ and Prodigy/DAdaptation is not supported / LoRA+とProdigy/DAdaptationの組み合わせはサポートされていません" + self.requires_grad_(True) + all_params = [] + lr_descriptions = [] def assemble_params(loras, lr, ratio): param_groups = {"lora": {}, "plus": {}} @@ -1056,6 +1073,7 @@ def assemble_params(loras, lr, ratio): param_groups["lora"][f"{lora.lora_name}.{name}"] = param params = [] + descriptions = [] for key in param_groups.keys(): param_data = {"params": param_groups[key].values()} @@ -1069,20 +1087,22 @@ def assemble_params(loras, lr, ratio): param_data["lr"] = lr if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: - print("NO LR skipping!") + logger.info("NO LR skipping!") continue params.append(param_data) + descriptions.append("plus" if key == "plus" else "") - return params + return params, descriptions if self.text_encoder_loras: - params = assemble_params( + params, descriptions = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_loraplus_ratio or loraplus_ratio + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) + lr_descriptions.extend(["textencoder" + (" " + d if d else "") for d in descriptions]) if self.unet_loras: if self.block_lr: @@ -1096,22 +1116,24 @@ def assemble_params(loras, lr, ratio): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - params = assemble_params( + params, descriptions = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) + lr_descriptions.extend([f"unet_block{idx}" + (" " + d if d else "") for d in descriptions]) else: - params = assemble_params( + params, descriptions = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) + lr_descriptions.extend(["unet" + (" " + d if d else "") for d in descriptions]) - return all_params + return all_params, lr_descriptions def enable_gradient_checkpointing(self): # not supported diff --git a/train_network.py b/train_network.py index 9670490ae..c43241e8d 100644 --- a/train_network.py +++ b/train_network.py @@ -53,7 +53,15 @@ def __init__(self): # TODO 他のスクリプトと共通化する def generate_step_logs( - self, args: argparse.Namespace, current_loss, avr_loss, lr_scheduler, keys_scaled=None, mean_norm=None, maximum_norm=None + self, + args: argparse.Namespace, + current_loss, + avr_loss, + lr_scheduler, + lr_descriptions, + keys_scaled=None, + mean_norm=None, + maximum_norm=None, ): logs = {"loss/current": current_loss, "loss/average": avr_loss} @@ -63,68 +71,25 @@ def generate_step_logs( logs["max_norm/max_key_norm"] = maximum_norm lrs = lr_scheduler.get_last_lr() - - if len(lrs) > 4: - idx = 0 - if not args.network_train_unet_only: - logs["lr/textencoder"] = float(lrs[0]) - idx = 1 - - for i in range(idx, len(lrs)): - lora_plus = "" - group_id = i - - if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: - lora_plus = '_lora+' if i % 2 == 1 else '' - group_id = int((i / 2) + (i % 2 + 0.5)) - - logs[f"lr/group{group_id}{lora_plus}"] = float(lrs[i]) - if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): - logs[f"lr/d*lr/group{group_id}{lora_plus}"] = ( - lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] - ) - - else: - if args.network_train_text_encoder_only: - if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/textencoder_lora+"] = float(lrs[1]) - else: - logs["lr/textencoder"] = float(lrs[0]) - - elif args.network_train_unet_only: - if args.loraplus_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: - logs["lr/unet"] = float(lrs[0]) - logs["lr/unet_lora+"] = float(lrs[1]) - else: - logs["lr/unet"] = float(lrs[0]) + for i, lr in enumerate(lrs): + if lr_descriptions is not None: + lr_desc = lr_descriptions[i] else: - if len(lrs) == 2: - if args.loraplus_text_encoder_lr_ratio is not None and args.loraplus_unet_lr_ratio is None: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/textencoder_lora+"] = float(lrs[1]) - elif args.loraplus_unet_lr_ratio is not None and args.loraplus_text_encoder_lr_ratio is None: - logs["lr/unet"] = float(lrs[0]) - logs["lr/unet_lora+"] = float(lrs[1]) - elif args.loraplus_unet_lr_ratio is None and args.loraplus_text_encoder_lr_ratio is None and args.loraplus_lr_ratio is not None: - logs["lr/all"] = float(lrs[0]) - logs["lr/all_lora+"] = float(lrs[1]) - else: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/unet"] = float(lrs[-1]) - elif len(lrs) == 4: - logs["lr/textencoder"] = float(lrs[0]) - logs["lr/textencoder_lora+"] = float(lrs[1]) - logs["lr/unet"] = float(lrs[2]) - logs["lr/unet_lora+"] = float(lrs[3]) + idx = i - (0 if args.network_train_unet_only else -1) + if idx == -1: + lr_desc = "textencoder" else: - logs["lr/all"] = float(lrs[0]) + if len(lrs) > 2: + lr_desc = f"group{idx}" + else: + lr_desc = "unet" + + logs[f"lr/{lr_desc}"] = lr - if ( - args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower() - ): # tracking d*lr value of unet. - logs["lr/d*lr"] = ( - lr_scheduler.optimizers[-1].param_groups[0]["d"] * lr_scheduler.optimizers[-1].param_groups[0]["lr"] + if args.optimizer_type.lower().startswith("DAdapt".lower()) or args.optimizer_type.lower() == "Prodigy".lower(): + # tracking d*lr value + logs[f"lr/d*lr/{lr_desc}"] = ( + lr_scheduler.optimizers[-1].param_groups[i]["d"] * lr_scheduler.optimizers[-1].param_groups[i]["lr"] ) return logs @@ -358,6 +323,7 @@ def train(self, args): network.apply_to(text_encoder, unet, train_text_encoder, train_unet) if args.network_weights is not None: + # FIXME consider alpha of weights info = network.load_weights(args.network_weights) accelerator.print(f"load network weights from {args.network_weights}: {info}") @@ -373,20 +339,23 @@ def train(self, args): # 後方互換性を確保するよ try: - trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate, args.loraplus_text_encoder_lr_ratio, args.loraplus_unet_lr_ratio, args.loraplus_lr_ratio) + results = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr, args.learning_rate) + if type(results) is tuple: + trainable_params = results[0] + lr_descriptions = results[1] + else: + trainable_params = results + lr_descriptions = None except TypeError: - accelerator.print( - "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" - ) + # accelerator.print( + # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" + # ) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) + lr_descriptions = None + print(lr_descriptions) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) - if args.loraplus_lr_ratio is not None or args.loraplus_text_encoder_lr_ratio is not None or args.loraplus_unet_lr_ratio is not None: - assert ( - (optimizer_name != "Prodigy" and "DAdapt" not in optimizer_name) - ), "LoRA+ and Prodigy/DAdaptation is not supported" - # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 n_workers = min(args.max_data_loader_n_workers, os.cpu_count()) # cpu_count or max_data_loader_n_workers @@ -992,7 +961,9 @@ def remove_model(old_ckpt_name): progress_bar.set_postfix(**{**max_mean_logs, **logs}) if args.logging_dir is not None: - logs = self.generate_step_logs(args, current_loss, avr_loss, lr_scheduler, keys_scaled, mean_norm, maximum_norm) + logs = self.generate_step_logs( + args, current_loss, avr_loss, lr_scheduler, lr_descriptions, keys_scaled, mean_norm, maximum_norm + ) accelerator.log(logs, step=global_step) if global_step >= args.max_train_steps: @@ -1143,6 +1114,9 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") + # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") + # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") return parser From dbb7bb288e416dae56d2911077e2642ad0f4b20d Mon Sep 17 00:00:00 2001 From: Dave Lage Date: Thu, 2 May 2024 17:39:35 -0400 Subject: [PATCH 19/97] Fix caption_separator missing in subset schema --- library/config_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/config_util.py b/library/config_util.py index d75d03b03..0276acb1e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -191,6 +191,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "keep_tokens": int, "keep_tokens_separator": str, "secondary_separator": str, + "caption_separator": str, "enable_wildcard": bool, "token_warmup_min": int, "token_warmup_step": Any(float, int), From 8db0cadcee47005feef5be34cbfaac8b85fe8837 Mon Sep 17 00:00:00 2001 From: Dave Lage Date: Thu, 2 May 2024 18:08:28 -0400 Subject: [PATCH 20/97] Add caption_separator to output for subset --- library/config_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/library/config_util.py b/library/config_util.py index d75d03b03..97554bbef 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -523,6 +523,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu shuffle_caption: {subset.shuffle_caption} keep_tokens: {subset.keep_tokens} keep_tokens_separator: {subset.keep_tokens_separator} + caption_separator: {subset.caption_separator} secondary_separator: {subset.secondary_separator} enable_wildcard: {subset.enable_wildcard} caption_dropout_rate: {subset.caption_dropout_rate} From 58c2d856ae6da6d6962cbfdd98c8a93eb790cbde Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 3 May 2024 22:18:20 +0900 Subject: [PATCH 21/97] support block dim/lr for sdxl --- networks/lora.py | 275 +++++++++++++++++++++++++++-------------------- train_network.py | 4 +- 2 files changed, 158 insertions(+), 121 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index b67c59bd5..61b8cd5a7 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -12,6 +12,7 @@ import torch import re from library.utils import setup_logging +from library.sdxl_original_unet import SdxlUNet2DConditionModel setup_logging() import logging @@ -385,14 +386,14 @@ def to_out_forward(self, x): return out -def parse_block_lr_kwargs(nw_kwargs): +def parse_block_lr_kwargs(is_sdxl: bool, nw_kwargs: Dict) -> Optional[List[float]]: down_lr_weight = nw_kwargs.get("down_lr_weight", None) mid_lr_weight = nw_kwargs.get("mid_lr_weight", None) up_lr_weight = nw_kwargs.get("up_lr_weight", None) # 以上のいずれにも設定がない場合は無効としてNoneを返す if down_lr_weight is None and mid_lr_weight is None and up_lr_weight is None: - return None, None, None + return None # extract learning rate weight for each block if down_lr_weight is not None: @@ -401,18 +402,16 @@ def parse_block_lr_kwargs(nw_kwargs): down_lr_weight = [(float(s) if s else 0.0) for s in down_lr_weight.split(",")] if mid_lr_weight is not None: - mid_lr_weight = float(mid_lr_weight) + mid_lr_weight = [(float(s) if s else 0.0) for s in mid_lr_weight.split(",")] if up_lr_weight is not None: if "," in up_lr_weight: up_lr_weight = [(float(s) if s else 0.0) for s in up_lr_weight.split(",")] - down_lr_weight, mid_lr_weight, up_lr_weight = get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) + return get_block_lr_weight( + is_sdxl, down_lr_weight, mid_lr_weight, up_lr_weight, float(nw_kwargs.get("block_lr_zero_threshold", 0.0)) ) - return down_lr_weight, mid_lr_weight, up_lr_weight - def create_network( multiplier: float, @@ -424,6 +423,9 @@ def create_network( neuron_dropout: Optional[float] = None, **kwargs, ): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + if network_dim is None: network_dim = 4 # default if network_alpha is None: @@ -441,21 +443,21 @@ def create_network( # block dim/alpha/lr block_dims = kwargs.get("block_dims", None) - down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) # 以上のいずれかに指定があればblockごとのdim(rank)を有効にする - if block_dims is not None or down_lr_weight is not None or mid_lr_weight is not None or up_lr_weight is not None: + if block_dims is not None or block_lr_weight is not None: block_alphas = kwargs.get("block_alphas", None) conv_block_dims = kwargs.get("conv_block_dims", None) conv_block_alphas = kwargs.get("conv_block_alphas", None) block_dims, block_alphas, conv_block_dims, conv_block_alphas = get_block_dims_and_alphas( - block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha ) # remove block dim/alpha without learning rate block_dims, block_alphas, conv_block_dims, conv_block_alphas = remove_block_dims_and_alphas( - block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight ) else: @@ -488,6 +490,7 @@ def create_network( conv_block_dims=conv_block_dims, conv_block_alphas=conv_block_alphas, varbose=True, + is_sdxl=is_sdxl, ) loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) @@ -498,8 +501,8 @@ def create_network( loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + if block_lr_weight is not None: + network.set_block_lr_weight(block_lr_weight) return network @@ -509,9 +512,13 @@ def create_network( # block_dims, block_alphas は両方ともNoneまたは両方とも値が入っている # conv_dim, conv_alpha は両方ともNoneまたは両方とも値が入っている def get_block_dims_and_alphas( - block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha + is_sdxl, block_dims, block_alphas, network_dim, network_alpha, conv_block_dims, conv_block_alphas, conv_dim, conv_alpha ): - num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + 1 + if not is_sdxl: + num_total_blocks = LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS + else: + # 1+9+3+9+1=23, no LoRA for emb_layers (0) + num_total_blocks = 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1 def parse_ints(s): return [int(i) for i in s.split(",")] @@ -522,9 +529,10 @@ def parse_floats(s): # block_dimsとblock_alphasをパースする。必ず値が入る if block_dims is not None: block_dims = parse_ints(block_dims) - assert ( - len(block_dims) == num_total_blocks - ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" + assert len(block_dims) == num_total_blocks, ( + f"block_dims must have {num_total_blocks} elements but {len(block_dims)} elements are given" + + f" / block_dimsは{num_total_blocks}個指定してください(指定された個数: {len(block_dims)})" + ) else: logger.warning( f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" @@ -575,15 +583,25 @@ def parse_floats(s): return block_dims, block_alphas, conv_block_dims, conv_block_alphas -# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出される可能性を考慮しておく +# 層別学習率用に層ごとの学習率に対する倍率を定義する、外部から呼び出せるようにclass外に出しておく +# 戻り値は block ごとの倍率のリスト def get_block_lr_weight( - down_lr_weight, mid_lr_weight, up_lr_weight, zero_threshold -) -> Tuple[List[float], List[float], List[float]]: + is_sdxl, + down_lr_weight: Union[str, List[float]], + mid_lr_weight: List[float], + up_lr_weight: Union[str, List[float]], + zero_threshold: float, +) -> Optional[List[float]]: # パラメータ未指定時は何もせず、今までと同じ動作とする if up_lr_weight is None and mid_lr_weight is None and down_lr_weight is None: - return None, None, None + return None - max_len = LoRANetwork.NUM_OF_BLOCKS # フルモデル相当でのup,downの層の数 + if not is_sdxl: + max_len_for_down_or_up = LoRANetwork.NUM_OF_BLOCKS + max_len_for_mid = LoRANetwork.NUM_OF_MID_BLOCKS + else: + max_len_for_down_or_up = LoRANetwork.SDXL_NUM_OF_BLOCKS + max_len_for_mid = LoRANetwork.SDXL_NUM_OF_MID_BLOCKS def get_list(name_with_suffix) -> List[float]: import math @@ -593,15 +611,18 @@ def get_list(name_with_suffix) -> List[float]: base_lr = float(tokens[1]) if len(tokens) > 1 else 0.0 if name == "cosine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in reversed(range(max_len))] + return [ + math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr + for i in reversed(range(max_len_for_down_or_up)) + ] elif name == "sine": - return [math.sin(math.pi * (i / (max_len - 1)) / 2) + base_lr for i in range(max_len)] + return [math.sin(math.pi * (i / (max_len_for_down_or_up - 1)) / 2) + base_lr for i in range(max_len_for_down_or_up)] elif name == "linear": - return [i / (max_len - 1) + base_lr for i in range(max_len)] + return [i / (max_len_for_down_or_up - 1) + base_lr for i in range(max_len_for_down_or_up)] elif name == "reverse_linear": - return [i / (max_len - 1) + base_lr for i in reversed(range(max_len))] + return [i / (max_len_for_down_or_up - 1) + base_lr for i in reversed(range(max_len_for_down_or_up))] elif name == "zeros": - return [0.0 + base_lr] * max_len + return [0.0 + base_lr] * max_len_for_down_or_up else: logger.error( "Unknown lr_weight argument %s is used. Valid arguments: / 不明なlr_weightの引数 %s が使われました。有効な引数:\n\tcosine, sine, linear, reverse_linear, zeros" @@ -614,20 +635,36 @@ def get_list(name_with_suffix) -> List[float]: if type(up_lr_weight) == str: up_lr_weight = get_list(up_lr_weight) - if (up_lr_weight != None and len(up_lr_weight) > max_len) or (down_lr_weight != None and len(down_lr_weight) > max_len): - logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len) - logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len) - up_lr_weight = up_lr_weight[:max_len] - down_lr_weight = down_lr_weight[:max_len] + if (up_lr_weight != None and len(up_lr_weight) > max_len_for_down_or_up) or ( + down_lr_weight != None and len(down_lr_weight) > max_len_for_down_or_up + ): + logger.warning("down_weight or up_weight is too long. Parameters after %d-th are ignored." % max_len_for_down_or_up) + logger.warning("down_weightもしくはup_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_down_or_up) + up_lr_weight = up_lr_weight[:max_len_for_down_or_up] + down_lr_weight = down_lr_weight[:max_len_for_down_or_up] + + if mid_lr_weight != None and len(mid_lr_weight) > max_len_for_mid: + logger.warning("mid_weight is too long. Parameters after %d-th are ignored." % max_len_for_mid) + logger.warning("mid_weightが長すぎます。%d個目以降のパラメータは無視されます。" % max_len_for_mid) + mid_lr_weight = mid_lr_weight[:max_len_for_mid] + + if (up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up) or ( + down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up + ): + logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_down_or_up) + logger.warning( + "down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_down_or_up + ) - if (up_lr_weight != None and len(up_lr_weight) < max_len) or (down_lr_weight != None and len(down_lr_weight) < max_len): - logger.warning("down_weight or up_weight is too short. Parameters after %d-th are filled with 1." % max_len) - logger.warning("down_weightもしくはup_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len) + if down_lr_weight != None and len(down_lr_weight) < max_len_for_down_or_up: + down_lr_weight = down_lr_weight + [1.0] * (max_len_for_down_or_up - len(down_lr_weight)) + if up_lr_weight != None and len(up_lr_weight) < max_len_for_down_or_up: + up_lr_weight = up_lr_weight + [1.0] * (max_len_for_down_or_up - len(up_lr_weight)) - if down_lr_weight != None and len(down_lr_weight) < max_len: - down_lr_weight = down_lr_weight + [1.0] * (max_len - len(down_lr_weight)) - if up_lr_weight != None and len(up_lr_weight) < max_len: - up_lr_weight = up_lr_weight + [1.0] * (max_len - len(up_lr_weight)) + if mid_lr_weight != None and len(mid_lr_weight) < max_len_for_mid: + logger.warning("mid_weight is too short. Parameters after %d-th are filled with 1." % max_len_for_mid) + logger.warning("mid_weightが短すぎます。%d個目までの不足したパラメータは1で補われます。" % max_len_for_mid) + mid_lr_weight = mid_lr_weight + [1.0] * (max_len_for_mid - len(mid_lr_weight)) if (up_lr_weight != None) or (mid_lr_weight != None) or (down_lr_weight != None): logger.info("apply block learning rate / 階層別学習率を適用します。") @@ -635,72 +672,84 @@ def get_list(name_with_suffix) -> List[float]: down_lr_weight = [w if w > zero_threshold else 0 for w in down_lr_weight] logger.info(f"down_lr_weight (shallower -> deeper, 浅い層->深い層): {down_lr_weight}") else: + down_lr_weight = [1.0] * max_len_for_down_or_up logger.info("down_lr_weight: all 1.0, すべて1.0") if mid_lr_weight != None: - mid_lr_weight = mid_lr_weight if mid_lr_weight > zero_threshold else 0 + mid_lr_weight = [w if w > zero_threshold else 0 for w in mid_lr_weight] logger.info(f"mid_lr_weight: {mid_lr_weight}") else: - logger.info("mid_lr_weight: 1.0") + mid_lr_weight = [1.0] * max_len_for_mid + logger.info("mid_lr_weight: all 1.0, すべて1.0") if up_lr_weight != None: up_lr_weight = [w if w > zero_threshold else 0 for w in up_lr_weight] logger.info(f"up_lr_weight (deeper -> shallower, 深い層->浅い層): {up_lr_weight}") else: + up_lr_weight = [1.0] * max_len_for_down_or_up logger.info("up_lr_weight: all 1.0, すべて1.0") - return down_lr_weight, mid_lr_weight, up_lr_weight + lr_weight = down_lr_weight + mid_lr_weight + up_lr_weight + + if is_sdxl: + lr_weight = [1.0] + lr_weight + [1.0] # add 1.0 for emb_layers and out + + assert (not is_sdxl and len(lr_weight) == LoRANetwork.NUM_OF_BLOCKS * 2 + LoRANetwork.NUM_OF_MID_BLOCKS) or ( + is_sdxl and len(lr_weight) == 1 + LoRANetwork.SDXL_NUM_OF_BLOCKS * 2 + LoRANetwork.SDXL_NUM_OF_MID_BLOCKS + 1 + ), f"lr_weight length is invalid: {len(lr_weight)}" + + return lr_weight # lr_weightが0のblockをblock_dimsから除外する、外部から呼び出す可能性を考慮しておく def remove_block_dims_and_alphas( - block_dims, block_alphas, conv_block_dims, conv_block_alphas, down_lr_weight, mid_lr_weight, up_lr_weight + is_sdxl, block_dims, block_alphas, conv_block_dims, conv_block_alphas, block_lr_weight: Optional[List[float]] ): - # set 0 to block dim without learning rate to remove the block - if down_lr_weight != None: - for i, lr in enumerate(down_lr_weight): + if block_lr_weight is not None: + for i, lr in enumerate(block_lr_weight): if lr == 0: block_dims[i] = 0 if conv_block_dims is not None: conv_block_dims[i] = 0 - if mid_lr_weight != None: - if mid_lr_weight == 0: - block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 - if conv_block_dims is not None: - conv_block_dims[LoRANetwork.NUM_OF_BLOCKS] = 0 - if up_lr_weight != None: - for i, lr in enumerate(up_lr_weight): - if lr == 0: - block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 - if conv_block_dims is not None: - conv_block_dims[LoRANetwork.NUM_OF_BLOCKS + 1 + i] = 0 - return block_dims, block_alphas, conv_block_dims, conv_block_alphas # 外部から呼び出す可能性を考慮しておく -def get_block_index(lora_name: str) -> int: +def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: block_idx = -1 # invalid lora name - - m = RE_UPDOWN.search(lora_name) - if m: - g = m.groups() - i = int(g[1]) - j = int(g[3]) - if g[2] == "resnets": - idx = 3 * i + j - elif g[2] == "attentions": - idx = 3 * i + j - elif g[2] == "upsamplers" or g[2] == "downsamplers": - idx = 3 * i + 2 - - if g[0] == "down": - block_idx = 1 + idx # 0に該当するLoRAは存在しない - elif g[0] == "up": - block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx - - elif "mid_block_" in lora_name: - block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + if not is_sdxl: + m = RE_UPDOWN.search(lora_name) + if m: + g = m.groups() + i = int(g[1]) + j = int(g[3]) + if g[2] == "resnets": + idx = 3 * i + j + elif g[2] == "attentions": + idx = 3 * i + j + elif g[2] == "upsamplers" or g[2] == "downsamplers": + idx = 3 * i + 2 + + if g[0] == "down": + block_idx = 1 + idx # 0に該当するLoRAは存在しない + elif g[0] == "up": + block_idx = LoRANetwork.NUM_OF_BLOCKS + 1 + idx + elif "mid_block_" in lora_name: + block_idx = LoRANetwork.NUM_OF_BLOCKS # idx=12 + else: + # copy from sdxl_train + if lora_name.startswith("lora_unet_"): + name = lora_name[len("lora_unet_") :] + if name.startswith("time_embed_") or name.startswith("label_emb_"): # No LoRA + block_idx = 0 # 0 + elif name.startswith("input_blocks_"): # 1-9 + block_idx = 1 + int(name.split("_")[2]) + elif name.startswith("middle_block_"): # 10-12 + block_idx = 10 + int(name.split("_")[2]) + elif name.startswith("output_blocks_"): # 13-21 + block_idx = 13 + int(name.split("_")[2]) + elif name.startswith("out_"): # 22, out, no LoRA + block_idx = 22 return block_idx @@ -742,15 +791,18 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh ) # block lr - down_lr_weight, mid_lr_weight, up_lr_weight = parse_block_lr_kwargs(kwargs) - if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: - network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + block_lr_weight = parse_block_lr_kwargs(kwargs) + if block_lr_weight is not None: + network.set_block_lr_weight(block_lr_weight) return network, weights_sd class LoRANetwork(torch.nn.Module): NUM_OF_BLOCKS = 12 # フルモデル相当でのup,downの層の数 + NUM_OF_MID_BLOCKS = 1 + SDXL_NUM_OF_BLOCKS = 9 # SDXLのモデルでのinput/outputの層の数 total=1(base) 9(input) + 3(mid) + 9(output) + 1(out) = 23 + SDXL_NUM_OF_MID_BLOCKS = 3 UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] @@ -782,6 +834,7 @@ def __init__( modules_alpha: Optional[Dict[str, int]] = None, module_class: Type[object] = LoRAModule, varbose: Optional[bool] = False, + is_sdxl: Optional[bool] = False, ) -> None: """ LoRA network: すごく引数が多いが、パターンは以下の通り @@ -863,7 +916,7 @@ def create_modules( alpha = modules_alpha[lora_name] elif is_unet and block_dims is not None: # U-Netでblock_dims指定あり - block_idx = get_block_index(lora_name) + block_idx = get_block_index(lora_name, is_sdxl) if is_linear or is_conv2d_1x1: dim = block_dims[block_idx] alpha = block_alphas[block_idx] @@ -927,15 +980,13 @@ def create_modules( skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - logger.warning( + logger.warn( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: logger.info(f"\t{name}") - self.up_lr_weight: List[float] = None - self.down_lr_weight: List[float] = None - self.mid_lr_weight: float = None + self.block_lr_weight = None self.block_lr = False # assertion @@ -966,12 +1017,12 @@ def load_weights(self, file): def apply_to(self, text_encoder, unet, apply_text_encoder=True, apply_unet=True): if apply_text_encoder: - logger.info("enable LoRA for text encoder") + logger.info(f"enable LoRA for text encoder: {len(self.text_encoder_loras)} modules") else: self.text_encoder_loras = [] if apply_unet: - logger.info("enable LoRA for U-Net") + logger.info(f"enable LoRA for U-Net: {len(self.unet_loras)} modules") else: self.unet_loras = [] @@ -1012,34 +1063,14 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): logger.info(f"weights are merged") # 層別学習率用に層ごとの学習率に対する倍率を定義する 引数の順番が逆だがとりあえず気にしない - def set_block_lr_weight( - self, - up_lr_weight: List[float] = None, - mid_lr_weight: float = None, - down_lr_weight: List[float] = None, - ): + def set_block_lr_weight(self, block_lr_weight: Optional[List[float]]): self.block_lr = True - self.down_lr_weight = down_lr_weight - self.mid_lr_weight = mid_lr_weight - self.up_lr_weight = up_lr_weight - - def get_lr_weight(self, lora: LoRAModule) -> float: - lr_weight = 1.0 - block_idx = get_block_index(lora.lora_name) - if block_idx < 0: - return lr_weight - - if block_idx < LoRANetwork.NUM_OF_BLOCKS: - if self.down_lr_weight != None: - lr_weight = self.down_lr_weight[block_idx] - elif block_idx == LoRANetwork.NUM_OF_BLOCKS: - if self.mid_lr_weight != None: - lr_weight = self.mid_lr_weight - elif block_idx > LoRANetwork.NUM_OF_BLOCKS: - if self.up_lr_weight != None: - lr_weight = self.up_lr_weight[block_idx - LoRANetwork.NUM_OF_BLOCKS - 1] - - return lr_weight + self.block_lr_weight = block_lr_weight + + def get_lr_weight(self, block_idx: int) -> float: + if not self.block_lr or self.block_lr_weight is None: + return 1.0 + return self.block_lr_weight[block_idx] def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): self.loraplus_lr_ratio = loraplus_lr_ratio @@ -1106,10 +1137,16 @@ def assemble_params(loras, lr, ratio): if self.unet_loras: if self.block_lr: + is_sdxl = False + for lora in self.unet_loras: + if "input_blocks" in lora.lora_name or "output_blocks" in lora.lora_name: + is_sdxl = True + break + # 学習率のグラフをblockごとにしたいので、blockごとにloraを分類 block_idx_to_lora = {} for lora in self.unet_loras: - idx = get_block_index(lora.lora_name) + idx = get_block_index(lora.lora_name, is_sdxl) if idx not in block_idx_to_lora: block_idx_to_lora[idx] = [] block_idx_to_lora[idx].append(lora) @@ -1118,7 +1155,7 @@ def assemble_params(loras, lr, ratio): for idx, block_loras in block_idx_to_lora.items(): params, descriptions = assemble_params( block_loras, - (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), + (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(idx), self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) diff --git a/train_network.py b/train_network.py index c43241e8d..2976f7635 100644 --- a/train_network.py +++ b/train_network.py @@ -346,13 +346,13 @@ def train(self, args): else: trainable_params = results lr_descriptions = None - except TypeError: + except TypeError as e: + # logger.warning(f"{e}") # accelerator.print( # "Deprecated: use prepare_optimizer_params(text_encoder_lr, unet_lr, learning_rate) instead of prepare_optimizer_params(text_encoder_lr, unet_lr)" # ) trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) lr_descriptions = None - print(lr_descriptions) optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) From 52e64c69cf249a7bc4ca6f4eebe82bc1b70e617b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 4 May 2024 18:43:52 +0900 Subject: [PATCH 22/97] add debug log --- train_network.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/train_network.py b/train_network.py index 2976f7635..feb455cea 100644 --- a/train_network.py +++ b/train_network.py @@ -354,6 +354,16 @@ def train(self, args): trainable_params = network.prepare_optimizer_params(args.text_encoder_lr, args.unet_lr) lr_descriptions = None + # if len(trainable_params) == 0: + # accelerator.print("no trainable parameters found / 学習可能なパラメータが見つかりませんでした") + # for params in trainable_params: + # for k, v in params.items(): + # if type(v) == float: + # pass + # else: + # v = len(v) + # accelerator.print(f"trainable_params: {k} = {v}") + optimizer_name, optimizer_args, optimizer = train_util.get_optimizer(args, trainable_params) # dataloaderを準備する From 7fe81502d04c1f68c85f276517e7144e6378c484 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 11:09:32 +0900 Subject: [PATCH 23/97] update loraplus on dylora/lofa_fa --- networks/dylora.py | 46 ++++++++++++++++++++++++--------------- networks/lora.py | 7 +++++- networks/lora_fa.py | 52 +++++++++++++++++++++++++++++++-------------- 3 files changed, 71 insertions(+), 34 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index 0546fc7ae..0d1701ded 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -18,10 +18,13 @@ import torch from torch import nn from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + class DyLoRAModule(torch.nn.Module): """ replaces forward method of the original Linear, instead of replacing the original Linear module. @@ -195,7 +198,7 @@ def create_network( conv_alpha = 1.0 else: conv_alpha = float(conv_alpha) - + if unit is not None: unit = int(unit) else: @@ -211,6 +214,16 @@ def create_network( unit=unit, varbose=True, ) + + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + return network @@ -280,6 +293,10 @@ def __init__( self.alpha = alpha self.apply_to_conv = apply_to_conv + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info("create LoRA network from weights") else: @@ -320,9 +337,9 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules lora = module_class(lora_name, child_module, self.multiplier, dim, alpha, unit) loras.append(lora) return loras - + text_encoders = text_encoder if type(text_encoder) == list else [text_encoder] - + self.text_encoder_loras = [] for i, text_encoder in enumerate(text_encoders): if len(text_encoders) > 1: @@ -331,7 +348,7 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules else: index = None logger.info("create LoRA for Text Encoder") - + text_encoder_loras = create_modules(False, text_encoder, DyLoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE) self.text_encoder_loras.extend(text_encoder_loras) @@ -346,6 +363,11 @@ def create_modules(is_unet, root_module: torch.nn.Module, target_replace_modules self.unet_loras = create_modules(True, unet, target_modules) logger.info(f"create LoRA for U-Net: {len(self.unet_loras)} modules.") + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -407,15 +429,7 @@ def merge_to(self, text_encoder, unet, weights_sd, dtype, device): """ # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params( - self, - text_encoder_lr, - unet_lr, - default_lr, - text_encoder_loraplus_ratio=None, - unet_loraplus_ratio=None, - loraplus_ratio=None - ): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] @@ -452,15 +466,13 @@ def assemble_params(loras, lr, ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_loraplus_ratio or loraplus_ratio + self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) if self.unet_loras: params = assemble_params( - self.unet_loras, - default_lr if unet_lr is None else unet_lr, - unet_loraplus_ratio or loraplus_ratio + self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_ratio ) all_params.extend(params) diff --git a/networks/lora.py b/networks/lora.py index 61b8cd5a7..6e5645577 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -499,7 +499,8 @@ def create_network( loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None - network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) if block_lr_weight is not None: network.set_block_lr_weight(block_lr_weight) @@ -855,6 +856,10 @@ def __init__( self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info(f"create LoRA network from weights") elif block_dims is not None: diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 9a608118a..58bcb2206 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -15,8 +15,10 @@ import torch import re from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -504,6 +506,15 @@ def create_network( if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) + loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) + loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) + loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) + loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None + loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None + loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None + if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: + network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) + return network @@ -529,7 +540,9 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") + logger.warning( + f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" + ) block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -803,11 +816,17 @@ def __init__( self.rank_dropout = rank_dropout self.module_dropout = module_dropout + self.loraplus_lr_ratio = None + self.loraplus_unet_lr_ratio = None + self.loraplus_text_encoder_lr_ratio = None + if modules_dim is not None: logger.info(f"create LoRA network from weights") elif block_dims is not None: logger.info(f"create LoRA network from block_dims") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) logger.info(f"block_dims: {block_dims}") logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: @@ -815,9 +834,13 @@ def __init__( logger.info(f"conv_block_alphas: {conv_block_alphas}") else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") + logger.info( + f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" + ) if self.conv_lora_dim is not None: - logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") + logger.info( + f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" + ) # create module instances def create_modules( @@ -939,6 +962,11 @@ def create_modules( assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) + def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): + self.loraplus_lr_ratio = loraplus_lr_ratio + self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio + self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -1033,15 +1061,7 @@ def get_lr_weight(self, lora: LoRAModule) -> float: return lr_weight # 二つのText Encoderに別々の学習率を設定できるようにするといいかも - def prepare_optimizer_params( - self, - text_encoder_lr, - unet_lr, - default_lr, - text_encoder_loraplus_ratio=None, - unet_loraplus_ratio=None, - loraplus_ratio=None - ): + def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] @@ -1078,7 +1098,7 @@ def assemble_params(loras, lr, ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - text_encoder_loraplus_ratio or loraplus_ratio + self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) @@ -1097,7 +1117,7 @@ def assemble_params(loras, lr, ratio): params = assemble_params( block_loras, (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) @@ -1105,7 +1125,7 @@ def assemble_params(loras, lr, ratio): params = assemble_params( self.unet_loras, unet_lr if unet_lr is not None else default_lr, - unet_loraplus_ratio or loraplus_ratio + self.loraplus_unet_lr_ratio or self.loraplus_ratio, ) all_params.extend(params) From 3fd8cdc55d7d87ceca2dc1127a807a7ddafb15ae Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 14:03:19 +0900 Subject: [PATCH 24/97] fix dylora loraplus --- networks/dylora.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/networks/dylora.py b/networks/dylora.py index 0d1701ded..d57e3d580 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -466,13 +466,13 @@ def assemble_params(loras, lr, ratio): params = assemble_params( self.text_encoder_loras, text_encoder_lr if text_encoder_lr is not None else default_lr, - self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, + self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio, ) all_params.extend(params) if self.unet_loras: params = assemble_params( - self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_ratio + self.unet_loras, default_lr if unet_lr is None else unet_lr, self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio ) all_params.extend(params) From 017b82ebe33a2199c8f842c99905f59c54292f56 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 15:05:42 +0900 Subject: [PATCH 25/97] update help message for fused_backward_pass --- library/train_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 46b55c03e..e3c0229a7 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2923,7 +2923,8 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--fused_backward_pass", action="store_true", - help="Combines backward pass and optimizer step to reduce VRAM usage / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。", + help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" + + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", ) From b56d5f7801dea45cdbbba8498544e8d2853ad6d6 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 6 May 2024 21:35:39 +0900 Subject: [PATCH 26/97] add experimental option to fuse params to optimizer groups --- sdxl_train.py | 114 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 104 insertions(+), 10 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 3b28575ed..c7eea2224 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -345,8 +345,8 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # calculate number of trainable parameters n_params = 0 - for params in params_to_optimize: - for p in params["params"]: + for group in params_to_optimize: + for p in group["params"]: n_params += p.numel() accelerator.print(f"train unet: {train_unet}, text_encoder1: {train_text_encoder1}, text_encoder2: {train_text_encoder2}") @@ -355,7 +355,44 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) + + if args.fused_optimizer_groups: + # calculate total number of parameters + n_total_params = sum(len(params["params"]) for params in params_to_optimize) + params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) + + # split params into groups + grouped_params = [] + param_group = [] + param_group_lr = -1 + for group in params_to_optimize: + lr = group["lr"] + for p in group["params"]: + if lr != param_group_lr: + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = lr + param_group.append(p) + if len(param_group) == params_per_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + param_group = [] + param_group_lr = -1 + if param_group: + grouped_params.append({"params": param_group, "lr": param_group_lr}) + + # prepare optimizers for each group + optimizers = [] + for group in grouped_params: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=[group]) + optimizers.append(optimizer) + optimizer = optimizers[0] # avoid error in the following code + + print(len(grouped_params)) + logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") + + else: + _, _, optimizer = train_util.get_optimizer(args, trainable_params=params_to_optimize) # dataloaderを準備する # DataLoaderのプロセス数:0 は persistent_workers が使えないので注意 @@ -382,7 +419,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): train_dataset_group.set_max_train_steps(args.max_train_steps) # lr schedulerを用意する - lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) + if args.fused_optimizer_groups: + lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] + lr_scheduler = lr_schedulers[0] # avoid error in the following code + else: + lr_scheduler = train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) # 実験的機能:勾配も含めたfp16/bf16学習を行う モデル全体をfp16/bf16にする if args.full_fp16: @@ -432,10 +473,12 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.fused_backward_pass: import library.adafactor_fused + library.adafactor_fused.patch_adafactor_fused(optimizer) for param_group in optimizer.param_groups: for parameter in param_group["params"]: if parameter.requires_grad: + def __grad_hook(tensor: torch.Tensor, param_group=param_group): if accelerator.sync_gradients and args.max_grad_norm != 0.0: accelerator.clip_grad_norm_(tensor, args.max_grad_norm) @@ -444,6 +487,36 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): parameter.register_post_accumulate_grad_hook(__grad_hook) + elif args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + optimizers[i] = accelerator.prepare(optimizers[i]) + lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + + global optimizer_hooked_count + global num_parameters_per_group + global parameter_optimizer_map + optimizer_hooked_count = {} + num_parameters_per_group = [0] * len(optimizers) + parameter_optimizer_map = {} + for opt_idx, optimizer in enumerate(optimizers): + for param_group in optimizer.param_groups: + for parameter in param_group["params"]: + if parameter.requires_grad: + + def optimizer_hook(parameter: torch.Tensor): + if accelerator.sync_gradients and args.max_grad_norm != 0.0: + accelerator.clip_grad_norm_(parameter, args.max_grad_norm) + + i = parameter_optimizer_map[parameter] + optimizer_hooked_count[i] += 1 + if optimizer_hooked_count[i] == num_parameters_per_group[i]: + optimizers[i].step() + optimizers[i].zero_grad() + + parameter.register_post_accumulate_grad_hook(optimizer_hook) + parameter_optimizer_map[parameter] = opt_idx + num_parameters_per_group[opt_idx] += 1 + # TextEncoderの出力をキャッシュするときにはCPUへ移動する if args.cache_text_encoder_outputs: # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 @@ -518,6 +591,10 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): for step, batch in enumerate(train_dataloader): current_step.value = global_step + + if args.fused_optimizer_groups: + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} + with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: latents = batch["latents"].to(accelerator.device).to(dtype=weight_dtype) @@ -596,7 +673,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -614,7 +693,9 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): or args.masked_loss ): # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) if args.masked_loss: loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) @@ -630,11 +711,13 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) - if not args.fused_backward_pass: + if not (args.fused_backward_pass or args.fused_optimizer_groups): if accelerator.sync_gradients and args.max_grad_norm != 0.0: params_to_clip = [] for m in training_models: @@ -642,9 +725,14 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() + elif args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + + if not (args.fused_backward_pass or args.fused_optimizer_groups): + optimizer.zero_grad(set_to_none=True) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -753,7 +841,7 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): accelerator.end_training() - if args.save_state or args.save_state_on_train_end: + if args.save_state or args.save_state_on_train_end: train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す @@ -822,6 +910,12 @@ def setup_parser() -> argparse.ArgumentParser: help=f"learning rates for each block of U-Net, comma-separated, {UNET_NUM_BLOCKS_FOR_BLOCK_LR} values / " + f"U-Netの各ブロックの学習率、カンマ区切り、{UNET_NUM_BLOCKS_FOR_BLOCK_LR}個の値", ) + parser.add_argument( + "--fused_optimizer_groups", + type=int, + default=None, + help="number of optimizers for fused backward pass and optimizer step / fused backward passとoptimizer stepのためのoptimizer数", + ) return parser From 793aeb94da53565fb08c7b0b2538f2ade04824bb Mon Sep 17 00:00:00 2001 From: AngelBottomless Date: Tue, 7 May 2024 18:21:31 +0900 Subject: [PATCH 27/97] fix get_trainable_params in controlnet-llite training --- sdxl_train_control_net_lllite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index f89c3628f..6ad6e763c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -477,7 +477,7 @@ def remove_model(old_ckpt_name): accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: - params_to_clip = unet.get_trainable_params() + params_to_clip = accelerator.unwrap_model(unet).get_trainable_params() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() From 607e041f3de972f2c3030e7c8b43dfc3c2eb2d65 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 14:16:41 +0900 Subject: [PATCH 28/97] chore: Refactor optimizer group --- sdxl_train.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index c7eea2224..be2b7166e 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -357,27 +357,37 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.print("prepare optimizer, data loader etc.") if args.fused_optimizer_groups: + # fused backward pass: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html + # Instead of creating an optimizer for all parameters as in the tutorial, we create an optimizer for each group of parameters. + # This balances memory usage and management complexity. + # calculate total number of parameters n_total_params = sum(len(params["params"]) for params in params_to_optimize) params_per_group = math.ceil(n_total_params / args.fused_optimizer_groups) - # split params into groups + # split params into groups, keeping the learning rate the same for all params in a group + # this will increase the number of groups if the learning rate is different for different params (e.g. U-Net and text encoders) grouped_params = [] param_group = [] param_group_lr = -1 for group in params_to_optimize: lr = group["lr"] for p in group["params"]: + # if the learning rate is different for different params, start a new group if lr != param_group_lr: if param_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) param_group = [] param_group_lr = lr + param_group.append(p) + + # if the group has enough parameters, start a new group if len(param_group) == params_per_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) param_group = [] param_group_lr = -1 + if param_group: grouped_params.append({"params": param_group, "lr": param_group_lr}) @@ -388,7 +398,6 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): optimizers.append(optimizer) optimizer = optimizers[0] # avoid error in the following code - print(len(grouped_params)) logger.info(f"using {len(optimizers)} optimizers for fused optimizer groups") else: @@ -420,6 +429,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # lr schedulerを用意する if args.fused_optimizer_groups: + # prepare lr schedulers for each optimizer lr_schedulers = [train_util.get_scheduler_fix(args, optimizer, accelerator.num_processes) for optimizer in optimizers] lr_scheduler = lr_schedulers[0] # avoid error in the following code else: @@ -472,6 +482,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) if args.fused_backward_pass: + # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused library.adafactor_fused.patch_adafactor_fused(optimizer) @@ -488,16 +499,20 @@ def __grad_hook(tensor: torch.Tensor, param_group=param_group): parameter.register_post_accumulate_grad_hook(__grad_hook) elif args.fused_optimizer_groups: + # prepare for additional optimizers and lr schedulers for i in range(1, len(optimizers)): optimizers[i] = accelerator.prepare(optimizers[i]) lr_schedulers[i] = accelerator.prepare(lr_schedulers[i]) + # counters are used to determine when to step the optimizer global optimizer_hooked_count global num_parameters_per_group global parameter_optimizer_map + optimizer_hooked_count = {} num_parameters_per_group = [0] * len(optimizers) parameter_optimizer_map = {} + for opt_idx, optimizer in enumerate(optimizers): for param_group in optimizer.param_groups: for parameter in param_group["params"]: @@ -511,7 +526,7 @@ def optimizer_hook(parameter: torch.Tensor): optimizer_hooked_count[i] += 1 if optimizer_hooked_count[i] == num_parameters_per_group[i]: optimizers[i].step() - optimizers[i].zero_grad() + optimizers[i].zero_grad(set_to_none=True) parameter.register_post_accumulate_grad_hook(optimizer_hook) parameter_optimizer_map[parameter] = opt_idx @@ -593,7 +608,7 @@ def optimizer_hook(parameter: torch.Tensor): current_step.value = global_step if args.fused_optimizer_groups: - optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} + optimizer_hooked_count = {i: 0 for i in range(len(optimizers))} # reset counter for each step with accelerator.accumulate(*training_models): if "latents" in batch and batch["latents"] is not None: @@ -725,14 +740,14 @@ def optimizer_hook(parameter: torch.Tensor): accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) optimizer.step() - elif args.fused_optimizer_groups: - for i in range(1, len(optimizers)): - lr_schedulers[i].step() - - lr_scheduler.step() - - if not (args.fused_backward_pass or args.fused_optimizer_groups): + lr_scheduler.step() optimizer.zero_grad(set_to_none=True) + else: + # optimizer.step() and optimizer.zero_grad() are called in the optimizer hook + lr_scheduler.step() + if args.fused_optimizer_groups: + for i in range(1, len(optimizers)): + lr_schedulers[i].step() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: From c1ba0b4356637c881ea99663fcce5943fc33fc56 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 14:21:10 +0900 Subject: [PATCH 29/97] update readme --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index a7047a360..859a7618d 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ## Change History +### Working in progress + +- Fixed some bugs when using DeepSpeed. Related [#1247] + + +- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247] + + ### Apr 7, 2024 / 2024-04-07: v0.8.7 - The default value of `huber_schedule` in Scheduled Huber Loss is changed from `exponential` to `snr`, which is expected to give better results. From f3d2cf22ff9ad49e7f8bd68494714fa3bedbd77d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 15:03:02 +0900 Subject: [PATCH 30/97] update README for fused optimizer --- README.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/README.md b/README.md index 859a7618d..4fd97fb25 100644 --- a/README.md +++ b/README.md @@ -139,8 +139,37 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! + - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. + - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. + - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. + - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. + - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. + - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. + +- Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. + - Specify the number of groups like `--fused_optimizer_groups 10` in `sdxl_train.py`. Increasing the number of groups reduces memory usage but slows down training. Since the effect is limited to a certain number, it is recommended to specify 4-10. + - Any optimizer can be used, but optimizers that automatically calculate the learning rate (such as D-Adaptation and Prodigy) cannot be used. Gradient accumulation is not available. + - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. + - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. + - Fixed some bugs when using DeepSpeed. Related [#1247] +- SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 + - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 + - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 + - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 + - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 + - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 + - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。 + +- SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) + - Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。 + - `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。 + - 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。 + - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 + - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。 - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247] From bee8cee7e8fbeecc05b1c80a1e9e8fadab3210a5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 15:08:52 +0900 Subject: [PATCH 31/97] update README for fused optimizer --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4fd97fb25..9c7ecad99 100644 --- a/README.md +++ b/README.md @@ -145,7 +145,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Setting mixed precision to `no` seems to use less memory than `fp16` or `bf16`. - Training is possible with a memory usage of about 17GB with a batch size of 1 and fp32. If you specify the `--full_bf16` option, you can further reduce the memory usage (but the accuracy will be lower). With the same memory usage as before, you can increase the batch size. - PyTorch 2.1 or later is required because it uses the new API `Tensor.register_post_accumulate_grad_hook(hook)`. - - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. + - Mechanism: Normally, backward -> step is performed for each parameter, so all gradients need to be temporarily stored in memory. "Fuse backward and step" reduces memory usage by performing backward/step for each parameter and reflecting the gradient immediately. The more parameters there are, the greater the effect, so it is not effective in other training scripts (LoRA, etc.) where the memory usage peak is elsewhere, and there are no plans to implement it in those training scripts. - Optimizer groups feature is added to SDXL training. PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) - Memory usage is reduced by the same principle as Fused optimizer. The training results and speed are the same as Fused optimizer. @@ -162,14 +162,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - mixed precision は `no` のほうが `fp16` や `bf16` よりも使用メモリ量が少ないようです。 - バッチサイズ 1、fp32 で 17GB 程度で学習可能なようです。`--full_bf16` オプションを指定するとさらに削減できます(精度は劣ります)。以前と同じメモリ使用量ではバッチサイズを増やせます。 - PyTorch 2.1 以降の新 API `Tensor.register_post_accumulate_grad_hook(hook)` を使用しているため、PyTorch 2.1 以降が必要です。 - - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。 + - 仕組み:通常は backward -> step の順で行うためすべての勾配を一時的にメモリに保持する必要があります。「backward と step の統合」はパラメータごとに backward/step を行って、勾配をすぐ反映することでメモリ使用量を削減します。パラメータ数が多いほど効果が大きいため、SDXL の学習以外(LoRA 等)ではほぼ効果がなく(メモリ使用量のピークが他の場所にあるため)、それらの学習スクリプトへの実装予定もありません。 - SDXL の学習時に optimizer group 機能を追加しました。PR [#1319](https://github.com/kohya-ss/sd-scripts/pull/1319) - Fused optimizer と同様の原理でメモリ使用量を削減します。学習結果や速度についても同様です。 - `sdxl_train.py` に `--fused_optimizer_groups 10` のようにグループ数を指定してください。グループ数を増やすとメモリ使用量が削減されますが、速度は遅くなります。ある程度の数までしか効果がないため、4~10 程度を指定すると良いでしょう。 - 任意の optimizer が使えますが、学習率を自動計算する optimizer (D-Adaptation や Prodigy など)は使えません。gradient accumulation は使えません。 - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 - - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。 + - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。 - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247] From 1ffc0b330aa362a408e46e9a52784d72aa73d263 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 16:18:43 +0900 Subject: [PATCH 32/97] fix typo --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index e3c0229a7..b2de8a216 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3093,7 +3093,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: ) parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed") parser.add_argument( - "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする" + "--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする" ) parser.add_argument( "--gradient_accumulation_steps", From 3c8193f64269fff68d16c1f38dedfde8715f70bb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 17:00:51 +0900 Subject: [PATCH 33/97] revert lora+ for lora_fa --- networks/lora_fa.py | 104 +++++++++++--------------------------------- 1 file changed, 25 insertions(+), 79 deletions(-) diff --git a/networks/lora_fa.py b/networks/lora_fa.py index 58bcb2206..919222ce8 100644 --- a/networks/lora_fa.py +++ b/networks/lora_fa.py @@ -15,10 +15,8 @@ import torch import re from library.utils import setup_logging - setup_logging() import logging - logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -506,15 +504,6 @@ def create_network( if up_lr_weight is not None or mid_lr_weight is not None or down_lr_weight is not None: network.set_block_lr_weight(up_lr_weight, mid_lr_weight, down_lr_weight) - loraplus_lr_ratio = kwargs.get("loraplus_lr_ratio", None) - loraplus_unet_lr_ratio = kwargs.get("loraplus_unet_lr_ratio", None) - loraplus_text_encoder_lr_ratio = kwargs.get("loraplus_text_encoder_lr_ratio", None) - loraplus_lr_ratio = float(loraplus_lr_ratio) if loraplus_lr_ratio is not None else None - loraplus_unet_lr_ratio = float(loraplus_unet_lr_ratio) if loraplus_unet_lr_ratio is not None else None - loraplus_text_encoder_lr_ratio = float(loraplus_text_encoder_lr_ratio) if loraplus_text_encoder_lr_ratio is not None else None - if loraplus_lr_ratio is not None or loraplus_unet_lr_ratio is not None or loraplus_text_encoder_lr_ratio is not None: - network.set_loraplus_lr_ratio(loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio) - return network @@ -540,9 +529,7 @@ def parse_floats(s): len(block_dims) == num_total_blocks ), f"block_dims must have {num_total_blocks} elements / block_dimsは{num_total_blocks}個指定してください" else: - logger.warning( - f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります" - ) + logger.warning(f"block_dims is not specified. all dims are set to {network_dim} / block_dimsが指定されていません。すべてのdimは{network_dim}になります") block_dims = [network_dim] * num_total_blocks if block_alphas is not None: @@ -816,17 +803,11 @@ def __init__( self.rank_dropout = rank_dropout self.module_dropout = module_dropout - self.loraplus_lr_ratio = None - self.loraplus_unet_lr_ratio = None - self.loraplus_text_encoder_lr_ratio = None - if modules_dim is not None: logger.info(f"create LoRA network from weights") elif block_dims is not None: logger.info(f"create LoRA network from block_dims") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") logger.info(f"block_dims: {block_dims}") logger.info(f"block_alphas: {block_alphas}") if conv_block_dims is not None: @@ -834,13 +815,9 @@ def __init__( logger.info(f"conv_block_alphas: {conv_block_alphas}") else: logger.info(f"create LoRA network. base dim (rank): {lora_dim}, alpha: {alpha}") - logger.info( - f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}" - ) + logger.info(f"neuron dropout: p={self.dropout}, rank dropout: p={self.rank_dropout}, module dropout: p={self.module_dropout}") if self.conv_lora_dim is not None: - logger.info( - f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}" - ) + logger.info(f"apply LoRA to Conv2d with kernel size (3,3). dim (rank): {self.conv_lora_dim}, alpha: {self.conv_alpha}") # create module instances def create_modules( @@ -962,11 +939,6 @@ def create_modules( assert lora.lora_name not in names, f"duplicated lora name: {lora.lora_name}" names.add(lora.lora_name) - def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, loraplus_text_encoder_lr_ratio): - self.loraplus_lr_ratio = loraplus_lr_ratio - self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio - self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio - def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: @@ -1065,42 +1037,18 @@ def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): self.requires_grad_(True) all_params = [] - def assemble_params(loras, lr, ratio): - param_groups = {"lora": {}, "plus": {}} - for lora in loras: - for name, param in lora.named_parameters(): - if ratio is not None and "lora_up" in name: - param_groups["plus"][f"{lora.lora_name}.{name}"] = param - else: - param_groups["lora"][f"{lora.lora_name}.{name}"] = param - + def enumerate_params(loras: List[LoRAModule]): params = [] - for key in param_groups.keys(): - param_data = {"params": param_groups[key].values()} - - if len(param_data["params"]) == 0: - continue - - if lr is not None: - if key == "plus": - param_data["lr"] = lr * ratio - else: - param_data["lr"] = lr - - if param_data.get("lr", None) == 0 or param_data.get("lr", None) is None: - continue - - params.append(param_data) - + for lora in loras: + # params.extend(lora.parameters()) + params.extend(lora.get_trainable_params()) return params if self.text_encoder_loras: - params = assemble_params( - self.text_encoder_loras, - text_encoder_lr if text_encoder_lr is not None else default_lr, - self.loraplus_text_encoder_lr_ratio or self.loraplus_ratio, - ) - all_params.extend(params) + param_data = {"params": enumerate_params(self.text_encoder_loras)} + if text_encoder_lr is not None: + param_data["lr"] = text_encoder_lr + all_params.append(param_data) if self.unet_loras: if self.block_lr: @@ -1114,20 +1062,21 @@ def assemble_params(loras, lr, ratio): # blockごとにパラメータを設定する for idx, block_loras in block_idx_to_lora.items(): - params = assemble_params( - block_loras, - (unet_lr if unet_lr is not None else default_lr) * self.get_lr_weight(block_loras[0]), - self.loraplus_unet_lr_ratio or self.loraplus_ratio, - ) - all_params.extend(params) + param_data = {"params": enumerate_params(block_loras)} + + if unet_lr is not None: + param_data["lr"] = unet_lr * self.get_lr_weight(block_loras[0]) + elif default_lr is not None: + param_data["lr"] = default_lr * self.get_lr_weight(block_loras[0]) + if ("lr" in param_data) and (param_data["lr"] == 0): + continue + all_params.append(param_data) else: - params = assemble_params( - self.unet_loras, - unet_lr if unet_lr is not None else default_lr, - self.loraplus_unet_lr_ratio or self.loraplus_ratio, - ) - all_params.extend(params) + param_data = {"params": enumerate_params(self.unet_loras)} + if unet_lr is not None: + param_data["lr"] = unet_lr + all_params.append(param_data) return all_params @@ -1144,9 +1093,6 @@ def on_epoch_start(self, text_encoder, unet): def get_trainable_params(self): return self.parameters() - def get_trainable_named_params(self): - return self.named_parameters() - def save_weights(self, file, dtype, metadata): if metadata is not None and len(metadata) == 0: metadata = None From 44190416c6389d9ae9ffb18c28744be1259fc02c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 17:01:20 +0900 Subject: [PATCH 34/97] update docs etc. --- README.md | 26 ++++++++++++++++++++++++-- docs/train_network_README-ja.md | 11 +++++++---- networks/lora.py | 2 +- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 9c7ecad99..b10da0f23 100644 --- a/README.md +++ b/README.md @@ -154,7 +154,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `--fused_optimizer_groups` cannot be used with `--fused_backward_pass`. When using AdaFactor, the memory usage is slightly larger than with Fused optimizer. PyTorch 2.1 or later is required. - Mechanism: While Fused optimizer performs backward/step for individual parameters within the optimizer, optimizer groups reduce memory usage by grouping parameters and creating multiple optimizers to perform backward/step for each group. Fused optimizer requires implementation on the optimizer side, while optimizer groups are implemented only on the training script side. -- Fixed some bugs when using DeepSpeed. Related [#1247] +- LoRA+ is supported. PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) Thanks to rockerBOO! + - LoRA+ is a method to improve training speed by increasing the learning rate of the UP side (LoRA-B) of LoRA. Specify the multiple. The original paper recommends 16, but adjust as needed. Please see the PR for details. + - Specify `loraplus_lr_ratio` with `--network_args`. Example: `--network_args "loraplus_lr_ratio=16"` + - `loraplus_unet_lr_ratio` and `loraplus_lr_ratio` can be specified separately for U-Net and Text Encoder. + - Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc. + - `network_module` `networks.lora` and `networks.dylora` are available. + +- LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) + - Specify the learning rate and dim (rank) for each block. + - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). + +- Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 @@ -171,7 +182,18 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - `--fused_optimizer_groups` は `--fused_backward_pass` と併用できません。AdaFactor 使用時は Fused optimizer よりも若干メモリ使用量は大きくなります。PyTorch 2.1 以降が必要です。 - 仕組み:Fused optimizer が optimizer 内で個別のパラメータについて backward/step を行っているのに対して、optimizer groups はパラメータをグループ化して複数の optimizer を作成し、それぞれ backward/step を行うことでメモリ使用量を削減します。Fused optimizer は optimizer 側の実装が必要ですが、optimizer groups は学習スクリプト側のみで実装されています。やはり SDXL の学習でのみ効果があります。 -- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247] +- LoRA+ がサポートされました。PR [#1233](https://github.com/kohya-ss/sd-scripts/pull/1233) rockerBOO 氏に感謝します。 + - LoRA の UP 側(LoRA-B)の学習率を上げることで学習速度の向上を図る手法です。倍数で指定します。元の論文では 16 が推奨されていますが、データセット等にもよりますので、適宜調整してください。PR もあわせてご覧ください。 + - `--network_args` で `loraplus_lr_ratio` を指定します。例:`--network_args "loraplus_lr_ratio=16"` + - `loraplus_unet_lr_ratio` と `loraplus_lr_ratio` で、U-Net および Text Encoder に個別の値を指定することも可能です。 + - 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など + - `network_module` の `networks.lora` および `networks.dylora` で使用可能です。 + +- SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) + - ブロックごとに学習率および dim (rank) を指定することができます。 + - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 + +- DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) ### Apr 7, 2024 / 2024-04-07: v0.8.7 diff --git a/docs/train_network_README-ja.md b/docs/train_network_README-ja.md index 2205a7736..46085117c 100644 --- a/docs/train_network_README-ja.md +++ b/docs/train_network_README-ja.md @@ -181,16 +181,16 @@ python networks\extract_lora_from_dylora.py --model "foldername/dylora-model.saf 詳細は[PR #355](https://github.com/kohya-ss/sd-scripts/pull/355) をご覧ください。 -SDXLは現在サポートしていません。 - フルモデルの25個のブロックの重みを指定できます。最初のブロックに該当するLoRAは存在しませんが、階層別LoRA適用等との互換性のために25個としています。またconv2d3x3に拡張しない場合も一部のブロックにはLoRAが存在しませんが、記述を統一するため常に25個の値を指定してください。 +SDXL では down/up 9 個、middle 3 個の値を指定してください。 + `--network_args` で以下の引数を指定してください。 - `down_lr_weight` : U-Netのdown blocksの学習率の重みを指定します。以下が指定可能です。 - - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個の数値を指定します。 + - ブロックごとの重み : `"down_lr_weight=0,0,0,0,0,0,1,1,1,1,1,1"` のように12個(SDXL では 9 個)の数値を指定します。 - プリセットからの指定 : `"down_lr_weight=sine"` のように指定します(サインカーブで重みを指定します)。sine, cosine, linear, reverse_linear, zeros が指定可能です。また `"down_lr_weight=cosine+.25"` のように `+数値` を追加すると、指定した数値を加算します(0.25~1.25になります)。 -- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します。 +- `mid_lr_weight` : U-Netのmid blockの学習率の重みを指定します。`"down_lr_weight=0.5"` のように数値を一つだけ指定します(SDXL の場合は 3 個)。 - `up_lr_weight` : U-Netのup blocksの学習率の重みを指定します。down_lr_weightと同様です。 - 指定を省略した部分は1.0として扱われます。また重みを0にするとそのブロックのLoRAモジュールは作成されません。 - `block_lr_zero_threshold` : 重みがこの値以下の場合、LoRAモジュールを作成しません。デフォルトは0です。 @@ -215,6 +215,9 @@ network_args = [ "block_lr_zero_threshold=0.1", "down_lr_weight=sine+.5", "mid_l フルモデルの25個のブロックのdim (rank)を指定できます。階層別学習率と同様に一部のブロックにはLoRAが存在しない場合がありますが、常に25個の値を指定してください。 +SDXL では 23 個の値を指定してください。一部のブロックにはLoRA が存在しませんが、`sdxl_train.py` の[階層別学習率](./train_SDXL-en.md) との互換性のためです。 +対応は、`0: time/label embed, 1-9: input blocks 0-8, 10-12: mid blocks 0-2, 13-21: output blocks 0-8, 22: out` です。 + `--network_args` で以下の引数を指定してください。 - `block_dims` : 各ブロックのdim (rank)を指定します。`"block_dims=2,2,2,2,4,4,4,4,6,6,6,6,8,6,6,6,6,4,4,4,4,2,2,2,2"` のように25個の数値を指定します。 diff --git a/networks/lora.py b/networks/lora.py index 6e5645577..00d21b0ed 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -985,7 +985,7 @@ def create_modules( skipped = skipped_te + skipped_un if varbose and len(skipped) > 0: - logger.warn( + logger.warning( f"because block_lr_weight is 0 or dim (rank) is 0, {len(skipped)} LoRA modules are skipped / block_lr_weightまたはdim (rank)が0の為、次の{len(skipped)}個のLoRAモジュールはスキップされます:" ) for name in skipped: From 9ddb4d7a0138722913f6f1a6f1bf30f7ff89bb5b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 17:55:08 +0900 Subject: [PATCH 35/97] update readme and help message etc. --- README.md | 8 ++++++++ library/sdxl_model_util.py | 6 ++++-- library/sdxl_train_util.py | 6 +++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b10da0f23..ed91d6d7b 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,10 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Specify the learning rate and dim (rank) for each block. - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). +- An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! + - It seems that the model file loading is faster in the WSL environment etc. + - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. + - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 @@ -193,6 +197,10 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - ブロックごとに学習率および dim (rank) を指定することができます。 - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 +- SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 + - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 + - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 + - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) diff --git a/library/sdxl_model_util.py b/library/sdxl_model_util.py index e6fcb1f9c..4fad78a1c 100644 --- a/library/sdxl_model_util.py +++ b/library/sdxl_model_util.py @@ -9,8 +9,10 @@ from library import model_util from library import sdxl_original_unet from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) VAE_SCALE_FACTOR = 0.13025 @@ -171,8 +173,8 @@ def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dty # Load the state dict if model_util.is_safetensors(ckpt_path): checkpoint = None - if(disable_mmap): - state_dict = safetensors.torch.load(open(ckpt_path, 'rb').read()) + if disable_mmap: + state_dict = safetensors.torch.load(open(ckpt_path, "rb").read()) else: try: state_dict = load_file(ckpt_path, device=map_location) diff --git a/library/sdxl_train_util.py b/library/sdxl_train_util.py index 106c5b455..b74bea91a 100644 --- a/library/sdxl_train_util.py +++ b/library/sdxl_train_util.py @@ -5,6 +5,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from accelerate import init_empty_weights @@ -13,8 +14,10 @@ from library import model_util, sdxl_model_util, train_util, sdxl_original_unet from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline from .utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) TOKENIZER1_PATH = "openai/clip-vit-large-patch14" @@ -44,7 +47,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype): weight_dtype, accelerator.device if args.lowram else "cpu", model_dtype, - args.disable_mmap_load_safetensors + args.disable_mmap_load_safetensors, ) # work on low-ram device @@ -336,6 +339,7 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--disable_mmap_load_safetensors", action="store_true", + help="disable mmap load for safetensors. Speed up model loading in WSL environment / safetensorsのmmapロードを無効にする。WSL環境等でモデル読み込みを高速化できる", ) From 3701507874c920e09e402980363702a91a67da3d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 20:56:56 +0900 Subject: [PATCH 36/97] raise original error if error is occured in checking latents --- library/train_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index d157cdbcd..8a69f0bef 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2136,9 +2136,8 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): if npz["latents_flipped"].shape[1:3] != expected_latents_size: return False except Exception as e: - print(npz_path) - print(e) - return False + logger.error(f"Error loading file: {npz_path}") + raise e return True From 39b82f26e5f9df6518a4e32f4b91b4c46cc667fb Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 20:58:45 +0900 Subject: [PATCH 37/97] update readme --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index ed91d6d7b..245853415 100644 --- a/README.md +++ b/README.md @@ -169,6 +169,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. +- When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath! + - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 @@ -201,6 +203,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 +- ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。 + - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) From 16677da0d90ad9094a0301990b831a8dd6c0e957 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 12 May 2024 22:15:07 +0900 Subject: [PATCH 38/97] fix create_network_from_weights doesn't work --- networks/lora.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/networks/lora.py b/networks/lora.py index 00d21b0ed..79dc6ec07 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -757,6 +757,9 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: # Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): + # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True + is_sdxl = unet is not None and issubclass(unet.__class__, SdxlUNet2DConditionModel) + if weights_sd is None: if os.path.splitext(file)[1] == ".safetensors": from safetensors.torch import load_file, safe_open @@ -792,7 +795,7 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh ) # block lr - block_lr_weight = parse_block_lr_kwargs(kwargs) + block_lr_weight = parse_block_lr_kwargs(is_sdxl, kwargs) if block_lr_weight is not None: network.set_block_lr_weight(block_lr_weight) From 589c2aa025d277497de32c2ceb8a9e76f4ca4bf2 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 13 May 2024 21:20:37 +0900 Subject: [PATCH 39/97] update README --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 245853415..9d042a41b 100644 --- a/README.md +++ b/README.md @@ -171,6 +171,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - When there is an error in the cached latents file on disk, the file name is now displayed. PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Thanks to Cauldrath! +- Fixed an error that occurs when specifying `--max_dataloader_n_workers` in `tag_images_by_wd14_tagger.py` when Onnx is not used. PR [#1291]( +https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290]( +https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! + +- Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO! + - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 @@ -205,6 +211,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - ディスクにキャッシュされた latents ファイルに何らかのエラーがあったとき、そのファイル名が表示されるようになりました。 PR [#1278](https://github.com/kohya-ss/sd-scripts/pull/1278) Cauldrath 氏に感謝します。 +- `tag_images_by_wd14_tagger.py` で Onnx 未使用時に `--max_dataloader_n_workers` を指定するとエラーになる不具合が修正されました。 PR [#1291]( +https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290]( +https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。 + +- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) rockerBOO 氏に感謝します。 + - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) From 153764a687d7553866335554d2b35ba89a123297 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 15 May 2024 20:21:49 +0900 Subject: [PATCH 40/97] add prompt option '--f' for filename --- README.md | 3 +++ gen_img.py | 55 +++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 43 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 9d042a41b..52d801217 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) +- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. + - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 - `sdxl_train.py` に `--fused_backward_pass` オプションを指定してください。現時点では optimizer は AdaFactor のみ対応しています。また gradient accumulation は使えません。 @@ -219,6 +221,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) +- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。 ### Apr 7, 2024 / 2024-04-07: v0.8.7 diff --git a/gen_img.py b/gen_img.py index 4fe898716..d0a8f8141 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1435,6 +1435,7 @@ class BatchDataBase(NamedTuple): clip_prompt: str guide_image: Any raw_prompt: str + file_name: Optional[str] class BatchDataExt(NamedTuple): @@ -2316,7 +2317,7 @@ def scale_and_round(x): # このバッチの情報を取り出す ( return_latents, - (step_first, _, _, _, init_image, mask_image, _, guide_image, _), + (step_first, _, _, _, init_image, mask_image, _, guide_image, _, _), ( width, height, @@ -2339,6 +2340,7 @@ def scale_and_round(x): prompts = [] negative_prompts = [] raw_prompts = [] + filenames = [] start_code = torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) noises = [ torch.zeros((batch_size, *noise_shape), device=device, dtype=dtype) @@ -2371,7 +2373,7 @@ def scale_and_round(x): all_guide_images_are_same = True for i, ( _, - (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt), + (_, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt, filename), _, ) in enumerate(batch): prompts.append(prompt) @@ -2379,6 +2381,7 @@ def scale_and_round(x): seeds.append(seed) clip_prompts.append(clip_prompt) raw_prompts.append(raw_prompt) + filenames.append(filename) if init_image is not None: init_images.append(init_image) @@ -2478,8 +2481,8 @@ def scale_and_round(x): # save image highres_prefix = ("0" if highres_1st else "1") if highres_fix else "" ts_str = time.strftime("%Y%m%d%H%M%S", time.localtime()) - for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt) in enumerate( - zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts) + for i, (image, prompt, negative_prompts, seed, clip_prompt, raw_prompt, filename) in enumerate( + zip(images, prompts, negative_prompts, seeds, clip_prompts, raw_prompts, filenames) ): if highres_fix: seed -= 1 # record original seed @@ -2505,17 +2508,23 @@ def scale_and_round(x): metadata.add_text("crop-top", str(crop_top)) metadata.add_text("crop-left", str(crop_left)) - if args.use_original_file_name and init_images is not None: - if type(init_images) is list: - fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" - else: - fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" - elif args.sequential_file_name: - fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + if filename is not None: + fln = filename else: - fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" + if args.use_original_file_name and init_images is not None: + if type(init_images) is list: + fln = os.path.splitext(os.path.basename(init_images[i % len(init_images)].filename))[0] + ".png" + else: + fln = os.path.splitext(os.path.basename(init_images.filename))[0] + ".png" + elif args.sequential_file_name: + fln = f"im_{highres_prefix}{step_first + i + 1:06d}.png" + else: + fln = f"im_{ts_str}_{highres_prefix}{i:03d}_{seed}.png" - image.save(os.path.join(args.outdir, fln), pnginfo=metadata) + if fln.endswith(".webp"): + image.save(os.path.join(args.outdir, fln), pnginfo=metadata, quality=100) # lossy + else: + image.save(os.path.join(args.outdir, fln), pnginfo=metadata) if not args.no_preview and not highres_1st and args.interactive: try: @@ -2562,6 +2571,7 @@ def scale_and_round(x): # repeat prompt for pi in range(args.images_per_prompt if len(raw_prompts) == 1 else len(raw_prompts)): raw_prompt = raw_prompts[pi] if len(raw_prompts) > 1 else raw_prompts[0] + filename = None if pi == 0 or len(raw_prompts) > 1: # parse prompt: if prompt is not changed, skip parsing @@ -2783,6 +2793,12 @@ def scale_and_round(x): logger.info(f"gradual latent unsharp params: {gl_unsharp_params}") continue + m = re.match(r"f (.+)", parg, re.IGNORECASE) + if m: # filename + filename = m.group(1) + logger.info(f"filename: {filename}") + continue + except ValueError as ex: logger.error(f"Exception in parsing / 解析エラー: {parg}") logger.error(f"{ex}") @@ -2873,7 +2889,16 @@ def scale_and_round(x): b1 = BatchData( False, BatchDataBase( - global_step, prompt, negative_prompt, seed, init_image, mask_image, clip_prompt, guide_image, raw_prompt + global_step, + prompt, + negative_prompt, + seed, + init_image, + mask_image, + clip_prompt, + guide_image, + raw_prompt, + filename, ), BatchDataExt( width, @@ -2916,7 +2941,7 @@ def setup_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser() add_logging_arguments(parser) - + parser.add_argument( "--sdxl", action="store_true", help="load Stable Diffusion XL model / Stable Diffusion XLのモデルを読み込む" ) From 146edce6934beee050d8e73458dad794449a0ff4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 18 May 2024 11:05:04 +0900 Subject: [PATCH 41/97] support Diffusers' based SDXL LoRA key for inference --- networks/lora.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/networks/lora.py b/networks/lora.py index 79dc6ec07..9f159f5db 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -755,6 +755,52 @@ def get_block_index(lora_name: str, is_sdxl: bool = False) -> int: return block_idx +def convert_diffusers_to_sai_if_needed(weights_sd): + # only supports U-Net LoRA modules + + found_up_down_blocks = False + for k in list(weights_sd.keys()): + if "down_blocks" in k: + found_up_down_blocks = True + break + if "up_blocks" in k: + found_up_down_blocks = True + break + if not found_up_down_blocks: + return + + from library.sdxl_model_util import make_unet_conversion_map + + unet_conversion_map = make_unet_conversion_map() + unet_conversion_map = {hf.replace(".", "_")[:-1]: sd.replace(".", "_")[:-1] for sd, hf in unet_conversion_map} + + # # add extra conversion + # unet_conversion_map["up_blocks_1_upsamplers_0"] = "lora_unet_output_blocks_2_2_conv" + + logger.info(f"Converting LoRA keys from Diffusers to SAI") + lora_unet_prefix = "lora_unet_" + for k in list(weights_sd.keys()): + if not k.startswith(lora_unet_prefix): + continue + + unet_module_name = k[len(lora_unet_prefix) :].split(".")[0] + + # search for conversion: this is slow because the algorithm is O(n^2), but the number of keys is small + for hf_module_name, sd_module_name in unet_conversion_map.items(): + if hf_module_name in unet_module_name: + new_key = ( + lora_unet_prefix + + unet_module_name.replace(hf_module_name, sd_module_name) + + k[len(lora_unet_prefix) + len(unet_module_name) :] + ) + weights_sd[new_key] = weights_sd.pop(k) + found = True + break + + if not found: + logger.warning(f"Key {k} is not found in unet_conversion_map") + + # Create network from weights for inference, weights are not loaded here (because can be merged) def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weights_sd=None, for_inference=False, **kwargs): # if unet is an instance of SdxlUNet2DConditionModel or subclass, set is_sdxl to True @@ -768,6 +814,9 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh else: weights_sd = torch.load(file, map_location="cpu") + # if keys are Diffusers based, convert to SAI based + convert_diffusers_to_sai_if_needed(weights_sd) + # get dim/alpha mapping modules_dim = {} modules_alpha = {} From 2f19175dfeb98e5ad93a633c79fa846d67210844 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 15:38:37 +0900 Subject: [PATCH 42/97] update README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 52d801217..b9852e0ad 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) -- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. +- Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported. - SDXL の学習時に Fused optimizer が使えるようになりました。PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) 2kpr 氏に感謝します。 - optimizer の backward pass に step を統合することで学習時のメモリ使用量を大きく削減します。学習結果は未適用時と同一ですが、メモリが潤沢にある場合は速度は遅くなります。 @@ -221,7 +221,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) -- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。 +- `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。 ### Apr 7, 2024 / 2024-04-07: v0.8.7 From e3ddd1fbbe4e00f49649f5aabd470b9dccf3019d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 16:26:10 +0900 Subject: [PATCH 43/97] update README and format code --- README.md | 4 ++++ sdxl_train_control_net_lllite.py | 9 +++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b9852e0ad..5d035eb6f 100644 --- a/README.md +++ b/README.md @@ -177,6 +177,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - Fixed a bug that `caption_separator` cannot be specified in the subset in the dataset settings .toml file. [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) and [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) Thanks to rockerBOO! +- Fixed a potential bug in ControlNet-LLLite training. PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) Thanks to aria1th! + - Fixed some bugs when using DeepSpeed. Related [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - Added a prompt option `--f` to `gen_imgs.py` to specify the file name when saving. Also, Diffusers-based keys for LoRA weights are now supported. @@ -219,6 +221,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します - データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) rockerBOO 氏に感謝します。 +- ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。 + - DeepSpeed 使用時のいくつかのバグを修正しました。関連 [#1247](https://github.com/kohya-ss/sd-scripts/pull/1247) - `gen_imgs.py` のプロンプトオプションに、保存時のファイル名を指定する `--f` オプションを追加しました。また同スクリプトで Diffusers ベースのキーを持つ LoRA の重みに対応しました。 diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 6ad6e763c..09b6d73be 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -15,6 +15,7 @@ import torch from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -439,7 +440,9 @@ def remove_model(old_ckpt_name): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype @@ -458,7 +461,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight From c68baae48033fe9794860518fe052dbf8def905e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 17:21:04 +0900 Subject: [PATCH 44/97] add `--log_config` option to enable/disable output training config --- README.md | 6 ++++++ fine_tune.py | 20 +++++++++++++++----- library/train_util.py | 16 +++++++++++++--- sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- train_controlnet.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 11 files changed, 42 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 5d035eb6f..cd7744598 100644 --- a/README.md +++ b/README.md @@ -165,6 +165,9 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Specify the learning rate and dim (rank) for each block. - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). +- Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! + - Some settings, such as API keys and directory specifications, are not output due to security issues. + - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. @@ -209,6 +212,9 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - ブロックごとに学習率および dim (rank) を指定することができます。 - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 +- 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 + - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 + - SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 diff --git a/fine_tune.py b/fine_tune.py index 77a1a4f30..d865cd2de 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -310,7 +310,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) @@ -354,7 +358,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): # Sample noise, sample a random timestep for each image, and add noise to the latents, # with noise offset and/or multires noise if specified - noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents) + noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps( + args, noise_scheduler, latents + ) # Predict the noise residual with accelerator.autocast(): @@ -368,7 +374,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: @@ -380,7 +388,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean() # mean over batch dimension else: - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c + ) accelerator.backward(loss) if accelerator.sync_gradients and args.max_grad_norm != 0.0: @@ -471,7 +481,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): accelerator.end_training() - if is_main_process and (args.save_state or args.save_state_on_train_end): + if is_main_process and (args.save_state or args.save_state_on_train_end): train_util.save_state_on_train_end(args, accelerator) del accelerator # この後メモリを使うのでこれは消す diff --git a/library/train_util.py b/library/train_util.py index 84764263e..410471470 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3180,6 +3180,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)", ) + parser.add_argument("--log_config", action="store_true", help="log training configuration / 学習設定をログに出力する") parser.add_argument( "--noise_offset", @@ -3388,7 +3389,15 @@ def add_masked_loss_arguments(parser: argparse.ArgumentParser): help="apply mask for calculating loss. conditioning_data_dir is required for dataset. / 損失計算時にマスクを適用する。datasetにはconditioning_data_dirが必要", ) -def filter_sensitive_args(args: argparse.Namespace): + +def get_sanitized_config_or_none(args: argparse.Namespace): + # if `--log_config` is enabled, return args for logging. if not, return None. + # when `--log_config is enabled, filter out sensitive values from args + # if wandb is not enabled, the log is not exposed to the public, but it is fine to filter out sensitive values to be safe + + if not args.log_config: + return None + sensitive_args = ["wandb_api_key", "huggingface_token"] sensitive_path_args = [ "pretrained_model_name_or_path", @@ -3402,9 +3411,9 @@ def filter_sensitive_args(args: argparse.Namespace): ] filtered_args = {} for k, v in vars(args).items(): - # filter out sensitive values + # filter out sensitive values and convert to string if necessary if k not in sensitive_args + sensitive_path_args: - #Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. + # Accelerate values need to have type `bool`,`str`, `float`, `int`, or `None`. if v is None or isinstance(v, bool) or isinstance(v, str) or isinstance(v, float) or isinstance(v, int): filtered_args[k] = v # accelerate does not support lists @@ -3416,6 +3425,7 @@ def filter_sensitive_args(args: argparse.Namespace): return filtered_args + # verify command line args for training def verify_command_line_training_args(args: argparse.Namespace): # if wandb is enabled, the command line is exposed to the public diff --git a/sdxl_train.py b/sdxl_train.py index 4c4e38721..11f9892a3 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -589,7 +589,7 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) + accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) # For --sample_at_first sdxl_train_util.sample_images( diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index b141965fa..301310901 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -354,7 +354,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 9490cf6f2..292a0463a 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -324,7 +324,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "lllite_control_net_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_controlnet.py b/train_controlnet.py index 793f79c7d..9994dd99c 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -344,7 +344,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_db.py b/train_db.py index 4f9018293..a5408cd3d 100644 --- a/train_db.py +++ b/train_db.py @@ -290,7 +290,7 @@ def train(args): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs) + accelerator.init_trackers("dreambooth" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) # For --sample_at_first train_util.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) diff --git a/train_network.py b/train_network.py index 401a1c70e..38e4888e8 100644 --- a/train_network.py +++ b/train_network.py @@ -774,7 +774,7 @@ def load_model_hook(models, input_dir): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) loss_recorder = train_util.LossRecorder() diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 56a387391..184607d1d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -510,7 +510,7 @@ def train(self, args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) # function for saving/removing diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 691785239..8eed00fa1 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -407,7 +407,7 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.filter_sensitive_args(args), init_kwargs=init_kwargs + "textual_inversion" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs ) # function for saving/removing From e4d9e3c843f5d9bfbfe56bd44c8f6a04d370201e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 17:46:07 +0900 Subject: [PATCH 45/97] remove dependency for omegaconf #ref 1284 --- README.md | 4 ++++ requirements.txt | 1 - train_controlnet.py | 38 +++++++++++++++++++++++++++++++------- 3 files changed, 35 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index cd7744598..04769a4cf 100644 --- a/README.md +++ b/README.md @@ -167,6 +167,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! - Some settings, such as API keys and directory specifications, are not output due to security issues. + +- The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds! - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. @@ -215,6 +217,8 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 +- SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。 + - SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 diff --git a/requirements.txt b/requirements.txt index 9495dab2a..e99775b8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,6 @@ easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 huggingface-hub==0.20.1 -omegaconf==2.3.0 # for Image utils imagesize==1.4.1 # for BLIP captioning diff --git a/train_controlnet.py b/train_controlnet.py index 3a1fa9de6..c9ac6c5a8 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -5,7 +5,8 @@ import random import time from multiprocessing import Value -from omegaconf import OmegaConf + +# from omegaconf import OmegaConf import toml from tqdm import tqdm @@ -13,6 +14,7 @@ import torch from library import deepspeed_utils from library.device_utils import init_ipex, clean_memory_on_device + init_ipex() from torch.nn.parallel import DistributedDataParallel as DDP @@ -197,7 +199,23 @@ def train(args): "resnet_time_scale_shift": "default", "projection_class_embeddings_input_dim": None, } - unet.config = OmegaConf.create(unet.config) + # unet.config = OmegaConf.create(unet.config) + + # make unet.config iterable and accessible by attribute + class CustomConfig: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + def __getattr__(self, name): + if name in self.__dict__: + return self.__dict__[name] + else: + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'") + + def __contains__(self, name): + return name in self.__dict__ + + unet.config = CustomConfig(**unet.config) controlnet = ControlNetModel.from_unet(unet) @@ -230,7 +248,7 @@ def train(args): ) vae.to("cpu") clean_memory_on_device(accelerator.device) - + accelerator.wait_for_everyone() if args.gradient_checkpointing: @@ -239,7 +257,7 @@ def train(args): # 学習に必要なクラスを準備する accelerator.print("prepare optimizer, data loader etc.") - trainable_params = controlnet.parameters() + trainable_params = list(controlnet.parameters()) _, _, optimizer = train_util.get_optimizer(args, trainable_params) @@ -348,7 +366,9 @@ def train(args): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "controlnet_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -424,7 +444,9 @@ def remove_model(old_ckpt_name): ) # Sample a random timestep for each image - timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device) + timesteps, huber_c = train_util.get_timesteps_and_huber_c( + args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler, b_size, latents.device + ) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -456,7 +478,9 @@ def remove_model(old_ckpt_name): else: target = noise - loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) + loss = train_util.conditional_loss( + noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c + ) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight From 4c798129b04955caad1c48405de168ff63a3809c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 19:00:32 +0900 Subject: [PATCH 46/97] update README --- README.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 04769a4cf..d0f2d65b2 100644 --- a/README.md +++ b/README.md @@ -165,11 +165,14 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Specify the learning rate and dim (rank) for each block. - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). +- Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath! + - The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended. + - Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! - Some settings, such as API keys and directory specifications, are not output due to security issues. - The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds! - + - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. @@ -214,6 +217,9 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - ブロックごとに学習率および dim (rank) を指定することができます。 - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 +- `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。 + - 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。 + - 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 From febc5c59fad74dfcead9064033171a9c674e4870 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 19:03:43 +0900 Subject: [PATCH 47/97] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index d0f2d65b2..838e4022c 100644 --- a/README.md +++ b/README.md @@ -167,6 +167,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Negative learning rates can now be specified during SDXL model training. PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Thanks to Cauldrath! - The model is trained to move away from the training images, so the model is easily collapsed. Use with caution. A value close to 0 is recommended. + - When specifying from the command line, use `=` like `--learning_rate=-1e-7`. - Training scripts can now output training settings to wandb or Tensor Board logs. Specify the `--log_config` option. PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) Thanks to ccharest93, plucked, rockerBOO, and VelocityRa! - Some settings, such as API keys and directory specifications, are not output due to security issues. @@ -219,6 +220,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - `sdxl_train.py` での SDXL モデル学習時に負の学習率が指定できるようになりました。PR [#1277](https://github.com/kohya-ss/sd-scripts/pull/1277) Cauldrath 氏に感謝します。 - 学習画像から離れるように学習するため、モデルは容易に崩壊します。注意して使用してください。0 に近い値を推奨します。 + - コマンドラインから指定する場合、`--learning_rate=-1e-7` のように`=` を使ってください。 - 各学習スクリプトで学習設定を wandb や Tensor Board などのログに出力できるようになりました。`--log_config` オプションを指定してください。PR [#1285](https://github.com/kohya-ss/sd-scripts/pull/1285) ccharest93 氏、plucked 氏、rockerBOO 氏および VelocityRa 氏に感謝します。 - API キーや各種ディレクトリ指定など、一部の設定はセキュリティ上の問題があるため出力されません。 From db6752901fc204686e460255797b188cb28611a5 Mon Sep 17 00:00:00 2001 From: u-haru <40634644+u-haru@users.noreply.github.com> Date: Sun, 19 May 2024 19:07:25 +0900 Subject: [PATCH 48/97] =?UTF-8?q?=E7=94=BB=E5=83=8F=E3=81=AE=E3=82=A2?= =?UTF-8?q?=E3=83=AB=E3=83=95=E3=82=A1=E3=83=81=E3=83=A3=E3=83=B3=E3=83=8D?= =?UTF-8?q?=E3=83=AB=E3=82=92loss=E3=81=AE=E3=83=9E=E3=82=B9=E3=82=AF?= =?UTF-8?q?=E3=81=A8=E3=81=97=E3=81=A6=E4=BD=BF=E7=94=A8=E3=81=99=E3=82=8B?= =?UTF-8?q?=E3=82=AA=E3=83=97=E3=82=B7=E3=83=A7=E3=83=B3=E3=82=92=E8=BF=BD?= =?UTF-8?q?=E5=8A=A0=20(#1223)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add alpha_mask parameter and apply masked loss * Fix type hint in trim_and_resize_if_required function * Refactor code to use keyword arguments in train_util.py * Fix alpha mask flipping logic * Fix alpha mask initialization * Fix alpha_mask transformation * Cache alpha_mask * Update alpha_masks to be on CPU * Set flipped_alpha_masks to Null if option disabled * Check if alpha_mask is None * Set alpha_mask to None if option disabled * Add description of alpha_mask option to docs --- docs/train_network_README-ja.md | 2 + docs/train_network_README-zh.md | 2 + library/config_util.py | 2 + library/custom_train_functions.py | 5 +- library/train_util.py | 203 ++++++++++++------------------ sdxl_train.py | 4 +- train_db.py | 4 +- train_network.py | 4 +- train_textual_inversion.py | 4 +- train_textual_inversion_XTI.py | 4 +- 10 files changed, 105 insertions(+), 129 deletions(-) diff --git a/docs/train_network_README-ja.md b/docs/train_network_README-ja.md index 46085117c..55c80c4b0 100644 --- a/docs/train_network_README-ja.md +++ b/docs/train_network_README-ja.md @@ -102,6 +102,8 @@ accelerate launch --num_cpu_threads_per_process 1 train_network.py * Text Encoderに関連するLoRAモジュールに、通常の学習率(--learning_rateオプションで指定)とは異なる学習率を使う時に指定します。Text Encoderのほうを若干低めの学習率(5e-5など)にしたほうが良い、という話もあるようです。 * `--network_args` * 複数の引数を指定できます。後述します。 +* `--alpha_mask` + * 画像のアルファ値をマスクとして使用します。透過画像を学習する際に使用します。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223) `--network_train_unet_only` と `--network_train_text_encoder_only` の両方とも未指定時(デフォルト)はText EncoderとU-Netの両方のLoRAモジュールを有効にします。 diff --git a/docs/train_network_README-zh.md b/docs/train_network_README-zh.md index ed7a0c4ef..830014f72 100644 --- a/docs/train_network_README-zh.md +++ b/docs/train_network_README-zh.md @@ -101,6 +101,8 @@ LoRA的模型将会被保存在通过`--output_dir`选项指定的文件夹中 * 当在Text Encoder相关的LoRA模块中使用与常规学习率(由`--learning_rate`选项指定)不同的学习率时,应指定此选项。可能最好将Text Encoder的学习率稍微降低(例如5e-5)。 * `--network_args` * 可以指定多个参数。将在下面详细说明。 +* `--alpha_mask` + * 使用图像的 Alpha 值作为遮罩。这在学习透明图像时使用。[PR #1223](https://github.com/kohya-ss/sd-scripts/pull/1223) 当未指定`--network_train_unet_only`和`--network_train_text_encoder_only`时(默认情况),将启用Text Encoder和U-Net的两个LoRA模块。 diff --git a/library/config_util.py b/library/config_util.py index 59f5f86d2..82baab83e 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -78,6 +78,7 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 + alpha_mask: bool = False @dataclass @@ -538,6 +539,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu random_crop: {subset.random_crop} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, + alpha_mask: {subset.alpha_mask}, """ ), " ", diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 406e0e36e..fad127405 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -479,9 +479,10 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, batch): +def apply_masked_loss(loss, mask_image): # mask image is -1 to 1. we need to convert it to 0 to 1 - mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + # mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = mask_image.to(dtype=loss.dtype) # resize to the same size as the loss mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") diff --git a/library/train_util.py b/library/train_util.py index 410471470..20f8055dc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -159,6 +159,9 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + self.alpha_mask: Optional[torch.Tensor] = None + self.alpha_mask_flipped: Optional[torch.Tensor] = None + self.use_alpha_mask: bool = False class BucketManager: @@ -379,6 +382,7 @@ def __init__( caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], + alpha_mask: bool, ) -> None: self.image_dir = image_dir self.num_repeats = num_repeats @@ -403,6 +407,7 @@ def __init__( self.img_count = 0 + self.alpha_mask = alpha_mask class DreamBoothSubset(BaseSubset): def __init__( @@ -412,47 +417,13 @@ def __init__( class_tokens: Optional[str], caption_extension: str, cache_info: bool, - num_repeats, - shuffle_caption, - caption_separator: str, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) self.is_reg = is_reg @@ -473,47 +444,13 @@ def __init__( self, image_dir, metadata_file: str, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" super().__init__( image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) self.metadata_file = metadata_file @@ -531,47 +468,13 @@ def __init__( conditioning_data_dir: str, caption_extension: str, cache_info: bool, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - num_repeats, - shuffle_caption, - caption_separator, - keep_tokens, - keep_tokens_separator, - secondary_separator, - enable_wildcard, - color_aug, - flip_aug, - face_crop_aug_range, - random_crop, - caption_dropout_rate, - caption_dropout_every_n_epochs, - caption_tag_dropout_rate, - caption_prefix, - caption_suffix, - token_warmup_min, - token_warmup_step, + **kwargs, ) self.conditioning_data_dir = conditioning_data_dir @@ -985,6 +888,8 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] + info.use_alpha_mask = subset.alpha_mask + if info.latents_npz is not None: # fine tuning dataset continue @@ -1088,8 +993,8 @@ def cache_text_encoder_outputs( def get_image_size(self, image_path): return imagesize.get(image_path) - def load_image_with_face_info(self, subset: BaseSubset, image_path: str): - img = load_image(image_path) + def load_image_with_face_info(self, subset: BaseSubset, image_path: str, alpha_mask=False): + img = load_image(image_path, alpha_mask) face_cx = face_cy = face_w = face_h = 0 if subset.face_crop_aug_range is not None: @@ -1166,6 +1071,7 @@ def __getitem__(self, index): input_ids_list = [] input_ids2_list = [] latents_list = [] + alpha_mask_list = [] images = [] original_sizes_hw = [] crop_top_lefts = [] @@ -1190,21 +1096,27 @@ def __getitem__(self, index): crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped if not flipped: latents = image_info.latents + alpha_mask = image_info.alpha_mask else: latents = image_info.latents_flipped - + alpha_mask = image_info.alpha_mask_flipped + image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz) if flipped: latents = flipped_latents + alpha_mask = flipped_alpha_mask del flipped_latents + del flipped_alpha_mask latents = torch.FloatTensor(latents) + if alpha_mask is not None: + alpha_mask = torch.FloatTensor(alpha_mask) image = None else: # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask) im_h, im_w = img.shape[0:2] if self.enable_bucket: @@ -1241,11 +1153,22 @@ def __getitem__(self, index): if flipped: img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + if subset.alpha_mask: + if img.shape[2] == 4: + alpha_mask = img[:, :, 3] # [W,H] + else: + alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] + alpha_mask = transforms.ToTensor()(alpha_mask) + else: + alpha_mask = None + img = img[:, :, :3] # remove alpha channel + latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる images.append(image) latents_list.append(latents) + alpha_mask_list.append(alpha_mask) target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) @@ -1348,6 +1271,8 @@ def __getitem__(self, index): example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) + example["alpha_mask"] = torch.stack(alpha_mask_list) if alpha_mask_list[0] is not None else None + if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -2145,7 +2070,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2154,13 +2079,19 @@ def load_latents_from_disk( original_size = npz["original_size"].tolist() crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - return latents, original_size, crop_ltrb, flipped_latents + alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None + flipped_alpha_mask = npz["flipped_alpha_mask"] if "flipped_alpha_mask" in npz else None + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() + if alpha_mask is not None: + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() + if flipped_alpha_mask is not None: + kwargs["flipped_alpha_mask"] = flipped_alpha_mask.float().cpu().numpy() np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2349,17 +2280,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path): +def load_image(image_path, alpha=False): image = Image.open(image_path) if not image.mode == "RGB": - image = image.convert("RGB") + if alpha: + image = image.convert("RGBA") + else: + image = image.convert("RGB") img = np.array(image, np.uint8) return img # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( - random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] + random_crop: bool, image: np.ndarray, reso, resized_size: Tuple[int, int] ) -> Tuple[np.ndarray, Tuple[int, int], Tuple[int, int, int, int]]: image_height, image_width = image.shape[0:2] original_size = (image_width, image_height) # size before resize @@ -2403,10 +2337,18 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] + alpha_masks = [] for info in image_infos: - image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) + image = load_image(info.absolute_path, info.use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + if info.use_alpha_mask: + if image.shape[2] == 4: + alpha_mask = image[:, :, 3] # [W,H] + image = image[:, :, :3] + else: + alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] + alpha_masks.append(transforms.ToTensor()(alpha_mask)) image = IMAGE_TRANSFORMS(image) images.append(image) @@ -2419,25 +2361,37 @@ def cache_batch_latents( with torch.no_grad(): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + if info.use_alpha_mask: + alpha_masks = torch.stack(alpha_masks, dim=0).to("cpu") + else: + alpha_masks = [None] * len(image_infos) + flipped_alpha_masks = [None] * len(image_infos) + if flip_aug: img_tensors = torch.flip(img_tensors, dims=[3]) with torch.no_grad(): flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") + if info.use_alpha_mask: + flipped_alpha_masks = torch.flip(alpha_masks, dims=[3]) else: flipped_latents = [None] * len(latents) + flipped_alpha_masks = [None] * len(image_infos) - for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): + for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) + save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask) else: info.latents = latent if flip_aug: info.latents_flipped = flipped_latent + info.alpha_mask = alpha_mask + info.alpha_mask_flipped = flipped_alpha_mask + if not HIGH_VRAM: clean_memory_on_device(vae.device) @@ -3683,6 +3637,11 @@ def add_dataset_arguments( default=0, help="tag length reaches maximum on N steps (or N*max_train_steps if N<1) / N(N<1ならN*max_train_steps)ステップでタグ長が最大になる。デフォルトは0(最初から最大)", ) + parser.add_argument( + "--alpha_mask", + action="store_true", + help="use alpha channel as mask for training / 画像のアルファチャンネルをlossのマスクに使用する", + ) parser.add_argument( "--dataset_class", diff --git a/sdxl_train.py b/sdxl_train.py index 7c71a5133..dcd06766b 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -712,7 +712,9 @@ def optimizer_hook(parameter: torch.Tensor): noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: diff --git a/train_db.py b/train_db.py index a5408cd3d..c46900006 100644 --- a/train_db.py +++ b/train_db.py @@ -360,7 +360,9 @@ def train(args): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_network.py b/train_network.py index 38e4888e8..cd1677ad2 100644 --- a/train_network.py +++ b/train_network.py @@ -903,7 +903,9 @@ def remove_model(old_ckpt_name): noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 184607d1d..a9c2a1094 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -590,7 +590,9 @@ def remove_model(old_ckpt_name): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 8eed00fa1..959839cbb 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -475,7 +475,9 @@ def remove_model(old_ckpt_name): loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) if args.masked_loss: - loss = apply_masked_loss(loss, batch) + loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) + if "alpha_mask" in batch and batch["alpha_mask"] is not None: + loss = apply_masked_loss(loss, batch["alpha_mask"]) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight From f2dd43e198f4bc059f4790ada041fa8f2a305f25 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 19:23:59 +0900 Subject: [PATCH 49/97] revert kwargs to explicit declaration --- library/train_util.py | 158 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 142 insertions(+), 16 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 20f8055dc..6cf285903 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -409,6 +409,7 @@ def __init__( self.alpha_mask = alpha_mask + class DreamBoothSubset(BaseSubset): def __init__( self, @@ -417,13 +418,47 @@ def __init__( class_tokens: Optional[str], caption_extension: str, cache_info: bool, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator: str, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) self.is_reg = is_reg @@ -444,13 +479,47 @@ def __init__( self, image_dir, metadata_file: str, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) -> None: assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です" super().__init__( image_dir, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) self.metadata_file = metadata_file @@ -468,13 +537,47 @@ def __init__( conditioning_data_dir: str, caption_extension: str, cache_info: bool, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) -> None: assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です" super().__init__( image_dir, - **kwargs, + num_repeats, + shuffle_caption, + caption_separator, + keep_tokens, + keep_tokens_separator, + secondary_separator, + enable_wildcard, + color_aug, + flip_aug, + face_crop_aug_range, + random_crop, + caption_dropout_rate, + caption_dropout_every_n_epochs, + caption_tag_dropout_rate, + caption_prefix, + caption_suffix, + token_warmup_min, + token_warmup_step, ) self.conditioning_data_dir = conditioning_data_dir @@ -1100,10 +1203,12 @@ def __getitem__(self, index): else: latents = image_info.latents_flipped alpha_mask = image_info.alpha_mask_flipped - + image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk( + image_info.latents_npz + ) if flipped: latents = flipped_latents alpha_mask = flipped_alpha_mask @@ -1116,7 +1221,9 @@ def __getitem__(self, index): image = None else: # 画像を読み込み、必要ならcropする - img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask) + img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info( + subset, image_info.absolute_path, subset.alpha_mask + ) im_h, im_w = img.shape[0:2] if self.enable_bucket: @@ -1157,7 +1264,7 @@ def __getitem__(self, index): if img.shape[2] == 4: alpha_mask = img[:, :, 3] # [W,H] else: - alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] + alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] alpha_mask = transforms.ToTensor()(alpha_mask) else: alpha_mask = None @@ -2070,7 +2177,14 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: +) -> Tuple[ + Optional[torch.Tensor], + Optional[List[int]], + Optional[List[int]], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], +]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2084,7 +2198,9 @@ def load_latents_from_disk( return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None): +def save_latents_to_disk( + npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None +): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() @@ -2344,10 +2460,10 @@ def cache_batch_latents( image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) if info.use_alpha_mask: if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [W,H] + alpha_mask = image[:, :, 3] # [W,H] image = image[:, :, :3] else: - alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] + alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] alpha_masks.append(transforms.ToTensor()(alpha_mask)) image = IMAGE_TRANSFORMS(image) images.append(image) @@ -2377,13 +2493,23 @@ def cache_batch_latents( flipped_latents = [None] * len(latents) flipped_alpha_masks = [None] * len(image_infos) - for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks): + for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip( + image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks + ): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask) + save_latents_to_disk( + info.latents_npz, + latent, + info.latents_original_size, + info.latents_crop_ltrb, + flipped_latent, + alpha_mask, + flipped_alpha_mask, + ) else: info.latents = latent if flip_aug: From da6fea3d9779970a1c573bf26fe37c924efc68d8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 19 May 2024 21:26:18 +0900 Subject: [PATCH 50/97] simplify and update alpha mask to work with various cases --- finetune/prepare_buckets_latents.py | 33 +++++-- library/config_util.py | 2 + library/custom_train_functions.py | 15 ++- library/train_util.py | 147 +++++++++++++++------------- sdxl_train.py | 6 +- tools/cache_latents.py | 12 ++- train_db.py | 6 +- train_network.py | 10 +- train_textual_inversion.py | 6 +- train_textual_inversion_XTI.py | 6 +- 10 files changed, 139 insertions(+), 104 deletions(-) diff --git a/finetune/prepare_buckets_latents.py b/finetune/prepare_buckets_latents.py index 0389da388..019c737a6 100644 --- a/finetune/prepare_buckets_latents.py +++ b/finetune/prepare_buckets_latents.py @@ -11,6 +11,7 @@ import torch from library.device_utils import init_ipex, get_preferred_device + init_ipex() from torchvision import transforms @@ -18,8 +19,10 @@ import library.model_util as model_util import library.train_util as train_util from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) DEVICE = get_preferred_device() @@ -89,7 +92,9 @@ def main(args): # bucketのサイズを計算する max_reso = tuple([int(t) for t in args.max_resolution.split(",")]) - assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" + assert ( + len(max_reso) == 2 + ), f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}" bucket_manager = train_util.BucketManager( args.bucket_no_upscale, max_reso, args.min_bucket_reso, args.max_bucket_reso, args.bucket_reso_steps @@ -107,7 +112,7 @@ def main(args): def process_batch(is_last): for bucket in bucket_manager.buckets: if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size: - train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False) + train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, args.alpha_mask, False) bucket.clear() # 読み込みの高速化のためにDataLoaderを使うオプション @@ -208,7 +213,9 @@ def setup_parser() -> argparse.ArgumentParser: parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル") parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先") parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル") - parser.add_argument("--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)") + parser.add_argument( + "--v2", action="store_true", help="not used (for backward compatibility) / 使用されません(互換性のため残してあります)" + ) parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ") parser.add_argument( "--max_data_loader_n_workers", @@ -231,10 +238,16 @@ def setup_parser() -> argparse.ArgumentParser: help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します", ) parser.add_argument( - "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" + "--bucket_no_upscale", + action="store_true", + help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します", ) parser.add_argument( - "--mixed_precision", type=str, default="no", choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度" + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help="use mixed precision / 混合精度を使う場合、その精度", ) parser.add_argument( "--full_path", @@ -242,7 +255,15 @@ def setup_parser() -> argparse.ArgumentParser: help="use full path as image-key in metadata (supports multiple directories) / メタデータで画像キーをフルパスにする(複数の学習画像ディレクトリに対応)", ) parser.add_argument( - "--flip_aug", action="store_true", help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する" + "--flip_aug", + action="store_true", + help="flip augmentation, save latents for flipped images / 左右反転した画像もlatentを取得、保存する", + ) + parser.add_argument( + "--alpha_mask", + type=str, + default="", + help="save alpha mask for images for loss calculation / 損失計算用に画像のアルファマスクを保存する", ) parser.add_argument( "--skip_existing", diff --git a/library/config_util.py b/library/config_util.py index 82baab83e..964270dbb 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -214,11 +214,13 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] DB_SUBSET_DISTINCT_SCHEMA = { Required("image_dir"): str, "is_reg": bool, + "alpha_mask": bool, } # FT means FineTuning FT_SUBSET_DISTINCT_SCHEMA = { Required("metadata_file"): str, "image_dir": str, + "alpha_mask": bool, } CN_SUBSET_ASCENDABLE_SCHEMA = { "caption_extension": str, diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index fad127405..af5813a1d 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -479,14 +479,19 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise -def apply_masked_loss(loss, mask_image): - # mask image is -1 to 1. we need to convert it to 0 to 1 - # mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel - mask_image = mask_image.to(dtype=loss.dtype) +def apply_masked_loss(loss, batch): + if "conditioning_images" in batch: + # conditioning image is -1 to 1. we need to convert it to 0 to 1 + mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel + mask_image = mask_image / 2 + 0.5 + elif "alpha_masks" in batch and batch["alpha_masks"] is not None: + # alpha mask is 0 to 1 + mask_image = batch["alpha_masks"].to(dtype=loss.dtype) + else: + return loss # resize to the same size as the loss mask_image = torch.nn.functional.interpolate(mask_image, size=loss.shape[2:], mode="area") - mask_image = mask_image / 2 + 0.5 loss = loss * mask_image return loss diff --git a/library/train_util.py b/library/train_util.py index 6cf285903..e7a50f04d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -159,9 +159,7 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None - self.alpha_mask: Optional[torch.Tensor] = None - self.alpha_mask_flipped: Optional[torch.Tensor] = None - self.use_alpha_mask: bool = False + self.alpha_mask: Optional[torch.Tensor] = None # alpha mask can be flipped in runtime class BucketManager: @@ -364,6 +362,7 @@ class BaseSubset: def __init__( self, image_dir: Optional[str], + alpha_mask: Optional[bool], num_repeats: int, shuffle_caption: bool, caption_separator: str, @@ -382,9 +381,9 @@ def __init__( caption_suffix: Optional[str], token_warmup_min: int, token_warmup_step: Union[float, int], - alpha_mask: bool, ) -> None: self.image_dir = image_dir + self.alpha_mask = alpha_mask if alpha_mask is not None else False self.num_repeats = num_repeats self.shuffle_caption = shuffle_caption self.caption_separator = caption_separator @@ -407,8 +406,6 @@ def __init__( self.img_count = 0 - self.alpha_mask = alpha_mask - class DreamBoothSubset(BaseSubset): def __init__( @@ -418,6 +415,7 @@ def __init__( class_tokens: Optional[str], caption_extension: str, cache_info: bool, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator: str, @@ -441,6 +439,7 @@ def __init__( super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -479,6 +478,7 @@ def __init__( self, image_dir, metadata_file: str, + alpha_mask: bool, num_repeats, shuffle_caption, caption_separator, @@ -502,6 +502,7 @@ def __init__( super().__init__( image_dir, + alpha_mask, num_repeats, shuffle_caption, caption_separator, @@ -921,7 +922,7 @@ def make_buckets(self): logger.info(f"mean ar error (without repeats): {mean_img_ar_error}") # データ参照用indexを作る。このindexはdatasetのshuffleに用いられる - self.buckets_indices: List(BucketBatchIndex) = [] + self.buckets_indices: List[BucketBatchIndex] = [] for bucket_index, bucket in enumerate(self.bucket_manager.buckets): batch_count = int(math.ceil(len(bucket) / self.batch_size)) for batch_index in range(batch_count): @@ -991,8 +992,6 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] - info.use_alpha_mask = subset.alpha_mask - if info.latents_npz is not None: # fine tuning dataset continue @@ -1002,7 +1001,9 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc if not is_main_process: # store to info only continue - cache_available = is_disk_cached_latents_is_expected(info.bucket_reso, info.latents_npz, subset.flip_aug) + cache_available = is_disk_cached_latents_is_expected( + info.bucket_reso, info.latents_npz, subset.flip_aug, subset.alpha_mask + ) if cache_available: # do not add to batch continue @@ -1028,7 +1029,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) + cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -1202,18 +1203,15 @@ def __getitem__(self, index): alpha_mask = image_info.alpha_mask else: latents = image_info.latents_flipped - alpha_mask = image_info.alpha_mask_flipped + alpha_mask = None if image_info.alpha_mask is None else torch.flip(image_info.alpha_mask, [1]) image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk( - image_info.latents_npz - ) + latents, original_size, crop_ltrb, flipped_latents, alpha_mask = load_latents_from_disk(image_info.latents_npz) if flipped: latents = flipped_latents - alpha_mask = flipped_alpha_mask + alpha_mask = None if alpha_mask is None else alpha_mask[:, ::-1].copy() # copy to avoid negative stride problem del flipped_latents - del flipped_alpha_mask latents = torch.FloatTensor(latents) if alpha_mask is not None: alpha_mask = torch.FloatTensor(alpha_mask) @@ -1255,23 +1253,28 @@ def __getitem__(self, index): # augmentation aug = self.aug_helper.get_augmentor(subset.color_aug) if aug is not None: - img = aug(image=img)["image"] + # augment RGB channels only + img_rgb = img[:, :, :3] + img_rgb = aug(image=img_rgb)["image"] + img[:, :, :3] = img_rgb if flipped: img = img[:, ::-1, :].copy() # copy to avoid negative stride problem if subset.alpha_mask: if img.shape[2] == 4: - alpha_mask = img[:, :, 3] # [W,H] + alpha_mask = img[:, :, 3] # [H,W] + alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1 else: - alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H] - alpha_mask = transforms.ToTensor()(alpha_mask) + alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) else: alpha_mask = None + img = img[:, :, :3] # remove alpha channel latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる + del img images.append(image) latents_list.append(latents) @@ -1361,6 +1364,23 @@ def __getitem__(self, index): example["text_encoder_outputs2_list"] = torch.stack(text_encoder_outputs2_list) example["text_encoder_pool2_list"] = torch.stack(text_encoder_pool2_list) + # if one of alpha_masks is not None, we need to replace None with ones + none_or_not = [x is None for x in alpha_mask_list] + if all(none_or_not): + example["alpha_masks"] = None + elif any(none_or_not): + for i in range(len(alpha_mask_list)): + if alpha_mask_list[i] is None: + if images[i] is not None: + alpha_mask_list[i] = torch.ones((images[i].shape[1], images[i].shape[2]), dtype=torch.float32) + else: + alpha_mask_list[i] = torch.ones( + (latents_list[i].shape[1] * 8, latents_list[i].shape[2] * 8), dtype=torch.float32 + ) + example["alpha_masks"] = torch.stack(alpha_mask_list) + else: + example["alpha_masks"] = torch.stack(alpha_mask_list) + if images[0] is not None: images = torch.stack(images) images = images.to(memory_format=torch.contiguous_format).float() @@ -1378,8 +1398,6 @@ def __getitem__(self, index): example["network_multipliers"] = torch.FloatTensor([self.network_multiplier] * len(captions)) - example["alpha_mask"] = torch.stack(alpha_mask_list) if alpha_mask_list[0] is not None else None - if self.debug_dataset: example["image_keys"] = bucket[image_index : image_index + self.batch_size] return example @@ -1393,6 +1411,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): resized_sizes = [] bucket_reso = None flip_aug = None + alpha_mask = None random_crop = None for image_key in bucket[image_index : image_index + bucket_batch_size]: @@ -1401,10 +1420,13 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): if flip_aug is None: flip_aug = subset.flip_aug + alpha_mask = subset.alpha_mask random_crop = subset.random_crop bucket_reso = image_info.bucket_reso else: + # TODO そもそも混在してても動くようにしたほうがいい assert flip_aug == subset.flip_aug, "flip_aug must be same in a batch" + assert alpha_mask == subset.alpha_mask, "alpha_mask must be same in a batch" assert random_crop == subset.random_crop, "random_crop must be same in a batch" assert bucket_reso == image_info.bucket_reso, "bucket_reso must be same in a batch" @@ -1441,6 +1463,7 @@ def get_item_for_caching(self, bucket, bucket_batch_size, image_index): example["absolute_paths"] = absolute_paths example["resized_sizes"] = resized_sizes example["flip_aug"] = flip_aug + example["alpha_mask"] = alpha_mask example["random_crop"] = random_crop example["bucket_reso"] = bucket_reso return example @@ -2149,7 +2172,7 @@ def disable_token_padding(self): dataset.disable_token_padding() -def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): +def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alpha_mask: bool): expected_latents_size = (reso[1] // 8, reso[0] // 8) # bucket_resoはWxHなので注意 if not os.path.exists(npz_path): @@ -2167,6 +2190,12 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): return False if npz["latents_flipped"].shape[1:3] != expected_latents_size: return False + + if alpha_mask: + if "alpha_mask" not in npz: + return False + if npz["alpha_mask"].shape[0:2] != reso: # HxW + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -2177,14 +2206,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[ - Optional[torch.Tensor], - Optional[List[int]], - Optional[List[int]], - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], -]: +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2194,20 +2216,15 @@ def load_latents_from_disk( crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None alpha_mask = npz["alpha_mask"] if "alpha_mask" in npz else None - flipped_alpha_mask = npz["flipped_alpha_mask"] if "flipped_alpha_mask" in npz else None - return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask + return latents, original_size, crop_ltrb, flipped_latents, alpha_mask -def save_latents_to_disk( - npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None -): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() - if flipped_alpha_mask is not None: - kwargs["flipped_alpha_mask"] = flipped_alpha_mask.float().cpu().numpy() + kwargs["alpha_mask"] = alpha_mask # ndarray np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2398,10 +2415,11 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: def load_image(image_path, alpha=False): image = Image.open(image_path) - if not image.mode == "RGB": - if alpha: + if alpha: + if not image.mode == "RGBA": image = image.convert("RGBA") - else: + else: + if not image.mode == "RGB": image = image.convert("RGB") img = np.array(image, np.uint8) return img @@ -2441,7 +2459,7 @@ def trim_and_resize_if_required( def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool + vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, use_alpha_mask: bool, random_crop: bool ) -> None: r""" requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz @@ -2453,49 +2471,43 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] - alpha_masks = [] + alpha_masks: List[np.ndarray] = [] for info in image_infos: - image = load_image(info.absolute_path, info.use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) + image = load_image(info.absolute_path, use_alpha_mask) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) - if info.use_alpha_mask: + + info.latents_original_size = original_size + info.latents_crop_ltrb = crop_ltrb + + if use_alpha_mask: if image.shape[2] == 4: - alpha_mask = image[:, :, 3] # [W,H] - image = image[:, :, :3] + alpha_mask = image[:, :, 3] # [H,W] + alpha_mask = alpha_mask.astype(np.float32) / 255.0 else: - alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H] - alpha_masks.append(transforms.ToTensor()(alpha_mask)) + alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32) + else: + alpha_mask = None + alpha_masks.append(alpha_mask) + + image = image[:, :, :3] # remove alpha channel if exists image = IMAGE_TRANSFORMS(image) images.append(image) - info.latents_original_size = original_size - info.latents_crop_ltrb = crop_ltrb - img_tensors = torch.stack(images, dim=0) img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype) with torch.no_grad(): latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - if info.use_alpha_mask: - alpha_masks = torch.stack(alpha_masks, dim=0).to("cpu") - else: - alpha_masks = [None] * len(image_infos) - flipped_alpha_masks = [None] * len(image_infos) - if flip_aug: img_tensors = torch.flip(img_tensors, dims=[3]) with torch.no_grad(): flipped_latents = vae.encode(img_tensors).latent_dist.sample().to("cpu") - if info.use_alpha_mask: - flipped_alpha_masks = torch.flip(alpha_masks, dims=[3]) else: flipped_latents = [None] * len(latents) - flipped_alpha_masks = [None] * len(image_infos) - for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip( - image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks - ): + for info, latent, flipped_latent, alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") @@ -2508,15 +2520,12 @@ def cache_batch_latents( info.latents_crop_ltrb, flipped_latent, alpha_mask, - flipped_alpha_mask, ) else: info.latents = latent if flip_aug: info.latents_flipped = flipped_latent - info.alpha_mask = alpha_mask - info.alpha_mask_flipped = flipped_alpha_mask if not HIGH_VRAM: clean_memory_on_device(vae.device) diff --git a/sdxl_train.py b/sdxl_train.py index dcd06766b..9e20c60ca 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -711,10 +711,8 @@ def optimizer_hook(parameter: torch.Tensor): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: diff --git a/tools/cache_latents.py b/tools/cache_latents.py index 347db27f7..b7c88121e 100644 --- a/tools/cache_latents.py +++ b/tools/cache_latents.py @@ -17,10 +17,13 @@ BlueprintGenerator, ) from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) + def cache_to_disk(args: argparse.Namespace) -> None: train_util.prepare_dataset_args(args, True) @@ -107,7 +110,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: else: _, vae, _, _ = train_util.load_target_model(args, weight_dtype, accelerator) - if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える + if torch.__version__ >= "2.0.0": # PyTorch 2.0.0 以上対応のxformersなら以下が使える vae.set_use_memory_efficient_attention_xformers(args.xformers) vae.to(accelerator.device, dtype=vae_dtype) vae.requires_grad_(False) @@ -136,6 +139,7 @@ def cache_to_disk(args: argparse.Namespace) -> None: b_size = len(batch["images"]) vae_batch_size = b_size if args.vae_batch_size is None else args.vae_batch_size flip_aug = batch["flip_aug"] + alpha_mask = batch["alpha_mask"] random_crop = batch["random_crop"] bucket_reso = batch["bucket_reso"] @@ -154,14 +158,16 @@ def cache_to_disk(args: argparse.Namespace) -> None: image_info.latents_npz = os.path.splitext(absolute_path)[0] + ".npz" if args.skip_existing: - if train_util.is_disk_cached_latents_is_expected(image_info.bucket_reso, image_info.latents_npz, flip_aug): + if train_util.is_disk_cached_latents_is_expected( + image_info.bucket_reso, image_info.latents_npz, flip_aug, alpha_mask + ): logger.warning(f"Skipping {image_info.latents_npz} because it already exists.") continue image_infos.append(image_info) if len(image_infos) > 0: - train_util.cache_batch_latents(vae, True, image_infos, flip_aug, random_crop) + train_util.cache_batch_latents(vae, True, image_infos, flip_aug, alpha_mask, random_crop) accelerator.wait_for_everyone() accelerator.print(f"Finished caching latents for {len(train_dataset_group)} batches.") diff --git a/train_db.py b/train_db.py index c46900006..39d8ea6ed 100644 --- a/train_db.py +++ b/train_db.py @@ -359,10 +359,8 @@ def train(args): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_network.py b/train_network.py index cd1677ad2..b272a6e1a 100644 --- a/train_network.py +++ b/train_network.py @@ -774,7 +774,9 @@ def load_model_hook(models, input_dir): if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) accelerator.init_trackers( - "network_train" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs + "network_train" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, ) loss_recorder = train_util.LossRecorder() @@ -902,10 +904,8 @@ def remove_model(old_ckpt_name): loss = train_util.conditional_loss( noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c ) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion.py b/train_textual_inversion.py index a9c2a1094..ade077c36 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -589,10 +589,8 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 959839cbb..efb59137b 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -474,10 +474,8 @@ def remove_model(old_ckpt_name): target = noise loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c) - if args.masked_loss: - loss = apply_masked_loss(loss, batch["conditioning_images"][:, 0].unsqueeze(1)) - if "alpha_mask" in batch and batch["alpha_mask"] is not None: - loss = apply_masked_loss(loss, batch["alpha_mask"]) + if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None): + loss = apply_masked_loss(loss, batch) loss = loss.mean([1, 2, 3]) loss_weights = batch["loss_weights"] # 各sampleごとのweight From 00513b9b7066fc1307fbe26ad13ed39f3bceceb0 Mon Sep 17 00:00:00 2001 From: rockerBOO Date: Thu, 23 May 2024 22:27:12 -0400 Subject: [PATCH 51/97] Add LoRA+ LR Ratio info message to logger --- networks/dylora.py | 3 +++ networks/lora.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/networks/dylora.py b/networks/dylora.py index d57e3d580..b0925453c 100644 --- a/networks/dylora.py +++ b/networks/dylora.py @@ -368,6 +368,9 @@ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, lorap self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + def set_multiplier(self, multiplier): self.multiplier = multiplier for lora in self.text_encoder_loras + self.unet_loras: diff --git a/networks/lora.py b/networks/lora.py index 9f159f5db..82b8b5b47 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -1134,6 +1134,9 @@ def set_loraplus_lr_ratio(self, loraplus_lr_ratio, loraplus_unet_lr_ratio, lorap self.loraplus_unet_lr_ratio = loraplus_unet_lr_ratio self.loraplus_text_encoder_lr_ratio = loraplus_text_encoder_lr_ratio + logger.info(f"LoRA+ UNet LR Ratio: {self.loraplus_unet_lr_ratio or self.loraplus_lr_ratio}") + logger.info(f"LoRA+ Text Encoder LR Ratio: {self.loraplus_text_encoder_lr_ratio or self.loraplus_lr_ratio}") + # 二つのText Encoderに別々の学習率を設定できるようにするといいかも def prepare_optimizer_params(self, text_encoder_lr, unet_lr, default_lr): # TODO warn if optimizer is not compatible with LoRA+ (but it will cause error so we don't need to check it here?) From e8cfd4ba1d4734c4dd37c9b5fdc0633378879d9b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 26 May 2024 22:01:37 +0900 Subject: [PATCH 52/97] fix to work cond mask and alpha mask --- library/config_util.py | 3 ++- library/custom_train_functions.py | 4 +++- library/train_util.py | 12 ++++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/library/config_util.py b/library/config_util.py index 964270dbb..10b2457f3 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -78,7 +78,6 @@ class BaseSubsetParams: caption_tag_dropout_rate: float = 0.0 token_warmup_min: int = 1 token_warmup_step: float = 0 - alpha_mask: bool = False @dataclass @@ -87,11 +86,13 @@ class DreamBoothSubsetParams(BaseSubsetParams): class_tokens: Optional[str] = None caption_extension: str = ".caption" cache_info: bool = False + alpha_mask: bool = False @dataclass class FineTuningSubsetParams(BaseSubsetParams): metadata_file: Optional[str] = None + alpha_mask: bool = False @dataclass diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index af5813a1d..2a513dc5b 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -484,9 +484,11 @@ def apply_masked_loss(loss, batch): # conditioning image is -1 to 1. we need to convert it to 0 to 1 mask_image = batch["conditioning_images"].to(dtype=loss.dtype)[:, 0].unsqueeze(1) # use R channel mask_image = mask_image / 2 + 0.5 + # print(f"conditioning_image: {mask_image.shape}") elif "alpha_masks" in batch and batch["alpha_masks"] is not None: # alpha mask is 0 to 1 - mask_image = batch["alpha_masks"].to(dtype=loss.dtype) + mask_image = batch["alpha_masks"].to(dtype=loss.dtype).unsqueeze(1) # add channel dimension + # print(f"mask_image: {mask_image.shape}, {mask_image.mean()}") else: return loss diff --git a/library/train_util.py b/library/train_util.py index e7a50f04d..1f9f3c5df 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -561,6 +561,7 @@ def __init__( super().__init__( image_dir, + False, # alpha_mask num_repeats, shuffle_caption, caption_separator, @@ -1947,6 +1948,7 @@ def __init__( None, subset.caption_extension, subset.cache_info, + False, subset.num_repeats, subset.shuffle_caption, subset.caption_separator, @@ -2196,6 +2198,9 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph return False if npz["alpha_mask"].shape[0:2] != reso: # HxW return False + else: + if "alpha_mask" in npz: + return False except Exception as e: logger.error(f"Error loading file: {npz_path}") raise e @@ -2296,6 +2301,13 @@ def debug_dataset(train_dataset, show_input_ids=False): if os.name == "nt": cv2.imshow("cond_img", cond_img) + if "alpha_masks" in example and example["alpha_masks"] is not None: + alpha_mask = example["alpha_masks"][j] + logger.info(f"alpha mask size: {alpha_mask.size()}") + alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8) + if os.name == "nt": + cv2.imshow("alpha_mask", alpha_mask) + if os.name == "nt": # only windows cv2.imshow("img", im) k = cv2.waitKey() From d50c1b3c5cfd590e43e832272a77bf8c84d371dd Mon Sep 17 00:00:00 2001 From: Dave Lage Date: Mon, 27 May 2024 01:11:01 -0400 Subject: [PATCH 53/97] Update issue link --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 838e4022c..23e049354 100644 --- a/README.md +++ b/README.md @@ -237,7 +237,7 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! https://github.com/kohya-ss/sd-scripts/pull/1291) issue [#1290]( https://github.com/kohya-ss/sd-scripts/pull/1290) frodo821 氏に感謝します。 -- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1312) rockerBOO 氏に感謝します。 +- データセット設定の .toml ファイルで、`caption_separator` が subset に指定できない不具合が修正されました。 PR [#1312](https://github.com/kohya-ss/sd-scripts/pull/1312) および [#1313](https://github.com/kohya-ss/sd-scripts/pull/1313) rockerBOO 氏に感謝します。 - ControlNet-LLLite 学習時の潜在バグが修正されました。 PR [#1322](https://github.com/kohya-ss/sd-scripts/pull/1322) aria1th 氏に感謝します。 From a4c3155148e667f5235c2e3df52bad7fd8f95dc4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 May 2024 20:59:40 +0900 Subject: [PATCH 54/97] add doc for mask loss --- docs/masked_loss_README-ja.md | 40 +++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 docs/masked_loss_README-ja.md diff --git a/docs/masked_loss_README-ja.md b/docs/masked_loss_README-ja.md new file mode 100644 index 000000000..860532247 --- /dev/null +++ b/docs/masked_loss_README-ja.md @@ -0,0 +1,40 @@ +## マスクロスについて + +マスクロスは、入力画像のマスクで指定された部分だけ損失計算することで、画像の一部分だけを学習することができる機能です。 +たとえばキャラクタを学習したい場合、キャラクタ部分だけをマスクして学習することで、背景を無視して学習することができます。 + +マスクロスのマスクには、二種類の指定方法があります。 + +- マスク画像を用いる方法 +- 透明度(アルファチャネル)を使用する方法 + +なお、サンプルは [ずんずんPJイラスト/3Dデータ](https://zunko.jp/con_illust.html) の「AI画像モデル用学習データ」を使用しています。 + +### マスク画像を用いる方法 + +学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。 + +マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。 + +DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存するしてください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 + +### 透明度(アルファチャネル)を使用する方法 + +学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します(127 ならおおむね 0.5)。 + +学習時のスクリプトのオプション `--alpha_mask`、または dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。 + +```toml +[[datasets.subsets]] +image_dir = "/path/to/image/dir" +caption_extension = ".txt" +num_repeats = 8 +alpha_mask = true +``` + +## 学習時の注意事項 + +- 現時点では DreamBooth 方式の dataset のみ対応しています。 +- マスクは latents のサイズ、つまり 1/8 に縮小されてから適用されます。そのため、細かい部分(たとえばアホ毛やイヤリングなど)はうまく学習できない可能性があります。マスクをわずかに拡張するなどの工夫が必要かもしれません。 +- マスクロスを用いる場合、学習対象外の部分をキャプションに含める必要はないかもしれません。(要検証) +- `alpha_mask` の場合、マスクの有無を切り替えると latents キャッシュが自動的に再生成されます。 From 71ad3c0f45ba64bd5dc069addc8ef0fa94bf4e19 Mon Sep 17 00:00:00 2001 From: Kohya S <52813779+kohya-ss@users.noreply.github.com> Date: Mon, 27 May 2024 21:07:57 +0900 Subject: [PATCH 55/97] Update masked_loss_README-ja.md add sample images --- docs/masked_loss_README-ja.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/masked_loss_README-ja.md b/docs/masked_loss_README-ja.md index 860532247..5377a5aff 100644 --- a/docs/masked_loss_README-ja.md +++ b/docs/masked_loss_README-ja.md @@ -14,6 +14,11 @@ 学習画像それぞれに対応するマスク画像を用意する方法です。学習画像と同じファイル名のマスク画像を用意し、それを学習画像と別のディレクトリに保存します。 +- 学習画像 + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441) +- マスク画像 + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) + マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。 DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存するしてください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 @@ -22,7 +27,11 @@ DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディ 学習画像の透明度(アルファチャネル)がマスクとして使用されます。透明度が 0 の部分は無視され、255 の部分は学習されます。半透明の場合は、その透明度に応じてロス重みが変化します(127 ならおおむね 0.5)。 -学習時のスクリプトのオプション `--alpha_mask`、または dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。 +![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e) + +※それぞれの画像は透過PNG + +学習時のスクリプトのオプションに `--alpha_mask` を指定するか、dataset の設定ファイルの subset で、`alpha_mask` を指定してください。たとえば、以下のようになります。 ```toml [[datasets.subsets]] From fc85496f7e99b2bbbbd0246e0b0521780c55d859 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 27 May 2024 21:25:06 +0900 Subject: [PATCH 56/97] update docs for masked loss --- README.md | 8 +++++ docs/masked_loss_README-ja.md | 10 ++++++- docs/masked_loss_README.md | 56 +++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 docs/masked_loss_README.md diff --git a/README.md b/README.md index 23e049354..52c963392 100644 --- a/README.md +++ b/README.md @@ -161,6 +161,10 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - Example: `--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` or `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` etc. - `network_module` `networks.lora` and `networks.dylora` are available. +- The feature to use the transparency (alpha channel) of the image as a mask in the loss calculation has been added. PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) Thanks to u-haru! + - The transparent part is ignored during training. Specify the `--alpha_mask` option in the training script or specify `alpha_mask = true` in the dataset configuration file. + - See [About masked loss](./docs/masked_loss_README.md) for details. + - LoRA training in SDXL now supports block-wise learning rates and block-wise dim (rank). PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) - Specify the learning rate and dim (rank) for each block. - See [Block-wise learning rates in LoRA](./docs/train_network_README-ja.md#階層別学習率) for details (Japanese only). @@ -214,6 +218,10 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - 例:`--network_args "loraplus_unet_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` または `--network_args "loraplus_lr_ratio=16" "loraplus_text_encoder_lr_ratio=4"` など - `network_module` の `networks.lora` および `networks.dylora` で使用可能です。 +- 画像の透明度(アルファチャネル)をロス計算時のマスクとして使用する機能が追加されました。PR [#1223](https://github.com/kohya-ss/sd-scripts/pull/1223) u-haru 氏に感謝します。 + - 透明部分が学習時に無視されるようになります。学習スクリプトに `--alpha_mask` オプションを指定するか、データセット設定ファイルに `alpha_mask = true` を指定してください。 + - 詳細は [マスクロスについて](./docs/masked_loss_README-ja.md) をご覧ください。 + - SDXL の LoRA で階層別学習率、階層別 dim (rank) をサポートしました。PR [#1331](https://github.com/kohya-ss/sd-scripts/pull/1331) - ブロックごとに学習率および dim (rank) を指定することができます。 - 詳細は [LoRA の階層別学習率](./docs/train_network_README-ja.md#階層別学習率) をご覧ください。 diff --git a/docs/masked_loss_README-ja.md b/docs/masked_loss_README-ja.md index 5377a5aff..58f042c3b 100644 --- a/docs/masked_loss_README-ja.md +++ b/docs/masked_loss_README-ja.md @@ -19,9 +19,17 @@ - マスク画像 ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) +```.toml +[[datasets.subsets]] +image_dir = "/path/to/a_zundamon" +caption_extension = ".txt" +conditioning_data_dir = "/path/to/a_zundamon_mask" +num_repeats = 8 +``` + マスク画像は、学習画像と同じサイズで、学習する部分を白、無視する部分を黒で描画します。グレースケールにも対応しています(127 ならロス重みが 0.5 になります)。なお、正確にはマスク画像の R チャネルが用いられます。 -DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存するしてください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 +DreamBooth 方式の dataset で、`conditioning_data_dir` で指定したディレクトリにマスク画像を保存してください。ControlNet のデータセットと同じですので、詳細は [ControlNet-LLLite](train_lllite_README-ja.md#データセットの準備) を参照してください。 ### 透明度(アルファチャネル)を使用する方法 diff --git a/docs/masked_loss_README.md b/docs/masked_loss_README.md new file mode 100644 index 000000000..3ac5ad211 --- /dev/null +++ b/docs/masked_loss_README.md @@ -0,0 +1,56 @@ +## Masked Loss + +Masked loss is a feature that allows you to train only part of an image by calculating the loss only for the part specified by the mask of the input image. For example, if you want to train a character, you can train only the character part by masking it, ignoring the background. + +There are two ways to specify the mask for masked loss. + +- Using a mask image +- Using transparency (alpha channel) of the image + +The sample uses the "AI image model training data" from [ZunZunPJ Illustration/3D Data](https://zunko.jp/con_illust.html). + +### Using a mask image + +This is a method of preparing a mask image corresponding to each training image. Prepare a mask image with the same file name as the training image and save it in a different directory from the training image. + +- Training image + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/607c5116-5f62-47de-8b66-9c4a597f0441) +- Mask image + ![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/53e9b0f8-a4bf-49ed-882d-4026f84e8450) + +```.toml +[[datasets.subsets]] +image_dir = "/path/to/a_zundamon" +caption_extension = ".txt" +conditioning_data_dir = "/path/to/a_zundamon_mask" +num_repeats = 8 +``` + +The mask image is the same size as the training image, with the part to be trained drawn in white and the part to be ignored in black. It also supports grayscale (127 gives a loss weight of 0.5). The R channel of the mask image is used currently. + +Use the dataset in the DreamBooth method, and save the mask image in the directory specified by `conditioning_data_dir`. It is the same as the ControlNet dataset, so please refer to [ControlNet-LLLite](train_lllite_README.md#Preparing-the-dataset) for details. + +### Using transparency (alpha channel) of the image + +The transparency (alpha channel) of the training image is used as a mask. The part with transparency 0 is ignored, the part with transparency 255 is trained. For semi-transparent parts, the loss weight changes according to the transparency (127 gives a weight of about 0.5). + +![image](https://github.com/kohya-ss/sd-scripts/assets/52813779/0baa129b-446a-4aac-b98c-7208efb0e75e) + +※Each image is a transparent PNG + +Specify `--alpha_mask` in the training script options or specify `alpha_mask` in the subset of the dataset configuration file. For example, it will look like this. + +```toml +[[datasets.subsets]] +image_dir = "/path/to/image/dir" +caption_extension = ".txt" +num_repeats = 8 +alpha_mask = true +``` + +## Notes on training + +- At the moment, only the dataset in the DreamBooth method is supported. +- The mask is applied after the size is reduced to 1/8, which is the size of the latents. Therefore, fine details (such as ahoge or earrings) may not be learned well. Some dilations of the mask may be necessary. +- If using masked loss, it may not be necessary to include parts that are not to be trained in the caption. (To be verified) +- In the case of `alpha_mask`, the latents cache is automatically regenerated when the enable/disable state of the mask is switched. From b2363f1021955c049c98e65676efca130690c40f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 31 May 2024 12:20:20 +0800 Subject: [PATCH 57/97] Final implementation --- library/train_util.py | 11 ++++- train_network.py | 104 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 106 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1f9f3c5df..beb33bf82 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -657,8 +657,15 @@ def set_caching_mode(self, mode): def set_current_epoch(self, epoch): if not self.current_epoch == epoch: # epochが切り替わったらバケツをシャッフルする - self.shuffle_buckets() - self.current_epoch = epoch + if epoch > self.current_epoch: + logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + num_epochs = epoch - self.current_epoch + for _ in range(num_epochs): + self.current_epoch += 1 + self.shuffle_buckets() + else: + logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) + self.current_epoch = epoch def set_current_step(self, step): self.current_step = step diff --git a/train_network.py b/train_network.py index b272a6e1a..76e6cd8a1 100644 --- a/train_network.py +++ b/train_network.py @@ -493,17 +493,24 @@ def train(self, args): # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 - if accelerator.is_main_process or args.deepspeed: + if accelerator.is_main_process: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - if len(weights) > i: - weights.pop(i) + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") + # save current ecpoch and step + train_state_file = os.path.join(output_dir, "train_state.json") + # +1 is needed because the state is saved before current_step is set from global_step + logger.info(f"save train state to {train_state_file} at epoch {current_epoch.value} step {current_step.value+1}") + with open(train_state_file, "w", encoding="utf-8") as f: + json.dump({"current_epoch": current_epoch.value, "current_step": current_step.value + 1}, f) + + steps_from_state = None + def load_model_hook(models, input_dir): # remove models except network remove_indices = [] @@ -514,6 +521,15 @@ def load_model_hook(models, input_dir): models.pop(i) # print(f"load model hook: {len(models)} models will be loaded") + # load current epoch and step to + nonlocal steps_from_state + train_state_file = os.path.join(input_dir, "train_state.json") + if os.path.exists(train_state_file): + with open(train_state_file, "r", encoding="utf-8") as f: + data = json.load(f) + steps_from_state = data["current_step"] + logger.info(f"load train state from {train_state_file}: {data}") + accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -757,7 +773,53 @@ def load_model_hook(models, input_dir): if key in metadata: minimum_metadata[key] = metadata[key] - progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps") + # calculate steps to skip when resuming or starting from a specific step + initial_step = 0 + if args.initial_epoch is not None or args.initial_step is not None: + # if initial_epoch or initial_step is specified, steps_from_state is ignored even when resuming + if steps_from_state is not None: + logger.warning( + "steps from the state is ignored because initial_step is specified / initial_stepが指定されているため、stateからのステップ数は無視されます" + ) + if args.initial_step is not None: + initial_step = args.initial_step + else: + # num steps per epoch is calculated by num_processes and gradient_accumulation_steps + initial_step = (args.initial_epoch - 1) * math.ceil( + len(train_dataloader) / accelerator.num_processes / args.gradient_accumulation_steps + ) + else: + # if initial_epoch and initial_step are not specified, steps_from_state is used when resuming + if steps_from_state is not None: + initial_step = steps_from_state + steps_from_state = None + + if initial_step > 0: + assert ( + args.max_train_steps > initial_step + ), f"max_train_steps should be greater than initial step / max_train_stepsは初期ステップより大きい必要があります: {args.max_train_steps} vs {initial_step}" + + progress_bar = tqdm( + range(args.max_train_steps - initial_step), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps" + ) + + epoch_to_start = 0 + if initial_step > 0: + if args.skip_until_initial_step: + # if skip_until_initial_step is specified, load data and discard it to ensure the same data is used + if not args.resume: + logger.info( + f"initial_step is specified but not resuming. lr scheduler will be started from the beginning / initial_stepが指定されていますがresumeしていないため、lr schedulerは最初から始まります" + ) + logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") + initial_step *= args.gradient_accumulation_steps + else: + # if not, only epoch no is skipped for informative purpose + epoch_to_start = initial_step // math.ceil( + len(train_dataloader) / args.gradient_accumulation_steps + ) + initial_step = 0 # do not skip + global_step = 0 noise_scheduler = DDPMScheduler( @@ -816,7 +878,11 @@ def remove_model(old_ckpt_name): self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for epoch in range(num_train_epochs): + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + + for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") current_epoch.value = epoch + 1 @@ -824,7 +890,12 @@ def remove_model(old_ckpt_name): accelerator.unwrap_model(network).on_epoch_start(text_encoder, unet) - for step, batch in enumerate(train_dataloader): + skipped_dataloader = None + if initial_step > 0: + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1) + initial_step = 1 + + for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) @@ -1126,6 +1197,25 @@ def setup_parser() -> argparse.ArgumentParser: action="store_true", help="do not use fp16/bf16 VAE in mixed precision (use float VAE) / mixed precisionでも fp16/bf16 VAEを使わずfloat VAEを使う", ) + parser.add_argument( + "--skip_until_initial_step", + action="store_true", + help="skip training until initial_step is reached / initial_stepに到達するまで学習をスキップする", + ) + parser.add_argument( + "--initial_epoch", + type=int, + default=None, + help="initial epoch number, 1 means first epoch (same as not specifying). NOTE: initial_epoch/step doesn't affect to lr scheduler. Which means lr scheduler will start from 0 without `--resume`." + + " / 初期エポック数、1で最初のエポック(未指定時と同じ)。注意:initial_epoch/stepはlr schedulerに影響しないため、`--resume`しない場合はlr schedulerは0から始まる", + ) + parser.add_argument( + "--initial_step", + type=int, + default=None, + help="initial step number including all epochs, 0 means first step (same as not specifying). overwrites initial_epoch." + + " / 初期ステップ数、全エポックを含むステップ数、0で最初のステップ(未指定時と同じ)。initial_epochを上書きする", + ) # parser.add_argument("--loraplus_lr_ratio", default=None, type=float, help="LoRA+ learning rate ratio") # parser.add_argument("--loraplus_unet_lr_ratio", default=None, type=float, help="LoRA+ UNet learning rate ratio") # parser.add_argument("--loraplus_text_encoder_lr_ratio", default=None, type=float, help="LoRA+ text encoder learning rate ratio") From 3eb27ced52e8bf522c7e490c3dacba1f8597f5b1 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 31 May 2024 12:24:15 +0800 Subject: [PATCH 58/97] Skip the final 1 step --- train_network.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train_network.py b/train_network.py index 76e6cd8a1..d1f02d530 100644 --- a/train_network.py +++ b/train_network.py @@ -897,6 +897,10 @@ def remove_model(old_ckpt_name): for step, batch in enumerate(skipped_dataloader or train_dataloader): current_step.value = global_step + if initial_step > 0: + initial_step -= 1 + continue + with accelerator.accumulate(training_model): on_step_start(text_encoder, unet) From e5bab69e3a8f3dc4afb1badba65b6c50ca2f36d8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 2 Jun 2024 21:11:40 +0900 Subject: [PATCH 59/97] fix alpha mask without disk cache closes #1351, ref #1339 --- library/train_util.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 1f9f3c5df..566f59279 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1265,7 +1265,8 @@ def __getitem__(self, index): if subset.alpha_mask: if img.shape[2] == 4: alpha_mask = img[:, :, 3] # [H,W] - alpha_mask = transforms.ToTensor()(alpha_mask) # 0-255 -> 0-1 + alpha_mask = alpha_mask.astype(np.float32) / 255.0 # 0.0~1.0 + alpha_mask = torch.FloatTensor(alpha_mask) else: alpha_mask = torch.ones((img.shape[0], img.shape[1]), dtype=torch.float32) else: @@ -2211,7 +2212,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: +) -> Tuple[Optional[np.ndarray], Optional[List[int]], Optional[List[int]], Optional[np.ndarray], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2229,7 +2230,7 @@ def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, fli if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() if alpha_mask is not None: - kwargs["alpha_mask"] = alpha_mask # ndarray + kwargs["alpha_mask"] = alpha_mask.float().cpu().numpy() np.savez( npz_path, latents=latents_tensor.float().cpu().numpy(), @@ -2496,8 +2497,9 @@ def cache_batch_latents( if image.shape[2] == 4: alpha_mask = image[:, :, 3] # [H,W] alpha_mask = alpha_mask.astype(np.float32) / 255.0 + alpha_mask = torch.FloatTensor(alpha_mask) # [H,W] else: - alpha_mask = np.ones_like(image[:, :, 0], dtype=np.float32) + alpha_mask = torch.ones_like(image[:, :, 0], dtype=torch.float32) # [H,W] else: alpha_mask = None alpha_masks.append(alpha_mask) From 4dbcef429b744d0cc101494802448b8c15f4f674 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 4 Jun 2024 21:26:55 +0900 Subject: [PATCH 60/97] update for corner cases --- library/train_util.py | 3 +++ train_network.py | 23 ++++++++++++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 102f9f03b..4736ff4ff 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -663,6 +663,7 @@ def set_current_epoch(self, epoch): for _ in range(num_epochs): self.current_epoch += 1 self.shuffle_buckets() + # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? else: logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) self.current_epoch = epoch @@ -5560,6 +5561,8 @@ def add(self, *, epoch: int, step: int, loss: float) -> None: if epoch == 0: self.loss_list.append(loss) else: + while len(self.loss_list) <= step: + self.loss_list.append(0.0) self.loss_total -= self.loss_list[step] self.loss_list[step] = loss self.loss_total += loss diff --git a/train_network.py b/train_network.py index d1f02d530..7ba073855 100644 --- a/train_network.py +++ b/train_network.py @@ -493,13 +493,15 @@ def train(self, args): # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))): remove_indices.append(i) for i in reversed(remove_indices): - weights.pop(i) + if len(weights) > i: + weights.pop(i) # print(f"save model hook: {len(weights)} weights will be saved") # save current ecpoch and step @@ -813,11 +815,12 @@ def load_model_hook(models, input_dir): ) logger.info(f"skipping {initial_step} steps / {initial_step}ステップをスキップします") initial_step *= args.gradient_accumulation_steps + + # set epoch to start to make initial_step less than len(train_dataloader) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) else: # if not, only epoch no is skipped for informative purpose - epoch_to_start = initial_step // math.ceil( - len(train_dataloader) / args.gradient_accumulation_steps - ) + epoch_to_start = initial_step // math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) initial_step = 0 # do not skip global_step = 0 @@ -878,9 +881,11 @@ def remove_model(old_ckpt_name): self.sample_images(accelerator, args, 0, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) # training loop - for skip_epoch in range(epoch_to_start): # skip epochs - logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") - initial_step -= len(train_dataloader) + if initial_step > 0: # only if skip_until_initial_step is specified + for skip_epoch in range(epoch_to_start): # skip epochs + logger.info(f"skipping epoch {skip_epoch+1} because initial_step (multiplied) is {initial_step}") + initial_step -= len(train_dataloader) + global_step = initial_step for epoch in range(epoch_to_start, num_train_epochs): accelerator.print(f"\nepoch {epoch+1}/{num_train_epochs}") @@ -892,7 +897,7 @@ def remove_model(old_ckpt_name): skipped_dataloader = None if initial_step > 0: - skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step-1) + skipped_dataloader = accelerator.skip_first_batches(train_dataloader, initial_step - 1) initial_step = 1 for step, batch in enumerate(skipped_dataloader or train_dataloader): From 4ecbac131aba3d121f9708b3ac2a1f4726b17dc0 Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Wed, 5 Jun 2024 16:31:44 +0900 Subject: [PATCH 61/97] Bump crate-ci/typos from 1.19.0 to 1.21.0, fix typos, and updated _typos.toml (Close #1307) --- .github/workflows/typos.yml | 2 +- _typos.toml | 2 ++ library/ipex/attention.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/typos.yml b/.github/workflows/typos.yml index e8b06483f..c81ff3210 100644 --- a/.github/workflows/typos.yml +++ b/.github/workflows/typos.yml @@ -18,4 +18,4 @@ jobs: - uses: actions/checkout@v4 - name: typos-action - uses: crate-ci/typos@v1.19.0 + uses: crate-ci/typos@v1.21.0 diff --git a/_typos.toml b/_typos.toml index ae9e06b18..bbf7728f4 100644 --- a/_typos.toml +++ b/_typos.toml @@ -2,6 +2,7 @@ # Instruction: https://github.com/marketplace/actions/typos-action#getting-started [default.extend-identifiers] +ddPn08="ddPn08" [default.extend-words] NIN="NIN" @@ -27,6 +28,7 @@ rik="rik" koo="koo" yos="yos" wn="wn" +hime="hime" [files] diff --git a/library/ipex/attention.py b/library/ipex/attention.py index d989ad53d..2bc62f65c 100644 --- a/library/ipex/attention.py +++ b/library/ipex/attention.py @@ -5,7 +5,7 @@ # pylint: disable=protected-access, missing-function-docstring, line-too-long -# ARC GPUs can't allocate more than 4GB to a single block so we slice the attetion layers +# ARC GPUs can't allocate more than 4GB to a single block so we slice the attention layers sdpa_slice_trigger_rate = float(os.environ.get('IPEX_SDPA_SLICE_TRIGGER_RATE', 4)) attention_slice_rate = float(os.environ.get('IPEX_ATTENTION_SLICE_RATE', 4)) From 58fb64819ab117e2b7bca6e87bae28901b616860 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jun 2024 19:26:09 +0900 Subject: [PATCH 62/97] set static graph flag when DDP ref #1363 --- sdxl_train_control_net_lllite.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 301310901..5ff060a9f 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -289,6 +289,9 @@ def train(args): # acceleratorがなんかよろしくやってくれるらしい unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler) + if isinstance(unet, DDP): + unet._set_static_graph() # avoid error for multiple use of the parameter + if args.gradient_checkpointing: unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる else: From 1a104dc75ee5733af8ba17cc9778b39e26673734 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 9 Jun 2024 19:26:36 +0900 Subject: [PATCH 63/97] make forward/backward pathes same ref #1363 --- networks/control_net_lllite_for_train.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/networks/control_net_lllite_for_train.py b/networks/control_net_lllite_for_train.py index 65b3520cf..366451b7f 100644 --- a/networks/control_net_lllite_for_train.py +++ b/networks/control_net_lllite_for_train.py @@ -7,8 +7,10 @@ import torch from library import sdxl_original_unet from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) # input_blocksに適用するかどうか / if True, input_blocks are not applied @@ -103,19 +105,15 @@ def set_lllite(self, depth, cond_emb_dim, name, mlp_dim, dropout=None, multiplie add_lllite_modules(self, in_dim, depth, cond_emb_dim, mlp_dim) self.cond_image = None - self.cond_emb = None def set_cond_image(self, cond_image): self.cond_image = cond_image - self.cond_emb = None def forward(self, x): if not self.enabled: return super().forward(x) - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb + cx = self.lllite_conditioning1(self.cond_image) # make forward and backward compatible # reshape / b,c,h,w -> b,h*w,c n, c, h, w = cx.shape @@ -159,9 +157,7 @@ def forward(self, x): # , cond_image=None): if not self.enabled: return super().forward(x) - if self.cond_emb is None: - self.cond_emb = self.lllite_conditioning1(self.cond_image) - cx = self.cond_emb + cx = self.lllite_conditioning1(self.cond_image) cx = torch.cat([cx, self.down(x)], dim=1) cx = self.mid(cx) From 18d7597b0b39cc2204dfbdfdcbf0fead97414be1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 11 Jun 2024 19:51:30 +0900 Subject: [PATCH 64/97] update README --- README.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/README.md b/README.md index 52c963392..25aba6397 100644 --- a/README.md +++ b/README.md @@ -178,6 +178,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - The ControlNet training script `train_controlnet.py` for SD1.5/2.x was not working, but it has been fixed. PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) Thanks to sdbds! +- `train_network.py` and `sdxl_train_network.py` now restore the order/position of data loading from DataSet when resuming training. PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) Thanks to KohakuBlueleaf! + - This resolves the issue where the order of data loading from DataSet changes when resuming training. + - Specify the `--skip_until_initial_step` option to skip data loading until the specified step. If not specified, data loading starts from the beginning of the DataSet (same as before). + - If `--resume` is specified, the step saved in the state is used. + - Specify the `--initial_step` or `--initial_epoch` option to skip data loading until the specified step or epoch. Use these options in conjunction with `--skip_until_initial_step`. These options can be used without `--resume` (use them when resuming training with `--network_weights`). + - An option `--disable_mmap_load_safetensors` is added to disable memory mapping when loading the model's .safetensors in SDXL. PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Thanks to Zovjsra! - It seems that the model file loading is faster in the WSL environment etc. - Available in `sdxl_train.py`, `sdxl_train_network.py`, `sdxl_train_textual_inversion.py`, and `sdxl_train_control_net_lllite.py`. @@ -235,6 +241,12 @@ https://github.com/kohya-ss/sd-scripts/pull/1290) Thanks to frodo821! - SD1.5/2.x 用の ControlNet 学習スクリプト `train_controlnet.py` が動作しなくなっていたのが修正されました。PR [#1284](https://github.com/kohya-ss/sd-scripts/pull/1284) sdbds 氏に感謝します。 +- `train_network.py` および `sdxl_train_network.py` で、学習再開時に DataSet の読み込み順についても復元できるようになりました。PR [#1353](https://github.com/kohya-ss/sd-scripts/pull/1353) [#1359](https://github.com/kohya-ss/sd-scripts/pull/1359) KohakuBlueleaf 氏に感謝します。 + - これにより、学習再開時に DataSet の読み込み順が変わってしまう問題が解消されます。 + - `--skip_until_initial_step` オプションを指定すると、指定したステップまで DataSet 読み込みをスキップします。指定しない場合の動作は変わりません(DataSet の最初から読み込みます) + - `--resume` オプションを指定すると、state に保存されたステップ数が使用されます。 + - `--initial_step` または `--initial_epoch` オプションを指定すると、指定したステップまたはエポックまで DataSet 読み込みをスキップします。これらのオプションは `--skip_until_initial_step` と併用してください。またこれらのオプションは `--resume` と併用しなくても使えます(`--network_weights` を用いた学習再開時などにお使いください )。 + - SDXL でモデルの .safetensors を読み込む際にメモリマッピングを無効化するオプション `--disable_mmap_load_safetensors` が追加されました。PR [#1266](https://github.com/kohya-ss/sd-scripts/pull/1266) Zovjsra 氏に感謝します。 - WSL 環境等でモデルファイルの読み込みが高速化されるようです。 - `sdxl_train.py`、`sdxl_train_network.py`、`sdxl_train_textual_inversion.py`、`sdxl_train_control_net_lllite.py` で使用可能です。 From 56bb81c9e6483b8b4d5b83639548855b8359f4b4 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 12 Jun 2024 21:39:35 +0900 Subject: [PATCH 65/97] add grad_hook after restore state closes #1344 --- sdxl_train.py | 46 +++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 9e20c60ca..ae92d6a3d 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -481,6 +481,26 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): text_encoder2 = accelerator.prepare(text_encoder2) optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler) + # TextEncoderの出力をキャッシュするときにはCPUへ移動する + if args.cache_text_encoder_outputs: + # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 + text_encoder1.to("cpu", dtype=torch.float32) + text_encoder2.to("cpu", dtype=torch.float32) + clean_memory_on_device(accelerator.device) + else: + # make sure Text Encoders are on GPU + text_encoder1.to(accelerator.device) + text_encoder2.to(accelerator.device) + + # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする + if args.full_fp16: + # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. + # -> But we think it's ok to patch accelerator even if deepspeed is enabled. + train_util.patch_accelerator_for_fp16_training(accelerator) + + # resumeする + train_util.resume_from_local_or_hf_if_specified(accelerator, args) + if args.fused_backward_pass: # use fused optimizer for backward pass: other optimizers will be supported in the future import library.adafactor_fused @@ -532,26 +552,6 @@ def optimizer_hook(parameter: torch.Tensor): parameter_optimizer_map[parameter] = opt_idx num_parameters_per_group[opt_idx] += 1 - # TextEncoderの出力をキャッシュするときにはCPUへ移動する - if args.cache_text_encoder_outputs: - # move Text Encoders for sampling images. Text Encoder doesn't work on CPU with fp16 - text_encoder1.to("cpu", dtype=torch.float32) - text_encoder2.to("cpu", dtype=torch.float32) - clean_memory_on_device(accelerator.device) - else: - # make sure Text Encoders are on GPU - text_encoder1.to(accelerator.device) - text_encoder2.to(accelerator.device) - - # 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする - if args.full_fp16: - # During deepseed training, accelerate not handles fp16/bf16|mixed precision directly via scaler. Let deepspeed engine do. - # -> But we think it's ok to patch accelerator even if deepspeed is enabled. - train_util.patch_accelerator_for_fp16_training(accelerator) - - # resumeする - train_util.resume_from_local_or_hf_if_specified(accelerator, args) - # epoch数を計算する num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) @@ -589,7 +589,11 @@ def optimizer_hook(parameter: torch.Tensor): init_kwargs["wandb"] = {"name": args.wandb_run_name} if args.log_tracker_config is not None: init_kwargs = toml.load(args.log_tracker_config) - accelerator.init_trackers("finetuning" if args.log_tracker_name is None else args.log_tracker_name, config=train_util.get_sanitized_config_or_none(args), init_kwargs=init_kwargs) + accelerator.init_trackers( + "finetuning" if args.log_tracker_name is None else args.log_tracker_name, + config=train_util.get_sanitized_config_or_none(args), + init_kwargs=init_kwargs, + ) # For --sample_at_first sdxl_train_util.sample_images( From 0b3e4f7ab62b7c93e66972b7bd2774b8fe679792 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Tue, 25 Jun 2024 20:03:09 +0900 Subject: [PATCH 66/97] show file name if error in load_image ref #1385 --- library/train_util.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 4736ff4ff..760be33eb 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2434,16 +2434,20 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): - image = Image.open(image_path) - if alpha: - if not image.mode == "RGBA": - image = image.convert("RGBA") - else: - if not image.mode == "RGB": - image = image.convert("RGB") - img = np.array(image, np.uint8) - return img +def load_image(image_path, alpha=False): + try: + with Image.open(image_path) as image: + if alpha: + if not image.mode == "RGBA": + image = image.convert("RGBA") + else: + if not image.mode == "RGB": + image = image.convert("RGB") + img = np.array(image, np.uint8) + return img + except (IOError, OSError) as e: + logger.error(f"Error loading file: {image_path}") + raise e # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) From 87526942a67fd71bb775bc479b0a7449df516dd8 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Fri, 12 Jul 2024 22:56:38 +0800 Subject: [PATCH 67/97] judge image size for using diff interpolation --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..74720fec6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2362,7 +2362,7 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA if image_width > resized_size[0] and image_height > resized_size[1] else cv2.INTER_LANCZOS4) image_height, image_width = image.shape[0:2] From 2e67978ee243a20f169ce76d7644bb1f9dec9bad Mon Sep 17 00:00:00 2001 From: Millie Date: Thu, 18 Jul 2024 11:52:58 -0700 Subject: [PATCH 68/97] Generate sample images without having CUDA (such as on Macs) --- library/train_util.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..9b0397d7d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5229,7 +5229,7 @@ def sample_images_common( clean_memory_on_device(accelerator.device) torch.set_rng_state(rng_state) - if cuda_rng_state is not None: + if torch.cuda.is_available() and cuda_rng_state is not None: torch.cuda.set_rng_state(cuda_rng_state) vae.to(org_vae_device) @@ -5263,11 +5263,13 @@ def sample_image_inference( if seed is not None: torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) else: # True random sample image generation torch.seed() - torch.cuda.seed() + if torch.cuda.is_available(): + torch.cuda.seed() scheduler = get_my_scheduler( sample_sampler=sampler_name, @@ -5302,8 +5304,9 @@ def sample_image_inference( controlnet_image=controlnet_image, ) - with torch.cuda.device(torch.cuda.current_device()): - torch.cuda.empty_cache() + if torch.cuda.is_available(): + with torch.cuda.device(torch.cuda.current_device()): + torch.cuda.empty_cache() image = pipeline.latents_to_image(latents)[0] From 1f16b80e88b1c4f05d49b4fc328d3b9b105ebcbe Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 20 Jul 2024 21:35:24 +0800 Subject: [PATCH 69/97] Revert "judge image size for using diff interpolation" This reverts commit 87526942a67fd71bb775bc479b0a7449df516dd8. --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index 74720fec6..15c23f3cc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2362,7 +2362,7 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA if image_width > resized_size[0] and image_height > resized_size[1] else cv2.INTER_LANCZOS4) + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ image_height, image_width = image.shape[0:2] From 9ca7a5b6cc99e25820a1aa6d02a779004d73bca0 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Sat, 20 Jul 2024 21:59:11 +0800 Subject: [PATCH 70/97] instead cv2 LANCZOS4 resize to pil resize --- finetune/tag_images_by_wd14_tagger.py | 8 +++++--- library/train_util.py | 11 ++++++----- library/utils.py | 14 +++++++++++++- tools/detect_face_rotate.py | 7 +++++-- tools/resize_images_to_resolution.py | 11 +++++++---- 5 files changed, 36 insertions(+), 15 deletions(-) diff --git a/finetune/tag_images_by_wd14_tagger.py b/finetune/tag_images_by_wd14_tagger.py index a327bbd61..6f5bdd36b 100644 --- a/finetune/tag_images_by_wd14_tagger.py +++ b/finetune/tag_images_by_wd14_tagger.py @@ -11,7 +11,7 @@ from tqdm import tqdm import library.train_util as train_util -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -42,8 +42,10 @@ def preprocess_image(image): pad_t = pad_y // 2 image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255) - interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4 - image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp) + if size > IMAGE_SIZE: + image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA) + else: + image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE)) image = image.astype(np.float32) return image diff --git a/library/train_util.py b/library/train_util.py index 15c23f3cc..160e3b44b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -71,7 +71,7 @@ import library.huggingface_util as huggingface_util import library.sai_model_spec as sai_model_spec import library.deepspeed_utils as deepspeed_utils -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging @@ -2028,9 +2028,7 @@ def __getitem__(self, index): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img = cv2.resize( - cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4 - ) + cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2362,7 +2360,10 @@ def trim_and_resize_if_required( if image_width != resized_size[0] or image_height != resized_size[1]: # リサイズする - image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + if image_width > resized_size[0] and image_height > resized_size[1]: + image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ + else: + image = pil_resize(image, resized_size) image_height, image_width = image.shape[0:2] diff --git a/library/utils.py b/library/utils.py index 3037c055d..a219f6cb7 100644 --- a/library/utils.py +++ b/library/utils.py @@ -7,7 +7,9 @@ from diffusers import EulerAncestralDiscreteScheduler import diffusers.schedulers.scheduling_euler_ancestral_discrete from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput - +import cv2 +from PIL import Image +import numpy as np def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -78,7 +80,17 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) +def pil_resize(image, size, interpolation=Image.LANCZOS): + + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + + # use Pillow resize + resized_pil = pil_image.resize(size, interpolation) + + # return cv2 image + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + return resized_cv2 # TODO make inf_utils.py diff --git a/tools/detect_face_rotate.py b/tools/detect_face_rotate.py index bbc643edc..d2a4d9cfb 100644 --- a/tools/detect_face_rotate.py +++ b/tools/detect_face_rotate.py @@ -15,7 +15,7 @@ from anime_face_detector import create_detector from tqdm import tqdm import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -172,7 +172,10 @@ def process(args): if scale != 1.0: w = int(w * scale + .5) h = int(h * scale + .5) - face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4) + if scale < 1.0: + face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA) + else: + face_img = pil_resize(face_img, (w, h)) cx = int(cx * scale + .5) cy = int(cy * scale + .5) fw = int(fw * scale + .5) diff --git a/tools/resize_images_to_resolution.py b/tools/resize_images_to_resolution.py index b8069fc1d..0f9e00b1e 100644 --- a/tools/resize_images_to_resolution.py +++ b/tools/resize_images_to_resolution.py @@ -6,7 +6,7 @@ import math from PIL import Image import numpy as np -from library.utils import setup_logging +from library.utils import setup_logging, pil_resize setup_logging() import logging logger = logging.getLogger(__name__) @@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi # Select interpolation method if interpolation == 'lanczos4': - cv2_interpolation = cv2.INTER_LANCZOS4 + pil_interpolation = Image.LANCZOS elif interpolation == 'cubic': - cv2_interpolation = cv2.INTER_CUBIC + pil_interpolation = Image.BICUBIC else: cv2_interpolation = cv2.INTER_AREA @@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi new_width = int(img.shape[1] * math.sqrt(scale_factor)) # Resize image - img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + if cv2_interpolation: + img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation) + else: + img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation) else: new_height, new_width = img.shape[0:2] From 74f91c2ff71035db105b218128567e6b8fa6c80d Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 11 Aug 2024 21:54:10 +0900 Subject: [PATCH 71/97] correct option name closes #1446 --- docs/train_README-ja.md | 2 +- docs/train_README-zh.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/train_README-ja.md b/docs/train_README-ja.md index d186bf243..cfa5a7d1c 100644 --- a/docs/train_README-ja.md +++ b/docs/train_README-ja.md @@ -648,7 +648,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 詳細については各自お調べください。 - 任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--scheduler_args`でオプション引数を指定してください。 + 任意のスケジューラを使う場合、任意のオプティマイザと同様に、`--lr_scheduler_args`でオプション引数を指定してください。 ### オプティマイザの指定について diff --git a/docs/train_README-zh.md b/docs/train_README-zh.md index 7e00278c5..1bc47e0f5 100644 --- a/docs/train_README-zh.md +++ b/docs/train_README-zh.md @@ -582,7 +582,7 @@ masterpiece, best quality, 1boy, in business suit, standing at street, looking b 有关详细信息,请自行研究。 - 要使用任何调度程序,请像使用任何优化器一样使用“--scheduler_args”指定可选参数。 + 要使用任何调度程序,请像使用任何优化器一样使用“--lr_scheduler_args”指定可选参数。 ### 关于指定优化器 使用 --optimizer_args 选项指定优化器选项参数。可以以key=value的格式指定多个值。此外,您可以指定多个值,以逗号分隔。例如,要指定 AdamW 优化器的参数,``--optimizer_args weight_decay=0.01 betas=.9,.999``。 From afb971f9c36823040eaba3c9e02fdfa0928cd4ee Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 22 Aug 2024 21:33:15 +0900 Subject: [PATCH 72/97] fix SD1.5 LoRA extraction #1490 --- networks/lora.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/networks/lora.py b/networks/lora.py index 82b8b5b47..6f33f1a1e 100644 --- a/networks/lora.py +++ b/networks/lora.py @@ -815,7 +815,8 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh weights_sd = torch.load(file, map_location="cpu") # if keys are Diffusers based, convert to SAI based - convert_diffusers_to_sai_if_needed(weights_sd) + if is_sdxl: + convert_diffusers_to_sai_if_needed(weights_sd) # get dim/alpha mapping modules_dim = {} @@ -840,7 +841,13 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh module_class = LoRAInfModule if for_inference else LoRAModule network = LoRANetwork( - text_encoder, unet, multiplier=multiplier, modules_dim=modules_dim, modules_alpha=modules_alpha, module_class=module_class + text_encoder, + unet, + multiplier=multiplier, + modules_dim=modules_dim, + modules_alpha=modules_alpha, + module_class=module_class, + is_sdxl=is_sdxl, ) # block lr From 1e8108fec9962333e4cf2a8db1dcedf657049900 Mon Sep 17 00:00:00 2001 From: liesen Date: Sat, 24 Aug 2024 01:38:17 +0300 Subject: [PATCH 73/97] Handle args.v_parameterization properly for MinSNR and changed prediction target --- sdxl_train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sdxl_train.py b/sdxl_train.py index 46d7860be..14b259657 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -590,7 +590,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): with accelerator.autocast(): noise_pred = unet(noisy_latents, timesteps, text_embedding, vector_embedding) - target = noise + if args.v_parameterization: + # v-parameterization training + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + target = noise if ( args.min_snr_gamma @@ -606,7 +610,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): loss = loss.mean([1, 2, 3]) if args.min_snr_gamma: - loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma) + loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: From d5c076cf9007f86f6dd1b9ecdfc5531336774b2f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 24 Aug 2024 21:21:39 +0900 Subject: [PATCH 74/97] update readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 946df58f3..81a549378 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened! - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. From 0005867ba509d2e1a5674b267e8286b561c0ed71 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sat, 7 Sep 2024 10:45:18 +0900 Subject: [PATCH 75/97] update README, format code --- README.md | 5 +++++ library/train_util.py | 4 ++-- library/utils.py | 4 +++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 81a549378..16ab80e7a 100644 --- a/README.md +++ b/README.md @@ -139,7 +139,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds! + +- Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v! + - `--v_parameterization` is available in `sdxl_train.py`. The results are unpredictable, so use with caution. PR [#1505](https://github.com/kohya-ss/sd-scripts/pull/1505) Thanks to liesened! + - Fused optimizer is available for SDXL training. PR [#1259](https://github.com/kohya-ss/sd-scripts/pull/1259) Thanks to 2kpr! - The memory usage during training is significantly reduced by integrating the optimizer's backward pass with step. The training results are the same as before, but if you have plenty of memory, the speed will be slower. - Specify the `--fused_backward_pass` option in `sdxl_train.py`. At this time, only AdaFactor is supported. Gradient accumulation is not available. diff --git a/library/train_util.py b/library/train_util.py index 102d39ed7..1441e74f6 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2094,7 +2094,7 @@ def __getitem__(self, index): # ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}" # resize to target if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]: - cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0]))) + cond_img = pil_resize(cond_img, (int(target_size_hw[1]), int(target_size_hw[0]))) if flipped: cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride @@ -2432,7 +2432,7 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: return train_dataset_group -def load_image(image_path, alpha=False): +def load_image(image_path, alpha=False): try: with Image.open(image_path) as image: if alpha: diff --git a/library/utils.py b/library/utils.py index a219f6cb7..5b7e657b2 100644 --- a/library/utils.py +++ b/library/utils.py @@ -11,6 +11,7 @@ from PIL import Image import numpy as np + def fire_in_thread(f, *args, **kwargs): threading.Thread(target=f, args=args, kwargs=kwargs).start() @@ -80,8 +81,8 @@ def setup_logging(args=None, log_level=None, reset=False): logger = logging.getLogger(__name__) logger.info(msg_init) -def pil_resize(image, size, interpolation=Image.LANCZOS): +def pil_resize(image, size, interpolation=Image.LANCZOS): pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # use Pillow resize @@ -92,6 +93,7 @@ def pil_resize(image, size, interpolation=Image.LANCZOS): return resized_cv2 + # TODO make inf_utils.py From fd68703f3795b3e9c75409ac5452807d056b928f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=92=E9=BE=8D=E8=81=96=E8=80=85=40bdsqlsz?= <865105819@qq.com> Date: Wed, 11 Sep 2024 20:25:45 +0800 Subject: [PATCH 76/97] Add New lr scheduler (#1393) * add new lr scheduler * fix bugs and use num_cycles / 2 * Update requirements.txt * add num_cycles for min lr * keep PIECEWISE_CONSTANT * allow use float with warmup or decay ratio. * Update train_util.py --- library/train_util.py | 80 ++++++++++++++++++++++++++++++++++++++----- requirements.txt | 6 ++-- 2 files changed, 75 insertions(+), 11 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c7b73ee37..340f6d640 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -42,7 +42,8 @@ from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers -from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION +from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION +from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, DDPMScheduler, @@ -2972,6 +2973,20 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): + def int_or_float(value): + if value.endswith('%'): + try: + return float(value[:-1]) / 100.0 + except ValueError: + raise argparse.ArgumentTypeError(f"Value '{value}' is not a valid percentage") + try: + float_value = float(value) + if float_value >= 1: + return int(value) + return float(value) + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not an int or float") + parser.add_argument( "--optimizer_type", type=str, @@ -3024,9 +3039,15 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): ) parser.add_argument( "--lr_warmup_steps", - type=int, + type=int_or_float, + default=0, + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + ) + parser.add_argument( + "--lr_decay_steps", + type=int_or_float, default=0, - help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps", ) parser.add_argument( "--lr_scheduler_num_cycles", @@ -3046,6 +3067,18 @@ def add_optimizer_arguments(parser: argparse.ArgumentParser): help="Combines backward pass and optimizer step to reduce VRAM usage. Only available in SDXL" + " / バックワードパスとオプティマイザステップを組み合わせてVRAMの使用量を削減します。SDXLでのみ有効", ) + parser.add_argument( + "--lr_scheduler_timescale", + type=int, + default=None, + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`", + ) + parser.add_argument( + "--lr_scheduler_min_lr_ratio", + type=float, + default=None, + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler", + ) def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: bool): @@ -4293,10 +4326,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): Unified API to get any scheduler from its name. """ name = args.lr_scheduler - num_warmup_steps: Optional[int] = args.lr_warmup_steps num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps + num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power + timescale = args.lr_scheduler_timescale + min_lr_ratio = args.lr_scheduler_min_lr_ratio lr_scheduler_kwargs = {} # get custom lr_scheduler kwargs if args.lr_scheduler_args is not None and len(args.lr_scheduler_args) > 0: @@ -4332,13 +4369,13 @@ def wrap_check_needless_num_warmup_steps(return_vals): # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) - name = SchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] + name = SchedulerType(name) or DiffusersSchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - if name == SchedulerType.PIECEWISE_CONSTANT: + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT: return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs # All other schedulers require `num_warmup_steps` @@ -4348,6 +4385,9 @@ def wrap_check_needless_num_warmup_steps(return_vals): if name == SchedulerType.CONSTANT_WITH_WARMUP: return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, **lr_scheduler_kwargs) + if name == SchedulerType.INVERSE_SQRT: + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, timescale=timescale, **lr_scheduler_kwargs) + # All other schedulers require `num_training_steps` if num_training_steps is None: raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.") @@ -4366,7 +4406,31 @@ def wrap_check_needless_num_warmup_steps(return_vals): optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, power=power, **lr_scheduler_kwargs ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, **lr_scheduler_kwargs) + if name == SchedulerType.COSINE_WITH_MIN_LR: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles / 2, + min_lr_rate=min_lr_ratio, + **lr_scheduler_kwargs, + ) + + # All other schedulers require `num_decay_steps` + if num_decay_steps is None: + raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") + if name == SchedulerType.WARMUP_STABLE_DECAY: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, + min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, + **lr_scheduler_kwargs, + ) + + return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): diff --git a/requirements.txt b/requirements.txt index 977c5cd91..d2a2fbb8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -accelerate==0.25.0 -transformers==4.36.2 +accelerate==0.30.0 +transformers==4.41.2 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.20.1 +huggingface-hub==0.23.3 # for Image utils imagesize==1.4.1 # for BLIP captioning From 6dbfd47a59cdb91be2077e1d0dec0f94698348dd Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 21:44:36 +0900 Subject: [PATCH 77/97] Fix to work PIECEWISE_CONSTANT, update requirement.txt and README #1393 --- README.md | 9 ++++++ library/train_util.py | 66 ++++++++++++++++++++++++++++--------------- requirements.txt | 4 +-- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 16ab80e7a..011141bf1 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,15 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress +- __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. + - transformers, accelerate and huggingface_hub are updated. + - If you encounter any issues, please report them. + +- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! + - See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler. + - `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc. + +https://github.com/kohya-ss/sd-scripts/pull/1393 - When enlarging images in the script (when the size of the training image is small and bucket_no_upscale is not specified), it has been changed to use Pillow's resize and LANCZOS interpolation instead of OpenCV2's resize and Lanczos4 interpolation. The quality of the image enlargement may be slightly improved. PR [#1426](https://github.com/kohya-ss/sd-scripts/pull/1426) Thanks to sdbds! - Sample image generation during training now works on non-CUDA devices. PR [#1433](https://github.com/kohya-ss/sd-scripts/pull/1433) Thanks to millie-v! diff --git a/library/train_util.py b/library/train_util.py index 340f6d640..e65760bae 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -42,7 +42,10 @@ from torchvision import transforms from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection import transformers -from diffusers.optimization import SchedulerType as DiffusersSchedulerType, TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION +from diffusers.optimization import ( + SchedulerType as DiffusersSchedulerType, + TYPE_TO_SCHEDULER_FUNCTION as DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION, +) from transformers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION from diffusers import ( StableDiffusionPipeline, @@ -2974,7 +2977,7 @@ def add_sd_models_arguments(parser: argparse.ArgumentParser): def add_optimizer_arguments(parser: argparse.ArgumentParser): def int_or_float(value): - if value.endswith('%'): + if value.endswith("%"): try: return float(value[:-1]) / 100.0 except ValueError: @@ -3041,13 +3044,15 @@ def int_or_float(value): "--lr_warmup_steps", type=int_or_float, default=0, - help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)", + help="Int number of steps for the warmup in the lr scheduler (default is 0) or float with ratio of train steps" + " / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", ) parser.add_argument( "--lr_decay_steps", type=int_or_float, default=0, - help="Int number of steps for the decay in the lr scheduler (default is 0) or float with ratio of train steps", + help="Int number of steps for the decay in the lr scheduler (default is 0) or float (<1) with ratio of train steps" + " / 学習率のスケジューラを減衰させるステップ数(デフォルト0)、または学習ステップの比率(1未満のfloat値の場合)", ) parser.add_argument( "--lr_scheduler_num_cycles", @@ -3071,13 +3076,16 @@ def int_or_float(value): "--lr_scheduler_timescale", type=int, default=None, - help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`", + help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", + , ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, - help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler", + help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) @@ -4327,8 +4335,12 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int): """ name = args.lr_scheduler num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps - num_warmup_steps: Optional[int] = int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps - num_decay_steps: Optional[int] = int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + num_warmup_steps: Optional[int] = ( + int(args.lr_warmup_steps * num_training_steps) if isinstance(args.lr_warmup_steps, float) else args.lr_warmup_steps + ) + num_decay_steps: Optional[int] = ( + int(args.lr_decay_steps * num_training_steps) if isinstance(args.lr_decay_steps, float) else args.lr_decay_steps + ) num_stable_steps = num_training_steps - num_warmup_steps - num_decay_steps num_cycles = args.lr_scheduler_num_cycles power = args.lr_scheduler_power @@ -4369,15 +4381,17 @@ def wrap_check_needless_num_warmup_steps(return_vals): # logger.info(f"adafactor scheduler init lr {initial_lr}") return wrap_check_needless_num_warmup_steps(transformers.optimization.AdafactorSchedule(optimizer, initial_lr)) - name = SchedulerType(name) or DiffusersSchedulerType(name) - schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] or DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + if name == DiffusersSchedulerType.PIECEWISE_CONSTANT.value: + name = DiffusersSchedulerType(name) + schedule_func = DIFFUSERS_TYPE_TO_SCHEDULER_FUNCTION[name] + return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs + + name = SchedulerType(name) + schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name] if name == SchedulerType.CONSTANT: return wrap_check_needless_num_warmup_steps(schedule_func(optimizer, **lr_scheduler_kwargs)) - if name == DiffusersSchedulerType.PIECEWISE_CONSTANT: - return schedule_func(optimizer, **lr_scheduler_kwargs) # step_rules and last_epoch are given as kwargs - # All other schedulers require `num_warmup_steps` if num_warmup_steps is None: raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.") @@ -4408,11 +4422,11 @@ def wrap_check_needless_num_warmup_steps(return_vals): if name == SchedulerType.COSINE_WITH_MIN_LR: return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_training_steps=num_training_steps, + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, num_cycles=num_cycles / 2, - min_lr_rate=min_lr_ratio, + min_lr_rate=min_lr_ratio, **lr_scheduler_kwargs, ) @@ -4421,16 +4435,22 @@ def wrap_check_needless_num_warmup_steps(return_vals): raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") if name == SchedulerType.WARMUP_STABLE_DECAY: return schedule_func( - optimizer, - num_warmup_steps=num_warmup_steps, - num_stable_steps=num_stable_steps, - num_decay_steps=num_decay_steps, - num_cycles=num_cycles / 2, + optimizer, + num_warmup_steps=num_warmup_steps, + num_stable_steps=num_stable_steps, + num_decay_steps=num_decay_steps, + num_cycles=num_cycles / 2, min_lr_ratio=min_lr_ratio if min_lr_ratio is not None else 0.0, **lr_scheduler_kwargs, ) - return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, num_decay_steps=num_decay_steps, **lr_scheduler_kwargs) + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_decay_steps=num_decay_steps, + **lr_scheduler_kwargs, + ) def prepare_dataset_args(args: argparse.Namespace, support_metadata: bool): diff --git a/requirements.txt b/requirements.txt index d2a2fbb8a..15e6e58f1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ accelerate==0.30.0 -transformers==4.41.2 +transformers==4.44.0 diffusers[torch]==0.25.0 ftfy==6.1.1 # albumentations==1.3.0 @@ -16,7 +16,7 @@ altair==4.2.2 easygui==0.98.3 toml==0.10.2 voluptuous==0.13.1 -huggingface-hub==0.23.3 +huggingface-hub==0.24.5 # for Image utils imagesize==1.4.1 # for BLIP captioning From c7c666b1829a7c1f3435558efa425b08b50fab41 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 11 Sep 2024 22:12:31 +0900 Subject: [PATCH 78/97] fix typo --- library/train_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index e65760bae..a46d94877 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3077,15 +3077,14 @@ def int_or_float(value): type=int, default=None, help="Inverse sqrt timescale for inverse sqrt scheduler,defaults to `num_warmup_steps`" - " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", - , + + " / 逆平方根スケジューラのタイムスケール、デフォルトは`num_warmup_steps`", ) parser.add_argument( "--lr_scheduler_min_lr_ratio", type=float, default=None, help="The minimum learning rate as a ratio of the initial learning rate for cosine with min lr scheduler and warmup decay scheduler" - " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", + + " / 初期学習率の比率としての最小学習率を指定する、cosine with min lr と warmup decay スケジューラ で有効", ) From 93d9fbf60761fc1158e37f45f0d0c142913d70f5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 13 Sep 2024 22:37:11 +0900 Subject: [PATCH 79/97] improve OFT implementation closes #944 --- README.md | 26 ++++++++- gen_img.py | 3 +- networks/check_lora_weights.py | 2 +- networks/oft.py | 96 +++++++++++++++++++++------------- 4 files changed, 89 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 0130ccffc..def528a22 100644 --- a/README.md +++ b/README.md @@ -143,7 +143,31 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. -- en: The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! +- Improvements in OFT (Orthogonal Finetuning) Implementation + 1. Optimization of Calculation Order: + - Changed the calculation order in the forward method from (Wx)R to W(xR). + - This has improved computational efficiency and processing speed. + 2. Correction of Bias Application: + - In the previous implementation, R was incorrectly applied to the bias. + - The new implementation now correctly handles bias by using F.conv2d and F.linear. + 3. Efficiency Enhancement in Matrix Operations: + - Introduced einsum in both the forward and merge_to methods. + - This has optimized matrix operations, resulting in further speed improvements. + 4. Proper Handling of Data Types: + - Improved to use torch.float32 during calculations and convert results back to the original data type. + - This maintains precision while ensuring compatibility with the original model. + 5. Unified Processing for Conv2d and Linear Layers: + - Implemented a consistent method for applying OFT to both layer types. + - These changes have made the OFT implementation more efficient and accurate, potentially leading to improved model performance and training stability. + + - Additional Information + * Recommended α value for OFT constraint: We recommend using α values between 1e-4 and 1e-2. This differs slightly from the original implementation of "(α\*out_dim\*out_dim)". Our implementation uses "(α\*out_dim)", hence we recommend higher values than the 1e-5 suggested in the original implementation. + + * Performance Improvement: Training speed has been improved by approximately 30%. + + * Inference Environment: This implementation is compatible with and operates within Stable Diffusion web UI (SD1/2 and SDXL). + +- The INVERSE_SQRT, COSINE_WITH_MIN_LR, and WARMUP_STABLE_DECAY learning rate schedules are now available in the transformers library. See PR [#1393](https://github.com/kohya-ss/sd-scripts/pull/1393) for details. Thanks to sdbds! - See the [transformers documentation](https://huggingface.co/docs/transformers/v4.44.2/en/main_classes/optimizer_schedules#schedules) for details on each scheduler. - `--lr_warmup_steps` and `--lr_decay_steps` can now be specified as a ratio of the number of training steps, not just the step value. Example: `--lr_warmup_steps=0.1` or `--lr_warmup_steps=10%`, etc. diff --git a/gen_img.py b/gen_img.py index d0a8f8141..59bcd5b09 100644 --- a/gen_img.py +++ b/gen_img.py @@ -86,7 +86,8 @@ """ -def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +# def replace_unet_modules(unet: diffusers.models.unets.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers, sdpa): +def replace_unet_modules(unet, mem_eff_attn, xformers, sdpa): if mem_eff_attn: logger.info("Enable memory efficient attention for U-Net") diff --git a/networks/check_lora_weights.py b/networks/check_lora_weights.py index 794659c94..f8eab53ba 100644 --- a/networks/check_lora_weights.py +++ b/networks/check_lora_weights.py @@ -18,7 +18,7 @@ def main(file): keys = list(sd.keys()) for key in keys: - if "lora_up" in key or "lora_down" in key: + if "lora_up" in key or "lora_down" in key or "lora_A" in key or "lora_B" in key or "oft_" in key: values.append((key, sd[key])) print(f"number of LoRA modules: {len(values)}") diff --git a/networks/oft.py b/networks/oft.py index 461a98698..6321def3b 100644 --- a/networks/oft.py +++ b/networks/oft.py @@ -4,13 +4,17 @@ import os from typing import Dict, List, Optional, Tuple, Type, Union from diffusers import AutoencoderKL +import einops from transformers import CLIPTextModel import numpy as np import torch +import torch.nn.functional as F import re from library.utils import setup_logging + setup_logging() import logging + logger = logging.getLogger(__name__) RE_UPDOWN = re.compile(r"(up|down)_blocks_(\d+)_(resnets|upsamplers|downsamplers|attentions)_(\d+)_") @@ -45,11 +49,16 @@ def __init__( if type(alpha) == torch.Tensor: alpha = alpha.detach().numpy() - self.constraint = alpha * out_dim + + # constraint in original paper is alpha * out_dim * out_dim, but we use alpha * out_dim for backward compatibility + # original alpha is 1e-6, so we use 1e-3 or 1e-4 for alpha + self.constraint = alpha * out_dim + self.register_buffer("alpha", torch.tensor(alpha)) self.block_size = out_dim // self.num_blocks self.oft_blocks = torch.nn.Parameter(torch.zeros(self.num_blocks, self.block_size, self.block_size)) + self.I = torch.eye(self.block_size).unsqueeze(0).repeat(self.num_blocks, 1, 1) # cpu self.out_dim = out_dim self.shape = org_module.weight.shape @@ -69,27 +78,36 @@ def get_weight(self, multiplier=None): norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - block_R_weighted = self.multiplier * block_R + (1 - self.multiplier) * I - R = torch.block_diag(*block_R_weighted) - - return R + if self.I.device != block_Q.device: + self.I = self.I.to(block_Q.device) + I = self.I + block_R = torch.matmul(I + block_Q, (I - block_Q).float().inverse()) + block_R_weighted = self.multiplier * (block_R - I) + I + return block_R_weighted def forward(self, x, scale=None): - x = self.org_forward(x) if self.multiplier == 0.0: - return x - - R = self.get_weight().to(x.device, dtype=x.dtype) - if x.dim() == 4: - x = x.permute(0, 2, 3, 1) - x = torch.matmul(x, R) - x = x.permute(0, 3, 1, 2) - else: - x = torch.matmul(x, R) - return x + return self.org_forward(x) + org_module = self.org_module[0] + org_dtype = x.dtype + + R = self.get_weight().to(torch.float32) + W = org_module.weight.to(torch.float32) + + if len(W.shape) == 4: # Conv2d + W_reshaped = einops.rearrange(W, "(k n) ... -> k n ...", k=self.num_blocks, n=self.block_size) + RW = torch.einsum("k n m, k n ... -> k m ...", R, W_reshaped) + RW = einops.rearrange(RW, "k m ... -> (k m) ...") + result = F.conv2d( + x, RW.to(org_dtype), org_module.bias, org_module.stride, org_module.padding, org_module.dilation, org_module.groups + ) + else: # Linear + W_reshaped = einops.rearrange(W, "(k n) m -> k n m", k=self.num_blocks, n=self.block_size) + RW = torch.einsum("k n m, k n p -> k m p", R, W_reshaped) + RW = einops.rearrange(RW, "k m p -> (k m) p") + result = F.linear(x, RW.to(org_dtype), org_module.bias) + return result class OFTInfModule(OFTModule): @@ -115,18 +133,19 @@ def forward(self, x, scale=None): return self.org_forward(x) return super().forward(x, scale) - def merge_to(self, multiplier=None, sign=1): - R = self.get_weight(multiplier) * sign - + def merge_to(self, multiplier=None): # get org weight org_sd = self.org_module[0].state_dict() - org_weight = org_sd["weight"] - R = R.to(org_weight.device, dtype=org_weight.dtype) + org_weight = org_sd["weight"].to(torch.float32) - if org_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", org_weight, R) - else: - weight = torch.einsum("oi, op -> pi", org_weight, R) + R = self.get_weight(multiplier).to(torch.float32) + + weight = org_weight.reshape(self.num_blocks, self.block_size, -1) + weight = torch.einsum("k n m, k n ... -> k m ...", R, weight) + weight = weight.reshape(org_weight.shape) + + # convert back to original dtype + weight = weight.to(org_sd["weight"].dtype) # set weight to org_module org_sd["weight"] = weight @@ -145,8 +164,16 @@ def create_network( ): if network_dim is None: network_dim = 4 # default - if network_alpha is None: - network_alpha = 1.0 + if network_alpha is None: # should be set + logger.info( + "network_alpha is not set, use default value 1e-3 / network_alphaが設定されていないのでデフォルト値 1e-3 を使用します" + ) + network_alpha = 1e-3 + elif network_alpha >= 1: + logger.warning( + "network_alpha is too large (>=1, maybe default value is too large), please consider to set smaller value like 1e-3" + " / network_alphaが大きすぎるようです(>=1, デフォルト値が大きすぎる可能性があります)。1e-3のような小さな値を推奨" + ) enable_all_linear = kwargs.get("enable_all_linear", None) enable_conv = kwargs.get("enable_conv", None) @@ -190,12 +217,11 @@ def create_network_from_weights(multiplier, file, vae, text_encoder, unet, weigh else: if dim is None: dim = param.size()[0] - if has_conv2d is None and param.dim() == 4: + if has_conv2d is None and "in_layers_2" in name: has_conv2d = True - if all_linear is None: - if param.dim() == 3 and "attn" not in name: - all_linear = True - if dim is not None and alpha is not None and has_conv2d is not None: + if all_linear is None and "_ff_" in name: + all_linear = True + if dim is not None and alpha is not None and has_conv2d is not None and all_linear is not None: break if has_conv2d is None: has_conv2d = False @@ -241,7 +267,7 @@ def __init__( self.alpha = alpha logger.info( - f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}" + f"create OFT network. num blocks: {self.dim}, constraint: {self.alpha}, multiplier: {self.multiplier}, enable_conv: {enable_conv}, enable_all_linear: {enable_all_linear}" ) # create module instances From e7040669bc9a31706fe9fedec14978b05223f968 Mon Sep 17 00:00:00 2001 From: Maru-mee <151493593+Maru-mee@users.noreply.github.com> Date: Thu, 19 Sep 2024 15:47:06 +0900 Subject: [PATCH 80/97] Bug fix: alpha_mask load --- library/train_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/train_util.py b/library/train_util.py index a46d94877..5a8da90e1 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2207,7 +2207,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool, alph if alpha_mask: if "alpha_mask" not in npz: return False - if npz["alpha_mask"].shape[0:2] != reso: # HxW + if (npz["alpha_mask"].shape[1], npz["alpha_mask"].shape[0]) != reso: # HxW => WxH != reso return False else: if "alpha_mask" in npz: From e1f23af1bc733a1a89c35cf1be1301006c744b4a Mon Sep 17 00:00:00 2001 From: recris Date: Sat, 21 Sep 2024 12:58:32 +0100 Subject: [PATCH 81/97] make timestep sampling behave in the standard way when huber loss is used --- library/train_util.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5a8da90e1..72d2d8112 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -5124,34 +5124,27 @@ def save_sd_model_on_train_end_common( def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - - # TODO: if a huber loss is selected, it will use constant timesteps for each batch - # as. In the future there may be a smarter way + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu') if args.loss_type == "huber" or args.loss_type == "smooth_l1": - timesteps = torch.randint(min_timestep, max_timestep, (1,), device="cpu") - timestep = timesteps.item() - if args.huber_schedule == "exponential": alpha = -math.log(args.huber_c) / noise_scheduler.config.num_train_timesteps - huber_c = math.exp(-alpha * timestep) + huber_c = torch.exp(-alpha * timesteps) elif args.huber_schedule == "snr": - alphas_cumprod = noise_scheduler.alphas_cumprod[timestep] + alphas_cumprod = torch.index_select(noise_scheduler.alphas_cumprod, 0, timesteps) sigmas = ((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5 huber_c = (1 - args.huber_c) / (1 + sigmas) ** 2 + args.huber_c elif args.huber_schedule == "constant": - huber_c = args.huber_c + huber_c = torch.full((b_size,), args.huber_c) else: raise NotImplementedError(f"Unknown Huber loss schedule {args.huber_schedule}!") - - timesteps = timesteps.repeat(b_size).to(device) + huber_c = huber_c.to(device) elif args.loss_type == "l2": - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) - huber_c = 1 # may be anything, as it's not used + huber_c = None # may be anything, as it's not used else: raise NotImplementedError(f"Unknown loss type {args.loss_type}") - timesteps = timesteps.long() + timesteps = timesteps.long().to(device) return timesteps, huber_c @@ -5190,20 +5183,21 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): return noise, noisy_latents, timesteps, huber_c -# NOTE: if you're using the scheduled version, huber_c has to depend on the timesteps already def conditional_loss( - model_pred: torch.Tensor, target: torch.Tensor, reduction: str = "mean", loss_type: str = "l2", huber_c: float = 0.1 + model_pred: torch.Tensor, target: torch.Tensor, reduction: str, loss_type: str, huber_c: Optional[torch.Tensor] ): if loss_type == "l2": loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) elif loss_type == "huber": + huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) elif loss_type == "smooth_l1": + huber_c = huber_c.view(-1, 1, 1, 1) loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) From 29177d2f0389bd13e3f12c95d463fb0e1c58f9a1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Mon, 23 Sep 2024 21:14:03 +0900 Subject: [PATCH 82/97] retain alpha in pil_resize backport #1619 --- library/utils.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/library/utils.py b/library/utils.py index 5b7e657b2..49d46a546 100644 --- a/library/utils.py +++ b/library/utils.py @@ -83,13 +83,20 @@ def setup_logging(args=None, log_level=None, reset=False): def pil_resize(image, size, interpolation=Image.LANCZOS): - pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + has_alpha = image.shape[2] == 4 if len(image.shape) == 3 else False + + if has_alpha: + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) - # use Pillow resize resized_pil = pil_image.resize(size, interpolation) - # return cv2 image - resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) + # Convert back to cv2 format + if has_alpha: + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGBA2BGRA) + else: + resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR) return resized_cv2 From ab7b23187062db86d34fc82db95f7266a68ab5c4 Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 25 Sep 2024 19:38:52 +0800 Subject: [PATCH 83/97] init --- library/train_util.py | 21 ++++++++++++++++++--- requirements.txt | 2 +- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 5a8da90e1..bdf7774e4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2994,7 +2994,7 @@ def int_or_float(value): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", ) # backward compatibility @@ -4032,7 +4032,7 @@ def task(): def get_optimizer(args, trainable_params): - # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" + # "Optimizer to use: AdamW, AdamW8bit, Lion, SGDNesterov, SGDNesterov8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, AdEMAMix8bit, PagedAdEMAMix8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, Adafactor" optimizer_type = args.optimizer_type if args.use_8bit_adam: @@ -4141,7 +4141,22 @@ def get_optimizer(args, trainable_params): raise AttributeError( "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - + elif optimizer_type == "Ademamix8bit".lower(): + logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.AdEMAMix8bit + except AttributeError: + raise AttributeError( + "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) + elif optimizer_type == "PagedAdemamix8bit".lower(): + logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") + try: + optimizer_class = bnb.optim.PagedAdEMAMix8bit + except AttributeError: + raise AttributeError( + "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" + ) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): diff --git a/requirements.txt b/requirements.txt index 15e6e58f1..e6e1bf6fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ ftfy==6.1.1 opencv-python==4.8.1.78 einops==0.7.0 pytorch-lightning==1.9.0 -bitsandbytes==0.43.0 +bitsandbytes==0.44.0 prodigyopt==1.0 lion-pytorch==0.0.6 tensorboard From e74f58148c5994889463afa42bb6fc5d6447a75e Mon Sep 17 00:00:00 2001 From: Kohya S Date: Wed, 25 Sep 2024 20:55:50 +0900 Subject: [PATCH 84/97] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index def528a22..9eabdaeef 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! + - Improvements in OFT (Orthogonal Finetuning) Implementation 1. Optimization of Calculation Order: - Changed the calculation order in the forward method from (Wx)R to W(xR). From 1beddd84e5c4db729a84356db227d981dc18cf8d Mon Sep 17 00:00:00 2001 From: sdbds <865105819@qq.com> Date: Wed, 25 Sep 2024 22:58:26 +0800 Subject: [PATCH 85/97] delete code for cleaning --- library/train_util.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index bdf7774e4..c4845c54b 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4141,22 +4141,7 @@ def get_optimizer(args, trainable_params): raise AttributeError( "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - elif optimizer_type == "Ademamix8bit".lower(): - logger.info(f"use 8-bit Ademamix optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.AdEMAMix8bit - except AttributeError: - raise AttributeError( - "No Ademamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / Ademamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) - elif optimizer_type == "PagedAdemamix8bit".lower(): - logger.info(f"use 8-bit PagedAdemamix optimizer | {optimizer_kwargs}") - try: - optimizer_class = bnb.optim.PagedAdEMAMix8bit - except AttributeError: - raise AttributeError( - "No PagedAdemamix8bit. The version of bitsandbytes installed seems to be old. Please install 0.44.0 or later. / PagedAdemamix8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" - ) + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): From bf91bea2e4363e5b3e0db11f0955ab93a19a0452 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 20:51:40 +0900 Subject: [PATCH 86/97] fix flip_aug, alpha_mask, random_crop issue in caching --- README.md | 2 ++ library/train_util.py | 44 +++++++++++++++++++++++++++++++------------ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 9eabdaeef..b67a2c4e1 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly. + - Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! - Improvements in OFT (Orthogonal Finetuning) Implementation diff --git a/library/train_util.py b/library/train_util.py index 72d2d8112..a31d00c69 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -998,9 +998,26 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # sort by resolution image_infos.sort(key=lambda info: info.bucket_reso[0] * info.bucket_reso[1]) - # split by resolution - batches = [] - batch = [] + # split by resolution and some conditions + class Condition: + def __init__(self, reso, flip_aug, alpha_mask, random_crop): + self.reso = reso + self.flip_aug = flip_aug + self.alpha_mask = alpha_mask + self.random_crop = random_crop + + def __eq__(self, other): + return ( + self.reso == other.reso + and self.flip_aug == other.flip_aug + and self.alpha_mask == other.alpha_mask + and self.random_crop == other.random_crop + ) + + batches: List[Tuple[Condition, List[ImageInfo]]] = [] + batch: List[ImageInfo] = [] + current_condition = None + logger.info("checking cache validity...") for info in tqdm(image_infos): subset = self.image_to_subset[info.image_key] @@ -1021,28 +1038,31 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc if cache_available: # do not add to batch continue - # if last member of batch has different resolution, flush the batch - if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso: - batches.append(batch) + # if batch is not empty and condition is changed, flush the batch. Note that current_condition is not None if batch is not empty + condition = Condition(info.bucket_reso, subset.flip_aug, subset.alpha_mask, subset.random_crop) + if len(batch) > 0 and current_condition != condition: + batches.append((current_condition, batch)) batch = [] batch.append(info) + current_condition = condition # if number of data in batch is enough, flush the batch if len(batch) >= vae_batch_size: - batches.append(batch) + batches.append((current_condition, batch)) batch = [] + current_condition = None if len(batch) > 0: - batches.append(batch) + batches.append((current_condition, batch)) if cache_to_disk and not is_main_process: # if cache to disk, don't cache latents in non-main process, set to info only return # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded logger.info("caching latents...") - for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.alpha_mask, subset.random_crop) + for condition, batch in tqdm(batches, smoothing=1, total=len(batches)): + cache_batch_latents(vae, cache_to_disk, batch, condition.flip_aug, condition.alpha_mask, condition.random_crop) # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -2315,7 +2335,7 @@ def debug_dataset(train_dataset, show_input_ids=False): if "alpha_masks" in example and example["alpha_masks"] is not None: alpha_mask = example["alpha_masks"][j] logger.info(f"alpha mask size: {alpha_mask.size()}") - alpha_mask = (alpha_mask[0].numpy() * 255.0).astype(np.uint8) + alpha_mask = (alpha_mask.numpy() * 255.0).astype(np.uint8) if os.name == "nt": cv2.imshow("alpha_mask", alpha_mask) @@ -5124,7 +5144,7 @@ def save_sd_model_on_train_end_common( def get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler, b_size, device): - timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device='cpu') + timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device="cpu") if args.loss_type == "huber" or args.loss_type == "smooth_l1": if args.huber_schedule == "exponential": From a94bc84dec8e85e8a71217b4d2570a52c6779b73 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:37:31 +0900 Subject: [PATCH 87/97] fix to work bitsandbytes optimizers with full path #1640 --- library/train_util.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index b40945ab8..47c367683 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3014,7 +3014,11 @@ def int_or_float(value): "--optimizer_type", type=str, default="", - help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, Lion8bit, PagedLion8bit, Lion, AdEMAMix8bit, PagedAdEMAMix8bit, SGDNesterov, SGDNesterov8bit, DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, AdaFactor", + help="Optimizer to use / オプティマイザの種類: AdamW (default), AdamW8bit, PagedAdamW, PagedAdamW8bit, PagedAdamW32bit, " + "Lion8bit, PagedLion8bit, Lion, SGDNesterov, SGDNesterov8bit, " + "DAdaptation(DAdaptAdamPreprint), DAdaptAdaGrad, DAdaptAdam, DAdaptAdan, DAdaptAdanIP, DAdaptLion, DAdaptSGD, " + "AdaFactor. " + "Also, you can use any optimizer by specifying the full path to the class, like 'bitsandbytes.optim.AdEMAMix8bit' or 'bitsandbytes.optim.PagedAdEMAMix8bit'.", ) # backward compatibility @@ -4105,6 +4109,7 @@ def get_optimizer(args, trainable_params): lr = args.learning_rate optimizer = None + optimizer_class = None if optimizer_type == "Lion".lower(): try: @@ -4162,7 +4167,8 @@ def get_optimizer(args, trainable_params): "No PagedLion8bit. The version of bitsandbytes installed seems to be old. Please install 0.39.0 or later. / PagedLion8bitが定義されていません。インストールされているbitsandbytesのバージョンが古いようです。0.39.0以上をインストールしてください" ) - optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + if optimizer_class is not None: + optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) elif optimizer_type == "PagedAdamW".lower(): logger.info(f"use PagedAdamW optimizer | {optimizer_kwargs}") @@ -4338,6 +4344,7 @@ def get_optimizer(args, trainable_params): optimizer_class = getattr(optimizer_module, optimizer_type) optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs) + # for logging optimizer_name = optimizer_class.__module__ + "." + optimizer_class.__name__ optimizer_args = ",".join([f"{k}={v}" for k, v in optimizer_kwargs.items()]) From ce49ced699298aa885d9a64b969fe8c77f30893b Mon Sep 17 00:00:00 2001 From: Kohya S Date: Thu, 26 Sep 2024 21:37:40 +0900 Subject: [PATCH 88/97] update readme --- README.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b67a2c4e1..9f024c1c9 100644 --- a/README.md +++ b/README.md @@ -140,9 +140,12 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser ### Working in progress - __important__ The dependent libraries are updated. Please see [Upgrade](#upgrade) and update the libraries. - - transformers, accelerate and huggingface_hub are updated. + - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! + - There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes). + - Fixed a bug in the cache of latents. When `flip_aug`, `alpha_mask`, and `random_crop` are different in multiple subsets in the dataset configuration file (.toml), the last subset is used instead of reflecting them correctly. - Fixed an issue where the timesteps in the batch were the same when using Huber loss. PR [#1628](https://github.com/kohya-ss/sd-scripts/pull/1628) Thanks to recris! From fe2aa32484a948f16955909e64c21da7fe1e4e0c Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 09:49:25 +0900 Subject: [PATCH 89/97] adjust min/max bucket reso divisible by reso steps #1632 --- README.md | 2 ++ docs/config_README-en.md | 2 ++ docs/config_README-ja.md | 2 ++ fine_tune.py | 2 ++ library/train_util.py | 40 ++++++++++++++++++++++++++++++++------ train_controlnet.py | 2 ++ train_db.py | 2 ++ train_network.py | 2 +- train_textual_inversion.py | 2 +- 9 files changed, 48 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 9f024c1c9..de5cddb92 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,8 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632) + - `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! - There is no abbreviation, so please specify the full path like `--optimizer_type bitsandbytes.optim.AdEMAMix8bit` (not bnb but bitsandbytes). diff --git a/docs/config_README-en.md b/docs/config_README-en.md index 83bea329b..66a50dc09 100644 --- a/docs/config_README-en.md +++ b/docs/config_README-en.md @@ -128,6 +128,8 @@ These are options related to the configuration of the data set. They cannot be d * `batch_size` * This corresponds to the command-line argument `--train_batch_size`. +* `max_bucket_reso`, `min_bucket_reso` + * Specify the maximum and minimum resolutions of the bucket. It must be divisible by `bucket_reso_steps`. These settings are fixed per dataset. That means that subsets belonging to the same dataset will share these settings. For example, if you want to prepare datasets with different resolutions, you can define them as separate datasets as shown in the example above, and set different resolutions for each. diff --git a/docs/config_README-ja.md b/docs/config_README-ja.md index cc74c341b..0ed95e0eb 100644 --- a/docs/config_README-ja.md +++ b/docs/config_README-ja.md @@ -118,6 +118,8 @@ DreamBooth の手法と fine tuning の手法の両方とも利用可能な学 * `batch_size` * コマンドライン引数の `--train_batch_size` と同等です。 +* `max_bucket_reso`, `min_bucket_reso` + * bucketの最大、最小解像度を指定します。`bucket_reso_steps` で割り切れる必要があります。 これらの設定はデータセットごとに固定です。 つまり、データセットに所属するサブセットはこれらの設定を共有することになります。 diff --git a/fine_tune.py b/fine_tune.py index d865cd2de..b556672d2 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -91,6 +91,8 @@ def train(args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/library/train_util.py b/library/train_util.py index 47c367683..0cb6383a4 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -653,6 +653,34 @@ def __init__( # caching self.caching_mode = None # None, 'latents', 'text' + def adjust_min_max_bucket_reso_by_steps( + self, resolution: Tuple[int, int], min_bucket_reso: int, max_bucket_reso: int, bucket_reso_steps: int + ) -> Tuple[int, int]: + # make min/max bucket reso to be multiple of bucket_reso_steps + if min_bucket_reso % bucket_reso_steps != 0: + adjusted_min_bucket_reso = min_bucket_reso - min_bucket_reso % bucket_reso_steps + logger.warning( + f"min_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / min_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {min_bucket_reso} -> {adjusted_min_bucket_reso}" + ) + min_bucket_reso = adjusted_min_bucket_reso + if max_bucket_reso % bucket_reso_steps != 0: + adjusted_max_bucket_reso = max_bucket_reso + bucket_reso_steps - max_bucket_reso % bucket_reso_steps + logger.warning( + f"max_bucket_reso is adjusted to be multiple of bucket_reso_steps" + f" / max_bucket_resoがbucket_reso_stepsの倍数になるように調整されました: {max_bucket_reso} -> {adjusted_max_bucket_reso}" + ) + max_bucket_reso = adjusted_max_bucket_reso + + assert ( + min(resolution) >= min_bucket_reso + ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" + assert ( + max(resolution) <= max_bucket_reso + ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + + return min_bucket_reso, max_bucket_reso + def set_seed(self, seed): self.seed = seed @@ -1533,12 +1561,9 @@ def __init__( self.enable_bucket = enable_bucket if self.enable_bucket: - assert ( - min(resolution) >= min_bucket_reso - ), f"min_bucket_reso must be equal or less than resolution / min_bucket_resoは最小解像度より大きくできません。解像度を大きくするかmin_bucket_resoを小さくしてください" - assert ( - max(resolution) <= max_bucket_reso - ), f"max_bucket_reso must be equal or greater than resolution / max_bucket_resoは最大解像度より小さくできません。解像度を小さくするかmin_bucket_resoを大きくしてください" + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps @@ -1901,6 +1926,9 @@ def __init__( self.enable_bucket = enable_bucket if self.enable_bucket: + min_bucket_reso, max_bucket_reso = self.adjust_min_max_bucket_reso_by_steps( + resolution, min_bucket_reso, max_bucket_reso, bucket_reso_steps + ) self.min_bucket_reso = min_bucket_reso self.max_bucket_reso = max_bucket_reso self.bucket_reso_steps = bucket_reso_steps diff --git a/train_controlnet.py b/train_controlnet.py index c9ac6c5a8..6938c4bcc 100644 --- a/train_controlnet.py +++ b/train_controlnet.py @@ -107,6 +107,8 @@ def train(args): ds_for_collator = train_dataset_group if args.max_data_loader_n_workers == 0 else None collator = train_util.collator_class(current_epoch, current_step, ds_for_collator) + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/train_db.py b/train_db.py index 39d8ea6ed..2c7f02582 100644 --- a/train_db.py +++ b/train_db.py @@ -93,6 +93,8 @@ def train(args): if args.no_token_padding: train_dataset_group.disable_token_padding() + train_dataset_group.verify_bucket_reso_steps(64) + if args.debug_dataset: train_util.debug_dataset(train_dataset_group) return diff --git a/train_network.py b/train_network.py index 7ba073855..044ec3aa8 100644 --- a/train_network.py +++ b/train_network.py @@ -95,7 +95,7 @@ def generate_step_logs( return logs def assert_extra_args(self, args, train_dataset_group): - pass + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index ade077c36..96e7bd509 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -99,7 +99,7 @@ def __init__(self): self.is_sdxl = False def assert_extra_args(self, args, train_dataset_group): - pass + train_dataset_group.verify_bucket_reso_steps(64) def load_target_model(self, args, weight_dtype, accelerator): text_encoder, vae, unet, _ = train_util.load_target_model(args, weight_dtype, accelerator) From 1567549220b5936af0c534ca23656ecd2f4882f0 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 09:51:36 +0900 Subject: [PATCH 90/97] update help text #1632 --- library/train_util.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 0cb6383a4..422dceca2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3865,8 +3865,20 @@ def add_dataset_arguments( action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする", ) - parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度") - parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最大解像度") + parser.add_argument( + "--min_bucket_reso", + type=int, + default=256, + help="minimum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最小解像度、bucket_reso_stepsで割り切れる必要があります", + ) + parser.add_argument( + "--max_bucket_reso", + type=int, + default=1024, + help="maximum resolution for buckets, must be divisible by bucket_reso_steps " + " / bucketの最大解像度、bucket_reso_stepsで割り切れる必要があります", + ) parser.add_argument( "--bucket_reso_steps", type=int, From 012e7e63a5b1acdf69c72eee4cb330a5a6defc41 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 29 Sep 2024 23:18:16 +0900 Subject: [PATCH 91/97] fix to work linear/cosine scheduler closes #1651 ref #1393 --- library/train_util.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/library/train_util.py b/library/train_util.py index 422dceca2..27910dc90 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -4496,6 +4496,15 @@ def wrap_check_needless_num_warmup_steps(return_vals): **lr_scheduler_kwargs, ) + # these schedulers do not require `num_decay_steps` + if name == SchedulerType.LINEAR or name == SchedulerType.COSINE: + return schedule_func( + optimizer, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + **lr_scheduler_kwargs, + ) + # All other schedulers require `num_decay_steps` if num_decay_steps is None: raise ValueError(f"{name} requires `num_decay_steps`, please provide that argument.") From 8fc30f820595f80ec3f09738cc4cf01f441c41b7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 21 Oct 2024 07:34:33 -0400 Subject: [PATCH 92/97] Fix training for V-pred and ztSNR 1) Updates debiased estimation loss function for V-pred. 2) Prevents now-deprecated scaling of loss if ztSNR is enabled. --- fine_tune.py | 4 ++-- library/custom_train_functions.py | 7 +++++-- library/train_util.py | 5 +++++ sdxl_train.py | 4 ++-- sdxl_train_control_net_lllite.py | 4 ++-- sdxl_train_control_net_lllite_old.py | 4 ++-- train_db.py | 4 ++-- train_network.py | 4 ++-- train_textual_inversion.py | 4 ++-- train_textual_inversion_XTI.py | 4 ++-- 10 files changed, 26 insertions(+), 18 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index b556672d2..19a35229f 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -383,10 +383,10 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2a513dc5b..faf443048 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -96,10 +96,13 @@ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_los return loss -def apply_debiased_estimation(loss, timesteps, noise_scheduler): +def apply_debiased_estimation(loss, timesteps, noise_scheduler, v_prediction=False): snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 - weight = 1 / torch.sqrt(snr_t) + if v_prediction: + weight = 1 / (snr_t + 1) + else: + weight = 1 / torch.sqrt(snr_t) loss = weight * loss return loss diff --git a/library/train_util.py b/library/train_util.py index 27910dc90..adb983d2f 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3731,6 +3731,11 @@ def verify_training_args(args: argparse.Namespace): raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) + + if args.scale_v_pred_loss_like_noise_pred and args.zero_terminal_snr: + raise ValueError( + "zero_terminal_snr enabled. scale_v_pred_loss_like_noise_pred will not be used / zero_terminal_snrが有効です。scale_v_pred_loss_like_noise_predは使用されません" + ) if args.v_pred_like_loss and args.v_parameterization: raise ValueError( diff --git a/sdxl_train.py b/sdxl_train.py index e0a8f2b2e..44ee9233f 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -725,12 +725,12 @@ def optimizer_hook(parameter: torch.Tensor): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # mean over batch dimension else: diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 5ff060a9f..436f0e194 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -474,12 +474,12 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 292a0463a..8fba9eba6 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -434,12 +434,12 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_db.py b/train_db.py index 2c7f02582..d5a94a565 100644 --- a/train_db.py +++ b/train_db.py @@ -370,10 +370,10 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_network.py b/train_network.py index 044ec3aa8..790fbfc9d 100644 --- a/train_network.py +++ b/train_network.py @@ -993,12 +993,12 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 96e7bd509..10b34db5e 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -598,12 +598,12 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index efb59137b..084b90c60 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -483,10 +483,10 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred: + if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: - loss = apply_debiased_estimation(loss, timesteps, noise_scheduler) + loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし From e1b63c2249345e4f14c10cbb252da68157ac13b7 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Mon, 21 Oct 2024 08:12:53 -0400 Subject: [PATCH 93/97] Only add warning for deprecated scaling vpred loss function --- fine_tune.py | 2 +- library/train_util.py | 11 ++++++----- sdxl_train.py | 2 +- sdxl_train_control_net_lllite.py | 2 +- sdxl_train_control_net_lllite_old.py | 2 +- train_db.py | 2 +- train_network.py | 2 +- train_textual_inversion.py | 2 +- train_textual_inversion_XTI.py | 2 +- 9 files changed, 14 insertions(+), 13 deletions(-) diff --git a/fine_tune.py b/fine_tune.py index 19a35229f..c79f97d25 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -383,7 +383,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) diff --git a/library/train_util.py b/library/train_util.py index adb983d2f..f479dcc64 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3727,15 +3727,16 @@ def verify_training_args(args: argparse.Namespace): if args.adaptive_noise_scale is not None and args.noise_offset is None: raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") + if args.scale_v_pred_loss_like_noise_pred: + logger.warning( + f"scale_v_pred_loss_like_noise_pred is deprecated. it is suggested to use min_snr_gamma or debiased_estimation_loss" + + " / scale_v_pred_loss_like_noise_pred は非推奨です。min_snr_gammaまたはdebiased_estimation_lossを使用することをお勧めします" + ) + if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" ) - - if args.scale_v_pred_loss_like_noise_pred and args.zero_terminal_snr: - raise ValueError( - "zero_terminal_snr enabled. scale_v_pred_loss_like_noise_pred will not be used / zero_terminal_snrが有効です。scale_v_pred_loss_like_noise_predは使用されません" - ) if args.v_pred_like_loss and args.v_parameterization: raise ValueError( diff --git a/sdxl_train.py b/sdxl_train.py index 44ee9233f..b533b2749 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -725,7 +725,7 @@ def optimizer_hook(parameter: torch.Tensor): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/sdxl_train_control_net_lllite.py b/sdxl_train_control_net_lllite.py index 436f0e194..0e67cde5c 100644 --- a/sdxl_train_control_net_lllite.py +++ b/sdxl_train_control_net_lllite.py @@ -474,7 +474,7 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/sdxl_train_control_net_lllite_old.py b/sdxl_train_control_net_lllite_old.py index 8fba9eba6..4a01f9e2c 100644 --- a/sdxl_train_control_net_lllite_old.py +++ b/sdxl_train_control_net_lllite_old.py @@ -434,7 +434,7 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_db.py b/train_db.py index d5a94a565..e7cf3cde3 100644 --- a/train_db.py +++ b/train_db.py @@ -370,7 +370,7 @@ def train(args): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) diff --git a/train_network.py b/train_network.py index 790fbfc9d..7bf125dca 100644 --- a/train_network.py +++ b/train_network.py @@ -993,7 +993,7 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 10b34db5e..37349da7d 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -598,7 +598,7 @@ def remove_model(old_ckpt_name): if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.v_pred_like_loss: loss = add_v_prediction_like_loss(loss, timesteps, noise_scheduler, args.v_pred_like_loss) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 084b90c60..fac0787b9 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -483,7 +483,7 @@ def remove_model(old_ckpt_name): loss = loss * loss_weights if args.min_snr_gamma: loss = apply_snr_weight(loss, timesteps, noise_scheduler, args.min_snr_gamma, args.v_parameterization) - if args.scale_v_pred_loss_like_noise_pred and not args.zero_terminal_snr: + if args.scale_v_pred_loss_like_noise_pred: loss = scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler) if args.debiased_estimation_loss: loss = apply_debiased_estimation(loss, timesteps, noise_scheduler, args.v_parameterization) From 0e7c5929336173e30d7932c0706eaf61a7d396f4 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 22 Oct 2024 11:19:34 -0400 Subject: [PATCH 94/97] Remove scale_v_pred_loss_like_noise_pred deprecation https://github.com/kohya-ss/sd-scripts/pull/1715#issuecomment-2427876376 --- library/train_util.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index f479dcc64..27910dc90 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3727,12 +3727,6 @@ def verify_training_args(args: argparse.Namespace): if args.adaptive_noise_scale is not None and args.noise_offset is None: raise ValueError("adaptive_noise_scale requires noise_offset / adaptive_noise_scaleを使用するにはnoise_offsetが必要です") - if args.scale_v_pred_loss_like_noise_pred: - logger.warning( - f"scale_v_pred_loss_like_noise_pred is deprecated. it is suggested to use min_snr_gamma or debiased_estimation_loss" - + " / scale_v_pred_loss_like_noise_pred は非推奨です。min_snr_gammaまたはdebiased_estimation_lossを使用することをお勧めします" - ) - if args.scale_v_pred_loss_like_noise_pred and not args.v_parameterization: raise ValueError( "scale_v_pred_loss_like_noise_pred can be enabled only with v_parameterization / scale_v_pred_loss_like_noise_predはv_parameterizationが有効なときのみ有効にできます" From be14c062674973d0e4fee1eb4527e04707bb72b8 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Tue, 22 Oct 2024 12:13:51 -0400 Subject: [PATCH 95/97] Remove v-pred warnings Different model architectures, such as SDXL, can take advantage of v-pred. It doesn't make sense to include these warnings anymore. --- gen_img.py | 2 -- gen_img_diffusers.py | 2 -- library/train_util.py | 4 ---- 3 files changed, 8 deletions(-) diff --git a/gen_img.py b/gen_img.py index 59bcd5b09..9427a8940 100644 --- a/gen_img.py +++ b/gen_img.py @@ -1495,8 +1495,6 @@ def main(args): highres_fix = args.highres_fix_scale is not None # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") diff --git a/gen_img_diffusers.py b/gen_img_diffusers.py index 2c40f1a06..04db4e9b4 100644 --- a/gen_img_diffusers.py +++ b/gen_img_diffusers.py @@ -2216,8 +2216,6 @@ def main(args): highres_fix = args.highres_fix_scale is not None # assert not highres_fix or args.image_path is None, f"highres_fix doesn't work with img2img / highres_fixはimg2imgと同時に使えません" - if args.v_parameterization and not args.v2: - logger.warning("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません") if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") diff --git a/library/train_util.py b/library/train_util.py index 27910dc90..100ef475d 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3698,10 +3698,6 @@ def verify_training_args(args: argparse.Namespace): global HIGH_VRAM HIGH_VRAM = True - if args.v_parameterization and not args.v2: - logger.warning( - "v_parameterization should be with v2 not v1 or sdxl / v1やsdxlでv_parameterizationを使用することは想定されていません" - ) if args.v2 and args.clip_skip is not None: logger.warning("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません") From b1e6504007aca20d15155d5c9fe880fb5e0002b8 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 25 Oct 2024 18:56:25 +0900 Subject: [PATCH 96/97] update README --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index de5cddb92..ce28d0049 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,9 @@ The majority of scripts is licensed under ASL 2.0 (including codes from Diffuser - bitsandbytes, transformers, accelerate and huggingface_hub are updated. - If you encounter any issues, please report them. +- Fixed a bug where the loss weight was incorrect when `--debiased_estimation_loss` was specified with `--v_parameterization`. PR [#1715](https://github.com/kohya-ss/sd-scripts/pull/1715) Thanks to catboxanon! See [the PR](https://github.com/kohya-ss/sd-scripts/pull/1715) for details. + - Removed the warning when `--v_parameterization` is specified in SDXL and SD1.5. PR [#1717](https://github.com/kohya-ss/sd-scripts/pull/1717) + - There was a bug where the min_bucket_reso/max_bucket_reso in the dataset configuration did not create the correct resolution bucket if it was not divisible by bucket_reso_steps. These values are now warned and automatically rounded to a divisible value. Thanks to Maru-mee for raising the issue. Related PR [#1632](https://github.com/kohya-ss/sd-scripts/pull/1632) - `bitsandbytes` is updated to 0.44.0. Now you can use `AdEMAMix8bit` and `PagedAdEMAMix8bit` in the training script. PR [#1640](https://github.com/kohya-ss/sd-scripts/pull/1640) Thanks to sdbds! From 345daaa986cdbcaedb6840997390f3d86846d677 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Fri, 17 Jan 2025 23:22:38 +0900 Subject: [PATCH 97/97] update README for merging --- README-ja.md | 8 ++++++-- README.md | 23 +++++++++++++++++++---- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/README-ja.md b/README-ja.md index 27cc56c34..60249f61e 100644 --- a/README-ja.md +++ b/README-ja.md @@ -36,6 +36,8 @@ Python 3.10.6およびGitが必要です。 - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe - git: https://git-scm.com/download/win +Python 3.10.x、3.11.x、3.12.xでも恐らく動作しますが、3.10.6でテストしています。 + PowerShellを使う場合、venvを使えるようにするためには以下の手順でセキュリティ設定を変更してください。 (venvに限らずスクリプトの実行が可能になりますので注意してください。) @@ -45,7 +47,7 @@ PowerShellを使う場合、venvを使えるようにするためには以下の ## Windows環境でのインストール -スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.0.1、1.12.1でも動作すると思われます。 +スクリプトはPyTorch 2.1.2でテストしています。PyTorch 2.2以降でも恐らく動作します。 (なお、python -m venv~の行で「python」とだけ表示された場合、py -m venv~のようにpythonをpyに変更してください。) @@ -67,10 +69,12 @@ accelerate config コマンドプロンプトでも同一です。 -注:`bitsandbytes==0.43.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。 +注:`bitsandbytes==0.44.0`、`prodigyopt==1.0`、`lion-pytorch==0.0.6` は `requirements.txt` に含まれるようになりました。他のバージョンを使う場合は適宜インストールしてください。 この例では PyTorch および xfomers は2.1.2/CUDA 11.8版をインストールします。CUDA 12.1版やPyTorch 1.12.1を使う場合は適宜書き換えください。たとえば CUDA 12.1版の場合は `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` および `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121` としてください。 +PyTorch 2.2以降を用いる場合は、`torch==2.1.2` と `torchvision==0.16.2` 、および `xformers==0.0.23.post1` を適宜変更してください。 + accelerate configの質問には以下のように答えてください。(bf16で学習する場合、最後の質問にはbf16と答えてください。) ```txt diff --git a/README.md b/README.md index 73564cb29..6beee5e3a 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ This repository contains the scripts for: The file does not contain requirements for PyTorch. Because the version of PyTorch depends on the environment, it is not included in the file. Please install PyTorch first according to the environment. See installation instructions below. -The scripts are tested with Pytorch 2.1.2. 2.0.1 and 1.12.1 is not tested but should work. +The scripts are tested with Pytorch 2.1.2. PyTorch 2.2 or later will work. Please install the appropriate version of PyTorch and xformers. ## Links to usage documentation @@ -52,6 +52,8 @@ Python 3.10.6 and Git: - Python 3.10.6: https://www.python.org/ftp/python/3.10.6/python-3.10.6-amd64.exe - git: https://git-scm.com/download/win +Python 3.10.x, 3.11.x, and 3.12.x will work but not tested. + Give unrestricted script access to powershell so venv can work: - Open an administrator powershell window @@ -78,10 +80,12 @@ accelerate config If `python -m venv` shows only `python`, change `python` to `py`. -__Note:__ Now `bitsandbytes==0.43.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually. +Note: Now `bitsandbytes==0.44.0`, `prodigyopt==1.0` and `lion-pytorch==0.0.6` are included in the requirements.txt. If you'd like to use the another version, please install it manually. This installation is for CUDA 11.8. If you use a different version of CUDA, please install the appropriate version of PyTorch and xformers. For example, if you use CUDA 12, please install `pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu121` and `pip install xformers==0.0.23.post1 --index-url https://download.pytorch.org/whl/cu121`. +If you use PyTorch 2.2 or later, please change `torch==2.1.2` and `torchvision==0.16.2` and `xformers==0.0.23.post1` to the appropriate version. +