diff --git a/integrations/astra/README.md b/integrations/astra/README.md index f8b6f7c31..f679b7207 100644 --- a/integrations/astra/README.md +++ b/integrations/astra/README.md @@ -24,8 +24,8 @@ pyenv local 3.9 Local install for the package `pip install -e .` To execute integration tests, add needed environment variables -`ASTRA_DB_API_ENDPOINT=` -`ASTRA_DB_APPLICATION_TOKEN=` +`ASTRA_DB_API_ENDPOINT="https://-.apps.astra.datastax.com"`, +`ASTRA_DB_APPLICATION_TOKEN="AstraCS:..."` and execute `python examples/example.py` @@ -34,10 +34,10 @@ Install requirements Export environment variables ``` -export ASTRA_DB_API_ENDPOINT= -export ASTRA_DB_APPLICATION_TOKEN= -export COLLECTION_NAME= -export OPENAI_API_KEY= +export ASTRA_DB_API_ENDPOINT="https://-.apps.astra.datastax.com" +export ASTRA_DB_APPLICATION_TOKEN="AstraCS:..." +export COLLECTION_NAME="my_collection" +export OPENAI_API_KEY="sk-..." ``` run the python examples @@ -59,19 +59,17 @@ from haystack.document_stores.types.policy import DuplicatePolicy Load in environment variables: ``` -api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT", "") -token = os.getenv("ASTRA_DB_APPLICATION_TOKEN", "") -collection_name = os.getenv("COLLECTION_NAME", "haystack_vector_search") +namespace = os.environ.get("ASTRA_DB_KEYSPACE") +collection_name = os.environ.get("COLLECTION_NAME", "haystack_vector_search") ``` -Create the Document Store object: +Create the Document Store object (API Endpoint and Token are read off the environment): ``` document_store = AstraDocumentStore( - api_endpoint=api_endpoint, - token=token, collection_name=collection_name, + namespace=namespace, duplicates_policy=DuplicatePolicy.SKIP, - embedding_dim=384, + embedding_dimension=384, ) ``` @@ -92,3 +90,31 @@ Add your AstraEmbeddingRetriever into the pipeline Add other components and connect them as desired. Then run your pipeline: `pipeline.run(...)` + +## Warnings about indexing + +When creating an Astra DB document store, you may see a warning similar to the following: + +> Astra DB collection '...' is detected as having indexing turned on for all fields (either created manually or by older versions of this plugin). This implies stricter limitations on the amount of text each string in a document can store. Consider indexing anew on a fresh collection to be able to store longer texts. + +or, + +> Astra DB collection '...' is detected as having the following indexing policy: {...}. This does not match the requested indexing policy for this object: {...}. In particular, there may be stricter limitations on the amount of text each string in a document can store. Consider indexing anew on a fresh collection to be able to store longer texts. + + +The reason for the warning is that the requested collection already exists on the database, and it is configured to [index all of its fields for search](https://docs.datastax.com/en/astra-db-serverless/api-reference/collections.html#the-indexing-option), possibly implicitly, by default. When the Haystack object tries to create it, it attempts to enforce, instead, an indexing policy tailored to the prospected usage: this is both to enable storing very long texts and to avoid indexing fields that will never be used in filtering a search (indexing those would also have a slight performance cost for writes). + +Typically there are two reasons why you may encounter the warning: + +1. you have created a collection by other means than letting this component do it for you: for example, through the Astra UI, or using AstraPy's `create_collection` method of class `Database` directly; +2. you have created the collection with an older version of the plugin. + +Keep in mind that this is a warning and your application will continue running just fine, as long as you don't store very long texts. +However, should you need to add to the document store, for example, a document with a very long textual content, you will get an indexing error from the database. + +### Remediation + +You have several options: + +- you can ignore the warning because you know your application will never need to store very long textual contents; +- if you can afford populating the collection anew, you can drop it and re-run the Haystack application: the collection will be created with the optimized indexing settings. **This is the recommended option, when possible**. diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index 01cc92a69..5a88a0fe9 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -74,12 +74,13 @@ def __init__( caller_version=integration_version, ) + indexing_options = {"indexing": {"deny": NON_INDEXED_FIELDS}} try: # Create and connect to the newly created collection self._astra_db_collection = self._astra_db.create_collection( collection_name=collection_name, dimension=embedding_dimension, - options={"indexing": {"deny": NON_INDEXED_FIELDS}}, + options=indexing_options, ) except APIRequestError: # possibly the collection is preexisting and has legacy @@ -98,11 +99,16 @@ def __init__( if "indexing" not in pre_col_options: warn( ( - f"Collection '{collection_name}' is detected as legacy" - " and has indexing turned on for all fields. This" - " implies stricter limitations on the amount of text" - " each entry can store. Consider reindexing anew on a" - " fresh collection to be able to store longer texts." + f"Astra DB collection '{collection_name}' is " + "detected as having indexing turned on for all " + "fields (either created manually or by older " + "versions of this plugin). This implies stricter " + "limitations on the amount of text each string in a " + "document can store. Consider indexing anew on a " + "fresh collection to be able to store longer texts. " + "See https://github.com/deepset-ai/haystack-core-" + "integrations/blob/main/integrations/astra/README" + ".md#warnings-about-indexing for more details." ), UserWarning, stacklevel=2, @@ -110,16 +116,22 @@ def __init__( self._astra_db_collection = self._astra_db.collection( collection_name=collection_name, ) - else: - options_json = json.dumps(pre_col_options["indexing"]) + elif pre_col_options["indexing"] != indexing_options["indexing"]: + detected_options_json = json.dumps(pre_col_options["indexing"]) + indexing_options_json = json.dumps(indexing_options["indexing"]) warn( ( - f"Collection '{collection_name}' has unexpected 'indexing'" - f" settings (options.indexing = {options_json})." - " This can result in odd behaviour when running " - " metadata filtering and/or unwarranted limitations" - " on storing long texts. Consider reindexing anew on a" - " fresh collection." + f"Astra DB collection '{collection_name}' is " + "detected as having the following indexing policy: " + f"{detected_options_json}. This does not match the requested " + f"indexing policy for this object: {indexing_options_json}. " + "In particular, there may be stricter " + "limitations on the amount of text each string in a " + "document can store. Consider indexing anew on a " + "fresh collection to be able to store longer texts. " + "See https://github.com/deepset-ai/haystack-core-" + "integrations/blob/main/integrations/astra/README" + ".md#warnings-about-indexing for more details." ), UserWarning, stacklevel=2, @@ -127,6 +139,9 @@ def __init__( self._astra_db_collection = self._astra_db.collection( collection_name=collection_name, ) + else: + # the collection mismatch lies elsewhere than the indexing + raise else: # other exception raise diff --git a/integrations/langfuse/README.md b/integrations/langfuse/README.md index 901ac122b..3b94a01a1 100644 --- a/integrations/langfuse/README.md +++ b/integrations/langfuse/README.md @@ -36,7 +36,7 @@ os.environ["LANGFUSE_HOST"] = "https://cloud.langfuse.com" os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" -from haystack.components.builders import DynamicChatPromptBuilder +from haystack.components.builders import ChatPromptBuilder from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage from haystack import Pipeline @@ -46,7 +46,7 @@ from haystack_integrations.components.connectors.langfuse import LangfuseConnect if __name__ == "__main__": pipe = Pipeline() pipe.add_component("tracer", LangfuseConnector("Chat example")) - pipe.add_component("prompt_builder", DynamicChatPromptBuilder()) + pipe.add_component("prompt_builder", ChatPromptBuilder()) pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) pipe.connect("prompt_builder.prompt", "llm.messages") @@ -57,7 +57,7 @@ if __name__ == "__main__": ] response = pipe.run( - data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "prompt_source": messages}} + data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}} ) print(response["llm"]["replies"][0]) print(response["tracer"]["trace_url"]) diff --git a/integrations/langfuse/example/chat.py b/integrations/langfuse/example/chat.py index 99ed7a238..443d65a13 100644 --- a/integrations/langfuse/example/chat.py +++ b/integrations/langfuse/example/chat.py @@ -3,7 +3,7 @@ os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" from haystack import Pipeline -from haystack.components.builders import DynamicChatPromptBuilder +from haystack.components.builders import ChatPromptBuilder from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage from haystack_integrations.components.connectors.langfuse import LangfuseConnector @@ -12,7 +12,7 @@ pipe = Pipeline() pipe.add_component("tracer", LangfuseConnector("Chat example")) - pipe.add_component("prompt_builder", DynamicChatPromptBuilder()) + pipe.add_component("prompt_builder", ChatPromptBuilder()) pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) pipe.connect("prompt_builder.prompt", "llm.messages") @@ -22,8 +22,6 @@ ChatMessage.from_user("Tell me about {{location}}"), ] - response = pipe.run( - data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "prompt_source": messages}} - ) + response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) print(response["llm"]["replies"][0]) print(response["tracer"]["trace_url"]) diff --git a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py index cfe150317..51703823e 100644 --- a/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py +++ b/integrations/langfuse/src/haystack_integrations/components/connectors/langfuse/langfuse_connector.py @@ -40,6 +40,7 @@ class LangfuseConnector: # ... + @app.on_event("shutdown") async def shutdown_event(): tracer.actual_tracer.flush() @@ -53,27 +54,35 @@ async def shutdown_event(): os.environ["HAYSTACK_CONTENT_TRACING_ENABLED"] = "true" from haystack import Pipeline - from haystack.components.builders import DynamicChatPromptBuilder + from haystack.components.builders import ChatPromptBuilder from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage - from haystack_integrations.components.connectors.langfuse import LangfuseConnector + from haystack_integrations.components.connectors.langfuse import ( + LangfuseConnector, + ) if __name__ == "__main__": - pipe = Pipeline() pipe.add_component("tracer", LangfuseConnector("Chat example")) - pipe.add_component("prompt_builder", DynamicChatPromptBuilder()) + pipe.add_component("prompt_builder", ChatPromptBuilder()) pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) pipe.connect("prompt_builder.prompt", "llm.messages") messages = [ - ChatMessage.from_system("Always respond in German even if some input data is in other languages."), + ChatMessage.from_system( + "Always respond in German even if some input data is in other languages." + ), ChatMessage.from_user("Tell me about {{location}}"), ] response = pipe.run( - data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "prompt_source": messages}} + data={ + "prompt_builder": { + "template_variables": {"location": "Berlin"}, + "template": messages, + } + } ) print(response["llm"]["replies"][0]) print(response["tracer"]["trace_url"]) diff --git a/integrations/langfuse/tests/test_tracing.py b/integrations/langfuse/tests/test_tracing.py index 4fc1cd9ce..111d89dfd 100644 --- a/integrations/langfuse/tests/test_tracing.py +++ b/integrations/langfuse/tests/test_tracing.py @@ -9,7 +9,7 @@ import requests from haystack import Pipeline -from haystack.components.builders import DynamicChatPromptBuilder +from haystack.components.builders import ChatPromptBuilder from haystack.components.generators.chat import OpenAIChatGenerator from haystack.dataclasses import ChatMessage from requests.auth import HTTPBasicAuth @@ -26,7 +26,7 @@ def test_tracing_integration(): pipe = Pipeline() pipe.add_component("tracer", LangfuseConnector(name="Chat example", public=True)) # public so anyone can verify run - pipe.add_component("prompt_builder", DynamicChatPromptBuilder()) + pipe.add_component("prompt_builder", ChatPromptBuilder()) pipe.add_component("llm", OpenAIChatGenerator(model="gpt-3.5-turbo")) pipe.connect("prompt_builder.prompt", "llm.messages") @@ -36,9 +36,7 @@ def test_tracing_integration(): ChatMessage.from_user("Tell me about {{location}}"), ] - response = pipe.run( - data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "prompt_source": messages}} - ) + response = pipe.run(data={"prompt_builder": {"template_variables": {"location": "Berlin"}, "template": messages}}) assert "Berlin" in response["llm"]["replies"][0].content assert response["tracer"]["trace_url"] url = "https://cloud.langfuse.com/api/public/traces/" diff --git a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py index 88eec4aed..93eb87005 100644 --- a/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py +++ b/integrations/mongodb_atlas/src/haystack_integrations/document_stores/mongodb_atlas/document_store.py @@ -82,7 +82,6 @@ def __init__( msg = f'Invalid collection name: "{collection_name}". It can only contain letters, numbers, -, or _.' raise ValueError(msg) - self.resolved_connection_string = mongo_connection_string.resolve_value() self.mongo_connection_string = mongo_connection_string self.database_name = database_name @@ -95,7 +94,7 @@ def __init__( def connection(self) -> MongoClient: if self._connection is None: self._connection = MongoClient( - self.resolved_connection_string, driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") + self.mongo_connection_string.resolve_value(), driver=DriverInfo(name="MongoDBAtlasHaystackIntegration") ) return self._connection diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py index fad264a46..4cc805c01 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/document_embedder.py @@ -2,12 +2,15 @@ from haystack import Document, component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.utils.nvidia import url_validation from tqdm import tqdm from ._nim_backend import NimBackend from .backend import EmbedderBackend from .truncate import EmbeddingTruncateMode +_DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" + @component class NvidiaDocumentEmbedder: @@ -33,7 +36,7 @@ def __init__( self, model: str = "NV-Embed-QA", api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), - api_url: str = "https://ai.api.nvidia.com/v1/retrieval/nvidia", + api_url: str = _DEFAULT_API_URL, prefix: str = "", suffix: str = "", batch_size: int = 32, @@ -51,6 +54,7 @@ def __init__( API key for the NVIDIA NIM. :param api_url: Custom API URL for the NVIDIA NIM. + Format for API URL is http://host:port :param prefix: A string to add to the beginning of each text. :param suffix: @@ -71,7 +75,7 @@ def __init__( self.api_key = api_key self.model = model - self.api_url = api_url + self.api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/embeddings"]) self.prefix = prefix self.suffix = suffix self.batch_size = batch_size diff --git a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py index 8923f8c81..e1a8c36dd 100644 --- a/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py +++ b/integrations/nvidia/src/haystack_integrations/components/embedders/nvidia/text_embedder.py @@ -2,11 +2,14 @@ from haystack import component, default_from_dict, default_to_dict from haystack.utils import Secret, deserialize_secrets_inplace +from haystack_integrations.utils.nvidia import url_validation from ._nim_backend import NimBackend from .backend import EmbedderBackend from .truncate import EmbeddingTruncateMode +_DEFAULT_API_URL = "https://ai.api.nvidia.com/v1/retrieval/nvidia" + @component class NvidiaTextEmbedder: @@ -34,7 +37,7 @@ def __init__( self, model: str = "NV-Embed-QA", api_key: Optional[Secret] = Secret.from_env_var("NVIDIA_API_KEY"), - api_url: str = "https://ai.api.nvidia.com/v1/retrieval/nvidia", + api_url: str = _DEFAULT_API_URL, prefix: str = "", suffix: str = "", truncate: Optional[Union[EmbeddingTruncateMode, str]] = None, @@ -48,6 +51,7 @@ def __init__( API key for the NVIDIA NIM. :param api_url: Custom API URL for the NVIDIA NIM. + Format for API URL is http://host:port :param prefix: A string to add to the beginning of each text. :param suffix: @@ -59,7 +63,7 @@ def __init__( self.api_key = api_key self.model = model - self.api_url = api_url + self.api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/embeddings"]) self.prefix = prefix self.suffix = suffix diff --git a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py index 7038e6251..6aea421dd 100644 --- a/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py +++ b/integrations/nvidia/src/haystack_integrations/components/generators/nvidia/generator.py @@ -5,6 +5,7 @@ from haystack import component, default_from_dict, default_to_dict from haystack.utils.auth import Secret, deserialize_secrets_inplace +from haystack_integrations.utils.nvidia import url_validation from ._nim_backend import NimBackend from .backend import GeneratorBackend @@ -63,7 +64,7 @@ def __init__( to know the supported arguments. """ self._model = model - self._api_url = api_url + self._api_url = url_validation(api_url, _DEFAULT_API_URL, ["v1/chat/completions"]) self._api_key = api_key self._model_arguments = model_arguments or {} diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py new file mode 100644 index 000000000..9863e4a38 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/__init__.py @@ -0,0 +1,3 @@ +from .utils import url_validation + +__all__ = ["url_validation"] diff --git a/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py new file mode 100644 index 000000000..4f8e14b09 --- /dev/null +++ b/integrations/nvidia/src/haystack_integrations/utils/nvidia/utils.py @@ -0,0 +1,39 @@ +import warnings +from typing import List +from urllib.parse import urlparse, urlunparse + + +def url_validation(api_url: str, default_api_url: str, allowed_paths: List[str]) -> str: + """ + Validate and normalize an API URL. + + :param api_url: + The API URL to validate and normalize. + :param default_api_url: + The default API URL for comparison. + :param allowed_paths: + A list of allowed base paths that are valid if present in the URL. + :returns: + A normalized version of the API URL with '/v1' path appended, if needed. + :raises ValueError: + If the base URL path is not recognized or does not match expected format. + """ + ## Making sure /v1 in added to the url, followed by infer_path + result = urlparse(api_url) + expected_format = "Expected format is 'http://host:port'." + + if api_url == default_api_url: + return api_url + if result.path: + normalized_path = result.path.strip("/") + if normalized_path == "v1": + pass + elif normalized_path in allowed_paths: + warn_msg = f"{expected_format} Rest is ignored." + warnings.warn(warn_msg, stacklevel=2) + else: + err_msg = f"Base URL path is not recognized. {expected_format}" + raise ValueError(err_msg) + + base_url = urlunparse((result.scheme, result.netloc, "v1", "", "", "")) + return base_url diff --git a/integrations/nvidia/tests/test_base_url.py b/integrations/nvidia/tests/test_base_url.py new file mode 100644 index 000000000..072807685 --- /dev/null +++ b/integrations/nvidia/tests/test_base_url.py @@ -0,0 +1,64 @@ +import pytest +from haystack_integrations.components.embedders.nvidia import NvidiaDocumentEmbedder, NvidiaTextEmbedder +from haystack_integrations.components.generators.nvidia import NvidiaGenerator + + +@pytest.mark.parametrize( + "base_url", + [ + "http://localhost:8888/embeddings", + "http://0.0.0.0:8888/rankings", + "http://0.0.0.0:8888/v1/rankings", + "http://localhost:8888/chat/completions", + "http://localhost:8888/v1/chat/completions", + ], +) +@pytest.mark.parametrize( + "embedder", + [NvidiaDocumentEmbedder, NvidiaTextEmbedder], +) +def test_base_url_invalid_not_hosted(base_url: str, embedder) -> None: + with pytest.raises(ValueError): + embedder(api_url=base_url, model="x") + + +@pytest.mark.parametrize( + "base_url", + ["http://localhost:8080/v1/embeddings", "http://0.0.0.0:8888/v1/embeddings"], +) +@pytest.mark.parametrize( + "embedder", + [NvidiaDocumentEmbedder, NvidiaTextEmbedder], +) +def test_base_url_valid_embedder(base_url: str, embedder) -> None: + with pytest.warns(UserWarning): + embedder(api_url=base_url) + + +@pytest.mark.parametrize( + "base_url", + [ + "http://localhost:8080/v1/chat/completions", + "http://0.0.0.0:8888/v1/chat/completions", + ], +) +def test_base_url_valid_generator(base_url: str) -> None: + with pytest.warns(UserWarning): + NvidiaGenerator( + api_url=base_url, + model="mistralai/mixtral-8x7b-instruct-v0.1", + ) + + +@pytest.mark.parametrize( + "base_url", + [ + "http://localhost:8888/embeddings", + "http://0.0.0.0:8888/rankings", + "http://0.0.0.0:8888/v1/rankings", + "http://localhost:8888/chat/completions", + ], +) +def test_base_url_invalid_generator(base_url: str) -> None: + with pytest.raises(ValueError): + NvidiaGenerator(api_url=base_url, model="x") diff --git a/integrations/nvidia/tests/test_document_embedder.py b/integrations/nvidia/tests/test_document_embedder.py index 06587cd78..856ae4652 100644 --- a/integrations/nvidia/tests/test_document_embedder.py +++ b/integrations/nvidia/tests/test_document_embedder.py @@ -33,27 +33,28 @@ def test_init_default(self, monkeypatch): assert embedder.embedding_separator == "\n" def test_init_with_parameters(self): - embedder = NvidiaDocumentEmbedder( - api_key=Secret.from_token("fake-api-key"), - model="nvolveqa_40k", - api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/test", - prefix="prefix", - suffix="suffix", - batch_size=30, - progress_bar=False, - meta_fields_to_embed=["test_field"], - embedding_separator=" | ", - ) - - assert embedder.api_key == Secret.from_token("fake-api-key") - assert embedder.model == "nvolveqa_40k" - assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia/test" - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" - assert embedder.batch_size == 30 - assert embedder.progress_bar is False - assert embedder.meta_fields_to_embed == ["test_field"] - assert embedder.embedding_separator == " | " + with pytest.raises(ValueError): + embedder = NvidiaDocumentEmbedder( + api_key=Secret.from_token("fake-api-key"), + model="nvolveqa_40k", + api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/test", + prefix="prefix", + suffix="suffix", + batch_size=30, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator=" | ", + ) + + assert embedder.api_key == Secret.from_token("fake-api-key") + assert embedder.model == "nvolveqa_40k" + assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia/test" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert embedder.batch_size == 30 + assert embedder.progress_bar is False + assert embedder.meta_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == " | " def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("NVIDIA_API_KEY", raising=False) @@ -99,7 +100,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "type": "haystack_integrations.components.embedders.nvidia.document_embedder.NvidiaDocumentEmbedder", "init_parameters": { "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, - "api_url": "https://example.com", + "api_url": "https://example.com/v1", "model": "playground_nvolveqa_40k", "prefix": "prefix", "suffix": "suffix", @@ -130,7 +131,7 @@ def from_dict(self, monkeypatch): } component = NvidiaDocumentEmbedder.from_dict(data) assert component.model == "nvolveqa_40k" - assert component.api_url == "https://example.com" + assert component.api_url == "https://example.com/v1" assert component.prefix == "prefix" assert component.suffix == "suffix" assert component.batch_size == 32 diff --git a/integrations/nvidia/tests/test_generator.py b/integrations/nvidia/tests/test_generator.py index 60f83dc43..3ddeebe88 100644 --- a/integrations/nvidia/tests/test_generator.py +++ b/integrations/nvidia/tests/test_generator.py @@ -80,7 +80,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "type": "haystack_integrations.components.generators.nvidia.generator.NvidiaGenerator", "init_parameters": { "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, - "api_url": "https://my.url.com", + "api_url": "https://my.url.com/v1", "model": "playground_nemotron_steerlm_8b", "model_arguments": { "temperature": 0.2, diff --git a/integrations/nvidia/tests/test_text_embedder.py b/integrations/nvidia/tests/test_text_embedder.py index 30f529534..42d60dee2 100644 --- a/integrations/nvidia/tests/test_text_embedder.py +++ b/integrations/nvidia/tests/test_text_embedder.py @@ -28,18 +28,19 @@ def test_init_default(self, monkeypatch): assert embedder.suffix == "" def test_init_with_parameters(self): - embedder = NvidiaTextEmbedder( - api_key=Secret.from_token("fake-api-key"), - model="nvolveqa_40k", - api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/test", - prefix="prefix", - suffix="suffix", - ) - assert embedder.api_key == Secret.from_token("fake-api-key") - assert embedder.model == "nvolveqa_40k" - assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia/test" - assert embedder.prefix == "prefix" - assert embedder.suffix == "suffix" + with pytest.raises(ValueError): + embedder = NvidiaTextEmbedder( + api_key=Secret.from_token("fake-api-key"), + model="nvolveqa_40k", + api_url="https://ai.api.nvidia.com/v1/retrieval/nvidia/test", + prefix="prefix", + suffix="suffix", + ) + assert embedder.api_key == Secret.from_token("fake-api-key") + assert embedder.model == "nvolveqa_40k" + assert embedder.api_url == "https://ai.api.nvidia.com/v1/retrieval/nvidia/test" + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("NVIDIA_API_KEY", raising=False) @@ -77,7 +78,7 @@ def test_to_dict_with_custom_init_parameters(self, monkeypatch): "type": "haystack_integrations.components.embedders.nvidia.text_embedder.NvidiaTextEmbedder", "init_parameters": { "api_key": {"env_vars": ["NVIDIA_API_KEY"], "strict": True, "type": "env_var"}, - "api_url": "https://example.com", + "api_url": "https://example.com/v1", "model": "nvolveqa_40k", "prefix": "prefix", "suffix": "suffix", @@ -100,7 +101,7 @@ def from_dict(self, monkeypatch): } component = NvidiaTextEmbedder.from_dict(data) assert component.model == "nvolveqa_40k" - assert component.api_url == "https://example.com" + assert component.api_url == "https://example.com/v1" assert component.prefix == "prefix" assert component.suffix == "suffix" assert component.truncate == "START" diff --git a/integrations/opensearch/CHANGELOG.md b/integrations/opensearch/CHANGELOG.md index dd1ddb86e..6509d1e0f 100644 --- a/integrations/opensearch/CHANGELOG.md +++ b/integrations/opensearch/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## [integrations/opensearch-v0.9.0] - 2024-08-01 + +### 🚀 Features + +- Support aws authentication with OpenSearchDocumentStore (#920) + ## [integrations/opensearch-v0.8.1] - 2024-07-15 ### 🚀 Features diff --git a/integrations/opensearch/pyproject.toml b/integrations/opensearch/pyproject.toml index aed34d503..842b46415 100644 --- a/integrations/opensearch/pyproject.toml +++ b/integrations/opensearch/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "pytest-rerunfailures", "pytest-xdist", "haystack-pydoc-tools", + "boto3", ] [tool.hatch.envs.default.scripts] test = "pytest --reruns 3 --reruns-delay 30 -x {args:tests}" @@ -61,7 +62,7 @@ python = ["3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true -dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243", "boto3"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" style = ["ruff check {args:.}", "black --check --diff {args:.}"] @@ -154,5 +155,5 @@ minversion = "6.0" markers = ["unit: unit tests", "integration: integration tests"] [[tool.mypy.overrides]] -module = ["haystack.*", "haystack_integrations.*", "pytest.*", "opensearchpy.*"] +module = ["botocore.*", "boto3.*", "haystack.*", "haystack_integrations.*", "pytest.*", "opensearchpy.*"] ignore_missing_imports = true diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py new file mode 100644 index 000000000..8249c16ca --- /dev/null +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/auth.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass, field, fields +from typing import Any, Dict, Optional + +from haystack import default_from_dict, default_to_dict +from haystack.document_stores.errors import DocumentStoreError +from haystack.lazy_imports import LazyImport +from haystack.utils.auth import Secret, deserialize_secrets_inplace +from opensearchpy import Urllib3AWSV4SignerAuth + +with LazyImport("Run 'pip install \"boto3\"' to install boto3.") as boto3_import: + import boto3 + from botocore.exceptions import BotoCoreError + + +AWS_CONFIGURATION_KEYS = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_profile_name", +] + + +class AWSConfigurationError(DocumentStoreError): + """Exception raised when AWS is not configured correctly""" + + +def _get_aws_session( + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + **kwargs, +): + """ + Creates an AWS Session with the given parameters. + Checks if the provided AWS credentials are valid and can be used to connect to AWS. + + :param aws_access_key_id: AWS access key ID. + :param aws_secret_access_key: AWS secret access key. + :param aws_session_token: AWS session token. + :param aws_region_name: AWS region name. + :param aws_profile_name: AWS profile name. + :param kwargs: The kwargs passed down to the service client. Supported kwargs depend on the model chosen. + See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html. + :raises AWSConfigurationError: If the provided AWS credentials are invalid. + :returns: The created AWS session. + """ + boto3_import.check() + try: + return boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + region_name=aws_region_name, + profile_name=aws_profile_name, + ) + except BotoCoreError as e: + provided_aws_config = {k: v for k, v in kwargs.items() if k in AWS_CONFIGURATION_KEYS} + msg = f"Failed to initialize the session with provided AWS credentials {provided_aws_config}" + raise AWSConfigurationError(msg) from e + + +@dataclass() +class AWSAuth: + """ + Auth credentials for AWS OpenSearch services. + + This class works as a thin wrapper around the `Urllib3AWSV4SignerAuth` class from the `opensearch-py` library. + It facilitates the creation of the `Urllib3AWSV4SignerAuth` by making use of Haystack secrets and taking care of + the necessary `Urllib3AWSV4SignerAuth` creation steps including boto3 Sessions and boto3 credentials. + """ + + aws_access_key_id: Optional[Secret] = field( + default_factory=lambda: Secret.from_env_var("AWS_ACCESS_KEY_ID", strict=False) + ) + aws_secret_access_key: Optional[Secret] = field( + default_factory=lambda: Secret.from_env_var("AWS_SECRET_ACCESS_KEY", strict=False) + ) + aws_session_token: Optional[Secret] = field( + default_factory=lambda: Secret.from_env_var("AWS_SESSION_TOKEN", strict=False) + ) + aws_region_name: Optional[Secret] = field( + default_factory=lambda: Secret.from_env_var("AWS_DEFAULT_REGION", strict=False) + ) + aws_profile_name: Optional[Secret] = field(default_factory=lambda: Secret.from_env_var("AWS_PROFILE", strict=False)) + aws_service: str = field(default="es") + + def __post_init__(self) -> None: + """ + Initializes the AWSAuth object. + """ + self._urllib3_aws_v4_signer_auth = self._get_urllib3_aws_v4_signer_auth() + + def to_dict(self) -> Dict[str, Any]: + """ + Converts the object to a dictionary representation for serialization. + """ + _fields = {} + for _field in fields(self): + field_value = getattr(self, _field.name) + if _field.type == Optional[Secret]: + _fields[_field.name] = field_value.to_dict() if field_value is not None else None + else: + _fields[_field.name] = field_value + + return default_to_dict(self, **_fields) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> Optional["AWSAuth"]: + """ + Converts a dictionary representation to an AWSAuth object. + """ + init_parameters = data.get("init_parameters", {}) + deserialize_secrets_inplace( + init_parameters, + ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name"], + ) + return default_from_dict(cls, data) + + def __call__(self, method: str, url: str, body: Any) -> Dict[str, str]: + """ + Signs the request and returns headers. + + This method is executed by Urllib3 when making a request to the OpenSearch service. + + :param method: HTTP method + :param url: URL + :param body: Body + """ + return self._urllib3_aws_v4_signer_auth(method, url, body) + + def _get_urllib3_aws_v4_signer_auth(self) -> Urllib3AWSV4SignerAuth: + def resolve_secret(secret: Optional[Secret]) -> Optional[str]: + return secret.resolve_value() if secret else None + + try: + region_name = resolve_secret(self.aws_region_name) + session = _get_aws_session( + aws_access_key_id=resolve_secret(self.aws_access_key_id), + aws_secret_access_key=resolve_secret(self.aws_secret_access_key), + aws_session_token=resolve_secret(self.aws_session_token), + aws_region_name=region_name, + aws_profile_name=resolve_secret(self.aws_profile_name), + ) + credentials = session.get_credentials() + return Urllib3AWSV4SignerAuth(credentials, region_name, self.aws_service) + except Exception as exception: + msg = ( + "Could not connect to AWS OpenSearch. Make sure the AWS environment is configured correctly. " + "See https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html#configuration" + ) + raise AWSConfigurationError(msg) from exception diff --git a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py index 929e24b17..465897608 100644 --- a/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py +++ b/integrations/opensearch/src/haystack_integrations/document_stores/opensearch/document_store.py @@ -10,6 +10,7 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils.filters import convert +from haystack_integrations.document_stores.opensearch.auth import AWSAuth from haystack_integrations.document_stores.opensearch.filters import normalize_filters from opensearchpy import OpenSearch from opensearchpy.helpers import bulk @@ -44,6 +45,10 @@ def __init__( mappings: Optional[Dict[str, Any]] = None, settings: Optional[Dict[str, Any]] = DEFAULT_SETTINGS, create_index: bool = True, + http_auth: Any = None, + use_ssl: Optional[bool] = None, + verify_certs: Optional[bool] = None, + timeout: Optional[int] = None, **kwargs, ): """ @@ -69,6 +74,16 @@ def __init__( :param settings: The settings of the index to be created. Please see the [official OpenSearch docs](https://opensearch.org/docs/latest/search-plugins/knn/knn-index/#index-settings) for more information. Defaults to {"index.knn": True} :param create_index: Whether to create the index if it doesn't exist. Defaults to True + :param http_auth: http_auth param passed to the underying connection class. + For basic authentication with default connection class `Urllib3HttpConnection` this can be + - a tuple of (username, password) + - a list of [username, password] + - a string of "username:password" + For AWS authentication with `Urllib3HttpConnection` pass an instance of `AWSAuth`. + Defaults to None + :param use_ssl: Whether to use SSL. Defaults to None + :param verify_certs: Whether to verify certificates. Defaults to None + :param timeout: Timeout in seconds. Defaults to None :param **kwargs: Optional arguments that ``OpenSearch`` takes. For the full list of supported kwargs, see the [official OpenSearch reference](https://opensearch-project.github.io/opensearch-py/api-ref/clients/opensearch_client.html) """ @@ -82,6 +97,10 @@ def __init__( self._mappings = mappings or self._get_default_mappings() self._settings = settings self._create_index = create_index + self._http_auth = http_auth + self._use_ssl = use_ssl + self._verify_certs = verify_certs + self._timeout = timeout self._kwargs = kwargs def _get_default_mappings(self) -> Dict[str, Any]: @@ -106,9 +125,14 @@ def _get_default_mappings(self) -> Dict[str, Any]: @property def client(self) -> OpenSearch: if not self._client: - self._client = OpenSearch(self._hosts, **self._kwargs) - # Check client connection, this will raise if not connected - self._client.info() # type:ignore + self._client = OpenSearch( + hosts=self._hosts, + http_auth=self._http_auth, + use_ssl=self._use_ssl, + verify_certs=self._verify_certs, + timeout=self._timeout, + **self._kwargs, + ) if self._client.indices.exists(index=self._index): # type:ignore logger.debug( @@ -170,6 +194,10 @@ def to_dict(self) -> Dict[str, Any]: settings=self._settings, create_index=self._create_index, return_embedding=self._return_embedding, + http_auth=self._http_auth.to_dict() if isinstance(self._http_auth, AWSAuth) else self._http_auth, + use_ssl=self._use_ssl, + verify_certs=self._verify_certs, + timeout=self._timeout, **self._kwargs, ) @@ -184,6 +212,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchDocumentStore": :returns: Deserialized component. """ + if http_auth := data.get("init_parameters", {}).get("http_auth"): + if isinstance(http_auth, dict): + data["init_parameters"]["http_auth"] = AWSAuth.from_dict(http_auth) + return default_from_dict(cls, data) def count_documents(self) -> int: diff --git a/integrations/opensearch/tests/test_auth.py b/integrations/opensearch/tests/test_auth.py new file mode 100644 index 000000000..25bda7d66 --- /dev/null +++ b/integrations/opensearch/tests/test_auth.py @@ -0,0 +1,113 @@ +from unittest.mock import Mock, patch + +import pytest +from haystack_integrations.document_stores.opensearch.auth import AWSAuth +from opensearchpy import Urllib3AWSV4SignerAuth + + +class TestAWSAuth: + @pytest.fixture(autouse=True) + def mock_boto3_session(self): + with patch("boto3.Session") as mock_client: + yield mock_client + + @pytest.fixture(autouse=True) + def set_aws_env_variables(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "some_fake_id") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "some_fake_key") + monkeypatch.setenv("AWS_SESSION_TOKEN", "some_fake_token") + monkeypatch.setenv("AWS_DEFAULT_REGION", "fake_region") + monkeypatch.setenv("AWS_PROFILE", "some_fake_profile") + + def test_init(self, mock_boto3_session): + aws_auth = AWSAuth() + assert isinstance(aws_auth._urllib3_aws_v4_signer_auth, Urllib3AWSV4SignerAuth) + mock_boto3_session.assert_called_with( + aws_access_key_id="some_fake_id", + aws_secret_access_key="some_fake_key", + aws_session_token="some_fake_token", + profile_name="some_fake_profile", + region_name="fake_region", + ) + + def test_to_dict(self): + aws_auth = AWSAuth() + res = aws_auth.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_service": "es", + }, + } + + def test_from_dict(self): + data = { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_service": "es", + }, + } + aws_auth = AWSAuth.from_dict(data) + assert aws_auth.aws_access_key_id.resolve_value() == "some_fake_id" + assert aws_auth.aws_secret_access_key.resolve_value() == "some_fake_key" + assert aws_auth.aws_session_token.resolve_value() == "some_fake_token" + assert aws_auth.aws_region_name.resolve_value() == "fake_region" + assert aws_auth.aws_profile_name.resolve_value() == "some_fake_profile" + assert aws_auth.aws_service == "es" + assert isinstance(aws_auth._urllib3_aws_v4_signer_auth, Urllib3AWSV4SignerAuth) + + def test_from_dict_no_init_parameters(self): + data = {"type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth"} + aws_auth = AWSAuth.from_dict(data) + assert aws_auth.aws_access_key_id.resolve_value() == "some_fake_id" + assert aws_auth.aws_secret_access_key.resolve_value() == "some_fake_key" + assert aws_auth.aws_session_token.resolve_value() == "some_fake_token" + assert aws_auth.aws_region_name.resolve_value() == "fake_region" + assert aws_auth.aws_profile_name.resolve_value() == "some_fake_profile" + assert aws_auth.aws_service == "es" + assert isinstance(aws_auth._urllib3_aws_v4_signer_auth, Urllib3AWSV4SignerAuth) + + def test_from_dict_disable_env_variables(self): + data = { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": { + "aws_access_key_id": None, + "aws_secret_access_key": None, + "aws_session_token": None, + "aws_service": "aoss", + }, + } + aws_auth = AWSAuth.from_dict(data) + assert aws_auth.aws_access_key_id is None + assert aws_auth.aws_secret_access_key is None + assert aws_auth.aws_session_token is None + assert aws_auth.aws_region_name.resolve_value() == "fake_region" + assert aws_auth.aws_profile_name.resolve_value() == "some_fake_profile" + assert aws_auth.aws_service == "aoss" + assert isinstance(aws_auth._urllib3_aws_v4_signer_auth, Urllib3AWSV4SignerAuth) + + @patch("haystack_integrations.document_stores.opensearch.auth.AWSAuth._get_urllib3_aws_v4_signer_auth") + def test_call(self, _get_urllib3_aws_v4_signer_auth_mock): + signer_auth_mock = Mock(spec=Urllib3AWSV4SignerAuth) + _get_urllib3_aws_v4_signer_auth_mock.return_value = signer_auth_mock + aws_auth = AWSAuth() + aws_auth(method="GET", url="http://some.url", body="some body") + signer_auth_mock.assert_called_once_with("GET", "http://some.url", "some body") diff --git a/integrations/opensearch/tests/test_bm25_retriever.py b/integrations/opensearch/tests/test_bm25_retriever.py index 1cce2961c..c015d360a 100644 --- a/integrations/opensearch/tests/test_bm25_retriever.py +++ b/integrations/opensearch/tests/test_bm25_retriever.py @@ -54,6 +54,10 @@ def test_to_dict(_mock_opensearch_client): "settings": {"index.knn": True}, "return_embedding": False, "create_index": True, + "http_auth": None, + "use_ssl": None, + "verify_certs": None, + "timeout": None, }, "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", }, diff --git a/integrations/opensearch/tests/test_document_store.py b/integrations/opensearch/tests/test_document_store.py index f41cf071e..287c24f63 100644 --- a/integrations/opensearch/tests/test_document_store.py +++ b/integrations/opensearch/tests/test_document_store.py @@ -10,7 +10,9 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.testing.document_store import DocumentStoreBaseTests +from haystack.utils.auth import Secret from haystack_integrations.document_stores.opensearch import OpenSearchDocumentStore +from haystack_integrations.document_stores.opensearch.auth import AWSAuth from haystack_integrations.document_stores.opensearch.document_store import DEFAULT_MAX_CHUNK_BYTES from opensearchpy.exceptions import RequestError @@ -37,6 +39,10 @@ def test_to_dict(_mock_opensearch_client): "settings": {"index.knn": True}, "return_embedding": False, "create_index": True, + "http_auth": None, + "use_ssl": None, + "verify_certs": None, + "timeout": None, }, } @@ -52,6 +58,11 @@ def test_from_dict(_mock_opensearch_client): "embedding_dim": 1536, "create_index": False, "return_embedding": True, + "aws_service": "es", + "http_auth": ("admin", "admin"), + "use_ssl": True, + "verify_certs": True, + "timeout": 60, }, } document_store = OpenSearchDocumentStore.from_dict(data) @@ -77,6 +88,10 @@ def test_from_dict(_mock_opensearch_client): assert document_store._settings == {"index.knn": True} assert document_store._return_embedding is True assert document_store._create_index is False + assert document_store._http_auth == ("admin", "admin") + assert document_store._use_ssl is True + assert document_store._verify_certs is True + assert document_store._timeout == 60 @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") @@ -96,6 +111,158 @@ def test_get_default_mappings(_mock_opensearch_client): } +class TestAuth: + @pytest.fixture(autouse=True) + def mock_boto3_session(self): + with patch("boto3.Session") as mock_client: + yield mock_client + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_init_with_basic_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="testhost", http_auth=("user", "pw")) + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] == ("user", "pw") + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_init_without_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="testhost") + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] is None + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_init_aws_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore( + hosts="testhost", + http_auth=AWSAuth(aws_region_name=Secret.from_token("dummy-region")), + use_ssl=True, + verify_certs=True, + ) + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert isinstance(_mock_opensearch_client.call_args[1]["http_auth"], AWSAuth) + assert _mock_opensearch_client.call_args[1]["use_ssl"] is True + assert _mock_opensearch_client.call_args[1]["verify_certs"] is True + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_from_dict_basic_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore.from_dict( + { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "testhost", + "http_auth": ["user", "pw"], + "use_ssl": True, + "verify_certs": True, + }, + } + ) + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert _mock_opensearch_client.call_args[1]["http_auth"] == ["user", "pw"] + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_from_dict_aws_auth(self, _mock_opensearch_client, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("AWS_DEFAULT_REGION", "dummy-region") + document_store = OpenSearchDocumentStore.from_dict( + { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "hosts": "testhost", + "http_auth": { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": {}, + }, + "use_ssl": True, + "verify_certs": True, + }, + } + ) + assert document_store.client + _mock_opensearch_client.assert_called_once() + assert isinstance(_mock_opensearch_client.call_args[1]["http_auth"], AWSAuth) + assert _mock_opensearch_client.call_args[1]["use_ssl"] is True + assert _mock_opensearch_client.call_args[1]["verify_certs"] is True + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_to_dict_basic_auth(self, _mock_opensearch_client): + document_store = OpenSearchDocumentStore(hosts="some hosts", http_auth=("user", "pw")) + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "embedding_dim": 768, + "hosts": "some hosts", + "index": "default", + "mappings": { + "dynamic_templates": [ + {"strings": {"mapping": {"type": "keyword"}, "match_mapping_type": "string"}} + ], + "properties": { + "content": {"type": "text"}, + "embedding": {"dimension": 768, "index": True, "type": "knn_vector"}, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": {"index.knn": True}, + "return_embedding": False, + "create_index": True, + "http_auth": ("user", "pw"), + "use_ssl": None, + "verify_certs": None, + "timeout": None, + }, + } + + @patch("haystack_integrations.document_stores.opensearch.document_store.OpenSearch") + def test_to_dict_aws_auth(self, _mock_opensearch_client, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("AWS_DEFAULT_REGION", "dummy-region") + document_store = OpenSearchDocumentStore(hosts="some hosts", http_auth=AWSAuth()) + res = document_store.to_dict() + assert res == { + "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", + "init_parameters": { + "embedding_dim": 768, + "hosts": "some hosts", + "index": "default", + "mappings": { + "dynamic_templates": [ + {"strings": {"mapping": {"type": "keyword"}, "match_mapping_type": "string"}} + ], + "properties": { + "content": {"type": "text"}, + "embedding": {"dimension": 768, "index": True, "type": "knn_vector"}, + }, + }, + "max_chunk_bytes": DEFAULT_MAX_CHUNK_BYTES, + "method": None, + "settings": {"index.knn": True}, + "return_embedding": False, + "create_index": True, + "http_auth": { + "type": "haystack_integrations.document_stores.opensearch.auth.AWSAuth", + "init_parameters": { + "aws_access_key_id": {"type": "env_var", "env_vars": ["AWS_ACCESS_KEY_ID"], "strict": False}, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": {"type": "env_var", "env_vars": ["AWS_SESSION_TOKEN"], "strict": False}, + "aws_region_name": {"type": "env_var", "env_vars": ["AWS_DEFAULT_REGION"], "strict": False}, + "aws_profile_name": {"type": "env_var", "env_vars": ["AWS_PROFILE"], "strict": False}, + "aws_service": "es", + }, + }, + "use_ssl": None, + "verify_certs": None, + "timeout": None, + }, + } + + @pytest.mark.integration class TestDocumentStore(DocumentStoreBaseTests): """ diff --git a/integrations/opensearch/tests/test_embedding_retriever.py b/integrations/opensearch/tests/test_embedding_retriever.py index 38be08698..e52a099c8 100644 --- a/integrations/opensearch/tests/test_embedding_retriever.py +++ b/integrations/opensearch/tests/test_embedding_retriever.py @@ -69,6 +69,10 @@ def test_to_dict(_mock_opensearch_client): }, "return_embedding": False, "create_index": True, + "http_auth": None, + "use_ssl": None, + "verify_certs": None, + "timeout": None, }, "type": "haystack_integrations.document_stores.opensearch.document_store.OpenSearchDocumentStore", },