This repository has been archived by the owner on Nov 13, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add initial implementation of dense Cohere encoder * added cohere encoder test * add cohere api key to action envs * switch to english as default * [pyproject] Added cohere optional dependency So that users could use Canopy with Cohere embedding * [README] Added Cohere API key * Update action.yml and add --all-extras flag * Fix pinecone-text depedency Should be `^`, not hard coded single version. --------- Co-authored-by: DosticJelena <[email protected]> Co-authored-by: ilai <[email protected]> Co-authored-by: miararoy <[email protected]> Co-authored-by: igiloh-pinecone <[email protected]>
- Loading branch information
1 parent
230034b
commit 5815be7
Showing
7 changed files
with
116 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .base import RecordEncoder | ||
from .cohere import CohereEncoder | ||
from .dense import DenseRecordEncoder | ||
from .openai import OpenAIRecordEncoder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import List | ||
from pinecone_text.dense.cohere_encoder import CohereEncoder | ||
from canopy.knowledge_base.models import KBDocChunk, KBEncodedDocChunk, KBQuery | ||
from canopy.knowledge_base.record_encoder.dense import DenseRecordEncoder | ||
from canopy.models.data_models import Query | ||
|
||
|
||
class CohereRecordEncoder(DenseRecordEncoder): | ||
""" | ||
CohereRecordEncoder is a type of DenseRecordEncoder that uses the Cohere `embed` API. | ||
The implementation uses the `CohereEncoder` class from the `pinecone-text` library. | ||
For more information about see: https://github.com/pinecone-io/pinecone-text | ||
""" # noqa: E501 | ||
|
||
def __init__( | ||
self, | ||
*, | ||
model_name: str = "embed-english-v3.0", | ||
batch_size: int = 100, | ||
**kwargs, | ||
): | ||
""" | ||
Initialize the CohereRecordEncoder | ||
Args: | ||
model_name: The name of the Cohere embeddings model to use for encoding. See https://docs.cohere.com/reference/embed | ||
batch_size: The number of documents or queries to encode at once. | ||
Defaults to 400. | ||
**kwargs: Additional arguments to pass to the underlying `pinecone-text. CohereEncoder`. | ||
""" # noqa: E501 | ||
encoder = CohereEncoder(model_name, **kwargs) | ||
super().__init__(dense_encoder=encoder, batch_size=batch_size) | ||
|
||
def encode_documents(self, documents: List[KBDocChunk]) -> List[KBEncodedDocChunk]: | ||
""" | ||
Encode a list of documents, takes a list of KBDocChunk and returns a list of KBEncodedDocChunk. | ||
Args: | ||
documents: A list of KBDocChunk to encode. | ||
Returns: | ||
encoded chunks: A list of KBEncodedDocChunk, with the `values` field populated by the generated embeddings vector. | ||
""" # noqa: E501 | ||
return super().encode_documents(documents) | ||
|
||
async def _aencode_documents_batch( | ||
self, documents: List[KBDocChunk] | ||
) -> List[KBEncodedDocChunk]: | ||
raise NotImplementedError | ||
|
||
async def _aencode_queries_batch(self, queries: List[Query]) -> List[KBQuery]: | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
|
||
from canopy.knowledge_base.models import KBDocChunk | ||
from canopy.knowledge_base.record_encoder.cohere import CohereRecordEncoder | ||
from canopy.models.data_models import Query | ||
|
||
|
||
documents = [KBDocChunk( | ||
id=f"doc_1_{i}", | ||
text=f"Sample document {i}", | ||
document_id=f"doc_{i}", | ||
metadata={"test": i}, | ||
source="doc_1", | ||
) | ||
for i in range(4) | ||
] | ||
|
||
queries = [Query(text="Sample query 1"), | ||
Query(text="Sample query 2"), | ||
Query(text="Sample query 3"), | ||
Query(text="Sample query 4")] | ||
|
||
|
||
@pytest.fixture | ||
def encoder(): | ||
return CohereRecordEncoder(batch_size=2) | ||
|
||
|
||
def test_dimension(encoder): | ||
assert encoder.dimension == 1024 | ||
|
||
|
||
@pytest.mark.parametrize("items,function", | ||
[(documents, "encode_documents"), | ||
(queries, "encode_queries"), | ||
([], "encode_documents"), | ||
([], "encode_queries")]) | ||
def test_encode_documents(encoder, items, function): | ||
|
||
encoded_documents = getattr(encoder, function)(items) | ||
|
||
assert len(encoded_documents) == len(items) | ||
assert all(len(encoded.values) == encoder.dimension | ||
for encoded in encoded_documents) | ||
|
||
|
||
@pytest.mark.asyncio | ||
@pytest.mark.parametrize("items,function", | ||
[("aencode_documents", documents), | ||
("aencode_queries", queries)]) | ||
async def test_aencode_not_implemented(encoder, function, items): | ||
with pytest.raises(NotImplementedError): | ||
await encoder.aencode_queries(items) |