From 155a5920290e5749e019e71c30c17e7a337e71fe Mon Sep 17 00:00:00 2001 From: Donnie Date: Sat, 18 Jan 2025 11:30:21 +0100 Subject: [PATCH 1/2] Add ElasticSearchRM --- knowledge_storm/rm.py | 45 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index a9836275..f7b32b9a 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -1228,3 +1228,48 @@ def forward( logging.error(f"Error occurs when searching query {query}: {e}") return collected_results + +class ElasticSearchRM(dspy.Retrieve): + def __init__(self, es_hosts=None, index_name=None, k=3): + super().__init__(k=k) + try: + from elasticsearch import Elasticsearch + except ImportError as err: + raise ImportError( + "Elasticsearch requires `pip install elasticsearch`." + ) from err + if not es_hosts: + raise RuntimeError("You must supply es_hosts") + if not index_name: + raise RuntimeError("You must supply index_name") + self.es = Elasticsearch(es_hosts) + self.index_name = index_name + self.usage = 0 + + def get_usage_and_reset(self): + usage = self.usage + self.usage = 0 + return {"ElasticSearchRM": usage} + + def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + self.usage += len(queries) + collected_results = [] + for query in queries: + try: + results = self.es.search(index=self.index_name, body={"query": {"match": {"content": query}}}) + for hit in results['hits']['hits']: + if hit['_source']['url'] not in exclude_urls: + collected_results.append({ + "description": hit['_source'].get('description', ''), + "snippets": [hit['_source'].get('content', '')], + "title": hit['_source'].get('title', ''), + "url": hit['_source'].get('url', '') + }) + except Exception as e: + logging.error(f"Error occurs when searching query {query}: {e}") + return collected_results From 12e14ab6b8280a8a2d1c0e21022f581a416d7550 Mon Sep 17 00:00:00 2001 From: Donnie Ashok Date: Sat, 18 Jan 2025 12:05:38 +0100 Subject: [PATCH 2/2] Add API Key --- knowledge_storm/rm.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index f7b32b9a..bef6f1e6 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -1229,8 +1229,9 @@ def forward( return collected_results + class ElasticSearchRM(dspy.Retrieve): - def __init__(self, es_hosts=None, index_name=None, k=3): + def __init__(self, api_host=None, api_key=None, index_name=None, k=3): super().__init__(k=k) try: from elasticsearch import Elasticsearch @@ -1238,11 +1239,13 @@ def __init__(self, es_hosts=None, index_name=None, k=3): raise ImportError( "Elasticsearch requires `pip install elasticsearch`." ) from err - if not es_hosts: - raise RuntimeError("You must supply es_hosts") + if not api_host: + raise RuntimeError("You must supply api_host") + if not api_key: + raise RuntimeError("You must supply api_key") if not index_name: raise RuntimeError("You must supply index_name") - self.es = Elasticsearch(es_hosts) + self.es = Elasticsearch(api_host, api_key=api_key) self.index_name = index_name self.usage = 0 @@ -1261,12 +1264,12 @@ def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[st collected_results = [] for query in queries: try: - results = self.es.search(index=self.index_name, body={"query": {"match": {"content": query}}}) + results = self.es.search(index=self.index_name, body={"query": {"match": {"body": query}}}) for hit in results['hits']['hits']: if hit['_source']['url'] not in exclude_urls: collected_results.append({ "description": hit['_source'].get('description', ''), - "snippets": [hit['_source'].get('content', '')], + "snippets": [hit['_source'].get('body', '')], "title": hit['_source'].get('title', ''), "url": hit['_source'].get('url', '') })