Skip to content

Commit

Permalink
style: tentatively add hints for some public function (huggingface#614)
Browse files Browse the repository at this point in the history
* style: tentatively add hints for some public function

Signed-off-by: Aaron <[email protected]>

* fix: import annotations to evaluate to str

Signed-off-by: Aaron <[email protected]>

* fix: style

Signed-off-by: Aaron <[email protected]>

---------

Signed-off-by: Aaron <[email protected]>
  • Loading branch information
aarnphm authored Jun 28, 2023
1 parent 563acf0 commit 86290e9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
16 changes: 13 additions & 3 deletions src/peft/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict

from .peft_model import (
PeftModel,
PeftModelForCausalLM,
Expand All @@ -32,6 +36,12 @@
from .utils import PromptLearningConfig


if TYPE_CHECKING:
from transformers import PreTrainedModel

from .utils.config import PeftConfig


MODEL_TYPE_TO_PEFT_MODEL_MAPPING = {
"SEQ_CLS": PeftModelForSequenceClassification,
"SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM,
Expand All @@ -50,7 +60,7 @@
}


def get_peft_config(config_dict):
def get_peft_config(config_dict: Dict[str, Any]):
"""
Returns a Peft config object from a dictionary.
Expand All @@ -61,7 +71,7 @@ def get_peft_config(config_dict):
return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict)


def _prepare_prompt_learning_config(peft_config, model_config):
def _prepare_prompt_learning_config(peft_config: PeftConfig, model_config: Dict[str, Any]):
if peft_config.num_layers is None:
if "num_hidden_layers" in model_config:
num_layers = model_config["num_hidden_layers"]
Expand Down Expand Up @@ -103,7 +113,7 @@ def _prepare_prompt_learning_config(peft_config, model_config):
return peft_config


def get_peft_model(model, peft_config, adapter_name="default") -> PeftModel:
def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel:
"""
Returns a Peft model object from a model and a config.
Expand Down
36 changes: 22 additions & 14 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import inspect
import os
import warnings
from contextlib import contextmanager
from copy import deepcopy
from typing import Optional
from typing import Any, Dict, Optional, Union

import torch
from accelerate import dispatch_model, infer_auto_device_map
Expand Down Expand Up @@ -94,7 +96,7 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
in the base model if using [`PromptLearningConfig`].
"""

def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
def __init__(self, model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default"):
super().__init__()
self.base_model = model
self.config = self.base_model.config
Expand All @@ -115,7 +117,7 @@ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
if getattr(model, "is_gradient_checkpointing", True):
model = self._prepare_model_for_gradient_checkpointing(model)

def save_pretrained(self, save_directory, safe_serialization=False, **kwargs):
def save_pretrained(self, save_directory: str, safe_serialization: bool = False, **kwargs: Any):
r"""
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`]
Expand Down Expand Up @@ -162,7 +164,13 @@ def save_pretrained(self, save_directory, safe_serialization=False, **kwargs):

@classmethod
def from_pretrained(
cls, model, model_id, adapter_name="default", is_trainable=False, config: Optional[PeftConfig] = None, **kwargs
cls,
model: PreTrainedModel,
model_id: Union[str, os.PathLike],
adapter_name: str = "default",
is_trainable: bool = False,
config: Optional[PeftConfig] = None,
**kwargs: Any,
):
r"""
Instantiate a [`LoraModel`] from a pretrained Lora configuration and weights.
Expand Down Expand Up @@ -223,7 +231,7 @@ def from_pretrained(
model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs)
return model

def _setup_prompt_encoder(self, adapter_name):
def _setup_prompt_encoder(self, adapter_name: str):
config = self.peft_config[adapter_name]
self.prompt_encoder = torch.nn.ModuleDict({})
self.prompt_tokens = {}
Expand Down Expand Up @@ -258,7 +266,7 @@ def _setup_prompt_encoder(self, adapter_name):
config.num_virtual_tokens * config.num_transformer_submodules
).long()

def _prepare_model_for_gradient_checkpointing(self, model):
def _prepare_model_for_gradient_checkpointing(self, model: PreTrainedModel):
r"""
Prepares the model for gradient checkpointing if necessary
"""
Expand All @@ -273,7 +281,7 @@ def make_inputs_require_grad(module, input, output):
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
return model

def get_prompt_embedding_to_save(self, adapter_name):
def get_prompt_embedding_to_save(self, adapter_name: str):
"""
Returns the prompt embedding to save when saving the model. Only applicable when `peft_config.peft_type !=
PeftType.LORA`.
Expand All @@ -287,7 +295,7 @@ def get_prompt_embedding_to_save(self, adapter_name):
prompt_embeddings = prompt_encoder(prompt_tokens)
return prompt_embeddings[0].detach().cpu()

def get_prompt(self, batch_size):
def get_prompt(self, batch_size: int):
"""
Returns the virtual prompts to use for Peft. Only applicable when `peft_config.peft_type != PeftType.LORA`.
"""
Expand Down Expand Up @@ -356,7 +364,7 @@ def __getattr__(self, name: str):
except AttributeError:
return getattr(self.base_model, name)

def forward(self, *args, **kwargs):
def forward(self, *args: Any, **kwargs: Any):
"""
Forward pass of the model.
"""
Expand Down Expand Up @@ -386,7 +394,7 @@ def get_base_model(self):
"""
return self.base_model if isinstance(self.active_peft_config, PromptLearningConfig) else self.base_model.model

def add_adapter(self, adapter_name, peft_config):
def add_adapter(self, adapter_name: str, peft_config: PeftConfig):
if peft_config.peft_type != self.peft_type:
raise ValueError(
f"Cannot combine adapters with different peft types. "
Expand All @@ -409,7 +417,7 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):
_set_trainable(self, adapter_name)

@classmethod
def _split_kwargs(cls, kwargs):
def _split_kwargs(cls, kwargs: Dict[str, Any]):
hf_hub_download_kwargs = {}
other_kwargs = {}

Expand All @@ -421,7 +429,7 @@ def _split_kwargs(cls, kwargs):

return hf_hub_download_kwargs, other_kwargs

def load_adapter(self, model_id, adapter_name, is_trainable=False, **kwargs):
def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = False, **kwargs: Any):
from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING

hf_hub_download_kwargs, kwargs = self._split_kwargs(kwargs)
Expand Down Expand Up @@ -535,7 +543,7 @@ def load_adapter(self, model_id, adapter_name, is_trainable=False, **kwargs):
self.eval()
return load_result

def set_adapter(self, adapter_name):
def set_adapter(self, adapter_name: str):
"""
Sets the active adapter.
"""
Expand All @@ -550,7 +558,7 @@ def set_adapter(self, adapter_name):
def active_peft_config(self):
return self.peft_config[self.active_adapter]

def create_or_update_model_card(self, output_dir):
def create_or_update_model_card(self, output_dir: str):
"""
Updates or create model card to include information about peft:
1. Adds `peft` library tag
Expand Down

0 comments on commit 86290e9

Please sign in to comment.