diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index a9836275..bef6f1e6 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -1228,3 +1228,51 @@ def forward( logging.error(f"Error occurs when searching query {query}: {e}") return collected_results + + +class ElasticSearchRM(dspy.Retrieve): + def __init__(self, api_host=None, api_key=None, index_name=None, k=3): + super().__init__(k=k) + try: + from elasticsearch import Elasticsearch + except ImportError as err: + raise ImportError( + "Elasticsearch requires `pip install elasticsearch`." + ) from err + if not api_host: + raise RuntimeError("You must supply api_host") + if not api_key: + raise RuntimeError("You must supply api_key") + if not index_name: + raise RuntimeError("You must supply index_name") + self.es = Elasticsearch(api_host, api_key=api_key) + self.index_name = index_name + self.usage = 0 + + def get_usage_and_reset(self): + usage = self.usage + self.usage = 0 + return {"ElasticSearchRM": usage} + + def forward(self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = []): + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + self.usage += len(queries) + collected_results = [] + for query in queries: + try: + results = self.es.search(index=self.index_name, body={"query": {"match": {"body": query}}}) + for hit in results['hits']['hits']: + if hit['_source']['url'] not in exclude_urls: + collected_results.append({ + "description": hit['_source'].get('description', ''), + "snippets": [hit['_source'].get('body', '')], + "title": hit['_source'].get('title', ''), + "url": hit['_source'].get('url', '') + }) + except Exception as e: + logging.error(f"Error occurs when searching query {query}: {e}") + return collected_results