Skip to content

Feature to improve training quality via detection of out-of-tolerance latent mean/std values #2010

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: sd3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,137 @@ def sd_saver(ckpt_file, epoch_no, global_step):

# endregion

# region Latent analysis

def check_latent_means_and_stds_against_thresholds(thresholds_string, latent_threshold_visualizer, image_data):

# Skip mean/std check?
if thresholds_string.lower() == "disable":
return

# Thresholds should be in 'mean,std' format; split on comma
parts = thresholds_string.split(',')
if len(parts) != 2:
logger.error(f"latent_threshold_warn_levels was set to '{thresholds_string}', "
"Expected latent threshold warning string to either be in 'mean,std' format, or 'disable'")
return

mean_thresh = float(parts[0].strip()) # Magnitude
std_thresh_max = float(parts[1].strip()) # This threshold is the max value. The min value of 1.0 / std_thresh is also tested against

if std_thresh_max < 1.0:
logger.error("Expected std threshold warning level to be >= 1.0. (Std values are "
"automatically checked against a lower bound of '1.0 / std threshold warning level')")
return

std_thresh_min = 1.0 / std_thresh_max

# Start forming a list of results, one for each latent
results = []

# Load and check each latent in turn
logger.info('Checking latent means/stds:')
for image_filename in tqdm(image_data):
image_info = image_data[image_filename]

# Load the latent
with np.load(image_info.latents_npz) as latents:

latent_name = f'latents_{image_info.bucket_reso[1] // 8}x{image_info.bucket_reso[0] // 8}.npy'
latent = latents[latent_name] # Only checking the unflipped latent

image_filename_no_path = image_filename.rsplit('/', 1)[-1]

# Check mean
mean = np.average(latent)
if mean < -mean_thresh or mean > mean_thresh:
warn_mean = abs(mean - mean_thresh) # Out of tolerance
else:
warn_mean = 0 # Passed mean check

# Check std
std = np.std(latent)
if std < std_thresh_min or std > std_thresh_max:
if std > std_thresh_max:
# log base 2 is not necessarily the ideal function, but hopefully it'll roughly
# balance an out-of-threshold std against an out-of-threshold mean in terms
# of magnitude
warn_std = math.log(std / std_thresh_max, 2) # Out of tolerance (too large)
else:
warn_std = math.log(std_thresh_min / std, 2) # Out of tolerance (too small)
else:
warn_std = 0 # Passed std check

# The first element is how notable this latent's mean and std is considered to be
# for the list of 'most out-of-threshold results' to warn about
results += [[warn_mean + warn_std, mean, std, image_filename_no_path, image_filename]]

# Sort the results into order of most notably out of threshold first
results.sort(key=lambda x: -x[0])

# List a few test failure image results

for i, result in enumerate(results):

if i >= 3: # Three results maximum
break
if result[0] == 0.0: # Fewer than 3 images that did not pass?
break

if i == 0:
logger.warning("Images are being trained on that have out-of-tolerance latent mean or std values. "
"Training may improve if these images are changed/deleted. Remember to delete the images' _flux.npz "
"files by hand if you modify the image, as they will not necessarily be automatically regenerated. "
"Here is a list (of up to three) out-of-tolerance images: (Consider using --latent_threshold_visualizer "
"to diagnose)")

print(f'Mean,std = [{result[1]:.3f}, {result[2]:.3f}]: {result[3]}')

# Show one latent test failure result visually in a window?
if results[0][0] > 0.0 and latent_threshold_visualizer:
# Re-fetch the latent for the 'worst' test fail latent
image_info = image_data[results[0][4]] # Get image_info by image filename
with np.load(image_info.latents_npz) as latents:
latent_name = f'latents_{image_info.bucket_reso[1] // 8}x{image_info.bucket_reso[0] // 8}.npy'
latent = latents[latent_name] # Only show the unflipped latent

# Average the latent's 16 channels together and clip to some reasonable range for the Flux AE.
averaged = np.mean(latent, axis=0)
averaged_clipped = np.clip(averaged, -1, 1)

rgb = np.zeros((latent.shape[1], latent.shape[2], 3), dtype=np.uint8)

# For negative values: Blue (fade from black to full blue)
mask_neg = averaged_clipped < 0
blue_intensity = (np.abs(averaged_clipped[mask_neg]) * 255).astype(np.uint8)
rgb[mask_neg, 2] = blue_intensity

# For positive values: Red (fade from black to full red)
mask_pos = averaged_clipped > 0
red_intensity = (averaged_clipped[mask_pos] * 255).astype(np.uint8)
rgb[mask_pos, 0] = red_intensity

# Scale up 8x both for clarity and to match the original image size
import cv2
scale_factor = 8
scaled_rgb = cv2.resize(
rgb,
(rgb.shape[1] * scale_factor, rgb.shape[0] * scale_factor), # (width, height)
interpolation=cv2.INTER_NEAREST
)

# Show the latent average image
window_name = f"{results[0][3]}: blue -'ve, red +'ve."
cv2.imshow(window_name, cv2.cvtColor(scaled_rgb, cv2.COLOR_RGB2BGR))
while True: # Wait until window is closed or escape key is pressed
key = cv2.waitKey(1) & 0xFF
if key == 27 or cv2.getWindowProperty(window_name, cv2.WND_PROP_VISIBLE) < 1:
break
cv2.destroyAllWindows()

pass

# endregion

def add_flux_train_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
Expand Down Expand Up @@ -617,3 +748,20 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser):
default=3.0,
help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。",
)

# Latent mean/std analysis tools
parser.add_argument(
"--latent_threshold_warn_levels",
type=str,
default="0.16,1.35",
help='Flux may train better if the training latents have a mean of 0.0 and an std of 1.0. '
'Set this parameter to "mean_thresh,std_thresh" to warn if tolerances are exceeded, or "disabled" to skip checks. '
'Mean is tested to be in [-mean_thresh..+mean_thresh] range, std in [1.0/std_thresh..std_thresh] range'
)
parser.add_argument(
"--latent_threshold_visualizer",
action="store_true",
help="If --latent_threshold_warn_levels detects at least one out-of-threshold latent, one of them is "
"shown on screen with red/blue blocks to show +'ve / -'ve latent values. This can help to identify "
"why this image has mean and std values that differ significantly from 0.0 and 1.0"
)
10 changes: 10 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,16 @@ def train(self, args):
persistent_workers=args.persistent_data_loader_workers,
)

# Warn user if any latents have mean values that are further than a theshold level away
# from 0.0, or that have standard deviations outside a threshold scale from 1.0.
if args.latent_threshold_warn_levels is not None:
# (Flux only for now, but this could be updated to support e.g. SDXL or SD3)
from library.flux_train_utils import check_latent_means_and_stds_against_thresholds
check_latent_means_and_stds_against_thresholds(
args.latent_threshold_warn_levels,
args.latent_threshold_visualizer,
train_dataset_group.image_data)

# 学習ステップ数を計算する
if args.max_train_epochs is not None:
args.max_train_steps = args.max_train_epochs * math.ceil(
Expand Down