Skip to content

Commit

Permalink
Merge branch 'kohya-ss:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
sdbds authored Jan 26, 2024
2 parents e0e9975 + cf790d8 commit 8f3d748
Show file tree
Hide file tree
Showing 22 changed files with 166 additions and 134 deletions.
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum

## Change History

### 作業中の内容 / Work in progress
### Jan 23, 2024 / 2024/1/23: v0.8.2

- [Experimental] The `--fp8_base` option is added to the training scripts for LoRA etc. The base model (U-Net, and Text Encoder when training modules for Text Encoder) can be trained with fp8. PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) Thanks to KohakuBlueleaf!
- Please specify `--fp8_base` in `train_network.py` or `sdxl_train_network.py`.
Expand All @@ -262,6 +262,13 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
- For example, if you train with state A as `1.0` and state B as `-1.0`, you may be able to generate by switching between state A and B depending on the LoRA application rate.
- Also, if you prepare five states and train them as `0.2`, `0.4`, `0.6`, `0.8`, and `1.0`, you may be able to generate by switching the states smoothly depending on the application rate.
- Please specify `network_multiplier` in `[[datasets]]` in `.toml` file.
- Some options are added to `networks/extract_lora_from_models.py` to reduce the memory usage.
- `--load_precision` option can be used to specify the precision when loading the model. If the model is saved in fp16, you can reduce the memory usage by specifying `--load_precision fp16` without losing precision.
- `--load_original_model_to` option can be used to specify the device to load the original model. `--load_tuned_model_to` option can be used to specify the device to load the derived model. The default is `cpu` for both options, but you can specify `cuda` etc. You can reduce the memory usage by loading one of them to GPU. This option is available only for SDXL.

- The gradient synchronization in LoRA training with multi-GPU is improved. PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) Thanks to KohakuBlueleaf!
- The code for Intel IPEX support is improved. PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) Thanks to akx!
- Fixed a bug in multi-GPU Textual Inversion training.

