Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EMA training for PEFT LoRAs #9998

Open
bghira opened this issue Nov 22, 2024 · 3 comments
Open

EMA training for PEFT LoRAs #9998

bghira opened this issue Nov 22, 2024 · 3 comments
Labels
enhancement New feature or request

Comments

@bghira
Copy link
Contributor

bghira commented Nov 22, 2024

Is your feature request related to a problem? Please describe.

EMAModel in Diffusers is not plumbed for interacting well with PEFT LoRAs, which leaves users to implement their own.

The idea has been thrown around that LoRA did not benefit from EMA, and research papers had shown this. However, after curiosity piqued, took a bit but managed to make it work.

Here is a pull request for SimpleTuner where I've updated my EMAModel implementation to behave more like nn.Module and allow EMAModel to be passed into more processes without "funny business".

This spot in the save hooks was hardcoded to take the class name following Diffusers convention but we can do more dynamic approach in perhaps a training_utils helper method.

Just a bit downward at L208 in the save hooks, I did something I'm not really 100% happy with, but users were:

  • For my own trainer's convenience, I save a copy of the EMA model in a simple loadable state_dict format so that I can load this during resume.
  • Additionally, we save a 2nd copy of the EMA in the PEFT LoRA format so that it can be loaded by pipelines.

The tricky part is the 2nd copy of the EMA model that gets saved in the standard LoRA format:

        if self.args.use_ema:
            # we'll temporarily overwrite teh LoRA parameters with the EMA parameters to save it.
            logger.info("Saving EMA model to disk.")
            trainable_parameters = [
                p
                for p in self._primary_model().parameters()
                if p.requires_grad
            ]
            self.ema_model.store(trainable_parameters)
            self.ema_model.copy_to(trainable_parameters)
            if self.transformer is not None:
                self.pipeline_class.save_lora_weights(
                    os.path.join(output_dir, "ema"),
                    transformer_lora_layers=convert_state_dict_to_diffusers(
                        get_peft_model_state_dict(self._primary_model())
                    ),
                )
            elif self.unet is not None:
                self.pipeline_class.save_lora_weights(
                    os.path.join(output_dir, "ema"),
                    unet_lora_layers=convert_state_dict_to_diffusers(
                        get_peft_model_state_dict(self._primary_model())
                    ),
                )
            self.ema_model.restore(trainable_parameters)

this could probably be done more nicely with a trainable_parameters() method on the model classes where appropriate.

I guess the decorations with converting state dicts are required for now, but it would be ideal if this could be simplified so that newcomers do not have to look into and understand so many moving pieces.

For quantised training, we have to quantise the EMA model just like the trained model had done to it.

The validations were kind of a pain but I wanted to make the EMA load/unload possible to do during the process repeatedly so that each prompt can be validated for the ckpt as well as the EMA weights. Here is my method for enabling (and just below, disabling) the EMA model at inference time.

However, the effect is really nice; here you see the starting SD 3.5M on the left, the trained LoRA in the centre, and EMA on the right.

image

image

image

image

image

these samples are from 60,000 steps of training a rank-128 PEFT LoRA on all of the attn layers for the SD 3.5 Medium model on ~120,000 high quality photos.

while it's not a cure-all for training problems, throughout the entire duration of training, the EMA model has outperformed the trained checkpoint.

It'd be a good idea to consider someday including EMA for LoRA with related improvements for saving/loading EMA weights on adapters so that users can receive better results from the training examples. I don't think the validation changes are needed, but they can be done in a non-intrusive way, more nicely than I have done here.

@bghira
Copy link
Contributor Author

bghira commented Nov 22, 2024

cc @linoytsaban @sayakpaul for your interest perhaps

@sayakpaul
Copy link
Member

Thanks for the interesting thread.

I think for now we can refer the users to SimpleTuner for this. Also, perhaps, it's subjective but I don't necessarily find the EMA results to be better than what's without.

@bghira
Copy link
Contributor Author

bghira commented Nov 23, 2024

yeah the centre's outputs are actually entirely incoherent. don't know why that is preferred

@sayakpaul sayakpaul added the enhancement New feature or request label Nov 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants