Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Elastic Search retriever module integration #534

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions docs/api/retrieval_model_clients/ElasticSearch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

# retrieve.elastic_rm

### Constructor

Initialize an instance of the `elastic_rm` class, .

```python
elastic_rm(
es_client: str,
es_index: str,
es_field: str,
k: int = 3,
)
```

**Parameters:**
- `es_client` (_str_): The Elastic Search Client previously created and initialized (Ref. 1)
- `es_index` (_str_): Path to the directory where chromadb data is persisted.
- `es_field` (_str): The function used for embedding documents and queries. Defaults to `DefaultEmbeddingFunction()` if not specified.
- `k` (_int_, _optional_): The number of top passages to retrieve. Defaults to 3.

Ref. 1 - Connecting to Elastic Cloud -
https://www.elastic.co/guide/en/elasticsearch/client/python-api/current/connecting.html

### Methods

#### `forward(self, query: [str], k: Optional[int] = None) -> dspy.Prediction`

Search the chromadb collection for the top `k` passages matching the given query or queries, using embeddings generated via the specified `embedding_function`.

**Parameters:**
- `query` (str_): The query.
- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization.

**Returns:**
- `dspy.Prediction`: Contains the retrieved passages as a list of string with the prediction signature.

ex:
```python
Prediction(
passages=['Passage 1 Lorem Ipsum awesome', 'Passage 2 Lorem Ipsum Youppidoo', 'Passage 3 Lorem Ipsum Yassssss']
)
```

### Quick Example how to use Elastic Search in a local environment.

Please refer to official doc if your instance is in the cloud. See (Ref. 1) above.

```python
from dspy.retrieve import elastic_rm
import os
from elasticsearch import Elasticsearch


ELASTIC_PASSWORD = os.getenv('ELASTIC_PASSWORD')

# Create the client instance
es = Elasticsearch(
"https://localhost:9200",
ca_certs="http_ca.crt", #Make sure you specifi the path to the certificate, generate one if you don't have.
basic_auth=("elastic", ELASTIC_PASSWORD)
)

# Check your connection
if es.ping():
print("Connected to Elasticsearch cluster")
else:
print("Could not connect to Elasticsearch")

# Index name you want to search
index_name = "wiki-summary"

retriever_model = elastic_rm(
'es_client',
'es_index',
es_field=embedding_function,
k=3
)

results = retriever_model("Explore the significance of quantum computing", k=3)

for passage in results.passages:
print("Document:", result, "\n")
```
63 changes: 63 additions & 0 deletions dspy/retrieve/elasticsearch_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import dspy
from typing import Optional

class elastic_rm(dspy.Retrieve):
def __init__(self, es_client, es_index, es_field, k=3):
""""
A retrieval module that uses Elastic simple vector search to return the top passages for a given query.
Assumes that you already have instanciate your ESClient.

The code has been tested with ElasticSearch 8.12
For more information on how to instanciate your ESClient, please refer to the official documentation.
Ref: https://www.elastic.co/guide/en/elasticsearch/client/python-api/current/connecting.html

Args:
es_client (Elasticsearch): An instance of the Elasticsearch client.
es_index (str): The name of the index to search.
es_field (str): The name of the field to search.
k (Optional[int]): The number of context strings to return. Default is 3.
"""
super().__init__()
self.k=k
self.es_index=es_index
self.es_client=es_client
self.field=es_field


def forward(self, query,k: Optional[int] = None) -> dspy.Prediction:
"""Search with Elastic Search - local or cloud for top k passages for query or queries


Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
k (Optional[int]): The number of context strings to return, if not already specified in self.k

Returns:
dspy.Prediction: An object containing the retrieved passages.
"""

k = k if k is not None else self.k

passages = []


index_name = self.es_index #the name of the index of your elastic-search-dump

search_query = {
"query": {
"match": {
self.field: query
}
}
}

response = self.es_client.search(index=index_name, body=search_query)

for hit in response['hits']['hits']:

text = hit["_source"]["text"]
passages.append(text)
if len(passages) == self.k: # Break the loop once k documents are retrieved
break

return dspy.Prediction(passages=passages)
Loading