- (実験的) LoRA等の学習スクリプトで、ベースモデル(U-Net、および Text Encoder のモジュール学習時は Text Encoder も)の重みを fp8 にして学習するオプションが追加されました。 PR [#1057](https://github.com/kohya-ss/sd-scripts/pull/1057) KohakuBlueleaf 氏に感謝します。
- `train_network.py` または `sdxl_train_network.py``--fp8_base` を指定してください。
Expand All @@ -273,6 +280,12 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum
- たとえば状態 A を `1.0`、状態 B を `-1.0` として学習すると、LoRA の適用率に応じて状態 A と B を切り替えつつ生成できるかもしれません。
- また、五段階の状態を用意し、それぞれ `0.2``0.4``0.6``0.8``1.0` として学習すると、適用率でなめらかに状態を切り替えて生成できるかもしれません。
- `.toml` ファイルで `[[datasets]]``network_multiplier` を指定してください。
- `networks/extract_lora_from_models.py` に使用メモリ量を削減するいくつかのオプションを追加しました。
- `--load_precision` で読み込み時の精度を指定できます。モデルが fp16 で保存されている場合は `--load_precision fp16` を指定して精度を変えずにメモリ量を削減できます。
- `--load_original_model_to` で元モデルを読み込むデバイスを、`--load_tuned_model_to` で派生モデルを読み込むデバイスを指定できます。デフォルトは両方とも `cpu` ですがそれぞれ `cuda` 等を指定できます。片方を GPU に読み込むことでメモリ量を削減できます。SDXL の場合のみ有効です。
- マルチ GPU での LoRA 等の学習時に勾配の同期が改善されました。 PR [#1064](https://github.com/kohya-ss/sd-scripts/pull/1064) KohakuBlueleaf 氏に感謝します。
- Intel IPEX サポートのコードが改善されました。PR [#1060](https://github.com/kohya-ss/sd-scripts/pull/1060) akx 氏に感謝します。
- マルチ GPU での Textual Inversion 学習の不具合を修正しました。

- `.toml` example for network multiplier / ネットワーク適用率の `.toml` の記述例

Expand Down
10 changes: 3 additions & 7 deletions XTI_hijack.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass
from library.ipex_interop import init_ipex

init_ipex()
from typing import Union, List, Optional, Dict, Any, Tuple
from diffusers.models.unet_2d_condition import UNet2DConditionOutput

Expand Down
9 changes: 2 additions & 7 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@
from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

Expand Down
9 changes: 2 additions & 7 deletions gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,10 @@
import numpy as np
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
import torchvision
from diffusers import (
AutoencoderKL,
Expand Down
24 changes: 24 additions & 0 deletions library/ipex_interop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch


def init_ipex():
"""
Try to import `intel_extension_for_pytorch`, and apply
the hijacks using `library.ipex.ipex_init`.
If IPEX is not installed, this function does nothing.
"""
try:
import intel_extension_for_pytorch as ipex # noqa
except ImportError:
return

try:
from library.ipex import ipex_init

if torch.xpu.is_available():
is_initialized, error_message = ipex_init()
if not is_initialized:
print("failed to initialize ipex:", error_message)
except Exception as e:
print("failed to initialize ipex:", e)
10 changes: 2 additions & 8 deletions library/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,9 @@
import os
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
init_ipex()
import diffusers
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig, logging
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline # , UNet2DConditionModel
Expand Down
2 changes: 1 addition & 1 deletion library/sdxl_lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def __call__(

unet_dtype = self.unet.dtype
dtype = unet_dtype
if dtype.itemsize == 1: # fp8
if hasattr(dtype, "itemsize") and dtype.itemsize == 1: # fp8
dtype = torch.float16
self.unet.to(dtype)

Expand Down
1 change: 1 addition & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1774,6 +1774,7 @@ def __init__(
tokenizer,
max_token_length,
resolution,
network_multiplier,
enable_bucket,
min_bucket_reso,
max_bucket_reso,
Expand Down
83 changes: 72 additions & 11 deletions networks/extract_lora_from_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def svd(
clamp_quantile=0.99,
min_diff=0.01,
no_metadata=False,
load_precision=None,
load_original_model_to=None,
load_tuned_model_to=None,
):
def str_to_dtype(p):
if p == "float":
Expand All @@ -57,28 +60,51 @@ def str_to_dtype(p):
if v_parameterization is None:
v_parameterization = v2

load_dtype = str_to_dtype(load_precision) if load_precision else None
save_dtype = str_to_dtype(save_precision)
work_device = "cpu"

# load models
if not sdxl:
print(f"loading original SD model : {model_org}")
text_encoder_o, _, unet_o = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_org)
text_encoders_o = [text_encoder_o]
if load_dtype is not None:
text_encoder_o = text_encoder_o.to(load_dtype)
unet_o = unet_o.to(load_dtype)

print(f"loading tuned SD model : {model_tuned}")
text_encoder_t, _, unet_t = model_util.load_models_from_stable_diffusion_checkpoint(v2, model_tuned)
text_encoders_t = [text_encoder_t]
if load_dtype is not None:
text_encoder_t = text_encoder_t.to(load_dtype)
unet_t = unet_t.to(load_dtype)

model_version = model_util.get_model_version_str_for_sd1_sd2(v2, v_parameterization)
else:
device_org = load_original_model_to if load_original_model_to else "cpu"
device_tuned = load_tuned_model_to if load_tuned_model_to else "cpu"

print(f"loading original SDXL model : {model_org}")
text_encoder_o1, text_encoder_o2, _, unet_o, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_org, device_org
)
text_encoders_o = [text_encoder_o1, text_encoder_o2]
if load_dtype is not None:
text_encoder_o1 = text_encoder_o1.to(load_dtype)
text_encoder_o2 = text_encoder_o2.to(load_dtype)
unet_o = unet_o.to(load_dtype)

print(f"loading original SDXL model : {model_tuned}")
text_encoder_t1, text_encoder_t2, _, unet_t, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, "cpu"
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, model_tuned, device_tuned
)
text_encoders_t = [text_encoder_t1, text_encoder_t2]
if load_dtype is not None:
text_encoder_t1 = text_encoder_t1.to(load_dtype)
text_encoder_t2 = text_encoder_t2.to(load_dtype)
unet_t = unet_t.to(load_dtype)

model_version = sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0

# create LoRA network to extract weights: Use dim (rank) as alpha
Expand All @@ -100,38 +126,54 @@ def str_to_dtype(p):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)

# clear weight to save memory
module_o.weight = None
module_t.weight = None

# Text Encoder might be same
if not text_encoder_different and torch.max(torch.abs(diff)) > min_diff:
text_encoder_different = True
print(f"Text encoder is different. {torch.max(torch.abs(diff))} > {min_diff}")

diff = diff.float()
diffs[lora_name] = diff

# clear target Text Encoder to save memory
for text_encoder in text_encoders_t:
del text_encoder

if not text_encoder_different:
print("Text encoder is same. Extract U-Net only.")
lora_network_o.text_encoder_loras = []
diffs = {}
diffs = {} # clear diffs

for i, (lora_o, lora_t) in enumerate(zip(lora_network_o.unet_loras, lora_network_t.unet_loras)):
lora_name = lora_o.lora_name
module_o = lora_o.org_module
module_t = lora_t.org_module
diff = module_t.weight - module_o.weight
diff = diff.float()
diff = module_t.weight.to(work_device) - module_o.weight.to(work_device)

if args.device:
diff = diff.to(args.device)
# clear weight to save memory
module_o.weight = None
module_t.weight = None

diffs[lora_name] = diff

# clear LoRA network, target U-Net to save memory
del lora_network_o
del lora_network_t
del unet_t

# make LoRA with svd
print("calculating by svd")
lora_weights = {}
with torch.no_grad():
for lora_name, mat in tqdm(list(diffs.items())):
if args.device:
mat = mat.to(args.device)
mat = mat.to(torch.float) # calc by float

# if conv_dim is None, diffs do not include LoRAs for conv2d-3x3
conv2d = len(mat.size()) == 4
kernel_size = None if not conv2d else mat.size()[2:4]
Expand Down Expand Up @@ -171,8 +213,8 @@ def str_to_dtype(p):
U = U.reshape(out_dim, rank, 1, 1)
Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])

