Skip to content

Commit

Permalink
little improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Oct 2, 2023
1 parent 50128e6 commit 217067c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(
model_name_or_path: str = "hkunlp/instructor-base",
device: Optional[str] = None,
use_auth_token: Union[bool, str, None] = None,
instruction: str = "Represent the 'domain' 'text_type' for 'task_objective'",
instruction: str = "Represent the document",
batch_size: int = 32,
progress_bar: bool = True,
normalize_embeddings: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(
model_name_or_path: str = "hkunlp/instructor-base",
device: Optional[str] = None,
use_auth_token: Union[bool, str, None] = None,
instruction: str = "Represent the 'domain' 'text_type' for 'task_objective'",
instruction: str = "Represent the sentence",
batch_size: int = 32,
progress_bar: bool = True,
normalize_embeddings: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_init_default(self):
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.device == "cpu"
assert embedder.use_auth_token is None
assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'"
assert embedder.instruction == "Represent the document"
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_to_dict(self):
"model_name_or_path": "hkunlp/instructor-base",
"device": "cpu",
"use_auth_token": None,
"instruction": "Represent the 'domain' 'text_type' for 'task_objective'",
"instruction": "Represent the document",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_embed(self):
"""
Test for checking output dimensions and embedding dimensions.
"""
embedder = InstructorDocumentEmbedder(model_name_or_path="hkunlp/instructor-base")
embedder = InstructorDocumentEmbedder(model_name_or_path="hkunlp/instructor-large")
embedder.embedding_backend = MagicMock()
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005

Expand Down Expand Up @@ -257,3 +257,19 @@ def test_embed_metadata(self):
show_progress_bar=True,
normalize_embeddings=False,
)

@pytest.mark.integration
def test_run(self):
embedder = InstructorDocumentEmbedder(model_name_or_path="hkunlp/instructor-base",
device="cpu",
instruction="Represent the Science document for retrieval")
embedder.warm_up()

doc = Document(text="Parton energy loss in QCD matter")

result = embedder.run(documents=[doc])
embedding = result['documents'][0].embedding

Check failure on line 271 in components/instructor-embedders/tests/test_instructor_document_embedder.py

View workflow job for this annotation

GitHub Actions / test

Ruff (Q000)

components/instructor-embedders/tests/test_instructor_document_embedder.py:271:28: Q000 Single quotes found but double quotes preferred

assert isinstance(embedding, list)
assert len(embedding) == 768
assert all(isinstance(emb, float) for emb in embedding)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_init_default(self):
assert embedder.model_name_or_path == "hkunlp/instructor-base"
assert embedder.device == "cpu"
assert embedder.use_auth_token is None
assert embedder.instruction == "Represent the 'domain' 'text_type' for 'task_objective'"
assert embedder.instruction == "Represent the sentence"
assert embedder.batch_size == 32
assert embedder.progress_bar is True
assert embedder.normalize_embeddings is False
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_to_dict(self):
"model_name_or_path": "hkunlp/instructor-base",
"device": "cpu",
"use_auth_token": None,
"instruction": "Represent the 'domain' 'text_type' for 'task_objective'",
"instruction": "Represent the sentence",
"batch_size": 32,
"progress_bar": True,
"normalize_embeddings": False,
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_embed(self):
"""
Test for checking output dimensions and embedding dimensions.
"""
embedder = InstructorTextEmbedder(model_name_or_path="hkunlp/instructor-base")
embedder = InstructorTextEmbedder(model_name_or_path="hkunlp/instructor-large")
embedder.embedding_backend = MagicMock()
embedder.embedding_backend.embed = lambda x, **kwargs: np.random.rand(len(x), 16).tolist() # noqa: ARG005

Expand All @@ -190,10 +190,26 @@ def test_run_wrong_incorrect_format(self):
"""
Test for checking incorrect input format when creating embedding.
"""
embedder = InstructorTextEmbedder(model_name_or_path="hkunlp/instructor-base")
embedder = InstructorTextEmbedder(model_name_or_path="hkunlp/instructor-large")
embedder.embedding_backend = MagicMock()

list_integers_input = [1, 2, 3]

with pytest.raises(TypeError, match="InstructorTextEmbedder expects a string as input"):
embedder.run(text=list_integers_input)

@pytest.mark.integration
def test_run(self):
embedder = InstructorTextEmbedder(model_name_or_path="hkunlp/instructor-base",
device="cpu",
instruction="Represent the Science sentence for retrieval")
embedder.warm_up()

text = "Parton energy loss in QCD matter"

result = embedder.run(text=text)
embedding = result["embedding"]

assert isinstance(embedding, list)
assert len(embedding) == 768
assert all(isinstance(emb, float) for emb in embedding)

0 comments on commit 217067c

Please sign in to comment.