You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is your feature request related to a problem? Please describe.
EMAModel in Diffusers is not plumbed for interacting well with PEFT LoRAs, which leaves users to implement their own.
The idea has been thrown around that LoRA did not benefit from EMA, and research papers had shown this. However, after curiosity piqued, took a bit but managed to make it work.
Here is a pull request for SimpleTuner where I've updated my EMAModel implementation to behave more like nn.Module and allow EMAModel to be passed into more processes without "funny business".
This spot in the save hooks was hardcoded to take the class name following Diffusers convention but we can do more dynamic approach in perhaps a training_utils helper method.
Just a bit downward at L208 in the save hooks, I did something I'm not really 100% happy with, but users were:
For my own trainer's convenience, I save a copy of the EMA model in a simple loadable state_dict format so that I can load this during resume.
Additionally, we save a 2nd copy of the EMA in the PEFT LoRA format so that it can be loaded by pipelines.
The tricky part is the 2nd copy of the EMA model that gets saved in the standard LoRA format:
ifself.args.use_ema:
# we'll temporarily overwrite teh LoRA parameters with the EMA parameters to save it.logger.info("Saving EMA model to disk.")
trainable_parameters= [
pforpinself._primary_model().parameters()
ifp.requires_grad
]
self.ema_model.store(trainable_parameters)
self.ema_model.copy_to(trainable_parameters)
ifself.transformerisnotNone:
self.pipeline_class.save_lora_weights(
os.path.join(output_dir, "ema"),
transformer_lora_layers=convert_state_dict_to_diffusers(
get_peft_model_state_dict(self._primary_model())
),
)
elifself.unetisnotNone:
self.pipeline_class.save_lora_weights(
os.path.join(output_dir, "ema"),
unet_lora_layers=convert_state_dict_to_diffusers(
get_peft_model_state_dict(self._primary_model())
),
)
self.ema_model.restore(trainable_parameters)
this could probably be done more nicely with a trainable_parameters() method on the model classes where appropriate.
I guess the decorations with converting state dicts are required for now, but it would be ideal if this could be simplified so that newcomers do not have to look into and understand so many moving pieces.
For quantised training, we have to quantise the EMA model just like the trained model had done to it.
The validations were kind of a pain but I wanted to make the EMA load/unload possible to do during the process repeatedly so that each prompt can be validated for the ckpt as well as the EMA weights. Here is my method for enabling (and just below, disabling) the EMA model at inference time.
However, the effect is really nice; here you see the starting SD 3.5M on the left, the trained LoRA in the centre, and EMA on the right.
these samples are from 60,000 steps of training a rank-128 PEFT LoRA on all of the attn layers for the SD 3.5 Medium model on ~120,000 high quality photos.
while it's not a cure-all for training problems, throughout the entire duration of training, the EMA model has outperformed the trained checkpoint.
It'd be a good idea to consider someday including EMA for LoRA with related improvements for saving/loading EMA weights on adapters so that users can receive better results from the training examples. I don't think the validation changes are needed, but they can be done in a non-intrusive way, more nicely than I have done here.
The text was updated successfully, but these errors were encountered:
I think for now we can refer the users to SimpleTuner for this. Also, perhaps, it's subjective but I don't necessarily find the EMA results to be better than what's without.
Is your feature request related to a problem? Please describe.
EMAModel in Diffusers is not plumbed for interacting well with PEFT LoRAs, which leaves users to implement their own.
The idea has been thrown around that LoRA did not benefit from EMA, and research papers had shown this. However, after curiosity piqued, took a bit but managed to make it work.
Here is a pull request for SimpleTuner where I've updated my EMAModel implementation to behave more like
nn.Module
and allow EMAModel to be passed into more processes without "funny business".This spot in the save hooks was hardcoded to take the class name following Diffusers convention but we can do more dynamic approach in perhaps a
training_utils
helper method.Just a bit downward at L208 in the save hooks, I did something I'm not really 100% happy with, but users were:
The tricky part is the 2nd copy of the EMA model that gets saved in the standard LoRA format:
this could probably be done more nicely with a
trainable_parameters()
method on the model classes where appropriate.I guess the decorations with converting state dicts are required for now, but it would be ideal if this could be simplified so that newcomers do not have to look into and understand so many moving pieces.
For quantised training, we have to quantise the EMA model just like the trained model had done to it.
The validations were kind of a pain but I wanted to make the EMA load/unload possible to do during the process repeatedly so that each prompt can be validated for the ckpt as well as the EMA weights. Here is my method for enabling (and just below, disabling) the EMA model at inference time.
However, the effect is really nice; here you see the starting SD 3.5M on the left, the trained LoRA in the centre, and EMA on the right.
these samples are from 60,000 steps of training a rank-128 PEFT LoRA on all of the attn layers for the SD 3.5 Medium model on ~120,000 high quality photos.
while it's not a cure-all for training problems, throughout the entire duration of training, the EMA model has outperformed the trained checkpoint.
It'd be a good idea to consider someday including EMA for LoRA with related improvements for saving/loading EMA weights on adapters so that users can receive better results from the training examples. I don't think the validation changes are needed, but they can be done in a non-intrusive way, more nicely than I have done here.
The text was updated successfully, but these errors were encountered: