From 23d1e2298ab620b421e063ec359f71ea6723b365 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Tue, 29 Oct 2024 13:54:12 +0100 Subject: [PATCH] Add a check for index schema --- .../azure_ai_search/example/document_store.py | 3 +- .../azure_ai_search/document_store.py | 39 ++++++++++--------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/integrations/azure_ai_search/example/document_store.py b/integrations/azure_ai_search/example/document_store.py index b3a87c64a..dfd3c8186 100644 --- a/integrations/azure_ai_search/example/document_store.py +++ b/integrations/azure_ai_search/example/document_store.py @@ -41,5 +41,4 @@ } results = document_store.filter_documents(filters) -for doc in results: - print(doc) +print(results) diff --git a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py index 625112fb0..777efe20d 100644 --- a/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py +++ b/integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py @@ -70,7 +70,7 @@ def __init__( api_key: Secret = Secret.from_env_var("AZURE_SEARCH_API_KEY", strict=False), # noqa: B008 azure_endpoint: Secret = Secret.from_env_var("AZURE_SEARCH_SERVICE_ENDPOINT", strict=False), # noqa: B008 index_name: str = "default", - embedding_dimension: Optional[int] = 768, + embedding_dimension: int = 768, metadata_fields: Optional[Dict[str, type]] = None, vector_search_configuration: VectorSearch = None, create_index: bool = True, @@ -100,12 +100,12 @@ def __init__( For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/) """ - azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") + azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None if not azure_endpoint: msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT." raise ValueError(msg) - api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") + api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None self._client = None self._index_client = None @@ -144,8 +144,10 @@ def client(self) -> SearchClient: msg = f"Failed to authenticate with Azure Search: {error}" raise AzureAISearchDocumentStoreConfigError(msg) from error - # Get the search client, if index client is initialized - if self._index_client: + if self._index_client: # type: ignore # self._index_client is not None (verified in the run method) + # Get the search client, if index client is initialized + index_fields = self._index_client.get_index(self._index_name).fields + self._index_fields = [field.name for field in index_fields] self._client = self._index_client.get_search_client(self._index_name) else: msg = "Search Index Client is not initialized." @@ -178,8 +180,6 @@ def create_index(self, index_name: str, **kwargs) -> None: index_name = self._index_name if self._metadata_fields: default_fields.extend(self._create_metadata_index_fields(self._metadata_fields)) - - self._index_fields = default_fields index = SearchIndex( name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs ) @@ -247,7 +247,7 @@ def _convert_input_document(documents: Document): if not isinstance(document_dict["id"], str): msg = f"Document id {document_dict['id']} is not a string, " raise Exception(msg) - index_document = self._default_index_mapping(document_dict) + index_document = self._convert_haystack_documents_to_azure(document_dict) return index_document @@ -324,17 +324,17 @@ def _convert_search_result_to_documents(self, azure_docs: List[Dict[str, Any]]) Converts Azure search results to Haystack Documents. """ documents = [] - for azure_doc in azure_docs: + for azure_doc in azure_docs: embedding = azure_doc.get("embedding") if embedding == self._dummy_vector: embedding = None - # Filter out meta fields + # Anything besides default fields (id, content, and embedding) is considered metadata meta = { key: value for key, value in azure_doc.items() - if key not in ["id", "content", "embedding"] and not key.startswith("@") and value is not None + if key not in ["id", "content", "embedding"] and key in self._index_fields and value is not None } # Create the document with meta only if it's non-empty @@ -375,14 +375,11 @@ def _get_raw_documents_by_id(self, document_ids: List[str]): logger.warning(f"Document with ID {doc_id} not found.") return azure_documents - def _default_index_mapping(self, document: Dict[str, Any]) -> Dict[str, Any]: + def _convert_haystack_documents_to_azure(self, document: Dict[str, Any]) -> Dict[str, Any]: """Map the document keys to fields of search index""" - keys_to_remove = ["dataframe", "blob", "sparse_embedding", "score"] - index_document = {k: v for k, v in document.items() if k not in keys_to_remove} - metadata = index_document.pop("meta", None) - for key, value in metadata.items(): - index_document[key] = value + # Because Azure Search does not allow dynamic fields, we only include fields that are part of the schema + index_document = {k: v for k, v in {**document, **document.get("meta", {})}.items() if k in self._index_fields} if index_document["embedding"] is None: index_document["embedding"] = self._dummy_vector @@ -405,11 +402,15 @@ def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str] metadata_field_mapping = {} for key, value_type in metadata.items(): + + # Azure Search index only allows field names starting with letters + field_name = next((key[i:] for i, char in enumerate(key) if char.isalpha()), key) + field_type = type_mapping.get(value_type) if not field_type: - error_message = f"Unsupported field type for key '{key}': {value_type}" + error_message = f"Unsupported field type for key '{field_name}': {value_type}" raise ValueError(error_message) - metadata_field_mapping[key] = field_type + metadata_field_mapping[field_name] = field_type return metadata_field_mapping