Skip to content

Added functionality for fine-tuning the PET heads only #517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

MarkusFas
Copy link

@MarkusFas MarkusFas commented Mar 11, 2025

added option to train only on heads layers of PET.
Invoked by FINETUNE_HEADS = True in options.yaml

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?

📚 Documentation preview 📚: https://metatrain--517.org.readthedocs.build/en/517/

@MarkusFas MarkusFas requested a review from abmazitov as a code owner March 11, 2025 17:21
@abmazitov abmazitov changed the title Markus Added functionality for fine-tuning the PET heads only Mar 11, 2025
@abmazitov
Copy link
Contributor

Thank you @MarkusFas ! I think we also need to add some checks if the LoRA and Heads fine-tuning is not enabled simultaneously.

if model.is_lora_applied and not FITTING_SCHEME.USE_LORA_PEFT:
if FITTING_SCHEME.FINETUNING not in [None, "lora", "heads"]:
raise ValueError("Finetuning only allows 'lora' or 'heads' option")
if model.ft_type == "lora" and not FITTING_SCHEME.FINETUNING == "lora":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to add a similar check for the "heads" fine tuning. The idea behind this check is to avoid the cases when fine-tuning is activated only in the checkpoint, but not explicitly stated in the hypers. In this case the user might think that the normal full training happens, while only the fine-tunable weights will be adapted

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that when you load a checkpoint it should not matter if it was trained fully or finetuned until that point, that's why I left the check out. If you disagree, I will add the check, sure!

}
elif model.ft_type == "lora":
ft_state_dict = {
"ft_type": "lora",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not optimal. One way of a proper implementation of this logic would look like this:

if model.ft_type is not None:
    ft_state_dict = {"ft_type": model.ft_type}
    if model.ft_type == 'lora':
        ft_state_dict.update({
            "lora_rank": model.pet.model.rank,
            "lora_alpha": model.pet.model.alpha,
        })
else:
    ft_state_dict = None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will change it, thanks!

def __init__(
self,
model: torch.nn.Module,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might misunderstood each other. When I said it would be cool to have a dict with options as an argument to the model, I didn't mean that you need to hide these options in kwargs. I think will work better as a class signature

class FinetuneWrapper(torch.nn.Module):
    def __init__(
        self,
        model: torch.nn.Module,
        ft_type: str,
        lora_rank: Optional[int] = None,
        lora_alpha: Optional[float] = None
    ):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok! Will change it!

self.hypers = model.hypers
self.hidden_dim = model.hypers.TRANSFORMER_D_MODEL
self.num_hidden_layers = model.hypers.N_GNN_LAYERS * model.hypers.N_TRANS_LAYERS
if kwargs["ft_type"] == "heads":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how you can avoid these calls for kwargs["ft_type"]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants