-
Notifications
You must be signed in to change notification settings - Fork 8
Added functionality for fine-tuning the PET heads only #517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you @MarkusFas ! I think we also need to add some checks if the LoRA and Heads fine-tuning is not enabled simultaneously. |
- added new options for specifiying FT - new Wrapper class
if model.is_lora_applied and not FITTING_SCHEME.USE_LORA_PEFT: | ||
if FITTING_SCHEME.FINETUNING not in [None, "lora", "heads"]: | ||
raise ValueError("Finetuning only allows 'lora' or 'heads' option") | ||
if model.ft_type == "lora" and not FITTING_SCHEME.FINETUNING == "lora": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to add a similar check for the "heads" fine tuning. The idea behind this check is to avoid the cases when fine-tuning is activated only in the checkpoint, but not explicitly stated in the hypers. In this case the user might think that the normal full training happens, while only the fine-tunable weights will be adapted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that when you load a checkpoint it should not matter if it was trained fully or finetuned until that point, that's why I left the check out. If you disagree, I will add the check, sure!
src/metatrain/pet/trainer.py
Outdated
} | ||
elif model.ft_type == "lora": | ||
ft_state_dict = { | ||
"ft_type": "lora", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not optimal. One way of a proper implementation of this logic would look like this:
if model.ft_type is not None:
ft_state_dict = {"ft_type": model.ft_type}
if model.ft_type == 'lora':
ft_state_dict.update({
"lora_rank": model.pet.model.rank,
"lora_alpha": model.pet.model.alpha,
})
else:
ft_state_dict = None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will change it, thanks!
def __init__( | ||
self, | ||
model: torch.nn.Module, | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might misunderstood each other. When I said it would be cool to have a dict with options as an argument to the model, I didn't mean that you need to hide these options in kwargs. I think will work better as a class signature
class FinetuneWrapper(torch.nn.Module):
def __init__(
self,
model: torch.nn.Module,
ft_type: str,
lora_rank: Optional[int] = None,
lora_alpha: Optional[float] = None
):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok! Will change it!
self.hypers = model.hypers | ||
self.hidden_dim = model.hypers.TRANSFORMER_D_MODEL | ||
self.num_hidden_layers = model.hypers.N_GNN_LAYERS * model.hypers.N_TRANS_LAYERS | ||
if kwargs["ft_type"] == "heads": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's how you can avoid these calls for kwargs["ft_type"]
added option to train only on heads layers of PET.
Invoked by FINETUNE_HEADS = True in options.yaml
Contributor (creator of pull-request) checklist
Reviewer checklist
📚 Documentation preview 📚: https://metatrain--517.org.readthedocs.build/en/517/