Skip to content

Commit

Permalink
add gradual latent
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 23, 2023
1 parent 6d6d862 commit 6849546
Show file tree
Hide file tree
Showing 3 changed files with 595 additions and 2 deletions.
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,35 @@
# Gradual Latent について

latentのサイズを徐々に大きくしていくHires fixです。`sdxl_gen_img.py` に以下のオプションが追加されています。

- `--gradual_latent_timesteps` : latentのサイズを大きくし始めるタイムステップを指定します。デフォルトは None で、Gradual Latentを使用しません。
- `--gradual_latent_ratio` : latentの初期サイズを指定します。デフォルトは 0.5 で、デフォルトの latent サイズの半分のサイズから始めます。
- `--gradual_latent_ratio_step`: latentのサイズを大きくする割合を指定します。デフォルトは 0.125 で、latentのサイズを 0.625, 0.75, 0.875, 1.0 と徐々に大きくします。
- `--gradual_latent_ratio_every_n_steps`: latentのサイズを大きくする間隔を指定します。デフォルトは 3 で、3ステップごとに latent のサイズを大きくします。

それぞれのオプションは、プロンプトオプション、`--glt``--glr``--gls``--gle` でも指定できます。

__サンプラーに `euler_a` を指定してください。__ 他のサンプラーでは動作しません。

`gen_img_diffusers.py` にも同様のオプションが追加されていますが、試した範囲ではどうやっても乱れた画像しか生成できませんでした。

# About Gradual Latent

Gradual Latent is a Hires fix that gradually increases the size of the latent. `sdxl_gen_img.py` has the following options added.

- `--gradual_latent_timesteps`: Specifies the timestep to start increasing the size of the latent. The default is None, which means Gradual Latent is not used.
- `--gradual_latent_ratio`: Specifies the initial size of the latent. The default is 0.5, which means it starts with half the default latent size.
- `--gradual_latent_ratio_step`: Specifies the ratio to increase the size of the latent. The default is 0.125, which means the latent size is gradually increased to 0.625, 0.75, 0.875, 1.0.
- `--gradual_latent_ratio_every_n_steps`: Specifies the interval to increase the size of the latent. The default is 3, which means the latent size is increased every 3 steps.

Each option can also be specified with prompt options, `--glt`, `--glr`, `--gls`, `--gle`.

__Please specify `euler_a` for the sampler.__ It will not work with other samplers.

`gen_img_diffusers.py` also has the same options, but in the range I tried, it only generated distorted images no matter what I did.

---

