diff --git a/InstructorEmbedding/instructor.py b/InstructorEmbedding/instructor.py index 0ba795c..9d4f90a 100644 --- a/InstructorEmbedding/instructor.py +++ b/InstructorEmbedding/instructor.py @@ -12,6 +12,8 @@ from torch import Tensor, nn from tqdm.autonotebook import trange from transformers import AutoConfig, AutoTokenizer +from sentence_transformers.util import disabled_tqdm +from huggingface_hub import snapshot_download def batch_to_device(batch, target_device: str): @@ -515,10 +517,21 @@ def smart_batching_collate(self, batch): return batched_input_features, labels - def _load_sbert_model(self, model_path): + def _load_sbert_model(self, model_path, token = None, cache_folder = None, revision = None, trust_remote_code = False): """ Loads a full sentence-transformers model """ + # Taken mostly from: https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L544 + download_kwargs = { + "repo_id": model_path, + "revision": revision, + "library_name": "sentence-transformers", + "token": token, + "cache_dir": cache_folder, + "tqdm_class": disabled_tqdm, + } + model_path = snapshot_download(**download_kwargs) + # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework) config_sentence_transformers_json_path = os.path.join( model_path, "config_sentence_transformers.json" diff --git a/requirements.txt b/requirements.txt index 05f8986..a7fe466 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ sentence_transformers>=2.2.0 torch tqdm rich -tensorboard \ No newline at end of file +tensorboard +huggingface-hub>=0.19.0