diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 8436c384db..af818f2d57 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict + from .peft_model import ( PeftModel, PeftModelForCausalLM, @@ -32,6 +36,12 @@ from .utils import PromptLearningConfig +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from .utils.config import PeftConfig + + MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { "SEQ_CLS": PeftModelForSequenceClassification, "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, @@ -50,7 +60,7 @@ } -def get_peft_config(config_dict): +def get_peft_config(config_dict: Dict[str, Any]): """ Returns a Peft config object from a dictionary. @@ -61,7 +71,7 @@ def get_peft_config(config_dict): return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) -def _prepare_prompt_learning_config(peft_config, model_config): +def _prepare_prompt_learning_config(peft_config: PeftConfig, model_config: Dict[str, Any]): if peft_config.num_layers is None: if "num_hidden_layers" in model_config: num_layers = model_config["num_hidden_layers"] @@ -103,7 +113,7 @@ def _prepare_prompt_learning_config(peft_config, model_config): return peft_config -def get_peft_model(model, peft_config, adapter_name="default") -> PeftModel: +def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel: """ Returns a Peft model object from a model and a config. diff --git a/src/peft/peft_model.py b/src/peft/peft_model.py index 548dc5c47a..9f63fd8f47 100644 --- a/src/peft/peft_model.py +++ b/src/peft/peft_model.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import inspect import os import warnings from contextlib import contextmanager from copy import deepcopy -from typing import Optional +from typing import Any, Dict, Optional, Union import torch from accelerate import dispatch_model, infer_auto_device_map @@ -94,7 +96,7 @@ class PeftModel(PushToHubMixin, torch.nn.Module): in the base model if using [`PromptLearningConfig`]. """ - def __init__(self, model, peft_config: PeftConfig, adapter_name="default"): + def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default"): super().__init__() self.base_model = model self.config = self.base_model.config @@ -115,7 +117,7 @@ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"): if getattr(model, "is_gradient_checkpointing", True): model = self._prepare_model_for_gradient_checkpointing(model) - def save_pretrained(self, save_directory, safe_serialization=False, **kwargs): + def save_pretrained(self, save_directory: str, safe_serialization: bool = False, **kwargs: Any): r""" This function saves the adapter model and the adapter configuration files to a directory, so that it can be reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`] @@ -162,7 +164,13 @@ def save_pretrained(self, save_directory, safe_serialization=False, **kwargs): @classmethod def from_pretrained( - cls, model, model_id, adapter_name="default", is_trainable=False, config: Optional[PeftConfig] = None, **kwargs + cls, + model: PreTrainedModel, + model_id: Union[str, os.PathLike], + adapter_name: str = "default", + is_trainable: bool = False, + config: Optional[PeftConfig] = None, + **kwargs: Any, ): r""" Instantiate a [`LoraModel`] from a pretrained Lora configuration and weights. @@ -223,7 +231,7 @@ def from_pretrained( model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs) return model - def _setup_prompt_encoder(self, adapter_name): + def _setup_prompt_encoder(self, adapter_name: str): config = self.peft_config[adapter_name] self.prompt_encoder = torch.nn.ModuleDict({}) self.prompt_tokens = {} @@ -258,7 +266,7 @@ def _setup_prompt_encoder(self, adapter_name): config.num_virtual_tokens * config.num_transformer_submodules ).long() - def _prepare_model_for_gradient_checkpointing(self, model): + def _prepare_model_for_gradient_checkpointing(self, model: PreTrainedModel): r""" Prepares the model for gradient checkpointing if necessary """ @@ -273,7 +281,7 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) return model - def get_prompt_embedding_to_save(self, adapter_name): + def get_prompt_embedding_to_save(self, adapter_name: str): """ Returns the prompt embedding to save when saving the model. Only applicable when `peft_config.peft_type != PeftType.LORA`. @@ -287,7 +295,7 @@ def get_prompt_embedding_to_save(self, adapter_name): prompt_embeddings = prompt_encoder(prompt_tokens) return prompt_embeddings[0].detach().cpu() - def get_prompt(self, batch_size): + def get_prompt(self, batch_size: int): """ Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`. """ @@ -356,7 +364,7 @@ def __getattr__(self, name: str): except AttributeError: return getattr(self.base_model, name) - def forward(self, *args, **kwargs): + def forward(self, *args: Any, **kwargs: Any): """ Forward pass of the model. """ @@ -386,7 +394,7 @@ def get_base_model(self): """ return self.base_model if isinstance(self.active_peft_config, PromptLearningConfig) else self.base_model.model - def add_adapter(self, adapter_name, peft_config): + def add_adapter(self, adapter_name: str, peft_config: PeftConfig): if peft_config.peft_type != self.peft_type: raise ValueError( f"Cannot combine adapters with different peft types. " @@ -409,7 +417,7 @@ def set_additional_trainable_modules(self, peft_config, adapter_name): _set_trainable(self, adapter_name) @classmethod - def _split_kwargs(cls, kwargs): + def _split_kwargs(cls, kwargs: Dict[str, Any]): hf_hub_download_kwargs = {} other_kwargs = {} @@ -421,7 +429,7 @@ def _split_kwargs(cls, kwargs): return hf_hub_download_kwargs, other_kwargs - def load_adapter(self, model_id, adapter_name, is_trainable=False, **kwargs): + def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs: Any): from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs) @@ -535,7 +543,7 @@ def load_adapter(self, model_id, adapter_name, is_trainable=False, **kwargs): self.eval() return load_result - def set_adapter(self, adapter_name): + def set_adapter(self, adapter_name: str): """ Sets the active adapter. """ @@ -550,7 +558,7 @@ def set_adapter(self, adapter_name): def active_peft_config(self): return self.peft_config[self.active_adapter] - def create_or_update_model_card(self, output_dir): + def create_or_update_model_card(self, output_dir: str): """ Updates or create model card to include information about peft: 1. Adds `peft` library tag