__SDXL is now supported. The sdxl branch has been merged into the main branch. If you update the repository, please follow the upgrade instructions. Also, the version of accelerate has been updated, so please run accelerate config again.__ The documentation for SDXL training is [here](./README.md#sdxl-training).

This repository contains training, generation and utility scripts for Stable Diffusion.
Expand Down
247 changes: 246 additions & 1 deletion gen_img_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,14 @@ def add_token_replacement_XTI(self, target_token_id, rep_token_ids):
def set_control_nets(self, ctrl_nets):
self.control_nets = ctrl_nets

def set_gradual_latent(self, gradual_latent):
if gradual_latent is None:
print("gradual_latent is disabled")
self.gradual_latent = None
else:
print(f"gradual_latent is enabled: {gradual_latent}")
self.gradual_latent = gradual_latent # (ds_ratio, start_timesteps, every_n_steps, ratio_step)

# region xformersとか使う部分:独自に書き換えるので関係なし

def enable_xformers_memory_efficient_attention(self):
Expand Down Expand Up @@ -958,7 +966,41 @@ def __call__(
else:
text_emb_last = text_embeddings

enable_gradual_latent = False
if self.gradual_latent:
if not hasattr(self.scheduler, "set_resized_size"):
print("gradual_latent is not supported for this scheduler. Ignoring.")
print(self.scheduler.__class__.__name__)
else:
enable_gradual_latent = True
current_ratio, start_timesteps, every_n_steps, ratio_step = self.gradual_latent
step_elapsed = 1000

# first, we downscale the latents to the specified ratio / 最初に指定された比率にlatentsをダウンスケールする
height, width = latents.shape[-2:]
org_dtype = latents.dtype
if org_dtype == torch.bfloat16:
latents = latents.float()
latents = torch.nn.functional.interpolate(
latents, scale_factor=current_ratio, mode="bicubic", align_corners=False
).to(org_dtype)

for i, t in enumerate(tqdm(timesteps)):
resized_size = None
if enable_gradual_latent:
# gradually upscale the latents / latentsを徐々にアップスケールする
if t < start_timesteps and current_ratio < 1.0 and step_elapsed >= every_n_steps:
print("upscale")
current_ratio = min(current_ratio + ratio_step, 1.0)
h = int(height * current_ratio) # // 8 * 8
w = int(width * current_ratio) # // 8 * 8
resized_size = (h, w)
self.scheduler.set_resized_size(resized_size)
step_elapsed = 0
else:
self.scheduler.set_resized_size(None)
step_elapsed += 1

# expand the latents if we are doing classifier free guidance
latent_model_input = latents.repeat((num_latent_input, 1, 1, 1))
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
Expand Down Expand Up @@ -2112,6 +2154,133 @@ def replacer():
return prompts


# endregion

# region Gradual Latent hires fix

import diffusers.schedulers.scheduling_euler_ancestral_discrete
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput


class EulerAncestralDiscreteSchedulerGL(EulerAncestralDiscreteScheduler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.resized_size = None

def set_resized_size(self, size):
self.resized_size = size

def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or tuple.
Returns:
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteSchedulerOutput`] is returned,
otherwise a tuple is returned where the first element is the sample tensor.
"""

if isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)

if not self.is_scale_input_called:
logger.warning(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)

if self.step_index is None:
self._init_step_index(timestep)

sigma = self.sigmas[self.step_index]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else:
raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`")

sigma_from = self.sigmas[self.step_index]
sigma_to = self.sigmas[self.step_index + 1]
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5

# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma

dt = sigma_down - sigma

prev_sample = sample + derivative * dt

device = model_output.device
if self.resized_size is None:
noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
model_output.shape, dtype=model_output.dtype, device=device, generator=generator
)
else:
print(
"resized_size", self.resized_size, "model_output.shape", model_output.shape, "prev_sample.shape", prev_sample.shape
)
org_dtype = prev_sample.dtype
if org_dtype == torch.bfloat16:
prev_sample = prev_sample.float()

prev_sample = torch.nn.functional.interpolate(
prev_sample.float(), size=self.resized_size, mode="bicubic", align_corners=False
).to(dtype=org_dtype)

noise = diffusers.schedulers.scheduling_euler_ancestral_discrete.randn_tensor(
(model_output.shape[0], model_output.shape[1], self.resized_size[0], self.resized_size[1]),
dtype=model_output.dtype,
device=device,
generator=generator,
)

prev_sample = prev_sample + noise * sigma_up

# upon completion increase step index by one
self._step_index += 1

if not return_dict:
return (prev_sample,)

return EulerAncestralDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)


# endregion


Expand Down Expand Up @@ -2249,7 +2418,7 @@ def main(args):
scheduler_cls = EulerDiscreteScheduler
scheduler_module = diffusers.schedulers.scheduling_euler_discrete
elif args.sampler == "euler_a" or args.sampler == "k_euler_a":
scheduler_cls = EulerAncestralDiscreteScheduler
scheduler_cls = EulerAncestralDiscreteSchedulerGL
scheduler_module = diffusers.schedulers.scheduling_euler_ancestral_discrete
elif args.sampler == "dpmsolver" or args.sampler == "dpmsolver++":
scheduler_cls = DPMSolverMultistepScheduler
Expand Down Expand Up @@ -2505,6 +2674,16 @@ def __getattr__(self, item):
if args.ds_depth_1 is not None:
unet.set_deep_shrink(args.ds_depth_1, args.ds_timesteps_1, args.ds_depth_2, args.ds_timesteps_2, args.ds_ratio)

# Gradual Latent
if args.gradual_latent_ratio is not None:
gradual_latent = (
args.gradual_latent_ratio,
args.gradual_latent_timesteps,
args.gradual_latent_every_n_steps,
args.gradual_latent_ratio_step,
)
pipe.set_gradual_latent(gradual_latent)

# Extended Textual Inversion および Textual Inversionを処理する
if args.XTI_embeddings:
diffusers.models.UNet2DConditionModel.forward = unet_forward_XTI
Expand Down Expand Up @@ -3096,6 +3275,12 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
ds_timesteps_2 = args.ds_timesteps_2
ds_ratio = args.ds_ratio

# Gradual Latent
gl_timesteps = None # means no override
gl_ratio = args.gradual_latent_ratio
gl_every_n_steps = args.gradual_latent_every_n_steps
gl_ratio_step = args.gradual_latent_ratio_step

prompt_args = raw_prompt.strip().split(" --")
prompt = prompt_args[0]
print(f"prompt {prompt_index+1}/{len(prompt_list)}: {prompt}")
Expand Down Expand Up @@ -3202,6 +3387,34 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
print(f"deep shrink ratio: {ds_ratio}")
continue

# Gradual Latent
m = re.match(r"glt ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent timesteps
gl_timesteps = int(m.group(1))
print(f"gradual latent timesteps: {gl_timesteps}")
continue

m = re.match(r"glr ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio
gl_ratio = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio: {ds_ratio}")
continue

m = re.match(r"gle ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent every n steps
gl_every_n_steps = int(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent every n steps: {gl_every_n_steps}")
continue

m = re.match(r"gls ([\d\.]+)", parg, re.IGNORECASE)
if m: # gradual latent ratio step
gl_ratio_step = float(m.group(1))
gl_timesteps = gl_timesteps if gl_timesteps is not None else -1 # -1 means override
print(f"gradual latent ratio step: {gl_ratio_step}")
continue

except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
Expand All @@ -3212,6 +3425,12 @@ def process_batch(batch: List[BatchData], highres_fix, highres_1st=False):
ds_depth_1 = args.ds_depth_1 or 3
unet.set_deep_shrink(ds_depth_1, ds_timesteps_1, ds_depth_2, ds_timesteps_2, ds_ratio)

# override Gradual Latent
if gl_ratio is not None:
if gl_timesteps is None:
gl_timesteps = args.gradual_latent_timesteps or 650
pipe.set_gradual_latent((gl_ratio, gl_timesteps, gl_every_n_steps, gl_ratio_step))

# prepare seed
if seeds is not None: # given in prompt
# 数が足りないなら前のをそのまま使う
Expand Down Expand Up @@ -3585,6 +3804,32 @@ def setup_parser() -> argparse.ArgumentParser:
"--ds_ratio", type=float, default=0.5, help="Deep Shrink ratio for downsampling / Deep Shrinkのdownsampling比率"
)

# gradual latent
parser.add_argument(
"--gradual_latent_timesteps",
type=int,
default=None,
help="enable Gradual Latent hires fix and apply upscaling from this timesteps / Gradual Latent hires fixをこのtimestepsで有効にし、このtimestepsからアップスケーリングを適用する",
)
parser.add_argument(
"--gradual_latent_ratio",
type=float,
default=0.5,
help=" this size ratio, 0.5 means 1/2 / Gradual Latent hires fixをこのサイズ比率で有効にする、0.5は1/2を意味する",
)
parser.add_argument(
"--gradual_latent_ratio_step",
type=float,
default=0.125,
help="step to increase ratio for Gradual Latent / Gradual Latentのratioをどのくらいずつ上げるか",
)
parser.add_argument(
"--gradual_latent_every_n_steps",
type=int,
default=3,
help="steps to increase size of latents every this steps for Gradual Latent / Gradual Latentでlatentsのサイズをこのステップごとに上げる",
)

return parser


Expand Down
Loading

0 comments on commit 6849546

Please sign in to comment.