From 2667d6bc6077a8db176b257bf178a1df45153214 Mon Sep 17 00:00:00 2001 From: Anushree Bannadabhavi Date: Fri, 24 May 2024 05:40:41 -0400 Subject: [PATCH] fix: add support for custom mapping in ElasticsearchDocumentStore (#721) * Add custom mapping in ElasticsearchDocumentStore init * Update docstrings and add test * Fix linting * Fix retrievers tests --------- Co-authored-by: Silvano Cerza --- .../elasticsearch/document_store.py | 45 ++++++++++++------- .../tests/test_bm25_retriever.py | 1 + .../tests/test_document_store.py | 35 ++++++++++++++- .../tests/test_embedding_retriever.py | 1 + 4 files changed, 64 insertions(+), 18 deletions(-) diff --git a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py index 12407a3dd..75af3df56 100644 --- a/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py +++ b/integrations/elasticsearch/src/haystack_integrations/document_stores/elasticsearch/document_store.py @@ -63,6 +63,7 @@ def __init__( self, *, hosts: Optional[Hosts] = None, + custom_mapping: Optional[Dict[str, Any]] = None, index: str = "default", embedding_similarity_function: Literal["cosine", "dot_product", "l2_norm", "max_inner_product"] = "cosine", **kwargs, @@ -82,6 +83,7 @@ def __init__( [reference](https://elasticsearch-py.readthedocs.io/en/stable/api.html#module-elasticsearch) :param hosts: List of hosts running the Elasticsearch client. + :param custom_mapping: Custom mapping for the index. If not provided, a default mapping will be used. :param index: Name of index in Elasticsearch. :param embedding_similarity_function: The similarity function used to compare Documents embeddings. This parameter only takes effect if the index does not yet exist and is created. @@ -98,29 +100,37 @@ def __init__( ) self._index = index self._embedding_similarity_function = embedding_similarity_function + self._custom_mapping = custom_mapping self._kwargs = kwargs # Check client connection, this will raise if not connected self._client.info() - # configure mapping for the embedding field - mappings = { - "properties": { - "embedding": {"type": "dense_vector", "index": True, "similarity": embedding_similarity_function}, - "content": {"type": "text"}, - }, - "dynamic_templates": [ - { - "strings": { - "path_match": "*", - "match_mapping_type": "string", - "mapping": { - "type": "keyword", - }, + if self._custom_mapping and not isinstance(self._custom_mapping, Dict): + msg = "custom_mapping must be a dictionary" + raise ValueError(msg) + + if self._custom_mapping: + mappings = self._custom_mapping + else: + # Configure mapping for the embedding field if none is provided + mappings = { + "properties": { + "embedding": {"type": "dense_vector", "index": True, "similarity": embedding_similarity_function}, + "content": {"type": "text"}, + }, + "dynamic_templates": [ + { + "strings": { + "path_match": "*", + "match_mapping_type": "string", + "mapping": { + "type": "keyword", + }, + } } - } - ], - } + ], + } # Create the index if it doesn't exist if not self._client.indices.exists(index=index): @@ -139,6 +149,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, hosts=self._hosts, + custom_mapping=self._custom_mapping, index=self._index, embedding_similarity_function=self._embedding_similarity_function, **self._kwargs, diff --git a/integrations/elasticsearch/tests/test_bm25_retriever.py b/integrations/elasticsearch/tests/test_bm25_retriever.py index dd88cd0a8..eebb39183 100644 --- a/integrations/elasticsearch/tests/test_bm25_retriever.py +++ b/integrations/elasticsearch/tests/test_bm25_retriever.py @@ -28,6 +28,7 @@ def test_to_dict(_mock_elasticsearch_client): "document_store": { "init_parameters": { "hosts": "some fake host", + "custom_mapping": None, "index": "default", "embedding_similarity_function": "cosine", }, diff --git a/integrations/elasticsearch/tests/test_document_store.py b/integrations/elasticsearch/tests/test_document_store.py index 308486a78..da33dfc91 100644 --- a/integrations/elasticsearch/tests/test_document_store.py +++ b/integrations/elasticsearch/tests/test_document_store.py @@ -4,7 +4,7 @@ import random from typing import List -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from elasticsearch.exceptions import BadRequestError # type: ignore[import-not-found] @@ -23,6 +23,7 @@ def test_to_dict(_mock_elasticsearch_client): "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", "init_parameters": { "hosts": "some hosts", + "custom_mapping": None, "index": "default", "embedding_similarity_function": "cosine", }, @@ -35,6 +36,7 @@ def test_from_dict(_mock_elasticsearch_client): "type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore", "init_parameters": { "hosts": "some hosts", + "custom_mapping": None, "index": "default", "embedding_similarity_function": "cosine", }, @@ -42,6 +44,7 @@ def test_from_dict(_mock_elasticsearch_client): document_store = ElasticsearchDocumentStore.from_dict(data) assert document_store._hosts == "some hosts" assert document_store._index == "default" + assert document_store._custom_mapping is None assert document_store._embedding_similarity_function == "cosine" @@ -280,3 +283,33 @@ def test_write_documents_different_embedding_sizes_fail(self, document_store: El with pytest.raises(DocumentStoreError): document_store.write_documents(docs) + + @patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch") + def test_init_with_custom_mapping(self, mock_elasticsearch): + custom_mapping = { + "properties": { + "embedding": {"type": "dense_vector", "index": True, "similarity": "dot_product"}, + "content": {"type": "text"}, + }, + "dynamic_templates": [ + { + "strings": { + "path_match": "*", + "match_mapping_type": "string", + "mapping": { + "type": "keyword", + }, + } + } + ], + } + mock_client = Mock( + indices=Mock(create=Mock(), exists=Mock(return_value=False)), + ) + mock_elasticsearch.return_value = mock_client + + ElasticsearchDocumentStore(hosts="some hosts", custom_mapping=custom_mapping) + mock_client.indices.create.assert_called_once_with( + index="default", + mappings=custom_mapping, + ) diff --git a/integrations/elasticsearch/tests/test_embedding_retriever.py b/integrations/elasticsearch/tests/test_embedding_retriever.py index f632c3655..ab63799e4 100644 --- a/integrations/elasticsearch/tests/test_embedding_retriever.py +++ b/integrations/elasticsearch/tests/test_embedding_retriever.py @@ -29,6 +29,7 @@ def test_to_dict(_mock_elasticsearch_client): "document_store": { "init_parameters": { "hosts": "some fake host", + "custom_mapping": None, "index": "default", "embedding_similarity_function": "cosine", },