diff --git a/library/flux_train_utils.py b/library/flux_train_utils.py index f7f06c5cf..10f3f5651 100644 --- a/library/flux_train_utils.py +++ b/library/flux_train_utils.py @@ -550,6 +550,137 @@ def sd_saver(ckpt_file, epoch_no, global_step): # endregion +# region Latent analysis + +def check_latent_means_and_stds_against_thresholds(thresholds_string, latent_threshold_visualizer, image_data): + + # Skip mean/std check? + if thresholds_string.lower() == "disable": + return + + # Thresholds should be in 'mean,std' format; split on comma + parts = thresholds_string.split(',') + if len(parts) != 2: + logger.error(f"latent_threshold_warn_levels was set to '{thresholds_string}', " + "Expected latent threshold warning string to either be in 'mean,std' format, or 'disable'") + return + + mean_thresh = float(parts[0].strip()) # Magnitude + std_thresh_max = float(parts[1].strip()) # This threshold is the max value. The min value of 1.0 / std_thresh is also tested against + + if std_thresh_max < 1.0: + logger.error("Expected std threshold warning level to be >= 1.0. (Std values are " + "automatically checked against a lower bound of '1.0 / std threshold warning level')") + return + + std_thresh_min = 1.0 / std_thresh_max + + # Start forming a list of results, one for each latent + results = [] + + # Load and check each latent in turn + logger.info('Checking latent means/stds:') + for image_filename in tqdm(image_data): + image_info = image_data[image_filename] + + # Load the latent + with np.load(image_info.latents_npz) as latents: + + latent_name = f'latents_{image_info.bucket_reso[1] // 8}x{image_info.bucket_reso[0] // 8}.npy' + latent = latents[latent_name] # Only checking the unflipped latent + + image_filename_no_path = image_filename.rsplit('/', 1)[-1] + + # Check mean + mean = np.average(latent) + if mean < -mean_thresh or mean > mean_thresh: + warn_mean = abs(mean - mean_thresh) # Out of tolerance + else: + warn_mean = 0 # Passed mean check + + # Check std + std = np.std(latent) + if std < std_thresh_min or std > std_thresh_max: + if std > std_thresh_max: + # log base 2 is not necessarily the ideal function, but hopefully it'll roughly + # balance an out-of-threshold std against an out-of-threshold mean in terms + # of magnitude + warn_std = math.log(std / std_thresh_max, 2) # Out of tolerance (too large) + else: + warn_std = math.log(std_thresh_min / std, 2) # Out of tolerance (too small) + else: + warn_std = 0 # Passed std check + + # The first element is how notable this latent's mean and std is considered to be + # for the list of 'most out-of-threshold results' to warn about + results += [[warn_mean + warn_std, mean, std, image_filename_no_path, image_filename]] + + # Sort the results into order of most notably out of threshold first + results.sort(key=lambda x: -x[0]) + + # List a few test failure image results + + for i, result in enumerate(results): + + if i >= 3: # Three results maximum + break + if result[0] == 0.0: # Fewer than 3 images that did not pass? + break + + if i == 0: + logger.warning("Images are being trained on that have out-of-tolerance latent mean or std values. " + "Training may improve if these images are changed/deleted. Remember to delete the images' _flux.npz " + "files by hand if you modify the image, as they will not necessarily be automatically regenerated. " + "Here is a list (of up to three) out-of-tolerance images: (Consider using --latent_threshold_visualizer " + "to diagnose)") + + print(f'Mean,std = [{result[1]:.3f}, {result[2]:.3f}]: {result[3]}') + + # Show one latent test failure result visually in a window? + if results[0][0] > 0.0 and latent_threshold_visualizer: + # Re-fetch the latent for the 'worst' test fail latent + image_info = image_data[results[0][4]] # Get image_info by image filename + with np.load(image_info.latents_npz) as latents: + latent_name = f'latents_{image_info.bucket_reso[1] // 8}x{image_info.bucket_reso[0] // 8}.npy' + latent = latents[latent_name] # Only show the unflipped latent + + # Average the latent's 16 channels together and clip to some reasonable range for the Flux AE. + averaged = np.mean(latent, axis=0) + averaged_clipped = np.clip(averaged, -1, 1) + + rgb = np.zeros((latent.shape[1], latent.shape[2], 3), dtype=np.uint8) + + # For negative values: Blue (fade from black to full blue) + mask_neg = averaged_clipped < 0 + blue_intensity = (np.abs(averaged_clipped[mask_neg]) * 255).astype(np.uint8) + rgb[mask_neg, 2] = blue_intensity + + # For positive values: Red (fade from black to full red) + mask_pos = averaged_clipped > 0 + red_intensity = (averaged_clipped[mask_pos] * 255).astype(np.uint8) + rgb[mask_pos, 0] = red_intensity + + # Scale up 8x both for clarity and to match the original image size + import cv2 + scale_factor = 8 + scaled_rgb = cv2.resize( + rgb, + (rgb.shape[1] * scale_factor, rgb.shape[0] * scale_factor), # (width, height) + interpolation=cv2.INTER_NEAREST + ) + + # Show the latent average image + window_name = f"{results[0][3]}: blue -'ve, red +'ve." + cv2.imshow(window_name, cv2.cvtColor(scaled_rgb, cv2.COLOR_RGB2BGR)) + while True: # Wait until window is closed or escape key is pressed + key = cv2.waitKey(1) & 0xFF + if key == 27 or cv2.getWindowProperty(window_name, cv2.WND_PROP_VISIBLE) < 1: + break + cv2.destroyAllWindows() + + pass + +# endregion def add_flux_train_arguments(parser: argparse.ArgumentParser): parser.add_argument( @@ -617,3 +748,20 @@ def add_flux_train_arguments(parser: argparse.ArgumentParser): default=3.0, help="Discrete flow shift for the Euler Discrete Scheduler, default is 3.0. / Euler Discrete Schedulerの離散フローシフト、デフォルトは3.0。", ) + + # Latent mean/std analysis tools + parser.add_argument( + "--latent_threshold_warn_levels", + type=str, + default="0.16,1.35", + help='Flux may train better if the training latents have a mean of 0.0 and an std of 1.0. ' + 'Set this parameter to "mean_thresh,std_thresh" to warn if tolerances are exceeded, or "disabled" to skip checks. ' + 'Mean is tested to be in [-mean_thresh..+mean_thresh] range, std in [1.0/std_thresh..std_thresh] range' + ) + parser.add_argument( + "--latent_threshold_visualizer", + action="store_true", + help="If --latent_threshold_warn_levels detects at least one out-of-threshold latent, one of them is " + "shown on screen with red/blue blocks to show +'ve / -'ve latent values. This can help to identify " + "why this image has mean and std values that differ significantly from 0.0 and 1.0" + ) diff --git a/train_network.py b/train_network.py index 2d279b3bf..0a3d0c7f1 100644 --- a/train_network.py +++ b/train_network.py @@ -753,6 +753,16 @@ def train(self, args): persistent_workers=args.persistent_data_loader_workers, ) + # Warn user if any latents have mean values that are further than a theshold level away + # from 0.0, or that have standard deviations outside a threshold scale from 1.0. + if args.latent_threshold_warn_levels is not None: + # (Flux only for now, but this could be updated to support e.g. SDXL or SD3) + from library.flux_train_utils import check_latent_means_and_stds_against_thresholds + check_latent_means_and_stds_against_thresholds( + args.latent_threshold_warn_levels, + args.latent_threshold_visualizer, + train_dataset_group.image_data) + # 学習ステップ数を計算する if args.max_train_epochs is not None: args.max_train_steps = args.max_train_epochs * math.ceil(