diff --git a/README.md b/README.md index e027b3c0..50c8f36c 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ You could also install the source code which allows you to modify the behavior o Currently, our package support: - `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient`, `GoogleModel`, `DeepSeekModel`, `GroqModel` as language model components -- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch` as retrieval module components +- `YouRM`, `BingSearch`, `VectorRM`, `SerperRM`, `BraveRM`, `SearXNG`, `DuckDuckGoSearchRM`, `TavilySearchRM`, `GoogleSearch`, and `AzureAISearch` as retrieval module components :star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!** diff --git a/examples/storm_examples/run_storm_wiki_gpt.py b/examples/storm_examples/run_storm_wiki_gpt.py index ac079681..b1740a12 100644 --- a/examples/storm_examples/run_storm_wiki_gpt.py +++ b/examples/storm_examples/run_storm_wiki_gpt.py @@ -20,10 +20,11 @@ """ import os + from argparse import ArgumentParser from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel -from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG +from knowledge_storm.rm import YouRM, BingSearch, BraveRM, SerperRM, DuckDuckGoSearchRM, TavilySearchRM, SearXNG, AzureAISearch from knowledge_storm.utils import load_api_key @@ -72,6 +73,7 @@ def main(args): # STORM is a knowledge curation system which consumes information from the retrieval module. # Currently, the information source is the Internet and we use search engine API as the retrieval module. + match args.retriever: case 'bing': rm = BingSearch(bing_search_api=os.getenv('BING_SEARCH_API_KEY'), k=engine_args.search_top_k) @@ -87,8 +89,10 @@ def main(args): rm = TavilySearchRM(tavily_search_api_key=os.getenv('TAVILY_API_KEY'), k=engine_args.search_top_k, include_raw_content=True) case 'searxng': rm = SearXNG(searxng_api_key=os.getenv('SEARXNG_API_KEY'), k=engine_args.search_top_k) + case 'azure_ai_search': + rm = AzureAISearch(azure_ai_search_api_key=os.getenv('AZURE_AI_SEARCH_API_KEY'), k=engine_args.search_top_k) case _: - raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", or "searxng"') + raise ValueError(f'Invalid retriever: {args.retriever}. Choose either "bing", "you", "brave", "duckduckgo", "serper", "tavily", "searxng", or "azure_ai_search"') runner = STORMWikiRunner(engine_args, lm_configs, rm) @@ -113,7 +117,7 @@ def main(args): help='Maximum number of threads to use. The information seeking part and the article generation' 'part can speed up by using multiple threads. Consider reducing it if keep getting ' '"Exceed rate limit" error when calling LM API.') - parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng'], + parser.add_argument('--retriever', type=str, choices=['bing', 'you', 'brave', 'serper', 'duckduckgo', 'tavily', 'searxng', 'azure_ai_search'], help='The search engine API to use for retrieving information.') # stage of the pipeline parser.add_argument('--do-research', action='store_true', @@ -138,4 +142,4 @@ def main(args): parser.add_argument('--remove-duplicate', action='store_true', help='If True, remove duplicate content from the article.') - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/knowledge_storm/rm.py b/knowledge_storm/rm.py index 7f029e79..ec57d79a 100644 --- a/knowledge_storm/rm.py +++ b/knowledge_storm/rm.py @@ -1093,3 +1093,136 @@ def forward( collected_results.append(r) return collected_results + + +class AzureAISearch(dspy.Retrieve): + """Retrieve information from custom queries using Azure AI Search. + + General Documentation: https://learn.microsoft.com/en-us/azure/search/search-create-service-portal. + Python Documentation: https://learn.microsoft.com/en-us/python/api/overview/azure/search-documents-readme?view=azure-python. + """ + + def __init__( + self, + azure_ai_search_api_key=None, + azure_ai_search_url=None, + azure_ai_search_index_name=None, + k=3, + is_valid_source: Callable = None, + ): + """ + Params: + azure_ai_search_api_key: Azure AI Search API key. Check out https://learn.microsoft.com/en-us/azure/search/search-security-api-keys?tabs=rest-use%2Cportal-find%2Cportal-query + "API key" section + azure_ai_search_url: Custom Azure AI Search Endpoint URL. Check out https://learn.microsoft.com/en-us/azure/search/search-create-service-portal#name-the-service + azure_ai_search_index_name: Custom Azure AI Search Index Name. Check out https://learn.microsoft.com/en-us/azure/search/search-how-to-create-search-index?tabs=portal + k: Number of top results to retrieve. + is_valid_source: Optional function to filter valid sources. + min_char_count: Minimum character count for the article to be considered valid. + snippet_chunk_size: Maximum character count for each snippet. + webpage_helper_max_threads: Maximum number of threads to use for webpage helper. + """ + super().__init__(k=k) + + try: + from azure.core.credentials import AzureKeyCredential + from azure.search.documents import SearchClient + except ImportError as err: + raise ImportError( + "AzureAISearch requires `pip install azure-search-documents`." + ) from err + + if not azure_ai_search_api_key and not os.environ.get( + "AZURE_AI_SEARCH_API_KEY" + ): + raise RuntimeError( + "You must supply azure_ai_search_api_key or set environment variable AZURE_AI_SEARCH_API_KEY" + ) + elif azure_ai_search_api_key: + self.azure_ai_search_api_key = azure_ai_search_api_key + else: + self.azure_ai_search_api_key = os.environ["AZURE_AI_SEARCH_API_KEY"] + + if not azure_ai_search_url and not os.environ.get("AZURE_AI_SEARCH_URL"): + raise RuntimeError( + "You must supply azure_ai_search_url or set environment variable AZURE_AI_SEARCH_URL" + ) + elif azure_ai_search_url: + self.azure_ai_search_url = azure_ai_search_url + else: + self.azure_ai_search_url = os.environ["AZURE_AI_SEARCH_URL"] + + if not azure_ai_search_index_name and not os.environ.get( + "AZURE_AI_SEARCH_INDEX_NAME" + ): + raise RuntimeError( + "You must supply azure_ai_search_index_name or set environment variable AZURE_AI_SEARCH_INDEX_NAME" + ) + elif azure_ai_search_index_name: + self.azure_ai_search_index_name = azure_ai_search_index_name + else: + self.azure_ai_search_index_name = os.environ["AZURE_AI_SEARCH_INDEX_NAME"] + + self.usage = 0 + + # If not None, is_valid_source shall be a function that takes a URL and returns a boolean. + if is_valid_source: + self.is_valid_source = is_valid_source + else: + self.is_valid_source = lambda x: True + + def get_usage_and_reset(self): + usage = self.usage + self.usage = 0 + + return {"AzureAISearch": usage} + + def forward( + self, query_or_queries: Union[str, List[str]], exclude_urls: List[str] = [] + ): + """Search with Azure Open AI for self.k top passages for query or queries + + Args: + query_or_queries (Union[str, List[str]]): The query or queries to search for. + exclude_urls (List[str]): A list of urls to exclude from the search results. + + Returns: + a list of Dicts, each dict has keys of 'description', 'snippets' (list of strings), 'title', 'url' + """ + try: + from azure.core.credentials import AzureKeyCredential + from azure.search.documents import SearchClient + except ImportError as err: + raise ImportError( + "AzureAISearch requires `pip install azure-search-documents`." + ) from err + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + self.usage += len(queries) + collected_results = [] + + client = SearchClient( + self.azure_ai_search_url, + self.azure_ai_search_index_name, + AzureKeyCredential(self.azure_ai_search_api_key), + ) + for query in queries: + try: + # https://learn.microsoft.com/en-us/python/api/azure-search-documents/azure.search.documents.searchclient?view=azure-python#azure-search-documents-searchclient-search + results = client.search(search_text=query, top=1) + + for result in results: + document = { + "url": result["metadata_storage_path"], + "title": result["title"], + "description": "N/A", + "snippets": [result["chunk"]], + } + collected_results.append(document) + except Exception as e: + logging.error(f"Error occurs when searching query {query}: {e}") + + return collected_results