diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 411ae6288..625112fb0 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -70,7 +70,7 @@ def __init__( api_key: Secret = Secret.from_env_var("AZURE_SEARCH_API_KEY", strict=False), # noqa: B008 azure_endpoint: Secret = Secret.from_env_var("AZURE_SEARCH_SERVICE_ENDPOINT", strict=False), # noqa: B008 index_name: str = "default", - embedding_dimension: int = 768, + embedding_dimension: Optional[int] = 768, metadata_fields: Optional[Dict[str, type]] = None, vector_search_configuration: VectorSearch = None, create_index: bool = True, @@ -104,6 +104,7 @@ def __init__( if not azure_endpoint: msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." raise ValueError(msg) + api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") self._client = None @@ -122,15 +123,16 @@ def __init__( @property def client(self) -> SearchClient: - if isinstance(self._azure_endpoint, Secret): - self._azure_endpoint = self._azure_endpoint.resolve_value() + # resolve secrets for authentication + resolved_endpoint = ( + self._azure_endpoint.resolve_value() if isinstance(self._azure_endpoint, Secret) else self._azure_endpoint + ) + resolved_key = self._api_key.resolve_value() if isinstance(self._api_key, Secret) else self._api_key - if isinstance(self._api_key, Secret): - self._api_key = self._api_key.resolve_value() - credential = AzureKeyCredential(self._api_key) if self._api_key else DefaultAzureCredential() + credential = AzureKeyCredential(resolved_key) if resolved_key else DefaultAzureCredential() try: if not self._index_client: - self._index_client = SearchIndexClient(self._azure_endpoint, credential, **self._kwargs) + self._index_client = SearchIndexClient(resolved_endpoint, credential, **self._kwargs) if not self.index_exists(self._index_name): # Create a new index if it does not exist logger.debug( @@ -202,7 +204,7 @@ def to_dict(self) -> Dict[str, Any]: create_index=self._create_index, embedding_dimension=self._embedding_dimension, metadata_fields=self._metadata_fields, - vector_search_configuration=self._vector_search_configuration, + vector_search_configuration=self._vector_search_configuration.as_dict(), **self._kwargs, ) @@ -219,6 +221,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "AzureAISearchDocumentStore": """ deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_endpoint"]) + if (vector_search_configuration := data["init_parameters"].get("vector_search_configuration")) is not None: + data["init_parameters"]["vector_search_configuration"] = VectorSearch.from_dict(vector_search_configuration) return default_from_dict(cls, data) def count_documents(self, **kwargs: Any) -> int: diff --git a/integrations/azure_ai_search/tests/test_document_store.py b/integrations/azure_ai_search/tests/test_document_store.py index 754a8c0d0..50733a8a4 100644 --- a/integrations/azure_ai_search/tests/test_document_store.py +++ b/integrations/azure_ai_search/tests/test_document_store.py @@ -36,7 +36,18 @@ def test_to_dict(monkeypatch): "embedding_dimension": 768, "metadata_fields": None, "create_index": True, - "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, }, } diff --git a/integrations/azure_ai_search/tests/test_embedding_retriever.py b/integrations/azure_ai_search/tests/test_embedding_retriever.py index 4b0c92b99..af4b21478 100644 --- a/integrations/azure_ai_search/tests/test_embedding_retriever.py +++ b/integrations/azure_ai_search/tests/test_embedding_retriever.py @@ -52,7 +52,18 @@ def test_to_dict(): "create_index": True, "embedding_dimension": 768, "metadata_fields": None, - "vector_search_configuration": DEFAULT_VECTOR_SEARCH, + "vector_search_configuration": { + "profiles": [ + {"name": "default-vector-config", "algorithm_configuration_name": "cosine-algorithm-config"} + ], + "algorithms": [ + { + "name": "cosine-algorithm-config", + "kind": "hnsw", + "parameters": {"m": 4, "ef_construction": 400, "ef_search": 500, "metric": "cosine"}, + } + ], + }, "hosts": "some fake host", }, },