Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
Add SentenceTransformersRecordEncoder (#263)
Browse files Browse the repository at this point in the history
* Add SentenceTransformerRecordEncoder based on pinecone_text

* Add tests based on the Jina tests

* Satisfy linter

* Correct an import

* Make query_encoder_name & device arguments explicit

* Add "defaults to" docstring for optional query_encoder_name

* Skipping SentenceTransformer system tests

* Wrap __init__ in try-except RepositoryNotFoundError

* [pyproj] Added 'huggingface_hub.utils' to mypy ignore list

It doesn't seem that huggingface_hub is providing stubs at the moment

* [pyproj] Minor bug after conflict fixing

* [test] Skip SentenceTransformerRecordEncoder UT and ST if not installed

The SentenceTransformerRecordEncoder depends on extra dependencies. It should only be tested if those extra dependencies are installed

* linter

---------

Co-authored-by: ilai <[email protected]>
  • Loading branch information
tomaarsen and igiloh-pinecone authored Feb 7, 2024
1 parent 289eef2 commit 0ebb2e1
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ module = [
'pinecone',
'transformers.*',
'cohere.*',
'pinecone.grpc'
'pinecone.grpc',
'huggingface_hub.utils'
]
ignore_missing_imports = true

Expand Down
1 change: 1 addition & 0 deletions src/canopy/knowledge_base/record_encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .anyscale import AnyscaleRecordEncoder
from .azure_openai import AzureOpenAIRecordEncoder
from .jina import JinaRecordEncoder
from .sentence_transformers import SentenceTransformerRecordEncoder
from .hybrid import HybridRecordEncoder
57 changes: 57 additions & 0 deletions src/canopy/knowledge_base/record_encoder/sentence_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Optional
from pinecone_text.dense import SentenceTransformerEncoder
from canopy.knowledge_base.record_encoder.dense import DenseRecordEncoder
from huggingface_hub.utils import RepositoryNotFoundError


class SentenceTransformerRecordEncoder(DenseRecordEncoder):
"""
SentenceTransformerRecordEncoder is a type of DenseRecordEncoder that uses a Sentence Transformer model.
The implementation uses the `SentenceTransformerEncoder` class from the `pinecone-text` library.
For more information about see: https://github.com/pinecone-io/pinecone-text
""" # noqa: E501

def __init__(self,
*,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
query_encoder_name: Optional[str] = None,
batch_size: int = 400,
device: Optional[str] = None,
**kwargs) -> None:
"""
Initialize the SentenceTransformerRecordEncoder
Args:
model_name: The name of the embedding model to use for encoding documents.
See https://huggingface.co/models?library=sentence-transformers
for all possible Sentence Transformer models.
query_encoder_name: The name of the embedding model to use for encoding queries.
See https://huggingface.co/models?library=sentence-transformers
for all possible Sentence Transformer models.
Defaults to `model_name`.
batch_size: The number of documents or queries to encode at once.
Defaults to 400.
device: The local device to use for encoding, for example "cpu", "cuda" or "mps".
Defaults to "cuda" if cuda is available, otherwise to "cpu".
**kwargs: Additional arguments to pass to the underlying `pinecone-text.SentenceTransformerEncoder`.
""" # noqa: E501
try:
encoder = SentenceTransformerEncoder(
document_encoder_name=model_name,
query_encoder_name=query_encoder_name,
device=device,
**kwargs,
)
except RepositoryNotFoundError as e:
raise RuntimeError(
"Your chosen Sentence Transformer model(s) could not be found. "
f"Details: {str(e)}"
) from e
except ImportError:
raise ImportError(
f"{self.__class__.__name__} requires the `torch` and `transformers` "
f"extra dependencies. Please install them using "
f"`pip install canopy-sdk[torch,transformers]`."
)
super().__init__(dense_encoder=encoder, batch_size=batch_size)
61 changes: 61 additions & 0 deletions tests/system/record_encoder/test_sentence_transformers_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import pytest

from canopy.knowledge_base.models import KBDocChunk
from canopy.knowledge_base.record_encoder.sentence_transformers import (
SentenceTransformerRecordEncoder
)
from canopy.models.data_models import Query

documents = [KBDocChunk(
id=f"doc_1_{i}",
text=f"Sample document {i}",
document_id=f"doc_{i}",
metadata={"test": i},
source="doc_1",
)
for i in range(4)
]

queries = [Query(text="Sample query 1"),
Query(text="Sample query 2"),
Query(text="Sample query 3"),
Query(text="Sample query 4")]


@pytest.fixture
def encoder():
try:
encoder = SentenceTransformerRecordEncoder(batch_size=2)
except ImportError:
pytest.skip(
"`transformers` extra not installed. Skipping SentenceTransformer system "
"tests"
)
return encoder


def test_dimension(encoder):
assert encoder.dimension == 384


@pytest.mark.parametrize("items,function",
[(documents, "encode_documents"),
(queries, "encode_queries"),
([], "encode_documents"),
([], "encode_queries")])
def test_encode_documents(encoder, items, function):

encoded_documents = getattr(encoder, function)(items)

assert len(encoded_documents) == len(items)
assert all(len(encoded.values) == encoder.dimension
for encoded in encoded_documents)


@pytest.mark.asyncio
@pytest.mark.parametrize("items,function",
[("aencode_documents", documents),
("aencode_queries", queries)])
async def test_aencode_not_implemented(encoder, function, items):
with pytest.raises(NotImplementedError):
await encoder.aencode_queries(items)
74 changes: 74 additions & 0 deletions tests/unit/record_encoder/test_sentence_transformers_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pytest

from canopy.knowledge_base.models import KBDocChunk
from canopy.knowledge_base.record_encoder.sentence_transformers import (
SentenceTransformerRecordEncoder
)
from canopy.models.data_models import Query

from unittest.mock import patch

documents = [KBDocChunk(
id=f"doc_1_{i}",
text=f"Sample document {i}",
document_id=f"doc_{i}",
metadata={"test": i},
source="doc_1",
)
for i in range(4)
]

queries = [Query(text="Sample query 1"),
Query(text="Sample query 2"),
Query(text="Sample query 3"),
Query(text="Sample query 4")]


@pytest.fixture
def encoder():
try:
encoder = SentenceTransformerRecordEncoder(batch_size=2)
except ImportError:
pytest.skip(
"`transformers` extra not installed. Skipping SentenceTransformer unit "
"tests"
)
return encoder


def test_dimension(encoder):
with patch('pinecone_text.dense.SentenceTransformerEncoder.encode_documents') \
as mock_encode_documents:
mock_encode_documents.return_value = [[0.1, 0.2, 0.3]]
assert encoder.dimension == 3


def custom_encode(*args, **kwargs):
input_to_encode = args[0]
return [[0.1, 0.2, 0.3] for _ in input_to_encode]


@pytest.mark.parametrize("items,function",
[(documents, "encode_documents"),
(queries, "encode_queries"),
([], "encode_documents"),
([], "encode_queries")])
def test_encode_documents(encoder, items, function):
with patch('pinecone_text.dense.SentenceTransformerEncoder.encode_documents',
side_effect=custom_encode):
with patch('pinecone_text.dense.SentenceTransformerEncoder.encode_queries',
side_effect=custom_encode):
encoded_documents = getattr(encoder, function)(items)

assert len(encoded_documents) == len(items)
assert all(len(encoded.values) == encoder.dimension
for encoded in encoded_documents)


@pytest.mark.asyncio
@pytest.mark.parametrize("items,function",
[("aencode_documents", documents),
("aencode_queries", queries)])
async def test_aencode_not_implemented(encoder, function, items):
with pytest.raises(NotImplementedError):
await encoder.aencode_queries(items)

0 comments on commit 0ebb2e1

Please sign in to comment.