Skip to content

Commit

Permalink
Merge branch 'main' into add-pgvector-datastore
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Jan 5, 2024
2 parents 825f73c + 3e314ea commit 339c6d8
Show file tree
Hide file tree
Showing 13 changed files with 70 additions and 72 deletions.
11 changes: 7 additions & 4 deletions integrations/cohere/src/cohere_haystack/chat/chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,14 @@ def __init__(
"""
cohere_import.check()

api_key = api_key or os.environ.get("COHERE_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
api_key = os.environ.get("COHERE_API_KEY")
if not api_key:
error = "CohereChatGenerator needs an API key to run. Either provide it as init parameter or set the env var COHERE_API_KEY." # noqa: E501
raise ValueError(error)
msg = (
"CohereChatGenerator expects an API key. "
"Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

if not api_base_url:
api_base_url = cohere.COHERE_API_URL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,14 @@ def __init__(
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""

if api_key is None:
try:
api_key = os.environ["COHERE_API_KEY"]
except KeyError as error_msg:
msg = (
"CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment "
"variable COHERE_API_KEY (recommended) or by passing it explicitly."
)
raise ValueError(msg) from error_msg
api_key = api_key or os.environ.get("COHERE_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"CohereDocumentEmbedder expects an API key. "
"Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_key = api_key
self.model_name = model_name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,14 @@ def __init__(
:param timeout: Request timeout in seconds, defaults to `120`.
"""

if api_key is None:
try:
api_key = os.environ["COHERE_API_KEY"]
except KeyError as error_msg:
msg = (
"CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment "
"variable COHERE_API_KEY (recommended) or by passing it explicitly."
)
raise ValueError(msg) from error_msg
api_key = api_key or os.environ.get("COHERE_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"CohereTextEmbedder expects an API key. "
"Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_key = api_key
self.model_name = model_name
Expand Down
8 changes: 4 additions & 4 deletions integrations/cohere/src/cohere_haystack/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,12 @@ def __init__(
- 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include
desired tokens. The format is {token_id: bias} where bias is a float between -10 and 10.
"""
if not api_key:
api_key = os.environ.get("COHERE_API_KEY")
api_key = api_key or os.environ.get("COHERE_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"CohereGenerator needs an API key to run."
"Either provide it as init parameter or set the env var COHERE_API_KEY."
"CohereGenerator expects an API key. "
"Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

Expand Down
4 changes: 2 additions & 2 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_init_default(self):
@pytest.mark.unit
def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("COHERE_API_KEY", raising=False)
with pytest.raises(ValueError, match=r"^CohereChatGenerator needs an API key to run. (.+)$"):
with pytest.raises(ValueError):
CohereChatGenerator()

@pytest.mark.unit
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
with pytest.raises(ValueError, match=r"^CohereChatGenerator needs an API key to run. (.+)$"):
with pytest.raises(ValueError):
CohereChatGenerator.from_dict(data)

@pytest.mark.unit
Expand Down
21 changes: 9 additions & 12 deletions integrations/jina/src/jina_haystack/document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,19 @@ def __init__(
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""
# if the user does not provide the API key, check if it is set in the module client
if api_key is None:
try:
api_key = os.environ["JINA_API_KEY"]
except KeyError as e:
msg = (
"JinaDocumentEmbedder expects a Jina API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg) from e

api_key = api_key or os.environ.get("JINA_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"JinaDocumentEmbedder expects an API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.model_name = model_name
self.prefix = prefix
self.suffix = suffix
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
Expand Down
19 changes: 9 additions & 10 deletions integrations/jina/src/jina_haystack/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,15 @@ def __init__(
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
"""
# if the user does not provide the API key, check if it is set in the module client
if api_key is None:
try:
api_key = os.environ["JINA_API_KEY"]
except KeyError as e:
msg = (
"JinaTextEmbedder expects a Jina API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg) from e

api_key = api_key or os.environ.get("JINA_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"JinaTextEmbedder expects an API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.model_name = model_name
self.prefix = prefix
Expand Down
2 changes: 1 addition & 1 deletion integrations/jina/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_init_with_parameters(self):

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("JINA_API_KEY", raising=False)
with pytest.raises(ValueError, match="JinaDocumentEmbedder expects a Jina API key"):
with pytest.raises(ValueError):
JinaDocumentEmbedder()

def test_to_dict(self):
Expand Down
2 changes: 1 addition & 1 deletion integrations/jina/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_init_with_parameters(self):

def test_init_fail_wo_api_key(self, monkeypatch):
monkeypatch.delenv("JINA_API_KEY", raising=False)
with pytest.raises(ValueError, match="JinaTextEmbedder expects a Jina API key"):
with pytest.raises(ValueError):
JinaTextEmbedder()

def test_to_dict(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
api_key = api_key or os.environ.get("PINECONE_API_KEY")
if not api_key:
msg = (
"PineconeDocumentStore expects a Pinecone API key. "
"PineconeDocumentStore expects an API key. "
"Set the PINECONE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)
Expand Down
6 changes: 3 additions & 3 deletions integrations/qdrant/src/qdrant_haystack/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@component
class QdrantRetriever:
class QdrantEmbeddingRetriever:
"""
A component for retrieving documents from an QdrantDocumentStore.
"""
Expand All @@ -20,7 +20,7 @@ def __init__(
return_embedding: bool = False, # noqa: FBT001, FBT002
):
"""
Create a QdrantRetriever component.
Create a QdrantEmbeddingRetriever component.
:param document_store: An instance of QdrantDocumentStore.
:param filters: A dictionary with filters to narrow down the search space. Default is None.
Expand Down Expand Up @@ -59,7 +59,7 @@ def to_dict(self) -> Dict[str, Any]:
return d

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "QdrantRetriever":
def from_dict(cls, data: Dict[str, Any]) -> "QdrantEmbeddingRetriever":
"""
Deserialize this component from a dictionary.
"""
Expand Down
14 changes: 7 additions & 7 deletions integrations/qdrant/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@
)

from qdrant_haystack import QdrantDocumentStore
from qdrant_haystack.retriever import QdrantRetriever
from qdrant_haystack.retriever import QdrantEmbeddingRetriever


class TestQdrantRetriever(FilterableDocsFixtureMixin):
def test_init_default(self):
document_store = QdrantDocumentStore(location=":memory:", index="test")
retriever = QdrantRetriever(document_store=document_store)
retriever = QdrantEmbeddingRetriever(document_store=document_store)
assert retriever._document_store == document_store
assert retriever._filters is None
assert retriever._top_k == 10
assert retriever._return_embedding is False

def test_to_dict(self):
document_store = QdrantDocumentStore(location=":memory:", index="test")
retriever = QdrantRetriever(document_store=document_store)
retriever = QdrantEmbeddingRetriever(document_store=document_store)
res = retriever.to_dict()
assert res == {
"type": "qdrant_haystack.retriever.QdrantRetriever",
"type": "qdrant_haystack.retriever.QdrantEmbeddingRetriever",
"init_parameters": {
"document_store": {
"type": "qdrant_haystack.document_store.QdrantDocumentStore",
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_to_dict(self):

def test_from_dict(self):
data = {
"type": "qdrant_haystack.retriever.QdrantRetriever",
"type": "qdrant_haystack.retriever.QdrantEmbeddingRetriever",
"init_parameters": {
"document_store": {
"init_parameters": {"location": ":memory:", "index": "test"},
Expand All @@ -86,7 +86,7 @@ def test_from_dict(self):
"return_embedding": True,
},
}
retriever = QdrantRetriever.from_dict(data)
retriever = QdrantEmbeddingRetriever.from_dict(data)
assert isinstance(retriever._document_store, QdrantDocumentStore)
assert retriever._document_store.index == "test"
assert retriever._filters is None
Expand All @@ -99,7 +99,7 @@ def test_run(self, filterable_docs: List[Document]):

document_store.write_documents(filterable_docs)

retriever = QdrantRetriever(document_store=document_store)
retriever = QdrantEmbeddingRetriever(document_store=document_store)

results: List[Document] = retriever.run(query_embedding=_random_embeddings(768))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,16 @@ def __init__(
self.progress_bar = progress_bar

is_hosted_api = api_url == UNSTRUCTURED_HOSTED_API_URL
if api_key is None and is_hosted_api:
try:
api_key = os.environ["UNSTRUCTURED_API_KEY"]
except KeyError as e:
msg = (
"To use the hosted version of Unstructured, you need to set the environment variable "
"UNSTRUCTURED_API_KEY (recommended) or explictly pass the parameter api_key."
)
raise ValueError(msg) from e

api_key = api_key or os.environ.get("UNSTRUCTURED_API_KEY")
# we check whether api_key is None or an empty string
if is_hosted_api and not api_key:
msg = (
"To use the hosted version of Unstructured, you need to set the environment variable "
"UNSTRUCTURED_API_KEY (recommended) or explictly pass the parameter api_key."
)
raise ValueError(msg)

self.api_key = api_key

def to_dict(self) -> Dict[str, Any]:
Expand Down

0 comments on commit 339c6d8

Please sign in to comment.