From 247d27c51b49fef94c3497dadbf35517521e9c98 Mon Sep 17 00:00:00 2001 From: "alicja.kotyla" Date: Thu, 17 Oct 2024 10:33:44 +0200 Subject: [PATCH 1/6] Add LLMQueryRephraser --- .../src/ragbits/core/llms/__init__.py | 22 +++++++ .../examples/from_config_example.py | 1 + .../src/ragbits/document_search/_main.py | 2 +- .../retrieval/rephrasers/__init__.py | 5 +- .../retrieval/rephrasers/base.py | 21 +++++- .../retrieval/rephrasers/llm.py | 64 +++++++++++++++++++ .../retrieval/rephrasers/noop.py | 9 ++- .../rephrasers/prompt_query_rephraser.py | 22 +++++++ 8 files changed, 139 insertions(+), 7 deletions(-) create mode 100644 packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py create mode 100644 packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompt_query_rephraser.py diff --git a/packages/ragbits-core/src/ragbits/core/llms/__init__.py b/packages/ragbits-core/src/ragbits/core/llms/__init__.py index 92a6733d..f8fb685c 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/llms/__init__.py @@ -1,5 +1,27 @@ +import sys + from .base import LLM from .litellm import LiteLLM from .local import LocalLLM __all__ = ["LLM", "LiteLLM", "LocalLLM"] + + +module = sys.modules[__name__] + + +def get_llm(llm_config: dict) -> LLM: + """ + Initializes and returns an LLM object based on the provided LLM configuration. + + Args: + llm_config : A dictionary containing configuration details for the LLM. + + Returns: + An instance of the specified LLM class, initialized with the provided config + (if any) or default arguments. + """ + llm_type = llm_config["type"] + config = llm_config.get("config", {}) + + return getattr(module, llm_type)(**config) diff --git a/packages/ragbits-document-search/examples/from_config_example.py b/packages/ragbits-document-search/examples/from_config_example.py index 1599cf84..a229e10f 100644 --- a/packages/ragbits-document-search/examples/from_config_example.py +++ b/packages/ragbits-document-search/examples/from_config_example.py @@ -32,6 +32,7 @@ }, "reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"}, "providers": {"txt": {"type": "DummyProvider"}}, + "rephraser": {"type": "LLMQueryRephraser", "config": {"llm": {"type": "LiteLLM"}}}, } diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index fe3f826c..6b8f7109 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -94,7 +94,7 @@ async def search(self, query: str, search_config: SearchConfig = SearchConfig()) Returns: A list of chunks. """ - queries = self.query_rephraser.rephrase(query) + queries = await self.query_rephraser.rephrase(query) elements = [] for rephrased_query in queries: search_vector = await self.embedder.embed_text([rephrased_query]) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py index b136d4f3..0ec89f25 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py @@ -4,9 +4,10 @@ from ragbits.core.utils.config_handling import get_cls_from_config from .base import QueryRephraser +from .llm import LLMQueryRephraser from .noop import NoopQueryRephraser -__all__ = ["NoopQueryRephraser", "QueryRephraser"] +__all__ = ["NoopQueryRephraser", "QueryRephraser", "LLMQueryRephraser"] module = sys.modules[__name__] @@ -29,4 +30,4 @@ def get_rephraser(rephraser_config: Optional[dict]) -> QueryRephraser: rephraser_cls = get_cls_from_config(rephraser_config["type"], module) config = rephraser_config.get("config", {}) - return rephraser_cls(**config) + return rephraser_cls.from_config(config) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py index a40b9f9b..bf8ff7d0 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py @@ -1,4 +1,7 @@ import abc +from typing import Optional + +from ragbits.core.llms.base import LLMOptions class QueryRephraser(abc.ABC): @@ -6,15 +9,29 @@ class QueryRephraser(abc.ABC): Rephrases a query. Can provide multiple rephrased queries from one sentence / question. """ - @staticmethod @abc.abstractmethod - def rephrase(query: str) -> list[str]: + async def rephrase(self, query: str, options: Optional[LLMOptions] = None) -> list[str]: """ Rephrase a query. Args: query: The query to rephrase. + options: OptionaL options to fine-tune the rephraser behavior. Returns: The rephrased queries. """ + + @classmethod + def from_config(cls, config: dict) -> "QueryRephraser": + """ + Create an instance of `QueryRephraser` from a configuration dictionary. + + Args: + config: A dictionary containing configuration settings for the rephraser. + + Returns: + An instance of the rephraser class initialized with the provided configuration. + """ + + return cls(**config) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py new file mode 100644 index 00000000..2521cda1 --- /dev/null +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -0,0 +1,64 @@ +from typing import Optional + +from ragbits.core.llms import get_llm +from ragbits.core.llms.base import LLM, LLMOptions +from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.document_search.retrieval import rephrasers +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.prompt_query_rephraser import QueryRephraserPrompt, _PromptInput + + +class LLMQueryRephraser(QueryRephraser): + """A rephraser class that uses a LLM to rephrase queries.""" + + def __init__(self, llm: LLM, prompt: Optional[QueryRephraserPrompt] = None): + """ + Initialize the LLMQueryRephraser with a LLM and an optional prompt. + + Args: + llm: A LLM instance to handle query rephrasing. + prompt: A prompt defining how the rephrasing should be done. + If not provided, the default `QueryRephraserPrompt` is used. + """ + + self._prompt = prompt or QueryRephraserPrompt + self._llm = llm + + async def rephrase(self, query: str, options: Optional[LLMOptions] = None) -> list[str]: + """ + Rephrase a given query using the LLM. + + Args: + query: The query to be rephrased. + options: OptionaL LLM options to fine-tune the generation behavior. + + Returns: + A list containing the rephrased query. + """ + + prompt = QueryRephraserPrompt(_PromptInput(query=query)) + response = await self._llm.generate(prompt, options=options) + + return [response] + + @classmethod + def from_config(cls, config: dict) -> "LLMQueryRephraser": + """ + Create an instance of `LLMQueryRephraser` from a configuration dictionary. + + Args: + config: A dictionary containing configuration settings for the rephraser. + + Returns: + An instance of the rephraser class initialized with the provided configuration. + """ + + llm = get_llm(config["llm"]) + + prompt_config = config.get("prompt") + + if prompt_config: + prompt = get_cls_from_config(prompt_config["type"], rephrasers) + return cls(llm=llm, prompt=prompt) + + return cls(llm=llm) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py index 8e6b92fd..85748ca2 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py @@ -1,3 +1,6 @@ +from typing import Optional + +from ragbits.core.llms.base import LLMOptions from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser @@ -6,13 +9,15 @@ class NoopQueryRephraser(QueryRephraser): A no-op query paraphraser that does not change the query. """ - @staticmethod - def rephrase(query: str) -> list[str]: + async def rephrase( + self, query: str, options: Optional[LLMOptions] = None # pylint: disable=unused-argument + ) -> list[str]: """ Mock implementation which outputs the same query as in input. Args: query: The query to rephrase. + options: OptionaL options to fine-tune the rephraser behavior. Returns: The list with non-transformed query. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompt_query_rephraser.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompt_query_rephraser.py new file mode 100644 index 00000000..80499b76 --- /dev/null +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompt_query_rephraser.py @@ -0,0 +1,22 @@ +import pydantic + +from ragbits.core.prompt.prompt import Prompt + + +class _PromptInput(pydantic.BaseModel): + query: str + + +class QueryRephraserPrompt(Prompt[_PromptInput, str]): + """ + A prompt class for generating a rephrased version of a user's query using a LLM. + """ + + user_prompt = "{{ query }}" + system_prompt = ( + "You are an expert in query rephrasing and clarity improvement. " + "Your task is to return a single paraphrased version of a user's query, " + "correcting any typos, handling abbreviations and improving clarity. " + "Focus on making the query more precise and readable while keeping its original intent.\n\n" + "Just return the rephrased query. No additional explanations are needed." + ) From 8154a205037707b73dfe3ec1da71904047eb4e22 Mon Sep 17 00:00:00 2001 From: "alicja.kotyla" Date: Thu, 17 Oct 2024 11:52:07 +0200 Subject: [PATCH 2/6] Fixes after review --- .../retrieval/rephrasers/base.py | 8 +++- .../retrieval/rephrasers/llm.py | 40 ++++++++++--------- .../retrieval/rephrasers/noop.py | 16 +++++++- 3 files changed, 42 insertions(+), 22 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py index bf8ff7d0..6b15a36c 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py @@ -2,6 +2,7 @@ from typing import Optional from ragbits.core.llms.base import LLMOptions +from ragbits.core.prompt import Prompt class QueryRephraser(abc.ABC): @@ -10,13 +11,16 @@ class QueryRephraser(abc.ABC): """ @abc.abstractmethod - async def rephrase(self, query: str, options: Optional[LLMOptions] = None) -> list[str]: + async def rephrase( + self, query: Optional[str] = None, prompt: Optional[Prompt] = None, options: Optional[LLMOptions] = None + ) -> list[str]: """ Rephrase a query. Args: query: The query to rephrase. - options: OptionaL options to fine-tune the rephraser behavior. + options: Optional configuration of the the rephraser behavior. + prompt: Optional prompt. Returns: The rephrased queries. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index 2521cda1..a3d4929b 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -2,8 +2,7 @@ from ragbits.core.llms import get_llm from ragbits.core.llms.base import LLM, LLMOptions -from ragbits.core.utils.config_handling import get_cls_from_config -from ragbits.document_search.retrieval import rephrasers +from ragbits.core.prompt import Prompt from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser from ragbits.document_search.retrieval.rephrasers.prompt_query_rephraser import QueryRephraserPrompt, _PromptInput @@ -11,33 +10,44 @@ class LLMQueryRephraser(QueryRephraser): """A rephraser class that uses a LLM to rephrase queries.""" - def __init__(self, llm: LLM, prompt: Optional[QueryRephraserPrompt] = None): + def __init__(self, llm: LLM): """ - Initialize the LLMQueryRephraser with a LLM and an optional prompt. + Initialize the LLMQueryRephraser with a LLM. Args: llm: A LLM instance to handle query rephrasing. - prompt: A prompt defining how the rephrasing should be done. - If not provided, the default `QueryRephraserPrompt` is used. """ - self._prompt = prompt or QueryRephraserPrompt self._llm = llm - async def rephrase(self, query: str, options: Optional[LLMOptions] = None) -> list[str]: + async def rephrase( + self, query: Optional[str] = None, prompt: Optional[Prompt] = None, options: Optional[LLMOptions] = None + ) -> list[str]: """ Rephrase a given query using the LLM. Args: - query: The query to be rephrased. - options: OptionaL LLM options to fine-tune the generation behavior. + query: The query to be rephrased. If not provided, a custom prompt must be given. + prompt: A prompt object defining how the rephrasing should be done. + If not provided, the default `QueryRephraserPrompt` is used, along with the provided query. + options: Optional settings for the LLM to control generation behavior. Returns: A list containing the rephrased query. + + Raises: + ValueError: If both `query` and `prompt` are None. """ - prompt = QueryRephraserPrompt(_PromptInput(query=query)) - response = await self._llm.generate(prompt, options=options) + if query is None and prompt is None: + raise ValueError("Either `query` or `prompt` must be provided.") + + if prompt is not None: + response = await self._llm.generate(prompt, options=options) + + else: + assert isinstance(query, str) + response = await self._llm.generate(QueryRephraserPrompt(_PromptInput(query=query)), options=options) return [response] @@ -55,10 +65,4 @@ def from_config(cls, config: dict) -> "LLMQueryRephraser": llm = get_llm(config["llm"]) - prompt_config = config.get("prompt") - - if prompt_config: - prompt = get_cls_from_config(prompt_config["type"], rephrasers) - return cls(llm=llm, prompt=prompt) - return cls(llm=llm) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py index 85748ca2..068aa62d 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py @@ -1,6 +1,7 @@ from typing import Optional from ragbits.core.llms.base import LLMOptions +from ragbits.core.prompt import Prompt from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser @@ -10,16 +11,27 @@ class NoopQueryRephraser(QueryRephraser): """ async def rephrase( - self, query: str, options: Optional[LLMOptions] = None # pylint: disable=unused-argument + self, + query: Optional[str] = None, + prompt: Optional[Prompt] = None, # pylint: disable=unused-argument + options: Optional[LLMOptions] = None, # pylint: disable=unused-argument ) -> list[str]: """ Mock implementation which outputs the same query as in input. Args: query: The query to rephrase. - options: OptionaL options to fine-tune the rephraser behavior. + options: Optional configuration of the the rephraser behavior. + prompt: Optional prompt. Returns: The list with non-transformed query. + + Raises: + ValueError: If both `query` and `prompt` are None. """ + + if not isinstance(query, str): + raise ValueError("`query` must be provided.") + return [query] From 0bfac01743f438fbf6ce5710b34ae1b13a6ed319 Mon Sep 17 00:00:00 2001 From: "alicja.kotyla" Date: Mon, 21 Oct 2024 14:38:45 +0200 Subject: [PATCH 3/6] Fix after comment --- .../retrieval/rephrasers/base.py | 6 +--- .../retrieval/rephrasers/llm.py | 30 +++++++++---------- .../retrieval/rephrasers/noop.py | 5 +--- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py index 6b15a36c..5fda9947 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py @@ -2,7 +2,6 @@ from typing import Optional from ragbits.core.llms.base import LLMOptions -from ragbits.core.prompt import Prompt class QueryRephraser(abc.ABC): @@ -11,16 +10,13 @@ class QueryRephraser(abc.ABC): """ @abc.abstractmethod - async def rephrase( - self, query: Optional[str] = None, prompt: Optional[Prompt] = None, options: Optional[LLMOptions] = None - ) -> list[str]: + async def rephrase(self, query: str, options: Optional[LLMOptions] = None) -> list[str]: """ Rephrase a query. Args: query: The query to rephrase. options: Optional configuration of the the rephraser behavior. - prompt: Optional prompt. Returns: The rephrased queries. diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index a3d4929b..f362003d 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -3,14 +3,16 @@ from ragbits.core.llms import get_llm from ragbits.core.llms.base import LLM, LLMOptions from ragbits.core.prompt import Prompt +from ragbits.core.utils.config_handling import get_cls_from_config +from ragbits.document_search.retrieval import rephrasers from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.prompt_query_rephraser import QueryRephraserPrompt, _PromptInput +from ragbits.document_search.retrieval.rephrasers.prompt_query_rephraser import QueryRephraserPrompt class LLMQueryRephraser(QueryRephraser): """A rephraser class that uses a LLM to rephrase queries.""" - def __init__(self, llm: LLM): + def __init__(self, llm: LLM, prompt_strategy: Optional[type[Prompt]] = None): """ Initialize the LLMQueryRephraser with a LLM. @@ -19,17 +21,14 @@ def __init__(self, llm: LLM): """ self._llm = llm + self._prompt_strategy = prompt_strategy or QueryRephraserPrompt - async def rephrase( - self, query: Optional[str] = None, prompt: Optional[Prompt] = None, options: Optional[LLMOptions] = None - ) -> list[str]: + async def rephrase(self, query: str, options: Optional[LLMOptions] = None) -> list[str]: """ Rephrase a given query using the LLM. Args: query: The query to be rephrased. If not provided, a custom prompt must be given. - prompt: A prompt object defining how the rephrasing should be done. - If not provided, the default `QueryRephraserPrompt` is used, along with the provided query. options: Optional settings for the LLM to control generation behavior. Returns: @@ -39,15 +38,10 @@ async def rephrase( ValueError: If both `query` and `prompt` are None. """ - if query is None and prompt is None: - raise ValueError("Either `query` or `prompt` must be provided.") + prompt_inputs = self._prompt_strategy.input_type(query=query) # type: ignore + prompt = self._prompt_strategy(prompt_inputs) - if prompt is not None: - response = await self._llm.generate(prompt, options=options) - - else: - assert isinstance(query, str) - response = await self._llm.generate(QueryRephraserPrompt(_PromptInput(query=query)), options=options) + response = await self._llm.generate(prompt, options=options) return [response] @@ -64,5 +58,11 @@ def from_config(cls, config: dict) -> "LLMQueryRephraser": """ llm = get_llm(config["llm"]) + prompt_strategy = config.get("prompt_strategy") + + if prompt_strategy is not None: + prompt_strategy_cls = get_cls_from_config(prompt_strategy, rephrasers) + + return cls(llm=llm, prompt_strategy=prompt_strategy_cls) return cls(llm=llm) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py index 068aa62d..3d47331a 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py @@ -1,7 +1,6 @@ from typing import Optional from ragbits.core.llms.base import LLMOptions -from ragbits.core.prompt import Prompt from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser @@ -12,8 +11,7 @@ class NoopQueryRephraser(QueryRephraser): async def rephrase( self, - query: Optional[str] = None, - prompt: Optional[Prompt] = None, # pylint: disable=unused-argument + query: str, options: Optional[LLMOptions] = None, # pylint: disable=unused-argument ) -> list[str]: """ @@ -22,7 +20,6 @@ async def rephrase( Args: query: The query to rephrase. options: Optional configuration of the the rephraser behavior. - prompt: Optional prompt. Returns: The list with non-transformed query. From 3a60c45ec7759dc119323cd8da03a2d15c026eb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 23 Oct 2024 04:36:20 +0200 Subject: [PATCH 4/6] refactor --- examples/document-search/from_config.py | 13 ++++- .../src/ragbits/core/llms/__init__.py | 26 +++++++-- .../retrieval/rephrasers/__init__.py | 42 +++++++++----- .../retrieval/rephrasers/base.py | 13 ++--- .../retrieval/rephrasers/llm.py | 55 +++++++++---------- .../retrieval/rephrasers/noop.py | 17 +----- .../rephrasers/prompt_query_rephraser.py | 38 ++++++++++++- 7 files changed, 127 insertions(+), 77 deletions(-) diff --git a/examples/document-search/from_config.py b/examples/document-search/from_config.py index a229e10f..ac2d90d1 100644 --- a/examples/document-search/from_config.py +++ b/examples/document-search/from_config.py @@ -32,7 +32,18 @@ }, "reranker": {"type": "ragbits.document_search.retrieval.rerankers.noop:NoopReranker"}, "providers": {"txt": {"type": "DummyProvider"}}, - "rephraser": {"type": "LLMQueryRephraser", "config": {"llm": {"type": "LiteLLM"}}}, + "rephraser": { + "type": "LLMQueryRephraser", + "config": { + "llm": { + "type": "LiteLLM", + "config": { + "model_name": "gpt-4-turbo", + }, + }, + "prompt": "QueryRephraserPrompt", + }, + }, } diff --git a/packages/ragbits-core/src/ragbits/core/llms/__init__.py b/packages/ragbits-core/src/ragbits/core/llms/__init__.py index f8fb685c..c28584d9 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/__init__.py +++ b/packages/ragbits-core/src/ragbits/core/llms/__init__.py @@ -10,18 +10,32 @@ module = sys.modules[__name__] -def get_llm(llm_config: dict) -> LLM: +def get_llm(config: dict) -> LLM: """ - Initializes and returns an LLM object based on the provided LLM configuration. + Initializes and returns an LLM object based on the provided configuration. Args: - llm_config : A dictionary containing configuration details for the LLM. + config : A dictionary containing configuration details for the LLM. Returns: An instance of the specified LLM class, initialized with the provided config (if any) or default arguments. + + Raises: + KeyError: If the configuration dictionary does not contain a "type" key. + ValueError: If the LLM class is not a subclass of LLM. """ - llm_type = llm_config["type"] - config = llm_config.get("config", {}) + llm_type = config["type"] + llm_config = config.get("config", {}) + default_options = llm_config.pop("default_options", None) + + llm_cls = getattr(module, llm_type) + + if not issubclass(llm_cls, LLM): + raise ValueError(f"Invalid LLM class: {llm_cls}") + + # We need to infer the options class from the LLM class. + # pylint: disable=protected-access + options = llm_cls._options_cls(**default_options) if default_options else None # type: ignore - return getattr(module, llm_type)(**config) + return llm_cls(**llm_config, default_options=options) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py index 0ec89f25..f751c31a 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py @@ -1,33 +1,47 @@ import sys -from typing import Optional from ragbits.core.utils.config_handling import get_cls_from_config - -from .base import QueryRephraser -from .llm import LLMQueryRephraser -from .noop import NoopQueryRephraser - -__all__ = ["NoopQueryRephraser", "QueryRephraser", "LLMQueryRephraser"] +from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser +from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser +from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser +from ragbits.document_search.retrieval.rephrasers.prompt_query_rephraser import ( + QueryRephraserInput, + QueryRephraserPrompt, +) + +__all__ = [ + "get_rephraser", + "QueryRephraser", + "NoopQueryRephraser", + "LLMQueryRephraser", + "QueryRephraserPrompt", + "QueryRephraserInput", +] module = sys.modules[__name__] -def get_rephraser(rephraser_config: Optional[dict]) -> QueryRephraser: +def get_rephraser(config: dict | None = None) -> QueryRephraser: """ Initializes and returns a QueryRephraser object based on the provided configuration. Args: - rephraser_config: A dictionary containing configuration details for the QueryRephraser. + config: A dictionary containing configuration details for the QueryRephraser. Returns: An instance of the specified QueryRephraser class, initialized with the provided config (if any) or default arguments. - """ - if rephraser_config is None: + Raises: + KeyError: If the configuration dictionary does not contain a "type" key. + ValueError: If an invalid rephraser class is specified in the configuration. + """ + if config is None: return NoopQueryRephraser() - rephraser_cls = get_cls_from_config(rephraser_config["type"], module) - config = rephraser_config.get("config", {}) + rephraser_cls = get_cls_from_config(config["type"], module) + + if not issubclass(rephraser_cls, QueryRephraser): + raise ValueError(f"Invalid rephraser class: {rephraser_cls}") - return rephraser_cls.from_config(config) + return rephraser_cls.from_config(config.get("config", {})) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py index 5fda9947..3eba3820 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/base.py @@ -1,22 +1,18 @@ -import abc -from typing import Optional +from abc import ABC, abstractmethod -from ragbits.core.llms.base import LLMOptions - -class QueryRephraser(abc.ABC): +class QueryRephraser(ABC): """ Rephrases a query. Can provide multiple rephrased queries from one sentence / question. """ - @abc.abstractmethod - async def rephrase(self, query: str, options: Optional[LLMOptions] = None) -> list[str]: + @abstractmethod + async def rephrase(self, query: str) -> list[str]: """ Rephrase a query. Args: query: The query to rephrase. - options: Optional configuration of the the rephraser behavior. Returns: The rephrased queries. @@ -33,5 +29,4 @@ def from_config(cls, config: dict) -> "QueryRephraser": Returns: An instance of the rephraser class initialized with the provided configuration. """ - return cls(**config) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index f362003d..1435f777 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -1,49 +1,51 @@ -from typing import Optional +from typing import Any from ragbits.core.llms import get_llm -from ragbits.core.llms.base import LLM, LLMOptions +from ragbits.core.llms.base import LLM from ragbits.core.prompt import Prompt -from ragbits.core.utils.config_handling import get_cls_from_config -from ragbits.document_search.retrieval import rephrasers from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.prompt_query_rephraser import QueryRephraserPrompt +from ragbits.document_search.retrieval.rephrasers.prompt_query_rephraser import ( + QueryRephraserInput, + QueryRephraserPrompt, + get_rephraser_prompt, +) class LLMQueryRephraser(QueryRephraser): - """A rephraser class that uses a LLM to rephrase queries.""" + """ + A rephraser class that uses a LLM to rephrase queries. + """ - def __init__(self, llm: LLM, prompt_strategy: Optional[type[Prompt]] = None): + def __init__(self, llm: LLM, prompt: type[Prompt[QueryRephraserInput, Any]] | None = None): """ Initialize the LLMQueryRephraser with a LLM. Args: llm: A LLM instance to handle query rephrasing. + prompt: The prompt to use for rephrasing queries. """ - self._llm = llm - self._prompt_strategy = prompt_strategy or QueryRephraserPrompt + self._prompt = prompt or QueryRephraserPrompt - async def rephrase(self, query: str, options: Optional[LLMOptions] = None) -> list[str]: + async def rephrase(self, query: str) -> list[str]: """ Rephrase a given query using the LLM. Args: query: The query to be rephrased. If not provided, a custom prompt must be given. - options: Optional settings for the LLM to control generation behavior. Returns: A list containing the rephrased query. Raises: - ValueError: If both `query` and `prompt` are None. + LLMConnectionError: If there is a connection error with the LLM API. + LLMStatusError: If the LLM API returns an error status code. + LLMResponseError: If the LLM API response is invalid. """ - - prompt_inputs = self._prompt_strategy.input_type(query=query) # type: ignore - prompt = self._prompt_strategy(prompt_inputs) - - response = await self._llm.generate(prompt, options=options) - - return [response] + input_data = self._prompt.input_type(query=query) # type: ignore + prompt = self._prompt(input_data) + response = await self._llm.generate(prompt) + return response if isinstance(response, list) else [response] @classmethod def from_config(cls, config: dict) -> "LLMQueryRephraser": @@ -55,14 +57,11 @@ def from_config(cls, config: dict) -> "LLMQueryRephraser": Returns: An instance of the rephraser class initialized with the provided configuration. - """ + Raises: + KeyError: If the configuration dictionary does not contain the required keys. + ValueError: If the prompt strategy class is not a subclass of `Prompt`. + """ llm = get_llm(config["llm"]) - prompt_strategy = config.get("prompt_strategy") - - if prompt_strategy is not None: - prompt_strategy_cls = get_cls_from_config(prompt_strategy, rephrasers) - - return cls(llm=llm, prompt_strategy=prompt_strategy_cls) - - return cls(llm=llm) + prompt_cls = get_rephraser_prompt(prompt) if (prompt := config.get("prompt")) else None + return cls(llm=llm, prompt=prompt_cls) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py index 3d47331a..b48bd7de 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/noop.py @@ -1,6 +1,3 @@ -from typing import Optional - -from ragbits.core.llms.base import LLMOptions from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser @@ -9,26 +6,14 @@ class NoopQueryRephraser(QueryRephraser): A no-op query paraphraser that does not change the query. """ - async def rephrase( - self, - query: str, - options: Optional[LLMOptions] = None, # pylint: disable=unused-argument - ) -> list[str]: + async def rephrase(self, query: str) -> list[str]: """ Mock implementation which outputs the same query as in input. Args: query: The query to rephrase. - options: Optional configuration of the the rephraser behavior. Returns: The list with non-transformed query. - - Raises: - ValueError: If both `query` and `prompt` are None. """ - - if not isinstance(query, str): - raise ValueError("`query` must be provided.") - return [query] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompt_query_rephraser.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompt_query_rephraser.py index 80499b76..1f1e0c90 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompt_query_rephraser.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompt_query_rephraser.py @@ -1,13 +1,23 @@ -import pydantic +import sys +from typing import Any + +from pydantic import BaseModel from ragbits.core.prompt.prompt import Prompt +from ragbits.core.utils.config_handling import get_cls_from_config + +module = sys.modules[__name__] -class _PromptInput(pydantic.BaseModel): +class QueryRephraserInput(BaseModel): + """ + Input data for the query rephraser prompt. + """ + query: str -class QueryRephraserPrompt(Prompt[_PromptInput, str]): +class QueryRephraserPrompt(Prompt[QueryRephraserInput, str]): """ A prompt class for generating a rephrased version of a user's query using a LLM. """ @@ -20,3 +30,25 @@ class QueryRephraserPrompt(Prompt[_PromptInput, str]): "Focus on making the query more precise and readable while keeping its original intent.\n\n" "Just return the rephrased query. No additional explanations are needed." ) + + +def get_rephraser_prompt(prompt: str) -> type[Prompt[QueryRephraserInput, Any]]: + """ + Initializes and returns a QueryRephraser object based on the provided configuration. + + Args: + prompt: The prompt class to use for rephrasing queries. + + Returns: + An instance of the specified QueryRephraser class, initialized with the provided config + (if any) or default arguments. + + Raises: + ValueError: If the prompt class is not a subclass of `Prompt`. + """ + prompt_cls = get_cls_from_config(prompt, module) + + if not issubclass(prompt_cls, Prompt): + raise ValueError(f"Invalid rephraser prompt class: {prompt_cls}") + + return prompt_cls From a983756fcdf41960c6f963c43901d529db40cd7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 23 Oct 2024 04:37:26 +0200 Subject: [PATCH 5/6] rename prompts --- .../ragbits/document_search/retrieval/rephrasers/__init__.py | 5 +---- .../src/ragbits/document_search/retrieval/rephrasers/llm.py | 2 +- .../rephrasers/{prompt_query_rephraser.py => prompts.py} | 0 3 files changed, 2 insertions(+), 5 deletions(-) rename packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/{prompt_query_rephraser.py => prompts.py} (100%) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py index f751c31a..d26afd20 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/__init__.py @@ -4,10 +4,7 @@ from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser from ragbits.document_search.retrieval.rephrasers.llm import LLMQueryRephraser from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser -from ragbits.document_search.retrieval.rephrasers.prompt_query_rephraser import ( - QueryRephraserInput, - QueryRephraserPrompt, -) +from ragbits.document_search.retrieval.rephrasers.prompts import QueryRephraserInput, QueryRephraserPrompt __all__ = [ "get_rephraser", diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index 1435f777..bef9c519 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -4,7 +4,7 @@ from ragbits.core.llms.base import LLM from ragbits.core.prompt import Prompt from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser -from ragbits.document_search.retrieval.rephrasers.prompt_query_rephraser import ( +from ragbits.document_search.retrieval.rephrasers.prompts import ( QueryRephraserInput, QueryRephraserPrompt, get_rephraser_prompt, diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompt_query_rephraser.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py similarity index 100% rename from packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompt_query_rephraser.py rename to packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/prompts.py From fd4a48d79cc89ab220d7c171b8dc322145c5aa5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Wed, 23 Oct 2024 04:42:31 +0200 Subject: [PATCH 6/6] update docstring --- .../src/ragbits/document_search/retrieval/rephrasers/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py index bef9c519..8398b2a5 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/retrieval/rephrasers/llm.py @@ -60,7 +60,7 @@ def from_config(cls, config: dict) -> "LLMQueryRephraser": Raises: KeyError: If the configuration dictionary does not contain the required keys. - ValueError: If the prompt strategy class is not a subclass of `Prompt`. + ValueError: If the prompt class is not a subclass of `Prompt` or the LLM class is not a subclass of `LLM`. """ llm = get_llm(config["llm"]) prompt_cls = get_rephraser_prompt(prompt) if (prompt := config.get("prompt")) else None