From edd6de9b4a9ece45514208a0dfde854607023194 Mon Sep 17 00:00:00 2001 From: Lukasz Karlowski <116562347+karllu3@users.noreply.github.com> Date: Mon, 27 May 2024 16:05:19 +0200 Subject: [PATCH] feat: Add elastic store (#34) * add elastic search store * feat add vector search store * update documentation --- docs/how-to/custom_views_code.py | 1 - docs/how-to/pandas_views_code.py | 3 - docs/how-to/use_elastic_store.md | 103 +++++++++++++++++ docs/how-to/use_elastic_vector_store_code.py | 102 ++++++++++++++++ docs/how-to/use_elasticsearch_store_code.py | 106 +++++++++++++++++ docs/quickstart/quickstart2_code.py | 11 +- docs/quickstart/quickstart3_code.py | 22 ++-- docs/quickstart/quickstart_code.py | 1 - .../similarity/similarity_store/elastic.md | 8 ++ docs/roadmap.md | 2 +- mkdocs.yml | 2 + setup.cfg | 2 + src/dbally/similarity/__init__.py | 13 +++ .../similarity/elastic_vector_search.py | 102 ++++++++++++++++ src/dbally/similarity/elasticsearch_store.py | 109 ++++++++++++++++++ 15 files changed, 569 insertions(+), 18 deletions(-) create mode 100644 docs/how-to/use_elastic_store.md create mode 100644 docs/how-to/use_elastic_vector_store_code.py create mode 100644 docs/how-to/use_elasticsearch_store_code.py create mode 100644 docs/reference/similarity/similarity_store/elastic.md create mode 100644 src/dbally/similarity/elastic_vector_search.py create mode 100644 src/dbally/similarity/elasticsearch_store.py diff --git a/docs/how-to/custom_views_code.py b/docs/how-to/custom_views_code.py index 8e033783..c96a0934 100644 --- a/docs/how-to/custom_views_code.py +++ b/docs/how-to/custom_views_code.py @@ -1,6 +1,5 @@ # pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring, missing-class-docstring, missing-raises-doc import dbally -import os import asyncio from dataclasses import dataclass from typing import Iterable, Callable, Any diff --git a/docs/how-to/pandas_views_code.py b/docs/how-to/pandas_views_code.py index 8b6ce17b..f71a973d 100644 --- a/docs/how-to/pandas_views_code.py +++ b/docs/how-to/pandas_views_code.py @@ -1,9 +1,6 @@ # pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring, missing-class-docstring, missing-raises-doc import dbally -import os import asyncio -from dataclasses import dataclass -from typing import Iterable, Callable, Any import pandas as pd from dbally import decorators, DataFrameBaseView diff --git a/docs/how-to/use_elastic_store.md b/docs/how-to/use_elastic_store.md new file mode 100644 index 00000000..a79bb743 --- /dev/null +++ b/docs/how-to/use_elastic_store.md @@ -0,0 +1,103 @@ +# How-To Use Elastic to Store Similarity Index + +[ElasticStore](https://www.elastic.co/guide/en/elasticsearch/reference/current/mapping-store.html) can be used as a store in SimilarityIndex. In this guide, we will show you how to execute a similarity search using Elasticsearch. +In the example, the Elasticsearch engine is provided by the official Docker image. There are two approaches available to perform similarity searches: Elastic Search Store and Elastic Vector Search. +Elastic Search Store uses embeddings and kNN search to find similarities, while Elastic Vector Search, which performs semantic search, uses the ELSER (Elastic Learned Sparse EncodeR) model to encode and search the data. + + +## Prerequisites + +[Download and deploy the Elasticsearch Docker image](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html). Please note that for Elastic Vector Search, the Elasticsearch Docker container requires at least 8GB of RAM and +[license activation](https://www.elastic.co/guide/en/kibana/current/managing-licenses.html) to use Machine Learning capabilities. + + +```commandline +docker network create elastic +docker pull docker.elastic.co/elasticsearch/elasticsearch:8.13.4 +docker run --name es01 --net elastic -p 9200:9200 -it -m 2GB docker.elastic.co/elasticsearch/elasticsearch:8.13.4 +``` + +Copy the generated elastic password and enrollment token. These credentials are only shown when you start Elasticsearch for the first time once. You can regenerate the credentials using the following commands. +```commandline +docker cp es01:/usr/share/elasticsearch/config/certs/http_ca.crt . +curl --cacert http_ca.crt -u elastic:$ELASTIC_PASSWORD https://localhost:9200 +``` + +To manage elasticsearch engine create Kibana container. +```commandline +docker run --name kib01 --net elastic -p 5601:5601 docker.elastic.co/kibana/kibana:8.13.4 +``` + +By default, the Kibana management dashboard is deployed at [link](http://localhost:5601/) + + +For vector search, it is necessary to enroll in an [appropriate subscription level](https://www.elastic.co/subscriptions) or trial version that supports machine learning. +Additionally, the [ELSER model must be downloaded](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html), which can be done through Kibana. Instructions can be found in the hosted Kibana instance under tabs: +
downloading and deploying model - **Analytics -> Machine Learning -> Trained Model**, +
vector search configuration - **Search -> Elastic Search -> Vector Search.** + + +* Install elasticsearch extension +```commandline +pip install dbally[elasticsearch] +``` + +## Implementing a SimilarityIndex + +To use similarity search it is required to define data fetcher and data store. + +### Data fetcher + +```python +class DummyCountryFetcher(SimilarityFetcher): + async def fetch(self): + return ["United States", "Canada", "Mexico"] +``` + +### Data store +Elastic store similarity search works on embeddings. For create embeddings the embedding client is passed as an argument. +You can use [one of dbally embedding clients][dbally.embeddings.EmbeddingClient], such as [LiteLLMEmbeddingClient][dbally.embeddings.LiteLLMEmbeddingClient]. + +```python +from dbally.embeddings.litellm import LiteLLMEmbeddingClient + +embedding_client=LiteLLMEmbeddingClient(api_key="your-api-key") +``` + +to define your [`ElasticsearchStore`][dbally.similarity.ElasticsearchStore]. + +```python +from dbally.similarity.elasticsearch_store import ElasticsearchStore + +data_store = ElasticsearchStore( + index_name="country_similarity", + host="https://localhost:9200", + ca_cert_path="path_to_cert/http_ca.crt", + http_user="elastic", + http_password="password", + embedding_client=embedding_client, + ), + +``` + +After this setup, you can initialize the SimilarityIndex + +```python +from dbally.similarity import SimilarityIndex + +country_similarity = SimilarityIndex( + fetcher=DummyCountryFetcher(), + store=data_store +) +``` + +and [update it and find the closest matches in the same way as in built-in similarity indices](use_custom_similarity_store.md/#using-the-similar) + +You can then use this store to create a similarity index that maps user input to the closest matching value. +To use Elastic Vector search download and deploy [ELSER v2](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html#elser-v2) model and create [ingest pipeline](https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-elser.html#elasticsearch-ingest-pipeline). +Now you can use this index to map user input to the closest matching value. For example, a user may type 'United States' and our index would return 'USA'. + +## Links +* [Similarity Indexes](use_custom_similarity_store.md) +* [Example Elastic Search Store](use_elasticsearch_store_code.py) +* [Example Elastic Vector Search](use_elastic_vector_store_code.py) \ No newline at end of file diff --git a/docs/how-to/use_elastic_vector_store_code.py b/docs/how-to/use_elastic_vector_store_code.py new file mode 100644 index 00000000..c325fa2c --- /dev/null +++ b/docs/how-to/use_elastic_vector_store_code.py @@ -0,0 +1,102 @@ +# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring +import os +import asyncio +from typing_extensions import Annotated + +import asyncclick as click +from dotenv import load_dotenv +import sqlalchemy +from sqlalchemy import create_engine +from sqlalchemy.ext.automap import automap_base + +import dbally +from dbally import decorators, SqlAlchemyBaseView +from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler +from dbally.llms.litellm import LiteLLM +from dbally.similarity import SimpleSqlAlchemyFetcher, SimilarityIndex +from dbally.similarity.elastic_vector_search import ElasticVectorStore + +load_dotenv() +engine = create_engine("sqlite:///candidates.db") + + +Base = automap_base() +Base.prepare(autoload_with=engine) + +Candidate = Base.classes.candidates + +country_similarity = SimilarityIndex( + fetcher=SimpleSqlAlchemyFetcher( + engine, + table=Candidate, + column=Candidate.country, + ), + store=ElasticVectorStore( + index_name="country_vector_similarity", + host=os.environ["ELASTIC_STORE_CONNECTION_STRING"], + ca_cert_path=os.environ["ELASTIC_CERT_PATH"], + http_user=os.environ["ELASTIC_AUTH_USER"], + http_password=os.environ["ELASTIC_USER_PASSWORD"], + ), +) + + +class CandidateView(SqlAlchemyBaseView): + """ + A view for retrieving candidates from the database. + """ + + def get_select(self) -> sqlalchemy.Select: + """ + Creates the initial SqlAlchemy select object, which will be used to build the query. + """ + return sqlalchemy.select(Candidate) + + @decorators.view_filter() + def at_least_experience(self, years: int) -> sqlalchemy.ColumnElement: + """ + Filters candidates with at least `years` of experience. + """ + return Candidate.years_of_experience >= years + + @decorators.view_filter() + def senior_data_scientist_position(self) -> sqlalchemy.ColumnElement: + """ + Filters candidates that can be considered for a senior data scientist position. + """ + return sqlalchemy.and_( + Candidate.position.in_(["Data Scientist", "Machine Learning Engineer", "Data Engineer"]), + Candidate.years_of_experience >= 3, + ) + + @decorators.view_filter() + def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchemy.ColumnElement: + """ + Filters candidates from a specific country. + """ + return Candidate.country == country + + +@click.command() +@click.argument("country", type=str, default="United States") +@click.argument("years_of_experience", type=str, default="2") +async def main(country="United States", years_of_experience="2"): + await country_similarity.update() + + llm = LiteLLM(model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"]) + collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) + collection.add(CandidateView, lambda: CandidateView(engine)) + + result = await collection.ask( + f"Find someone from the {country} with more than {years_of_experience} years of experience." + ) + + print(f"The generated SQL query is: {result.context.get('sql')}") + print() + print(f"Retrieved {len(result.results)} candidates:") + for candidate in result.results: + print(candidate) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/how-to/use_elasticsearch_store_code.py b/docs/how-to/use_elasticsearch_store_code.py new file mode 100644 index 00000000..39258c44 --- /dev/null +++ b/docs/how-to/use_elasticsearch_store_code.py @@ -0,0 +1,106 @@ +# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring +import os +import asyncio +from typing_extensions import Annotated + +import asyncclick as click +from dotenv import load_dotenv +import sqlalchemy +from sqlalchemy import create_engine +from sqlalchemy.ext.automap import automap_base + +import dbally +from dbally import decorators, SqlAlchemyBaseView +from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler +from dbally.similarity import SimpleSqlAlchemyFetcher, SimilarityIndex +from dbally.embeddings.litellm import LiteLLMEmbeddingClient +from dbally.llms.litellm import LiteLLM +from dbally.similarity.elasticsearch_store import ElasticsearchStore + +load_dotenv() +engine = create_engine("sqlite:///candidates.db") + + +Base = automap_base() +Base.prepare(autoload_with=engine) + +Candidate = Base.classes.candidates + +country_similarity = SimilarityIndex( + fetcher=SimpleSqlAlchemyFetcher( + engine, + table=Candidate, + column=Candidate.country, + ), + store=ElasticsearchStore( + index_name="country_similarity", + host=os.environ["ELASTIC_STORE_CONNECTION_STRING"], + ca_cert_path=os.environ["ELASTIC_CERT_PATH"], + http_user=os.environ["ELASTIC_AUTH_USER"], + http_password=os.environ["ELASTIC_USER_PASSWORD"], + embedding_client=LiteLLMEmbeddingClient( + api_key=os.environ["OPENAI_API_KEY"], + ), + ), +) + + +class CandidateView(SqlAlchemyBaseView): + """ + A view for retrieving candidates from the database. + """ + + def get_select(self) -> sqlalchemy.Select: + """ + Creates the initial SqlAlchemy select object, which will be used to build the query. + """ + return sqlalchemy.select(Candidate) + + @decorators.view_filter() + def at_least_experience(self, years: int) -> sqlalchemy.ColumnElement: + """ + Filters candidates with at least `years` of experience. + """ + return Candidate.years_of_experience >= years + + @decorators.view_filter() + def senior_data_scientist_position(self) -> sqlalchemy.ColumnElement: + """ + Filters candidates that can be considered for a senior data scientist position. + """ + return sqlalchemy.and_( + Candidate.position.in_(["Data Scientist", "Machine Learning Engineer", "Data Engineer"]), + Candidate.years_of_experience >= 3, + ) + + @decorators.view_filter() + def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchemy.ColumnElement: + """ + Filters candidates from a specific country. + """ + return Candidate.country == country + + +@click.command() +@click.argument("country", type=str, default="United States") +@click.argument("years_of_experience", type=str, default="2") +async def main(country="United States", years_of_experience="2"): + await country_similarity.update() + await country_similarity.update() + llm = LiteLLM(model_name="gpt-3.5-turbo", api_key=os.environ["OPENAI_API_KEY"]) + collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) + collection.add(CandidateView, lambda: CandidateView(engine)) + + result = await collection.ask( + f"Find someone from the {country} with more than {years_of_experience} years of experience." + ) + + print(f"The generated SQL query is: {result.context.get('sql')}") + print() + print(f"Retrieved {len(result.results)} candidates:") + for candidate in result.results: + print(candidate) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index d330470a..eab1e38e 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -1,20 +1,22 @@ # pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring -import dbally import os import asyncio from typing_extensions import Annotated +from dotenv import load_dotenv import sqlalchemy from sqlalchemy import create_engine from sqlalchemy.ext.automap import automap_base +import dbally from dbally import decorators, SqlAlchemyBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.llms.litellm import LiteLLM -engine = create_engine('sqlite:///candidates.db') +load_dotenv() +engine = create_engine("sqlite:///candidates.db") Base = automap_base() Base.prepare(autoload_with=engine) @@ -22,7 +24,7 @@ Candidate = Base.classes.candidates country_similarity = SimilarityIndex( - fetcher=SimpleSqlAlchemyFetcher( + fetcher=SimpleSqlAlchemyFetcher( engine, table=Candidate, column=Candidate.country, @@ -37,10 +39,12 @@ ), ) + class CandidateView(SqlAlchemyBaseView): """ A view for retrieving candidates from the database. """ + def get_select(self) -> sqlalchemy.Select: """ Creates the initial SqlAlchemy select object, which will be used to build the query. @@ -71,6 +75,7 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem """ return Candidate.country == country + async def main(): await country_similarity.update() diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index f0e385b1..2ef8f1df 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -14,7 +14,7 @@ from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.llms.litellm import LiteLLM -engine = create_engine('sqlite:///candidates.db') +engine = create_engine("sqlite:///candidates.db") Base = automap_base() Base.prepare(autoload_with=engine) @@ -23,7 +23,7 @@ country_similarity = SimilarityIndex( - fetcher=SimpleSqlAlchemyFetcher( + fetcher=SimpleSqlAlchemyFetcher( engine, table=Candidate, column=Candidate.country, @@ -38,10 +38,12 @@ ), ) + class CandidateView(SqlAlchemyBaseView): """ A view for retrieving candidates from the database. """ + def get_select(self) -> sqlalchemy.Select: """ Creates the initial SqlAlchemy select object, which will be used to build the query. @@ -73,13 +75,15 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem return Candidate.country == country -jobs_data = pd.DataFrame.from_records([ - {"title": "Data Scientist", "company": "Company A", "location": "New York", "salary": 100000}, - {"title": "Data Engineer", "company": "Company B", "location": "San Francisco", "salary": 120000}, - {"title": "Machine Learning Engineer", "company": "Company C", "location": "Berlin", "salary": 90000}, - {"title": "Data Scientist", "company": "Company D", "location": "London", "salary": 110000}, - {"title": "Data Scientist", "company": "Company E", "location": "Warsaw", "salary": 80000}, -]) +jobs_data = pd.DataFrame.from_records( + [ + {"title": "Data Scientist", "company": "Company A", "location": "New York", "salary": 100000}, + {"title": "Data Engineer", "company": "Company B", "location": "San Francisco", "salary": 120000}, + {"title": "Machine Learning Engineer", "company": "Company C", "location": "Berlin", "salary": 90000}, + {"title": "Data Scientist", "company": "Company D", "location": "London", "salary": 110000}, + {"title": "Data Scientist", "company": "Company E", "location": "Warsaw", "salary": 80000}, + ] +) class JobView(DataFrameBaseView): diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 09c5924e..8cdbd446 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -1,6 +1,5 @@ # pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring import dbally -import os import asyncio import sqlalchemy diff --git a/docs/reference/similarity/similarity_store/elastic.md b/docs/reference/similarity/similarity_store/elastic.md new file mode 100644 index 00000000..52604f96 --- /dev/null +++ b/docs/reference/similarity/similarity_store/elastic.md @@ -0,0 +1,8 @@ +#ElasticStore + +!!! info + To see example of using ElasticStore visit: [How-To: Use Elastic Search to Store Similarity Index](../../../how-to/use_elastic_store.md) + + +::: dbally.similarity.ElasticsearchStore +::: dbally.similarity.ElasticVectorStore \ No newline at end of file diff --git a/docs/roadmap.md b/docs/roadmap.md index f590f35a..56ead50b 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -50,7 +50,7 @@ And many more, the full list can be found in the [LiteLLM documentation](https:/ - [x] FAISS - [x] Chroma -- [ ] Elasticsearch +- [x] Elasticsearch - [ ] Weaviate - [ ] Qdrant - [ ] VertexAI Vector Search diff --git a/mkdocs.yml b/mkdocs.yml index 5d42aa5e..637842cf 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -26,6 +26,7 @@ nav: - Using similarity indexes: - how-to/use_custom_similarity_fetcher.md - how-to/use_chromadb_store.md + - how-to/use_elastic_store.md - how-to/use_custom_similarity_store.md - how-to/update_similarity_indexes.md - how-to/log_runs_to_langsmith.md @@ -59,6 +60,7 @@ nav: - reference/similarity/similarity_store/index.md - reference/similarity/similarity_store/faiss.md - reference/similarity/similarity_store/chroma.md + - reference/similarity/similarity_store/elastic.md - Fetcher: - reference/similarity/similarity_fetcher/index.md - reference/similarity/similarity_fetcher/sqlalchemy.md diff --git a/setup.cfg b/setup.cfg index 62d3bef8..34e18232 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,6 +63,8 @@ benchmark = pydantic-core~=2.16.2 pydantic-settings~=2.0.3 psycopg2-binary~=2.9.9 +elasticsearch = + elasticsearch==8.13.1 [options.packages.find] where = src diff --git a/src/dbally/similarity/__init__.py b/src/dbally/similarity/__init__.py index 5ce05035..767fc6a1 100644 --- a/src/dbally/similarity/__init__.py +++ b/src/dbally/similarity/__init__.py @@ -4,6 +4,7 @@ from .store import SimilarityStore # depends on the faiss package + try: from .faiss_store import FaissStore except ImportError: @@ -14,6 +15,16 @@ except ImportError: pass +try: + from .elasticsearch_store import ElasticsearchStore +except ImportError: + pass + +try: + from .elastic_vector_search import ElasticVectorStore +except ImportError: + pass + __all__ = [ "AbstractSimilarityIndex", "SimilarityIndex", @@ -22,5 +33,7 @@ "SimilarityStore", "SimilarityFetcher", "FaissStore", + "ElasticsearchStore", + "ElasticVectorStore", "ChromadbStore", ] diff --git a/src/dbally/similarity/elastic_vector_search.py b/src/dbally/similarity/elastic_vector_search.py new file mode 100644 index 00000000..3c2f26d8 --- /dev/null +++ b/src/dbally/similarity/elastic_vector_search.py @@ -0,0 +1,102 @@ +from hashlib import sha256 +from typing import List, Optional + +from elasticsearch import AsyncElasticsearch +from elasticsearch.helpers import async_bulk + +from dbally.similarity.store import SimilarityStore + + +class ElasticVectorStore(SimilarityStore): + """ + The Elastic Vector Store class uses the ELSER (Elastic Learned Sparse EncodeR) model on Elasticsearch to + store and search data. + """ + + def __init__( + self, + index_name: str, + host: str, + http_user: str, + http_password: str, + ca_cert_path: str, + ) -> None: + """ + Initializes the Elastic Vector Store. + + Args: + index_name: The name of the index. + host: The host address of the Elasticsearch instance. + http_user: The username used for HTTP authentication. + http_password: The password used for HTTP authentication. + ca_cert_path: The path to the CA certificate for SSL/TLS verification. + """ + super().__init__() + self.client = AsyncElasticsearch( + hosts=host, + http_auth=(http_user, http_password), + ca_certs=ca_cert_path, + ) + self.index_name = index_name + + async def store(self, data: List[str]) -> None: + """ + Stores the given data in an Elasticsearch store. + + Args: + data: The data to store in the Elasticsearch index. + """ + mappings = { + "properties": { + "column": { + "type": "text", + }, + "column_embedding": {"type": "sparse_vector"}, + } + } + if not await self.client.indices.exists(index=self.index_name): + await self.client.indices.create( + index=self.index_name, + mappings=mappings, + settings={"index": {"default_pipeline": "elser-ingest-pipeline"}}, + ) + store_data = [ + { + "_index": self.index_name, + "_id": sha256(column.encode("utf-8")).hexdigest(), + "column": column, + } + for column in data + ] + await async_bulk(self.client, store_data) + + async def find_similar( + self, + text: str, + ) -> Optional[str]: + """ + Finds the most similar stored text to the given input text. + + This function performs a search in the Elasticsearch index using text expansion to find + the stored text that is most similar to the provided input text. + + Args: + text: The input text for which to find a similar stored text. + + Returns: + The most similar stored text if found, otherwise None. + """ + response = await self.client.search( + index=self.index_name, + size=1, + query={ + "text_expansion": { + "column_embedding": { + "model_id": ".elser_model_2", + "model_text": text, + } + } + }, + ) + + return response["hits"]["hits"][0]["_source"]["column"] if len(response["hits"]["hits"]) > 0 else None diff --git a/src/dbally/similarity/elasticsearch_store.py b/src/dbally/similarity/elasticsearch_store.py new file mode 100644 index 00000000..205e6400 --- /dev/null +++ b/src/dbally/similarity/elasticsearch_store.py @@ -0,0 +1,109 @@ +from hashlib import sha256 +from typing import List, Optional + +from elasticsearch import AsyncElasticsearch +from elasticsearch.helpers import async_bulk + +from dbally.embeddings.base import EmbeddingClient +from dbally.similarity.store import SimilarityStore + + +class ElasticsearchStore(SimilarityStore): + """ + The ElasticsearchStore class stores text embeddings and implements method to find the most similar values using + knn algorithm. + """ + + def __init__( + self, + index_name: str, + embedding_client: EmbeddingClient, + host: str, + http_user: str, + http_password: str, + ca_cert_path: str, + ) -> None: + """ + Initializes the ElasticsearchStore. + + Args: + index_name: The name of the index. + embedding_client: The client to use for creating text embeddings. + host: The host address of the Elasticsearch instance. + http_user: The username used for HTTP authentication. + http_password: The password used for HTTP authentication. + ca_cert_path: The path to the CA certificate for SSL/TLS verification. + """ + super().__init__() + self.client = AsyncElasticsearch( + hosts=host, + http_auth=(http_user, http_password), + ca_certs=ca_cert_path, + ) + self.index_name = index_name + self.embedding_client = embedding_client + + async def store(self, data: List[str]) -> None: + """ + Stores the data in a elastic store. + + Args: + data: The data to store. + """ + + mappings = { + "properties": { + "search_vector": { + "type": "dense_vector", + "index": "true", + "similarity": "cosine", + } + } + } + + if not await self.client.indices.exists(index=self.index_name): + await self.client.indices.create(index=self.index_name, mappings=mappings) + + store_data = [ + { + "_index": self.index_name, + "_id": sha256(column.encode("utf-8")).hexdigest(), + "column": column, + "search_vector": (await self.embedding_client.get_embeddings([column]))[0], + } + for column in data + ] + + await async_bulk(self.client, store_data) + + async def find_similar( + self, + text: str, + k_closest: int = 5, + num_candidates: int = 50, + ) -> Optional[str]: + """ + Finds the most similar text in the store or returns None if no similar text is found. + + Args: + text: The text to find similar to. + k_closest: The k nearest neighbours used by knn-search. + num_candidates: The number of approximate nearest neighbor candidates on each shard. + + Returns: + The most similar text or None if no similar text is found. + """ + query_embedding = (await self.embedding_client.get_embeddings([text]))[0] + + search_results = await self.client.search( + knn={ + "field": "search_vector", + "k": k_closest, + "num_candidates": num_candidates, + "query_vector": query_embedding, + } + ) + + return ( + search_results["hits"]["hits"][0]["_source"]["column"] if len(search_results["hits"]["hits"]) != 0 else None + )