Skip to content

Commit

Permalink
better hub kwargs management (huggingface#712)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored Jul 17, 2023
1 parent 71b326d commit e90dcc4
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 17 deletions.
24 changes: 11 additions & 13 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def from_pretrained(
subfolder=kwargs.get("subfolder", None),
revision=kwargs.get("revision", None),
cache_dir=kwargs.get("cache_dir", None),
use_auth_token=kwargs.get("use_auth_token", None),
)
].from_pretrained(model_id, subfolder=kwargs.get("subfolder", None), **kwargs)
].from_pretrained(model_id, **kwargs)
elif isinstance(config, PeftConfig):
config.inference_mode = not is_trainable
else:
Expand Down Expand Up @@ -481,11 +482,12 @@ def set_additional_trainable_modules(self, peft_config, adapter_name):

@classmethod
def _split_kwargs(cls, kwargs: Dict[str, Any]):
_kwargs_not_in_hf_hub_download_signature = ("use_auth_token",)
hf_hub_download_kwargs = {}
other_kwargs = {}

for key, value in kwargs.items():
if key in inspect.signature(hf_hub_download).parameters:
if key in inspect.signature(hf_hub_download).parameters or key in _kwargs_not_in_hf_hub_download_signature:
hf_hub_download_kwargs[key] = value
else:
other_kwargs[key] = value
Expand All @@ -502,15 +504,11 @@ def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = Fa
peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[
PeftConfig._get_peft_type(
model_id,
subfolder=kwargs.get("subfolder", None),
revision=kwargs.get("revision", None),
cache_dir=kwargs.get("cache_dir", None),
**hf_hub_download_kwargs,
)
].from_pretrained(
model_id,
subfolder=kwargs.get("subfolder", None),
revision=kwargs.get("revision", None),
cache_dir=kwargs.get("cache_dir", None),
**hf_hub_download_kwargs,
)
if isinstance(peft_config, PromptLearningConfig) and is_trainable:
raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.")
Expand All @@ -529,7 +527,10 @@ def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = Fa
use_safetensors = False
else:
has_remote_safetensors_file = hub_file_exists(
model_id, SAFETENSORS_WEIGHTS_NAME, revision=kwargs.get("revision", None)
model_id,
SAFETENSORS_WEIGHTS_NAME,
revision=hf_hub_download_kwargs.get("revision", None),
repo_type=hf_hub_download_kwargs.get("repo_type", None),
)
use_safetensors = has_remote_safetensors_file

Expand All @@ -538,14 +539,11 @@ def load_adapter(self, model_id: str, adapter_name: str, is_trainable: bool = Fa
filename = hf_hub_download(
model_id,
SAFETENSORS_WEIGHTS_NAME,
subfolder=kwargs.get("subfolder", None),
**hf_hub_download_kwargs,
)
else:
try:
filename = hf_hub_download(
model_id, WEIGHTS_NAME, subfolder=kwargs.get("subfolder", None), **hf_hub_download_kwargs
)
filename = hf_hub_download(model_id, WEIGHTS_NAME, **hf_hub_download_kwargs)
except EntryNotFoundError:
raise ValueError(
f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
Expand Down
10 changes: 6 additions & 4 deletions src/peft/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,20 @@ def _split_kwargs(cls, kwargs):
def _get_peft_type(
cls,
model_id,
subfolder: Optional[str] = None,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
**hf_hub_download_kwargs,
):
subfolder = hf_hub_download_kwargs.get("subfolder", None)

path = os.path.join(model_id, subfolder) if subfolder is not None else model_id

if os.path.isfile(os.path.join(path, CONFIG_NAME)):
config_file = os.path.join(path, CONFIG_NAME)
else:
try:
config_file = hf_hub_download(
model_id, CONFIG_NAME, subfolder=subfolder, revision=revision, cache_dir=cache_dir
model_id,
CONFIG_NAME,
**hf_hub_download_kwargs,
)
except Exception:
raise ValueError(f"Can't find '{CONFIG_NAME}' at '{model_id}'")
Expand Down

0 comments on commit e90dcc4

Please sign in to comment.