Skip to content

Commit

Permalink
Merge pull request #112 from SilasMarvin/silas-update-for-newer-sente…
Browse files Browse the repository at this point in the history
…nce-transformers

Updated to work with the latest version of sentence transformers
  • Loading branch information
hongjin-su authored Apr 12, 2024
2 parents d92edb7 + 989b34d commit 5cca65e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
15 changes: 14 additions & 1 deletion InstructorEmbedding/instructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from torch import Tensor, nn
from tqdm.autonotebook import trange
from transformers import AutoConfig, AutoTokenizer
from sentence_transformers.util import disabled_tqdm
from huggingface_hub import snapshot_download


def batch_to_device(batch, target_device: str):
Expand Down Expand Up @@ -515,10 +517,21 @@ def smart_batching_collate(self, batch):

return batched_input_features, labels

def _load_sbert_model(self, model_path):
def _load_sbert_model(self, model_path, token = None, cache_folder = None, revision = None, trust_remote_code = False):
"""
Loads a full sentence-transformers model
"""
# Taken mostly from: https://github.com/UKPLab/sentence-transformers/blob/66e0ee30843dd411c64f37f65447bb38c7bf857a/sentence_transformers/util.py#L544
download_kwargs = {
"repo_id": model_path,
"revision": revision,
"library_name": "sentence-transformers",
"token": token,
"cache_dir": cache_folder,
"tqdm_class": disabled_tqdm,
}
model_path = snapshot_download(**download_kwargs)

# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
config_sentence_transformers_json_path = os.path.join(
model_path, "config_sentence_transformers.json"
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ sentence_transformers>=2.2.0
torch
tqdm
rich
tensorboard
tensorboard
huggingface-hub>=0.19.0

0 comments on commit 5cca65e

Please sign in to comment.