Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prior model preservation #505

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft

Prior model preservation #505

wants to merge 2 commits into from

Conversation

dxqbYD
Copy link
Contributor

@dxqbYD dxqbYD commented Oct 11, 2024

This code can be used to preserve the prior model on prompts other than the trained captions. After several more tests I think this is worth implementing and a quite generic feature:

  • It does not require any regularization image data. It works even when using the same training data for the reg steps as for the regular training steps.
  • It does not require a regularization caption. An empty caption for the reg steps works, indicating that this can preserve all kinds of concepts and whatever you train on
  • Additionally, it might improve training results on the trained captions, but I am not sure about this yet.

Let me know if I should provide more details here, which you can currently find on the OT discord.
There is a feature request for SimpleTuner here: bghira/SimpleTuner#1031

This is a draft PR only to determine the interest for a full PR. It only works with batch size one, only for Flux, only for LoRA, and only for transformer.

It could be implemented generically for all LoRA. With major effort, it could be implemented for Full Finetune, but to avoid having the full model in VRAM twice, pre-generation of reg steps predictions would be necessary.

@FurkanGozukara
Copy link

@dxqbYD can you add examples? your examples are great. even though i couldnt make it work maybe after properly implemented it will work :D

so examples of comparison and how you did setup your concepts

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Oct 14, 2024

samples can be found in these release notes of SimpleTuner: https://www.reddit.com/r/StableDiffusion/comments/1g2i13s/simpletuner_v112_now_with_masked_loss_training/

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Oct 19, 2024

kohya implementation: kohya-ss/sd-scripts#1710

@Nerogar
Copy link
Owner

Nerogar commented Oct 20, 2024

This sounds like a really good idea to add as an option. But it definitely needs a more generic implementation. There are two issues to solve

Dataset

How do we select the regularization samples during training? This also needs to work with a higher batch size than 1. Ideally it would mix regularization samples and normal training samples within the same batch.
"It does not require a regularization caption" I don't think this is strictly true. You need some kind of conditioning for the model. Not conditioning the model at all will probably significantly reduce the effect of this training method.
What do you think about adding a new flag to concepts that toggles this loss calculation for specific training samples? Then the user can decide whether to include captions or not, and which images to use.

Unhooking the LoRA

Each model has different sub-modules. So we need a generic method of disabling the LoRA for the prior result. A function in the model class to enable/disable all LoRAs could work well.

@bghira
Copy link

bghira commented Oct 20, 2024

how do you intend on mixing regularisation and training samples in a single batch @Nerogar ? that seems like not trivial. the actual target is changed.

@Nerogar
Copy link
Owner

Nerogar commented Oct 20, 2024

The only difference between prior preservation and normal training is the prediction target. So what I would do is basically this:

  1. Find the samples in the batch where the prior_preservation flag is set to True
  2. Calculate the prior prediction without the LoRA for those samples
  3. Replace the target of the batch in those samples with the prior prediction
  4. Calculate the loss without any modification

@bghira
Copy link

bghira commented Oct 20, 2024

yes, unfortunately it just doesn't have the same regularisation effect to do it that way. having an entire batch pull back toward the model works.

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Oct 20, 2024

yes, unfortunately it just doesn't have the same regularisation effect to do it that way. having an entire batch pull back toward the model works.

what are you basing this on?

what Nerogar describes above is what kohya has implemented. So if true, that would mean kohya's implementation doesn't work (as well)

@bghira
Copy link

bghira commented Oct 20, 2024

basing it on numerous tests we've run on a cluster of H100s over the last week

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Oct 20, 2024

How do we select the regularization samples during training? This also needs to work with a higher batch size than 1. Ideally it would mix regularization samples and normal training samples within the same batch. "It does not require a regularization caption" I don't think this is strictly true. You need some kind of conditioning for the model. Not conditioning the model at all will probably significantly reduce the effect of this training method.

It isn't obvious that this would work without captions, but it does. You can see samples in the reddit link above. The right-most column is without captions.

What do you think about adding a new flag to concepts that toggles this loss calculation for specific training samples? Then the user can decide whether to include captions or not, and which images to use.

Yes, agreed. There are more use cases than captions in favor of having it as a separate concept, for example balancing the regularisation using the number of repeats. In some of my tests, 1:1 was too much.

@bghira has also found using his implementation in SimpleTuner that even though it works with no external data, it works better against high-quality external data.

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Oct 20, 2024

basing it on numerous tests we've run on a cluster of H100s over the last week

okay thanks. any theory on why that would be? I don't see a theoretical reason for your finding that it works better on a separate batch:
reg gradients are tiny.
the regularisation described in the Dreambooth paper was always implemented in the same batch in the early scripts.
you could even argue that this type of contrastive training should work better in the same batch.

@O-J1
Copy link
Collaborator

O-J1 commented Oct 21, 2024

basing it on numerous tests we've run on a cluster of H100s over the last week

Could you please provide some evidence of this? I.e a significant enough amount of samples that your aren’t falling victim to seed rng

it’s important to get this right

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Oct 21, 2024

basing it on numerous tests we've run on a cluster of H100s over the last week

Could you please provide some evidence of this? I.e a significant enough amount of samples that your aren’t falling victim to seed rng

it’s important to get this right

if this turns out to be right, I'd recommend to implement a feature into the OT concepts like
"try to keep this concept separate from concept Y in batches"
and
"try to combine this concept with concept Y in batches"

It would influence how the batches are built, and the first option would be how ST builds batches.

This could be a useful feature on its own. For example, if you train 2 concepts, it can be beneficial to have 1 image of each concept in a batch, instead of the same concept twice, especially if the images in a concept are very similar.

@bghira
Copy link

bghira commented Oct 21, 2024

i dont have time, sorry, do it however works best for your codebase.

@DriveHabits
Copy link

any update on this @dxqbYD

@dxqbYD
Copy link
Contributor Author

dxqbYD commented Nov 13, 2024

any update on this @dxqbYD

nothing usable for OneTrainer users yet.
more interesting experiments beyond just preserving prior knowledge of a separate prompt as above: It appears it can also be very useful when training a concept, controlling for what you don't want it to learn. The concept can then be mixed in by prompting, and even mixing with other independently trained LoRAs seems to work better then.

I should mention that there was apparently a paper published proposing this technique in April of this year, I just didn't know about it: https://arxiv.org/pdf/2404.07554
The authors have pointed this out at the PR of kohya's implementation. They coined it "Contrastive Adapter Training"

@FurkanGozukara
Copy link

@dxqbYD so we have it in kohya atm? i couldnt find

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants