Skip to content

Commit

Permalink
Add a check for index schema
Browse files Browse the repository at this point in the history
  • Loading branch information
Amnah199 committed Oct 29, 2024
1 parent b9563ba commit 23d1e22
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 21 deletions.
3 changes: 1 addition & 2 deletions integrations/azure_ai_search/example/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,4 @@
}

results = document_store.filter_documents(filters)
for doc in results:
print(doc)
print(results)
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
api_key: Secret = Secret.from_env_var("AZURE_SEARCH_API_KEY", strict=False), # noqa: B008
azure_endpoint: Secret = Secret.from_env_var("AZURE_SEARCH_SERVICE_ENDPOINT", strict=False), # noqa: B008
index_name: str = "default",
embedding_dimension: Optional[int] = 768,
embedding_dimension: int = 768,
metadata_fields: Optional[Dict[str, type]] = None,
vector_search_configuration: VectorSearch = None,
create_index: bool = True,
Expand Down Expand Up @@ -100,12 +100,12 @@ def __init__(
For more information on parameters, see the [official Azure AI Search documentation](https://learn.microsoft.com/en-us/azure/search/)
"""

azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT")
azure_endpoint = azure_endpoint or os.environ.get("AZURE_SEARCH_SERVICE_ENDPOINT") or None
if not azure_endpoint:
msg = "Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT."
raise ValueError(msg)

api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY")
api_key = api_key or os.environ.get("AZURE_SEARCH_API_KEY") or None

self._client = None
self._index_client = None
Expand Down Expand Up @@ -144,8 +144,10 @@ def client(self) -> SearchClient:
msg = f"Failed to authenticate with Azure Search: {error}"
raise AzureAISearchDocumentStoreConfigError(msg) from error

# Get the search client, if index client is initialized
if self._index_client:
if self._index_client: # type: ignore # self._index_client is not None (verified in the run method)
# Get the search client, if index client is initialized
index_fields = self._index_client.get_index(self._index_name).fields
self._index_fields = [field.name for field in index_fields]
self._client = self._index_client.get_search_client(self._index_name)
else:
msg = "Search Index Client is not initialized."
Expand Down Expand Up @@ -178,8 +180,6 @@ def create_index(self, index_name: str, **kwargs) -> None:
index_name = self._index_name
if self._metadata_fields:
default_fields.extend(self._create_metadata_index_fields(self._metadata_fields))

self._index_fields = default_fields
index = SearchIndex(
name=index_name, fields=default_fields, vector_search=self._vector_search_configuration, **kwargs
)
Expand Down Expand Up @@ -247,7 +247,7 @@ def _convert_input_document(documents: Document):
if not isinstance(document_dict["id"], str):
msg = f"Document id {document_dict['id']} is not a string, "
raise Exception(msg)
index_document = self._default_index_mapping(document_dict)
index_document = self._convert_haystack_documents_to_azure(document_dict)

return index_document

Expand Down Expand Up @@ -324,17 +324,17 @@ def _convert_search_result_to_documents(self, azure_docs: List[Dict[str, Any]])
Converts Azure search results to Haystack Documents.
"""
documents = []
for azure_doc in azure_docs:

for azure_doc in azure_docs:
embedding = azure_doc.get("embedding")
if embedding == self._dummy_vector:
embedding = None

# Filter out meta fields
# Anything besides default fields (id, content, and embedding) is considered metadata
meta = {
key: value
for key, value in azure_doc.items()
if key not in ["id", "content", "embedding"] and not key.startswith("@") and value is not None
if key not in ["id", "content", "embedding"] and key in self._index_fields and value is not None
}

# Create the document with meta only if it's non-empty
Expand Down Expand Up @@ -375,14 +375,11 @@ def _get_raw_documents_by_id(self, document_ids: List[str]):
logger.warning(f"Document with ID {doc_id} not found.")
return azure_documents

def _default_index_mapping(self, document: Dict[str, Any]) -> Dict[str, Any]:
def _convert_haystack_documents_to_azure(self, document: Dict[str, Any]) -> Dict[str, Any]:
"""Map the document keys to fields of search index"""

keys_to_remove = ["dataframe", "blob", "sparse_embedding", "score"]
index_document = {k: v for k, v in document.items() if k not in keys_to_remove}
metadata = index_document.pop("meta", None)
for key, value in metadata.items():
index_document[key] = value
# Because Azure Search does not allow dynamic fields, we only include fields that are part of the schema
index_document = {k: v for k, v in {**document, **document.get("meta", {})}.items() if k in self._index_fields}
if index_document["embedding"] is None:
index_document["embedding"] = self._dummy_vector

Expand All @@ -405,11 +402,15 @@ def _map_metadata_field_types(self, metadata: Dict[str, type]) -> Dict[str, str]
metadata_field_mapping = {}

for key, value_type in metadata.items():

# Azure Search index only allows field names starting with letters
field_name = next((key[i:] for i, char in enumerate(key) if char.isalpha()), key)

field_type = type_mapping.get(value_type)
if not field_type:
error_message = f"Unsupported field type for key '{key}': {value_type}"
error_message = f"Unsupported field type for key '{field_name}': {value_type}"
raise ValueError(error_message)
metadata_field_mapping[key] = field_type
metadata_field_mapping[field_name] = field_type

return metadata_field_mapping

Expand Down

0 comments on commit 23d1e22

Please sign in to comment.