From 695de8eb2fa4f6cf24be96b667b80d8be7ea7f60 Mon Sep 17 00:00:00 2001 From: Anushree Bannadabhavi Date: Sun, 31 Mar 2024 16:38:58 -0400 Subject: [PATCH 1/4] Add cohere ranker --- haystack/components/rankers/__init__.py | 2 + haystack/components/rankers/cohere.py | 149 +++++++++ pyproject.toml | 1 + .../add-cohere-ranker-5e94f8e771916150.yaml | 5 + test/components/rankers/test_cohere.py | 314 ++++++++++++++++++ 5 files changed, 471 insertions(+) create mode 100644 haystack/components/rankers/cohere.py create mode 100644 releasenotes/notes/add-cohere-ranker-5e94f8e771916150.yaml create mode 100644 test/components/rankers/test_cohere.py diff --git a/haystack/components/rankers/__init__.py b/haystack/components/rankers/__init__.py index 282cf5cf2f..2c9ad0c477 100644 --- a/haystack/components/rankers/__init__.py +++ b/haystack/components/rankers/__init__.py @@ -1,3 +1,4 @@ +from haystack.components.rankers.cohere import CohereRanker from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker from haystack.components.rankers.meta_field import MetaFieldRanker from haystack.components.rankers.sentence_transformers_diversity import SentenceTransformersDiversityRanker @@ -8,4 +9,5 @@ "MetaFieldRanker", "SentenceTransformersDiversityRanker", "TransformersSimilarityRanker", + "CohereRanker", ] diff --git a/haystack/components/rankers/cohere.py b/haystack/components/rankers/cohere.py new file mode 100644 index 0000000000..9c01e3c13d --- /dev/null +++ b/haystack/components/rankers/cohere.py @@ -0,0 +1,149 @@ +from typing import Any, Dict, List, Optional + +import cohere + +from haystack import Document, component, default_from_dict, default_to_dict, logging +from haystack.utils import Secret, deserialize_secrets_inplace + +logger = logging.getLogger(__name__) + + +@component +class CohereRanker: + """ + Performs reranking of documents using Cohere reranking models for semantic search. + + Reranks retrieved documents based on semantic relevance to a query. + Documents are indexed from most to least semantically relevant to the query. [Cohere reranker](https://docs.cohere.com/reference/rerank-1) + + Usage example: + ```python + from haystack import Document + from haystack.components.rankers import CohereRanker + + ranker = CohereRanker(model="rerank-english-v2.0", top_k=3) + + docs = [Document(content="Paris"), Document(content="Berlin")] + query = "What is the capital of germany?" + output = ranker.run(query=query, documents=docs) + docs = output["documents"] + ``` + """ + + def __init__( + self, + model: str = "rerank-english-v2.0", + top_k: int = 10, + api_key: Secret = Secret.from_env_var("CO_API_KEY", strict=False), + max_chunks_per_doc: Optional[int] = None, + meta_fields_to_embed: Optional[List[str]] = None, + meta_data_separator: str = "\n", + ): + """ + Creates an instance of the 'CohereRanker'. + + :param model: Cohere model name. Check the list of supported models in the [Cohere documentation](https://docs.cohere.com/docs/models). + :param top_k: The maximum number of documents to return. + :param api_key: Cohere API key. + :param max_chunks_per_doc: If your document exceeds 512 tokens, this determines the maximum number of + chunks a document can be split into. If None, the default of 10 is used. + For example, if your document is 6000 tokens, with the default of 10, the document will be split into 10 + chunks each of 512 tokens and the last 880 tokens will be disregarded. Check this [link](https://docs.cohere.com/docs/reranking-best-practices) for more information. + :param meta_fields_to_embed: List of meta fields that should be concatenated with the document content for reranking. + :param meta_data_separator: Separator to be used to separate the concatenated the meta fields and document content. + """ + self.cohere_client = cohere.Client(api_key.resolve_value()) + self.model_name = model + self.api_key = api_key + self.top_k = top_k + self.max_chunks_per_doc = max_chunks_per_doc + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.meta_data_separator = meta_data_separator + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + model=self.model_name, + api_key=self.api_key.to_dict() if self.api_key else None, + top_k=self.top_k, + max_chunks_per_doc=self.max_chunks_per_doc, + meta_fields_to_embed=self.meta_fields_to_embed, + meta_data_separator=self.meta_data_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CohereRanker": + """ + Deserializes the component from a dictionary. + + :param data: + The dictionary to deserialize from. + :returns: + The deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"]) + return default_from_dict(cls, data) + + def _prepare_cohere_input_docs(self, documents: List[Document]) -> List[str]: + """ + Prepare the input by concatenating the document text with the metadata fields specified. + :param documents: The list of Document objects. + + :return: A list of strings to be input to the cohere model. + """ + concatenated_input_list = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] + ] + concatenated_input = self.meta_data_separator.join(meta_values_to_embed + [doc.content or ""]) + concatenated_input_list.append(concatenated_input) + + return concatenated_input_list + + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): + """ + Use the Cohere Reranker to re-rank the list of documents based on the query. + + :param query: The query string. + :param documents: List of Document objects to be re-ranked. + :param top_k: Optional. An integer to override the top_k set during initialization. + + :returns: A dictionary with the following key: + - `documents`: List of re-ranked Document objects. + + :raises ValueError: If the top_k value is less than or equal to 0. + """ + if top_k is None: + top_k = self.top_k + if top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") + + cohere_input_docs = self._prepare_cohere_input_docs(documents) + if len(cohere_input_docs) > 1000: + logger.warning( + "The Cohere reranking endpoint only supports 1000 documents. " + "The number of documents has been truncated to 1000 from %s.", + len(cohere_input_docs), + ) + cohere_input_docs = cohere_input_docs[:1000] + + response = self.cohere_client.rerank( + model=self.model_name, query=query, documents=cohere_input_docs, max_chunks_per_doc=self.max_chunks_per_doc + ) + + indices = [output.index for output in response.results] + scores = [output.relevance_score for output in response.results] + sorted_docs = [] + for idx, score in zip(indices, scores): + doc = documents[idx] + doc.score = score + sorted_docs.append(documents[idx]) + return {"documents": sorted_docs[:top_k]} diff --git a/pyproject.toml b/pyproject.toml index 454c77b375..5574a6e186 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ extra-dependencies = [ "spacy>=3.7,<3.8", # NamedEntityExtractor "spacy-curated-transformers>=0.2,<=0.3", # NamedEntityExtractor "en-core-web-trf @ https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.7.3/en_core_web_trf-3.7.3-py3-none-any.whl", # NamedEntityExtractor + "cohere==5.1.7", # CohereRanker # Converters "pypdf", # PyPDFConverter diff --git a/releasenotes/notes/add-cohere-ranker-5e94f8e771916150.yaml b/releasenotes/notes/add-cohere-ranker-5e94f8e771916150.yaml new file mode 100644 index 0000000000..cd507d9ed0 --- /dev/null +++ b/releasenotes/notes/add-cohere-ranker-5e94f8e771916150.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Add 'CohereRanker'. + Performs reranking of documents using Cohere reranking models. diff --git a/test/components/rankers/test_cohere.py b/test/components/rankers/test_cohere.py new file mode 100644 index 0000000000..3f90bccc05 --- /dev/null +++ b/test/components/rankers/test_cohere.py @@ -0,0 +1,314 @@ +import os +from unittest.mock import MagicMock + +import pytest +from cohere.types.rerank_response import RerankResponse +from cohere.types.rerank_response_results_item import RerankResponseResultsItem + +from haystack import Document +from haystack.components.rankers import CohereRanker +from haystack.utils.auth import Secret + + +def mock_cohere_response(**kwargs): + id = "abcd-123hijk-xyz" + results = [ + RerankResponseResultsItem(document=None, index=2, relevance_score=0.98), + RerankResponseResultsItem(document=None, index=1, relevance_score=0.95), + RerankResponseResultsItem(document=None, index=0, relevance_score=0.12), + ] + response = RerankResponse(id=id, results=results) + return response + + +class TestCohereRanker: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker() + assert component.model_name == "rerank-english-v2.0" + assert component.top_k == 10 + assert component.api_key == Secret.from_env_var("CO_API_KEY", strict=False) + assert component.max_chunks_per_doc == None + assert component.meta_fields_to_embed == [] + assert component.meta_data_separator == "\n" + assert component.api_key.resolve_value() == "test-api-key" + + def test_init_fail_wo_api_key(self, monkeypatch): + monkeypatch.delenv("CO_API_KEY", raising=False) + with pytest.raises(Exception, match="The client must be instantiated be either passing in token or setting *"): + CohereRanker() + + def test_init_with_parameters(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker( + model="rerank-multilingual-v2.0", + top_k=5, + api_key=Secret.from_env_var("CO_API_KEY", strict=False), + max_chunks_per_doc=40, + meta_fields_to_embed=["meta_field_1", "meta_field_2"], + meta_data_separator=",", + ) + assert component.model_name == "rerank-multilingual-v2.0" + assert component.top_k == 5 + assert component.api_key == Secret.from_env_var("CO_API_KEY", strict=False) + assert component.max_chunks_per_doc == 40 + assert component.meta_fields_to_embed == ["meta_field_1", "meta_field_2"] + assert component.meta_data_separator == "," + assert component.api_key.resolve_value() == "test-api-key" + + def test_to_dict_default(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker() + data = component.to_dict() + assert data == { + "type": "haystack.components.rankers.cohere.CohereRanker", + "init_parameters": { + "model": "rerank-english-v2.0", + "api_key": {"env_vars": ["CO_API_KEY"], "strict": False, "type": "env_var"}, + "top_k": 10, + "max_chunks_per_doc": None, + "meta_fields_to_embed": [], + "meta_data_separator": "\n", + }, + } + + def test_to_dict_with_parameters(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "test-api-key") + component = CohereRanker( + model="rerank-multilingual-v2.0", + top_k=2, + api_key=Secret.from_env_var("CO_API_KEY", strict=False), + max_chunks_per_doc=50, + meta_fields_to_embed=["meta_field_1", "meta_field_2"], + meta_data_separator=",", + ) + data = component.to_dict() + assert data == { + "type": "haystack.components.rankers.cohere.CohereRanker", + "init_parameters": { + "model": "rerank-multilingual-v2.0", + "api_key": {"env_vars": ["CO_API_KEY"], "strict": False, "type": "env_var"}, + "top_k": 2, + "max_chunks_per_doc": 50, + "meta_fields_to_embed": ["meta_field_1", "meta_field_2"], + "meta_data_separator": ",", + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") + data = { + "type": "haystack.components.rankers.cohere.CohereRanker", + "init_parameters": { + "model": "rerank-multilingual-v2.0", + "api_key": {"env_vars": ["CO_API_KEY"], "strict": False, "type": "env_var"}, + "top_k": 2, + "max_chunks_per_doc": 50, + "meta_fields_to_embed": ["meta_field_1", "meta_field_2"], + "meta_data_separator": ",", + }, + } + component = CohereRanker.from_dict(data) + assert component.model_name == "rerank-multilingual-v2.0" + assert component.top_k == 2 + assert component.api_key == Secret.from_env_var("CO_API_KEY", strict=False) + assert component.max_chunks_per_doc == 50 + assert component.meta_fields_to_embed == ["meta_field_1", "meta_field_2"] + assert component.meta_data_separator == "," + assert component.api_key.resolve_value() == "fake-api-key" + + def test_from_dict_fail_wo_env_var(self, monkeypatch): + monkeypatch.delenv("CO_API_KEY", raising=False) + data = { + "type": "haystack.components.rankers.cohere.CohereRanker", + "init_parameters": { + "model": "rerank-multilingual-v2.0", + "api_key": {"env_vars": ["CO_API_KEY"], "strict": False, "type": "env_var"}, + "top_k": 2, + "max_chunks_per_doc": 50, + }, + } + with pytest.raises(Exception, match="The client must be instantiated be either passing in token or setting *"): + CohereRanker.from_dict(data) + + def test_prepare_cohere_input_docs_default_separator(self): + component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"]) + documents = [ + Document( + content=f"document number {i}", + meta={ + "meta_field_1": f"meta_value_1 {i}", + "meta_field_2": f"meta_value_2 {i+5}", + "meta_field_3": f"meta_value_3 {i+15}", + }, + ) + for i in range(5) + ] + + texts = component._prepare_cohere_input_docs(documents=documents) + + assert texts == [ + "meta_value_1 0\nmeta_value_2 5\ndocument number 0", + "meta_value_1 1\nmeta_value_2 6\ndocument number 1", + "meta_value_1 2\nmeta_value_2 7\ndocument number 2", + "meta_value_1 3\nmeta_value_2 8\ndocument number 3", + "meta_value_1 4\nmeta_value_2 9\ndocument number 4", + ] + + def test_prepare_cohere_input_docs_custom_separator(self): + component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=" ") + documents = [ + Document( + content=f"document number {i}", + meta={ + "meta_field_1": f"meta_value_1 {i}", + "meta_field_2": f"meta_value_2 {i+5}", + "meta_field_3": f"meta_value_3 {i+15}", + }, + ) + for i in range(5) + ] + + texts = component._prepare_cohere_input_docs(documents=documents) + + assert texts == [ + "meta_value_1 0 meta_value_2 5 document number 0", + "meta_value_1 1 meta_value_2 6 document number 1", + "meta_value_1 2 meta_value_2 7 document number 2", + "meta_value_1 3 meta_value_2 8 document number 3", + "meta_value_1 4 meta_value_2 9 document number 4", + ] + + def test_prepare_cohere_input_docs_no_meta_data(self): + component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=" ") + documents = [Document(content=f"document number {i}") for i in range(5)] + + texts = component._prepare_cohere_input_docs(documents=documents) + + assert texts == [ + "document number 0", + "document number 1", + "document number 2", + "document number 3", + "document number 4", + ] + + def test_run_negative_topk_in_init(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") + ranker = CohereRanker(top_k=-2) + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + with pytest.raises(ValueError, match="top_k must be > 0, but got *"): + ranker.run(query, documents) + + def test_run_zero_topk_in_init(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") + ranker = CohereRanker(top_k=0) + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + with pytest.raises(ValueError, match="top_k must be > 0, but got *"): + ranker.run(query, documents) + + def test_run_negative_topk_in_run(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") + ranker = CohereRanker() + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + with pytest.raises(ValueError, match="top_k must be > 0, but got *"): + ranker.run(query, documents, -3) + + def test_run_zero_topk_in_run(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") + ranker = CohereRanker() + query = "test" + documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] + with pytest.raises(ValueError, match="top_k must be > 0, but got *"): + ranker.run(query, documents, 0) + + def test_run_documents_provided(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") + ranker = CohereRanker() + query = "test" + documents = [ + Document(id="abcd", content="doc1", meta={"meta_field": "meta_value_1"}), + Document(id="efgh", content="doc2", meta={"meta_field": "meta_value_2"}), + Document(id="ijkl", content="doc3", meta={"meta_field": "meta_value_3"}), + ] + ranker.cohere_client = MagicMock() + ranker.cohere_client.rerank = MagicMock(side_effect=mock_cohere_response) + + ranker_results = ranker.run(query, documents, 2) + + assert isinstance(ranker_results, dict) + reranked_docs = ranker_results["documents"] + assert reranked_docs == [ + Document(id="ijkl", content="doc3", meta={"meta_field": "meta_value_3"}, score=0.98), + Document(id="efgh", content="doc2", meta={"meta_field": "meta_value_2"}, score=0.95), + ] + + def test_run_topk_set_in_init(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") + ranker = CohereRanker(top_k=2) + query = "test" + documents = [ + Document(id="abcd", content="doc1"), + Document(id="efgh", content="doc2"), + Document(id="ijkl", content="doc3"), + ] + ranker.cohere_client = MagicMock() + ranker.cohere_client.rerank = MagicMock(side_effect=mock_cohere_response) + + ranker_results = ranker.run(query, documents) + + assert isinstance(ranker_results, dict) + reranked_docs = ranker_results["documents"] + assert reranked_docs == [ + Document(id="ijkl", content="doc3", score=0.98), + Document(id="efgh", content="doc2", score=0.95), + ] + + def test_run_topk_greater_than_docs(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") + ranker = CohereRanker() + query = "test" + documents = [ + Document(id="abcd", content="doc1"), + Document(id="efgh", content="doc2"), + Document(id="ijkl", content="doc3"), + ] + ranker.cohere_client = MagicMock() + ranker.cohere_client.rerank = MagicMock(side_effect=mock_cohere_response) + + ranker_results = ranker.run(query, documents, 5) + + assert isinstance(ranker_results, dict) + reranked_docs = ranker_results["documents"] + assert reranked_docs == [ + Document(id="ijkl", content="doc3", score=0.98), + Document(id="efgh", content="doc2", score=0.95), + Document(id="abcd", content="doc1", score=0.12), + ] + + @pytest.mark.skipif( + not os.environ.get("CO_API_KEY", None), + reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_live_run(self): + component = CohereRanker() + documents = [ + Document(id="abcd", content="Paris is in France"), + Document(id="efgh", content="Berlin is in Germany"), + Document(id="ijkl", content="Lyon is in France"), + ] + + ranker_result = component.run("Cities in France", documents, 2) + expected_documents = [documents[0], documents[2]] + expected_documents_content = [doc.content for doc in expected_documents] + result_documents_contents = [doc.content for doc in ranker_result["documents"]] + + assert isinstance(ranker_result, dict) + assert isinstance(ranker_result["documents"], list) + assert len(ranker_result["documents"]) == 2 + assert all(isinstance(doc, Document) for doc in ranker_result["documents"]) + assert set(result_documents_contents) == set(expected_documents_content) From a4e1d55abf99b2fc9ea28f33f1d48ecf1268c764 Mon Sep 17 00:00:00 2001 From: Anushree Bannadabhavi Date: Sun, 31 Mar 2024 17:10:01 -0400 Subject: [PATCH 2/4] fix mypy issue --- haystack/components/rankers/cohere.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/haystack/components/rankers/cohere.py b/haystack/components/rankers/cohere.py index 9c01e3c13d..0e93c10728 100644 --- a/haystack/components/rankers/cohere.py +++ b/haystack/components/rankers/cohere.py @@ -129,9 +129,7 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None cohere_input_docs = self._prepare_cohere_input_docs(documents) if len(cohere_input_docs) > 1000: logger.warning( - "The Cohere reranking endpoint only supports 1000 documents. " - "The number of documents has been truncated to 1000 from %s.", - len(cohere_input_docs), + f"The Cohere reranking endpoint only supports 1000 documents. The number of documents has been truncated to 1000 from {len(cohere_input_docs)}." ) cohere_input_docs = cohere_input_docs[:1000] From f94f0bf09db6e50654e3496cb04bceded3a2b653 Mon Sep 17 00:00:00 2001 From: Anushree Bannadabhavi Date: Sun, 31 Mar 2024 17:24:31 -0400 Subject: [PATCH 3/4] Fix unit test failure --- test/components/rankers/test_cohere.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/components/rankers/test_cohere.py b/test/components/rankers/test_cohere.py index 3f90bccc05..a4b40ec954 100644 --- a/test/components/rankers/test_cohere.py +++ b/test/components/rankers/test_cohere.py @@ -131,7 +131,8 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): with pytest.raises(Exception, match="The client must be instantiated be either passing in token or setting *"): CohereRanker.from_dict(data) - def test_prepare_cohere_input_docs_default_separator(self): + def test_prepare_cohere_input_docs_default_separator(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"]) documents = [ Document( @@ -155,7 +156,8 @@ def test_prepare_cohere_input_docs_default_separator(self): "meta_value_1 4\nmeta_value_2 9\ndocument number 4", ] - def test_prepare_cohere_input_docs_custom_separator(self): + def test_prepare_cohere_input_docs_custom_separator(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=" ") documents = [ Document( @@ -179,7 +181,8 @@ def test_prepare_cohere_input_docs_custom_separator(self): "meta_value_1 4 meta_value_2 9 document number 4", ] - def test_prepare_cohere_input_docs_no_meta_data(self): + def test_prepare_cohere_input_docs_no_meta_data(self, monkeypatch): + monkeypatch.setenv("CO_API_KEY", "fake-api-key") component = CohereRanker(meta_fields_to_embed=["meta_field_1", "meta_field_2"], meta_data_separator=" ") documents = [Document(content=f"document number {i}") for i in range(5)] From 7fbe9bf3dddae8b08188d3cd75225038ca7a6da2 Mon Sep 17 00:00:00 2001 From: Anushree Bannadabhavi Date: Sun, 31 Mar 2024 18:30:33 -0400 Subject: [PATCH 4/4] Added cohere under dependencies in pyproject.toml --- haystack/components/rankers/cohere.py | 6 ++++-- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/haystack/components/rankers/cohere.py b/haystack/components/rankers/cohere.py index 0e93c10728..292afc9d7e 100644 --- a/haystack/components/rankers/cohere.py +++ b/haystack/components/rankers/cohere.py @@ -1,12 +1,14 @@ from typing import Any, Dict, List, Optional -import cohere - from haystack import Document, component, default_from_dict, default_to_dict, logging +from haystack.lazy_imports import LazyImport from haystack.utils import Secret, deserialize_secrets_inplace logger = logging.getLogger(__name__) +with LazyImport(message="Run 'pip install \"cohere==5.1.7\"'") as cohere_import: + import cohere + @component class CohereRanker: diff --git a/pyproject.toml b/pyproject.toml index 5574a6e186..467cf31523 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ dependencies = [ "requests", "numpy", "python-dateutil", + "cohere==5.1.7", # CohereRanker ] [tool.hatch.envs.default] @@ -106,7 +107,6 @@ extra-dependencies = [ "spacy>=3.7,<3.8", # NamedEntityExtractor "spacy-curated-transformers>=0.2,<=0.3", # NamedEntityExtractor "en-core-web-trf @ https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.7.3/en_core_web_trf-3.7.3-py3-none-any.whl", # NamedEntityExtractor - "cohere==5.1.7", # CohereRanker # Converters "pypdf", # PyPDFConverter