U = U.to("cpu").contiguous()
Vh = Vh.to("cpu").contiguous()
U = U.to(work_device, dtype=save_dtype).contiguous()
Vh = Vh.to(work_device, dtype=save_dtype).contiguous()

lora_weights[lora_name] = (U, Vh)

Expand Down Expand Up @@ -230,6 +272,13 @@ def setup_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--sdxl", action="store_true", help="load Stable Diffusion SDXL base model / Stable Diffusion SDXL baseのモデルを読み込む"
)
parser.add_argument(
"--load_precision",
type=str,
default=None,
choices=[None, "float", "fp16", "bf16"],
help="precision in loading, model default if omitted / 読み込み時に精度を変更して読み込む、省略時はモデルファイルによる"
)
parser.add_argument(
"--save_precision",
type=str,
Expand Down Expand Up @@ -285,6 +334,18 @@ def setup_parser() -> argparse.ArgumentParser:
help="do not save sai modelspec metadata (minimum ss_metadata for LoRA is saved) / "
+ "sai modelspecのメタデータを保存しない(LoRAの最低限のss_metadataは保存される)",
)
parser.add_argument(
"--load_original_model_to",
type=str,
default=None,
help="location to load original model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 元モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)
parser.add_argument(
"--load_tuned_model_to",
type=str,
default=None,
help="location to load tuned model, cpu or cuda, cuda:0, etc, default is cpu, only for SDXL / 派生モデル読み込み先、cpuまたはcuda、cuda:0など、省略時はcpu、SDXLのみ有効",
)

return parser

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ einops==0.6.1
pytorch-lightning==1.9.0
# bitsandbytes==0.39.1
tensorboard==2.10.1
safetensors==0.3.1
safetensors==0.4.2
# gradio==3.16.2
altair==4.2.2
easygui==0.98.3
Expand Down
9 changes: 2 additions & 7 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,10 @@
import numpy as np
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
import torchvision
from diffusers import (
AutoencoderKL,
Expand Down
12 changes: 5 additions & 7 deletions sdxl_minimal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
from einops import repeat
import numpy as np
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass

from library.ipex_interop import init_ipex

init_ipex()

from tqdm import tqdm
from transformers import CLIPTokenizer
from diffusers import EulerDiscreteScheduler
Expand Down
9 changes: 2 additions & 7 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,10 @@
from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex
from library.ipex_interop import init_ipex

if torch.xpu.is_available():
from library.ipex import ipex_init
init_ipex()

ipex_init()
except Exception:
pass
from accelerate.utils import set_seed
from diffusers import DDPMScheduler
from library import sdxl_model_util
Expand Down
12 changes: 5 additions & 7 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@

from tqdm import tqdm
import torch
try:
import intel_extension_for_pytorch as ipex
if torch.xpu.is_available():
from library.ipex import ipex_init
ipex_init()
except Exception:
pass

from library.ipex_interop import init_ipex

init_ipex()

from torch.nn.parallel import DistributedDataParallel as DDP
from accelerate.utils import set_seed
import accelerate
Expand Down
Loading

0 comments on commit 8f3d748

Please sign in to comment.