diff --git a/examples/document-search/from_config.py b/examples/document-search/from_config.py index 89ab7e5b..2253d518 100644 --- a/examples/document-search/from_config.py +++ b/examples/document-search/from_config.py @@ -32,6 +32,18 @@ }, "reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"}, "providers": {"txt": {"type": "DummyProvider"}}, + "rephraser": { + "type": "LLMQueryRephraser", + "config": { + "llm": { + "type": "LiteLLM", + "config": { + "model_name": "gpt-4-turbo", + }, + }, + "prompt": "QueryRephraserPrompt", + }, + }, } diff --git a/packages/ragbits-core/src/ragbits/core/llms/__init__.py b/packages/ragbits-core/src/ragbits/core/llms/__init__.py index 92a6733d..c28584d9 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/llms/__init__.py @@ -1,5 +1,41 @@ +import sys + from .base import LLM from .litellm import LiteLLM from .local import LocalLLM __all__ = ["LLM", "LiteLLM", "LocalLLM"] + + +module = sys.modules[__name__] + + +def get_llm(config: dict) -> LLM: + """ + Initializes and returns an LLM object based on the provided configuration. + + Args: + config : A dictionary containing configuration details for the LLM. + + Returns: + An instance of the specified LLM class, initialized with the provided config + (if any) or default arguments. + + Raises: + KeyError: If the configuration dictionary does not contain a "type" key. + ValueError: If the LLM class is not a subclass of LLM. + """ + llm_type = config["type"] + llm_config = config.get("config", {}) + default_options = llm_config.pop("default_options", None) + + llm_cls = getattr(module, llm_type) + + if not issubclass(llm_cls, LLM): + raise ValueError(f"Invalid LLM class: {llm_cls}") + + # We need to infer the options class from the LLM class. + # pylint: disable=protected-access + options = llm_cls._options_cls(**default_options) if default_options else None # type: ignore + + return llm_cls(**llm_config, default_options=options) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index 8a72f174..1881c587 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -95,7 +95,7 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig()) Returns: A list of chunks. """ - queries = self.query_rephraser.rephrase(query) + queries = await self.query_rephraser.rephrase(query) elements = [] for rephrased_query in queries: search_vector = await self.embedder.embed_text([rephrased_query]) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py index b136d4f3..d26afd20 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py @@ -1,32 +1,44 @@ import sys -from typing import Optional from ragbits.core.utils.config_handling import get_cls_from_config - -from .base import QueryRephraser -from .noop import NoopQueryRephraser - -__all__ = ["NoopQueryRephraser", "QueryRephraser"] +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser +from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser +from ragbits.document_search.retrieval.rephrasers.prompts import QueryRephraserInput, QueryRephraserPrompt + +__all__ = [ + "get_rephraser", + "QueryRephraser", + "NoopQueryRephraser", + "LLMQueryRephraser", + "QueryRephraserPrompt", + "QueryRephraserInput", +] module = sys.modules[__name__] -def get_rephraser(rephraser_config: Optional[dict]) -> QueryRephraser: +def get_rephraser(config: dict | None = None) -> QueryRephraser: """ Initializes and returns a QueryRephraser object based on the provided configuration. Args: - rephraser_config: A dictionary containing configuration details for the QueryRephraser. + config: A dictionary containing configuration details for the QueryRephraser. Returns: An instance of the specified QueryRephraser class, initialized with the provided config (if any) or default arguments. - """ - if rephraser_config is None: + Raises: + KeyError: If the configuration dictionary does not contain a "type" key. + ValueError: If an invalid rephraser class is specified in the configuration. + """ + if config is None: return NoopQueryRephraser() - rephraser_cls = get_cls_from_config(rephraser_config["type"], module) - config = rephraser_config.get("config", {}) + rephraser_cls = get_cls_from_config(config["type"], module) + + if not issubclass(rephraser_cls, QueryRephraser): + raise ValueError(f"Invalid rephraser class: {rephraser_cls}") - return rephraser_cls(**config) + return rephraser_cls.from_config(config.get("config", {})) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py index a40b9f9b..3eba3820 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py @@ -1,14 +1,13 @@ -import abc +from abc import ABC, abstractmethod -class QueryRephraser(abc.ABC): +class QueryRephraser(ABC): """ Rephrases a query. Can provide multiple rephrased queries from one sentence / question. """ - @staticmethod - @abc.abstractmethod - def rephrase(query: str) -> list[str]: + @abstractmethod + async def rephrase(self, query: str) -> list[str]: """ Rephrase a query. @@ -18,3 +17,16 @@ def rephrase(query: str) -> list[str]: Returns: The rephrased queries. """ + + @classmethod + def from_config(cls, config: dict) -> "QueryRephraser": + """ + Create an instance of `QueryRephraser` from a configuration dictionary. + + Args: + config: A dictionary containing configuration settings for the rephraser. + + Returns: + An instance of the rephraser class initialized with the provided configuration. + """ + return cls(**config) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py new file mode 100644 index 00000000..8398b2a5 --- /dev/null +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -0,0 +1,67 @@ +from typing import Any + +from ragbits.core.llms import get_llm +from ragbits.core.llms.base import LLM +from ragbits.core.prompt import Prompt +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.prompts import ( + QueryRephraserInput, + QueryRephraserPrompt, + get_rephraser_prompt, +) + + +class LLMQueryRephraser(QueryRephraser): + """ + A rephraser class that uses a LLM to rephrase queries. + """ + + def __init__(self, llm: LLM, prompt: type[Prompt[QueryRephraserInput, Any]] | None = None): + """ + Initialize the LLMQueryRephraser with a LLM. + + Args: + llm: A LLM instance to handle query rephrasing. + prompt: The prompt to use for rephrasing queries. + """ + self._llm = llm + self._prompt = prompt or QueryRephraserPrompt + + async def rephrase(self, query: str) -> list[str]: + """ + Rephrase a given query using the LLM. + + Args: + query: The query to be rephrased. If not provided, a custom prompt must be given. + + Returns: + A list containing the rephrased query. + + Raises: + LLMConnectionError: If there is a connection error with the LLM API. + LLMStatusError: If the LLM API returns an error status code. + LLMResponseError: If the LLM API response is invalid. + """ + input_data = self._prompt.input_type(query=query) # type: ignore + prompt = self._prompt(input_data) + response = await self._llm.generate(prompt) + return response if isinstance(response, list) else [response] + + @classmethod + def from_config(cls, config: dict) -> "LLMQueryRephraser": + """ + Create an instance of `LLMQueryRephraser` from a configuration dictionary. + + Args: + config: A dictionary containing configuration settings for the rephraser. + + Returns: + An instance of the rephraser class initialized with the provided configuration. + + Raises: + KeyError: If the configuration dictionary does not contain the required keys. + ValueError: If the prompt class is not a subclass of `Prompt` or the LLM class is not a subclass of `LLM`. + """ + llm = get_llm(config["llm"]) + prompt_cls = get_rephraser_prompt(prompt) if (prompt := config.get("prompt")) else None + return cls(llm=llm, prompt=prompt_cls) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py index 8e6b92fd..b48bd7de 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py @@ -6,8 +6,7 @@ class NoopQueryRephraser(QueryRephraser): A no-op query paraphraser that does not change the query. """ - @staticmethod - def rephrase(query: str) -> list[str]: + async def rephrase(self, query: str) -> list[str]: """ Mock implementation which outputs the same query as in input. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py new file mode 100644 index 00000000..1f1e0c90 --- /dev/null +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py @@ -0,0 +1,54 @@ +import sys +from typing import Any + +from pydantic import BaseModel + +from ragbits.core.prompt.prompt import Prompt +from ragbits.core.utils.config_handling import get_cls_from_config + +module = sys.modules[__name__] + + +class QueryRephraserInput(BaseModel): + """ + Input data for the query rephraser prompt. + """ + + query: str + + +class QueryRephraserPrompt(Prompt[QueryRephraserInput, str]): + """ + A prompt class for generating a rephrased version of a user's query using a LLM. + """ + + user_prompt = "{{ query }}" + system_prompt = ( + "You are an expert in query rephrasing and clarity improvement. " + "Your task is to return a single paraphrased version of a user's query, " + "correcting any typos, handling abbreviations and improving clarity. " + "Focus on making the query more precise and readable while keeping its original intent.\n\n" + "Just return the rephrased query. No additional explanations are needed." + ) + + +def get_rephraser_prompt(prompt: str) -> type[Prompt[QueryRephraserInput, Any]]: + """ + Initializes and returns a QueryRephraser object based on the provided configuration. + + Args: + prompt: The prompt class to use for rephrasing queries. + + Returns: + An instance of the specified QueryRephraser class, initialized with the provided config + (if any) or default arguments. + + Raises: + ValueError: If the prompt class is not a subclass of `Prompt`. + """ + prompt_cls = get_cls_from_config(prompt, module) + + if not issubclass(prompt_cls, Prompt): + raise ValueError(f"Invalid rephraser prompt class: {prompt_cls}") + + return prompt_cls