This repository has been archived by the owner on Nov 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
SentenceTransformersRecordEncoder
(#263)
* Add SentenceTransformerRecordEncoder based on pinecone_text * Add tests based on the Jina tests * Satisfy linter * Correct an import * Make query_encoder_name & device arguments explicit * Add "defaults to" docstring for optional query_encoder_name * Skipping SentenceTransformer system tests * Wrap __init__ in try-except RepositoryNotFoundError * [pyproj] Added 'huggingface_hub.utils' to mypy ignore list It doesn't seem that huggingface_hub is providing stubs at the moment * [pyproj] Minor bug after conflict fixing * [test] Skip SentenceTransformerRecordEncoder UT and ST if not installed The SentenceTransformerRecordEncoder depends on extra dependencies. It should only be tested if those extra dependencies are installed * linter --------- Co-authored-by: ilai <[email protected]>
- Loading branch information
1 parent
289eef2
commit 0ebb2e1
Showing
5 changed files
with
195 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
src/canopy/knowledge_base/record_encoder/sentence_transformers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import Optional | ||
from pinecone_text.dense import SentenceTransformerEncoder | ||
from canopy.knowledge_base.record_encoder.dense import DenseRecordEncoder | ||
from huggingface_hub.utils import RepositoryNotFoundError | ||
|
||
|
||
class SentenceTransformerRecordEncoder(DenseRecordEncoder): | ||
""" | ||
SentenceTransformerRecordEncoder is a type of DenseRecordEncoder that uses a Sentence Transformer model. | ||
The implementation uses the `SentenceTransformerEncoder` class from the `pinecone-text` library. | ||
For more information about see: https://github.com/pinecone-io/pinecone-text | ||
""" # noqa: E501 | ||
|
||
def __init__(self, | ||
*, | ||
model_name: str = "sentence-transformers/all-MiniLM-L6-v2", | ||
query_encoder_name: Optional[str] = None, | ||
batch_size: int = 400, | ||
device: Optional[str] = None, | ||
**kwargs) -> None: | ||
""" | ||
Initialize the SentenceTransformerRecordEncoder | ||
Args: | ||
model_name: The name of the embedding model to use for encoding documents. | ||
See https://huggingface.co/models?library=sentence-transformers | ||
for all possible Sentence Transformer models. | ||
query_encoder_name: The name of the embedding model to use for encoding queries. | ||
See https://huggingface.co/models?library=sentence-transformers | ||
for all possible Sentence Transformer models. | ||
Defaults to `model_name`. | ||
batch_size: The number of documents or queries to encode at once. | ||
Defaults to 400. | ||
device: The local device to use for encoding, for example "cpu", "cuda" or "mps". | ||
Defaults to "cuda" if cuda is available, otherwise to "cpu". | ||
**kwargs: Additional arguments to pass to the underlying `pinecone-text.SentenceTransformerEncoder`. | ||
""" # noqa: E501 | ||
try: | ||
encoder = SentenceTransformerEncoder( | ||
document_encoder_name=model_name, | ||
query_encoder_name=query_encoder_name, | ||
device=device, | ||
**kwargs, | ||
) | ||
except RepositoryNotFoundError as e: | ||
raise RuntimeError( | ||
"Your chosen Sentence Transformer model(s) could not be found. " | ||
f"Details: {str(e)}" | ||
) from e | ||
except ImportError: | ||
raise ImportError( | ||
f"{self.__class__.__name__} requires the `torch` and `transformers` " | ||
f"extra dependencies. Please install them using " | ||
f"`pip install canopy-sdk[torch,transformers]`." | ||
) | ||
super().__init__(dense_encoder=encoder, batch_size=batch_size) |
61 changes: 61 additions & 0 deletions
61
tests/system/record_encoder/test_sentence_transformers_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import pytest | ||
|
||
from canopy.knowledge_base.models import KBDocChunk | ||
from canopy.knowledge_base.record_encoder.sentence_transformers import ( | ||
SentenceTransformerRecordEncoder | ||
) | ||
from canopy.models.data_models import Query | ||
|
||
documents = [KBDocChunk( | ||
id=f"doc_1_{i}", | ||
text=f"Sample document {i}", | ||
document_id=f"doc_{i}", | ||
metadata={"test": i}, | ||
source="doc_1", | ||
) | ||
for i in range(4) | ||
] | ||
|
||
queries = [Query(text="Sample query 1"), | ||
Query(text="Sample query 2"), | ||
Query(text="Sample query 3"), | ||
Query(text="Sample query 4")] | ||
|
||
|
||
@pytest.fixture | ||
def encoder(): | ||
try: | ||
encoder = SentenceTransformerRecordEncoder(batch_size=2) | ||
except ImportError: | ||
pytest.skip( | ||
"`transformers` extra not installed. Skipping SentenceTransformer system " | ||
"tests" | ||
) | ||
return encoder | ||
|
||
|
||
def test_dimension(encoder): | ||
assert encoder.dimension == 384 | ||
|
||
|
||
@pytest.mark.parametrize("items,function", | ||
[(documents, "encode_documents"), | ||
(queries, "encode_queries"), | ||
([], "encode_documents"), | ||
([], "encode_queries")]) | ||
def test_encode_documents(encoder, items, function): | ||
|
||
encoded_documents = getattr(encoder, function)(items) | ||
|
||
assert len(encoded_documents) == len(items) | ||
assert all(len(encoded.values) == encoder.dimension | ||
for encoded in encoded_documents) | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("items,function", | ||
[("aencode_documents", documents), | ||
("aencode_queries", queries)]) | ||
async def test_aencode_not_implemented(encoder, function, items): | ||
with pytest.raises(NotImplementedError): | ||
await encoder.aencode_queries(items) |
74 changes: 74 additions & 0 deletions
74
tests/unit/record_encoder/test_sentence_transformers_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import pytest | ||
|
||
from canopy.knowledge_base.models import KBDocChunk | ||
from canopy.knowledge_base.record_encoder.sentence_transformers import ( | ||
SentenceTransformerRecordEncoder | ||
) | ||
from canopy.models.data_models import Query | ||
|
||
from unittest.mock import patch | ||
|
||
documents = [KBDocChunk( | ||
id=f"doc_1_{i}", | ||
text=f"Sample document {i}", | ||
document_id=f"doc_{i}", | ||
metadata={"test": i}, | ||
source="doc_1", | ||
) | ||
for i in range(4) | ||
] | ||
|
||
queries = [Query(text="Sample query 1"), | ||
Query(text="Sample query 2"), | ||
Query(text="Sample query 3"), | ||
Query(text="Sample query 4")] | ||
|
||
|
||
@pytest.fixture | ||
def encoder(): | ||
try: | ||
encoder = SentenceTransformerRecordEncoder(batch_size=2) | ||
except ImportError: | ||
pytest.skip( | ||
"`transformers` extra not installed. Skipping SentenceTransformer unit " | ||
"tests" | ||
) | ||
return encoder | ||
|
||
|
||
def test_dimension(encoder): | ||
with patch('pinecone_text.dense.SentenceTransformerEncoder.encode_documents') \ | ||
as mock_encode_documents: | ||
mock_encode_documents.return_value = [[0.1, 0.2, 0.3]] | ||
assert encoder.dimension == 3 | ||
|
||
|
||
def custom_encode(*args, **kwargs): | ||
input_to_encode = args[0] | ||
return [[0.1, 0.2, 0.3] for _ in input_to_encode] | ||
|
||
|
||
@pytest.mark.parametrize("items,function", | ||
[(documents, "encode_documents"), | ||
(queries, "encode_queries"), | ||
([], "encode_documents"), | ||
([], "encode_queries")]) | ||
def test_encode_documents(encoder, items, function): | ||
with patch('pinecone_text.dense.SentenceTransformerEncoder.encode_documents', | ||
side_effect=custom_encode): | ||
with patch('pinecone_text.dense.SentenceTransformerEncoder.encode_queries', | ||
side_effect=custom_encode): | ||
encoded_documents = getattr(encoder, function)(items) | ||
|
||
assert len(encoded_documents) == len(items) | ||
assert all(len(encoded.values) == encoder.dimension | ||
for encoded in encoded_documents) | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("items,function", | ||
[("aencode_documents", documents), | ||
("aencode_queries", queries)]) | ||
async def test_aencode_not_implemented(encoder, function, items): | ||
with pytest.raises(NotImplementedError): | ||
await encoder.aencode_queries(items) |