Skip to content

Commit

Permalink
Updated guidance bypass mechanism to use built-in Flux.params.guidanc…
Browse files Browse the repository at this point in the history
…e_embed bool
  • Loading branch information
stepfunction83 committed Jan 23, 2025
1 parent a768d53 commit 1ade582
Showing 1 changed file with 3 additions and 23 deletions.
26 changes: 3 additions & 23 deletions library/flux_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,13 @@
MODEL_NAME_DEV = "dev"
MODEL_NAME_SCHNELL = "schnell"

def guidance_embed_bypass_forward(self, timestep, guidance, pooled_projection):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(
timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
pooled_projections = self.text_embedder(pooled_projection)
conditioning = timesteps_emb + pooled_projections
return conditioning

# bypass the forward function
# bypass guidance
def bypass_flux_guidance(transformer):
if hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
return
# dont bypass if it doesnt have the guidance embedding
if not hasattr(transformer.time_text_embed, 'guidance_embedder'):
return
transformer.time_text_embed._bfg_orig_forward = transformer.time_text_embed.forward
transformer.time_text_embed.forward = partial(
guidance_embed_bypass_forward, transformer.time_text_embed
)
transformer.params.guidance_embed = False

# restore the forward function
def restore_flux_guidance(transformer):
if not hasattr(transformer.time_text_embed, '_bfg_orig_forward'):
return
transformer.time_text_embed.forward = transformer.time_text_embed._bfg_orig_forward
del transformer.time_text_embed._bfg_orig_forward
transformer.params.guidance_embed = True

def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int], List[str]]:
"""
Expand Down Expand Up @@ -86,7 +67,6 @@ def analyze_checkpoint_state(ckpt_path: str) -> Tuple[bool, bool, Tuple[int, int

is_diffusers = "transformer_blocks.0.attn.add_k_proj.bias" in keys
is_schnell = not ("guidance_in.in_layer.bias" in keys or "time_text_embed.guidance_embedder.linear_1.bias" in keys)
# is_schnell = True

# check number of double and single blocks
if not is_diffusers:
Expand Down

0 comments on commit 1ade582

Please sign in to comment.