Feature to improve training quality via detection of out-of-tolerance latent mean/std values #2010
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a new feature, currently only for Flux LoRA training (although it could be applied to full fine-tune later too at least). It analyses the latents of training images, and checks that their mean (average) and standard deviation values are near 0.0 and 1.0 respectively.
It was requested as a feature here: std/mean detection code.
Flux is a diffusion model that takes gaussian noise which is set to have a mean of 0.0 and an std of 1.0 as a starting image. So having training images with that same characteristic may offer improved training as the network does not have to learn how to adjust mean and std values. Current diffusion models don't seem to be good at that.
The default tolerances for detection can be set via
--latent_threshold_warn_levels=mean,std_max
, e.g.--latent_threshold_warn_levels=0.15,1.40
. The test can be disabled using--latent_threshold_warn_levels=disable
If images do not pass the threshold test, then a warning message appears like this:
The std_max value sets the upper limit for the standard deviation. A lower limit is also set to 1.0 / std_max. For example, a std_max value of 1.40 also creates a lower threshold of around 0.714.
Sometimes it's not obvious why the mean and std values are not near 0 and 1. In that case, a parameter
--latent_threshold_visualizer
can be passed in which will show the latent average values in a window. (This has been tested on Ubuntu Linux. Please can someone try it on Windows? But it should probably work).Various changes that can move latent means/stds towards 0,1 are possible. e.g.:
Or even:
One thing I've found in the small number of days I've been training with images that are closer to mean/std 0,1 is that I've had to raise the alpha value of my training, and also reduce the LR. I think that might be because the 'gravity well' of the model remaining closer to base model quality is stronger due to it not being disrupted by training images that are outside the 0,1 distribution.
Edit: I need to check that that test for
args.latent_threshold_warn_levels
in train_network.py doesn't break e.g. SDXL training, which won't have that option.