diff --git a/bindings/python/py_src/safetensors/paddle.py b/bindings/python/py_src/safetensors/paddle.py index b242237a..cec36866 100644 --- a/bindings/python/py_src/safetensors/paddle.py +++ b/bindings/python/py_src/safetensors/paddle.py @@ -105,7 +105,7 @@ def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, padd Args: filename (`str`, or `os.PathLike`)): The name of the file which contains the tensors - device (`Dict[str, any]`, *optional*, defaults to `cpu`): + device (`Union[Dict[str, any], str]`, *optional*, defaults to `cpu`): The device where the tensors need to be located after load. available options are all regular paddle device locations diff --git a/bindings/python/py_src/safetensors/torch.py b/bindings/python/py_src/safetensors/torch.py index 22915c98..5d98bac4 100644 --- a/bindings/python/py_src/safetensors/torch.py +++ b/bindings/python/py_src/safetensors/torch.py @@ -173,7 +173,7 @@ def save_model( raise ValueError(msg) -def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict=True) -> Tuple[List[str], List[str]]: +def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu") -> Tuple[List[str], List[str]]: """ Loads a given filename onto a torch model. This method exists specifically to avoid tensor sharing issues which are @@ -185,8 +185,11 @@ def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict filename (`str`, or `os.PathLike`): The filename location to load the file from. strict (`bool`, *optional*, defaults to True): - Wether to fail if you're missing keys or having unexpected ones + Whether to fail if you're missing keys or having unexpected ones. When false, the function simply returns missing and unexpected names. + device (`Union[str, int]`, *optional*, defaults to `cpu`): + The device where the tensors need to be located after load. + available options are all regular torch device locations. Returns: `(missing, unexpected): (List[str], List[str])` @@ -194,7 +197,7 @@ def load_model(model: torch.nn.Module, filename: Union[str, os.PathLike], strict `unexpected` are names that are on the file, but weren't used during the load. """ - state_dict = load_file(filename) + state_dict = load_file(filename, device=device) model_state_dict = model.state_dict() to_removes = _remove_duplicate_names(model_state_dict, preferred_names=state_dict.keys()) missing, unexpected = model.load_state_dict(state_dict, strict=False) @@ -281,16 +284,16 @@ def save_file( serialize_file(_flatten(tensors), filename, metadata=metadata) -def load_file(filename: Union[str, os.PathLike], device="cpu") -> Dict[str, torch.Tensor]: +def load_file(filename: Union[str, os.PathLike], device: Union[str, int] = "cpu") -> Dict[str, torch.Tensor]: """ Loads a safetensors file into torch format. Args: filename (`str`, or `os.PathLike`): The name of the file which contains the tensors - device (`Dict[str, any]`, *optional*, defaults to `cpu`): + device (`Union[str, int]`, *optional*, defaults to `cpu`): The device where the tensors need to be located after load. - available options are all regular torch device locations + available options are all regular torch device locations. Returns: `Dict[str, torch.Tensor]`: dictionary that contains name as key, value as `torch.Tensor`