Skip to content

Commit

Permalink
fix: add support for custom mapping in ElasticsearchDocumentStore (#721)
Browse files Browse the repository at this point in the history
* Add custom mapping in ElasticsearchDocumentStore init

* Update docstrings and add test

* Fix linting

* Fix retrievers tests

---------

Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
AnushreeBannadabhavi and silvanocerza authored May 24, 2024
1 parent 95daee3 commit 2667d6b
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self,
*,
hosts: Optional[Hosts] = None,
custom_mapping: Optional[Dict[str, Any]] = None,
index: str = "default",
embedding_similarity_function: Literal["cosine", "dot_product", "l2_norm", "max_inner_product"] = "cosine",
**kwargs,
Expand All @@ -82,6 +83,7 @@ def __init__(
[reference](https://elasticsearch-py.readthedocs.io/en/stable/api.html#module-elasticsearch)
:param hosts: List of hosts running the Elasticsearch client.
:param custom_mapping: Custom mapping for the index. If not provided, a default mapping will be used.
:param index: Name of index in Elasticsearch.
:param embedding_similarity_function: The similarity function used to compare Documents embeddings.
This parameter only takes effect if the index does not yet exist and is created.
Expand All @@ -98,29 +100,37 @@ def __init__(
)
self._index = index
self._embedding_similarity_function = embedding_similarity_function
self._custom_mapping = custom_mapping
self._kwargs = kwargs

# Check client connection, this will raise if not connected
self._client.info()

# configure mapping for the embedding field
mappings = {
"properties": {
"embedding": {"type": "dense_vector", "index": True, "similarity": embedding_similarity_function},
"content": {"type": "text"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {
"type": "keyword",
},
if self._custom_mapping and not isinstance(self._custom_mapping, Dict):
msg = "custom_mapping must be a dictionary"
raise ValueError(msg)

if self._custom_mapping:
mappings = self._custom_mapping
else:
# Configure mapping for the embedding field if none is provided
mappings = {
"properties": {
"embedding": {"type": "dense_vector", "index": True, "similarity": embedding_similarity_function},
"content": {"type": "text"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {
"type": "keyword",
},
}
}
}
],
}
],
}

# Create the index if it doesn't exist
if not self._client.indices.exists(index=index):
Expand All @@ -139,6 +149,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
hosts=self._hosts,
custom_mapping=self._custom_mapping,
index=self._index,
embedding_similarity_function=self._embedding_similarity_function,
**self._kwargs,
Expand Down
1 change: 1 addition & 0 deletions integrations/elasticsearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_to_dict(_mock_elasticsearch_client):
"document_store": {
"init_parameters": {
"hosts": "some fake host",
"custom_mapping": None,
"index": "default",
"embedding_similarity_function": "cosine",
},
Expand Down
35 changes: 34 additions & 1 deletion integrations/elasticsearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import random
from typing import List
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest
from elasticsearch.exceptions import BadRequestError # type: ignore[import-not-found]
Expand All @@ -23,6 +23,7 @@ def test_to_dict(_mock_elasticsearch_client):
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
"init_parameters": {
"hosts": "some hosts",
"custom_mapping": None,
"index": "default",
"embedding_similarity_function": "cosine",
},
Expand All @@ -35,13 +36,15 @@ def test_from_dict(_mock_elasticsearch_client):
"type": "haystack_integrations.document_stores.elasticsearch.document_store.ElasticsearchDocumentStore",
"init_parameters": {
"hosts": "some hosts",
"custom_mapping": None,
"index": "default",
"embedding_similarity_function": "cosine",
},
}
document_store = ElasticsearchDocumentStore.from_dict(data)
assert document_store._hosts == "some hosts"
assert document_store._index == "default"
assert document_store._custom_mapping is None
assert document_store._embedding_similarity_function == "cosine"


Expand Down Expand Up @@ -280,3 +283,33 @@ def test_write_documents_different_embedding_sizes_fail(self, document_store: El

with pytest.raises(DocumentStoreError):
document_store.write_documents(docs)

@patch("haystack_integrations.document_stores.elasticsearch.document_store.Elasticsearch")
def test_init_with_custom_mapping(self, mock_elasticsearch):
custom_mapping = {
"properties": {
"embedding": {"type": "dense_vector", "index": True, "similarity": "dot_product"},
"content": {"type": "text"},
},
"dynamic_templates": [
{
"strings": {
"path_match": "*",
"match_mapping_type": "string",
"mapping": {
"type": "keyword",
},
}
}
],
}
mock_client = Mock(
indices=Mock(create=Mock(), exists=Mock(return_value=False)),
)
mock_elasticsearch.return_value = mock_client

ElasticsearchDocumentStore(hosts="some hosts", custom_mapping=custom_mapping)
mock_client.indices.create.assert_called_once_with(
index="default",
mappings=custom_mapping,
)
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_to_dict(_mock_elasticsearch_client):
"document_store": {
"init_parameters": {
"hosts": "some fake host",
"custom_mapping": None,
"index": "default",
"embedding_similarity_function": "cosine",
},
Expand Down

0 comments on commit 2667d6b

Please sign in to comment.