diff --git a/README.md b/README.md index 0d1dbea1..a999d6fe 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ This is a basic implementation of a db-ally view for an example HR application, ```python from dbally import decorators, SqlAlchemyBaseView, create_collection -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM from sqlalchemy import create_engine class CandidateView(SqlAlchemyBaseView): @@ -53,7 +53,7 @@ class CandidateView(SqlAlchemyBaseView): return Candidate.country == country engine = create_engine('sqlite:///candidates.db') -llm = OpenAIClient(model_name="gpt-3.5-turbo") +llm = LiteLLM(model_name="gpt-3.5-turbo") my_collection = create_collection("collection_name", llm) my_collection.add(CandidateView, lambda: CandidateView(engine)) @@ -82,12 +82,12 @@ pip install dbally Additionally, you can install one of our extensions to use specific features. -* `dbally[openai]`: Use [OpenAI's models](https://platform.openai.com/docs/models) +* `dbally[litellm]`: Use [100+ LLMs](https://docs.litellm.ai/docs/providers) * `dbally[faiss]`: Use [Faiss](https://github.com/facebookresearch/faiss) indexes for similarity search * `dbally[langsmith]`: Use [LangSmith](https://www.langchain.com/langsmith) for query tracking ```bash -pip install dbally[openai,faiss,langsmith] +pip install dbally[litellm,faiss,langsmith] ``` ## License diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py index 06d84287..430a8cd0 100644 --- a/benchmark/dbally_benchmark/e2e_benchmark.py +++ b/benchmark/dbally_benchmark/e2e_benchmark.py @@ -23,7 +23,7 @@ import dbally from dbally.collection import Collection from dbally.iql_generator.iql_prompt_template import default_iql_template -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM from dbally.utils.errors import NoViewFoundError, UnsupportedQueryError from dbally.view_selection.view_selector_prompt_template import default_view_selector_template @@ -82,12 +82,12 @@ async def evaluate(cfg: DictConfig) -> Any: engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}") - llm_client = OpenAIClient( + llm = LiteLLM( model_name="gpt-4", api_key=benchmark_cfg.openai_api_key, ) - db = dbally.create_collection(cfg.db_name, llm_client) + db = dbally.create_collection(cfg.db_name, llm) for view_name in cfg.view_names: view = VIEW_REGISTRY[ViewName(view_name)] diff --git a/benchmark/dbally_benchmark/iql_benchmark.py b/benchmark/dbally_benchmark/iql_benchmark.py index f99411e5..0147dacc 100644 --- a/benchmark/dbally_benchmark/iql_benchmark.py +++ b/benchmark/dbally_benchmark/iql_benchmark.py @@ -22,7 +22,7 @@ from dbally.audit.event_tracker import EventTracker from dbally.iql_generator.iql_generator import IQLGenerator from dbally.iql_generator.iql_prompt_template import default_iql_template -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM from dbally.utils.errors import UnsupportedQueryError from dbally.views.structured import BaseStructuredView @@ -96,13 +96,13 @@ async def evaluate(cfg: DictConfig) -> Any: view = VIEW_REGISTRY[ViewName(view_name)](engine) if "gpt" in cfg.model_name: - llm_client = OpenAIClient( + llm = LiteLLM( model_name=cfg.model_name, api_key=benchmark_cfg.openai_api_key, ) else: raise ValueError("Only OpenAI's GPT models are supported for now.") - iql_generator = IQLGenerator(llm_client=llm_client) + iql_generator = IQLGenerator(llm=llm) run = None if cfg.neptune.log: diff --git a/benchmark/dbally_benchmark/text2sql_benchmark.py b/benchmark/dbally_benchmark/text2sql_benchmark.py index 99eebe6a..ede53f88 100644 --- a/benchmark/dbally_benchmark/text2sql_benchmark.py +++ b/benchmark/dbally_benchmark/text2sql_benchmark.py @@ -21,8 +21,7 @@ from sqlalchemy import create_engine from dbally.audit.event_tracker import EventTracker -from dbally.llm_client.base import LLMClient -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM def _load_db_schema(db_name: str, encoding: Optional[str] = None) -> str: @@ -35,12 +34,12 @@ def _load_db_schema(db_name: str, encoding: Optional[str] = None) -> str: return db_schema -async def _run_text2sql_for_single_example(example: BIRDExample, llm_client: LLMClient) -> Text2SQLResult: +async def _run_text2sql_for_single_example(example: BIRDExample, llm: LiteLLM) -> Text2SQLResult: event_tracker = EventTracker() db_schema = _load_db_schema(example.db_id) - response = await llm_client.text_generation( + response = await llm.generate_text( TEXT2SQL_PROMPT_TEMPLATE, {"schema": db_schema, "question": example.question}, event_tracker=event_tracker ) @@ -49,13 +48,13 @@ async def _run_text2sql_for_single_example(example: BIRDExample, llm_client: LLM ) -async def run_text2sql_for_dataset(dataset: BIRDDataset, llm_client: LLMClient) -> List[Text2SQLResult]: +async def run_text2sql_for_dataset(dataset: BIRDDataset, llm: LiteLLM) -> List[Text2SQLResult]: """ Transforms questions into SQL queries using a Text2SQL model. Args: dataset: The dataset containing questions to be transformed into SQL queries. - llm_client: LLM client. + llm: LLM client. Returns: A list of Text2SQLResult objects representing the predictions. @@ -64,9 +63,7 @@ async def run_text2sql_for_dataset(dataset: BIRDDataset, llm_client: LLMClient) results: List[Text2SQLResult] = [] for group in batch(dataset, 5): - current_results = await asyncio.gather( - *[_run_text2sql_for_single_example(example, llm_client) for example in group] - ) + current_results = await asyncio.gather(*[_run_text2sql_for_single_example(example, llm) for example in group]) results = [*current_results, *results] return results @@ -88,7 +85,7 @@ async def evaluate(cfg: DictConfig) -> Any: engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}") if "gpt" in cfg.model_name: - llm_client = OpenAIClient( + llm = LiteLLM( model_name=cfg.model_name, api_key=benchmark_cfg.openai_api_key, ) @@ -112,7 +109,7 @@ async def evaluate(cfg: DictConfig) -> Any: evaluation_dataset = BIRDDataset.from_json_file( Path(cfg.dataset_path), difficulty_levels=cfg.get("difficulty_levels") ) - text2sql_results = await run_text2sql_for_dataset(dataset=evaluation_dataset, llm_client=llm_client) + text2sql_results = await run_text2sql_for_dataset(dataset=evaluation_dataset, llm=llm) with open(output_dir / results_file_name, "w", encoding="utf-8") as outfile: json.dump([result.model_dump() for result in text2sql_results], outfile, indent=4) diff --git a/docs/concepts/collections.md b/docs/concepts/collections.md index 7371899c..b76a5298 100644 --- a/docs/concepts/collections.md +++ b/docs/concepts/collections.md @@ -3,7 +3,7 @@ At its core, a collection groups together multiple [views](views.md). Once you've defined your views, the next step is to register them within a collection. Here's how you might do it: ```python -my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient()) +my_collection = dbally.create_collection("collection_name", llm=LiteLLM()) my_collection.add(ExampleView) my_collection.add(RecipesView) ``` @@ -11,7 +11,7 @@ my_collection.add(RecipesView) Sometimes, view classes might need certain arguments when they're instantiated. In these instances, you'll want to register your view with a builder function that takes care of supplying these arguments. For instance, with views that rely on SQLAlchemy, you'll typically need to pass a database engine object like so: ```python -my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient()) +my_collection = dbally.create_collection("collection_name", llm=LiteLLM()) engine = sqlalchemy.create_engine("sqlite://") my_collection.add(ExampleView, lambda: ExampleView(engine)) my_collection.add(RecipesView, lambda: RecipesView(engine)) diff --git a/docs/concepts/freeform_views.md b/docs/concepts/freeform_views.md index 7f62e6ba..7cd3a333 100644 --- a/docs/concepts/freeform_views.md +++ b/docs/concepts/freeform_views.md @@ -2,7 +2,7 @@ Freeform views are a type of [view](views.md) that provides a way for developers using db-ally to define what they need from the LLM without requiring a fixed response structure. This flexibility is beneficial when the data structure is unknown beforehand or when potential queries are too diverse to be covered by a structured view. Though freeform views offer more flexibility than structured views, they are less predictable, efficient, and secure, and may be more challenging to integrate with other systems. For these reasons, we recommend using [structured views](./structured_views.md) when possible. -Unlike structured views, which define a response format and a set of operations the LLM may use in response to natural language queries, freeform views only have one task - to respond directly to natural language queries with data from the datasource. They accomplish this by implementing the [`ask`][dbally.views.base.BaseView] method. This method takes a natural language query as input and returns a response. The method also has access to the LLM model (via the `llm_client` attribute), which is typically used to retrieve the correct data from the source (for example, by generating a source-specific query string). To learn more about implementing freeform views, refer to the [How to: Custom Freeform Views](../how-to/custom_freeform_views.md) guide. +Unlike structured views, which define a response format and a set of operations the LLM may use in response to natural language queries, freeform views only have one task - to respond directly to natural language queries with data from the datasource. They accomplish this by implementing the [`ask`][dbally.views.base.BaseView] method. This method takes a natural language query as input and returns a response. The method also has access to the LLM model (via the `llm` attribute), which is typically used to retrieve the correct data from the source (for example, by generating a source-specific query string). To learn more about implementing freeform views, refer to the [How to: Custom Freeform Views](../how-to/custom_freeform_views.md) guide. ## Security diff --git a/docs/how-to/create_custom_event_handler.md b/docs/how-to/create_custom_event_handler.md index 46c2ab32..410973c5 100644 --- a/docs/how-to/create_custom_event_handler.md +++ b/docs/how-to/create_custom_event_handler.md @@ -117,11 +117,11 @@ To use our event handler, we need to pass it to the collection when creating it: ```python import dbally -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM my_collection = bally.create_collection( "collection_name", - llm_client=OpenAIClient(), + llm=LiteLLM(), event_handlers=[FileEventHandler()], ) ``` diff --git a/docs/how-to/custom_views.md b/docs/how-to/custom_views.md index a82e94e4..999f5425 100644 --- a/docs/how-to/custom_views.md +++ b/docs/how-to/custom_views.md @@ -219,10 +219,10 @@ Finally, we can use the `CandidatesView` just like any other view in db-ally. We ```python import asyncio import dbally -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM async def main(): - llm = OpenAIClient(model_name="gpt-3.5-turbo") + llm = LiteLLM(model_name="gpt-3.5-turbo") collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView) diff --git a/docs/how-to/custom_views_code.py b/docs/how-to/custom_views_code.py index 5ea4fb8e..8e033783 100644 --- a/docs/how-to/custom_views_code.py +++ b/docs/how-to/custom_views_code.py @@ -10,7 +10,7 @@ from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.iql import IQLQuery, syntax from dbally.data_models.execution_result import ViewExecutionResult -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM @dataclass class Candidate: @@ -99,7 +99,7 @@ def from_country(self, country: str) -> Callable[[Candidate], bool]: return lambda x: x.country == country async def main(): - llm = OpenAIClient(model_name="gpt-3.5-turbo") + llm = LiteLLM(model_name="gpt-3.5-turbo") event_handlers = [CLIEventHandler()] collection = dbally.create_collection("recruitment", llm, event_handlers=event_handlers) collection.add(CandidateView) diff --git a/docs/how-to/log_runs_to_langsmith.md b/docs/how-to/log_runs_to_langsmith.md index 12dea590..fd7209fc 100644 --- a/docs/how-to/log_runs_to_langsmith.md +++ b/docs/how-to/log_runs_to_langsmith.md @@ -29,7 +29,7 @@ from dbally.audit.event_handlers.langsmith_event_handler import LangSmithEventHa my_collection = dbally.create_collection( "collection_name", - llm_client=OpenAIClient(), + llm=LiteLLM(), event_handlers=[LangSmithEventHandler(api_key="your_api_key")], ) ``` diff --git a/docs/how-to/pandas_views.md b/docs/how-to/pandas_views.md index 8c233c29..f99dc762 100644 --- a/docs/how-to/pandas_views.md +++ b/docs/how-to/pandas_views.md @@ -74,9 +74,9 @@ To use the view, you need to create a [Collection](../concepts/collections.md) a ```python import dbally -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM -llm = OpenAIClient(model_name="gpt-3.5-turbo") +llm = LiteLLM(model_name="gpt-3.5-turbo") collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(CANDIDATE_DATA)) diff --git a/docs/how-to/pandas_views_code.py b/docs/how-to/pandas_views_code.py index fe17f232..8b6ce17b 100644 --- a/docs/how-to/pandas_views_code.py +++ b/docs/how-to/pandas_views_code.py @@ -8,7 +8,7 @@ from dbally import decorators, DataFrameBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.clients.litellm import LiteLLM class CandidateView(DataFrameBaseView): @@ -46,7 +46,7 @@ def senior_data_scientist_position(self) -> pd.Series: ]) async def main(): - llm = OpenAIClient(model_name="gpt-3.5-turbo") + llm = LiteLLM(model_name="gpt-3.5-turbo") collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) collection.add(CandidateView, lambda: CandidateView(CANDIDATE_DATA)) diff --git a/docs/how-to/sql_views.md b/docs/how-to/sql_views.md index 05731521..85411180 100644 --- a/docs/how-to/sql_views.md +++ b/docs/how-to/sql_views.md @@ -84,9 +84,9 @@ engine = create_engine('sqlite:///candidates.db') Once you have defined your view and created an engine, you can register the view with db-ally. You do this by creating a collection and adding the view to it: ```python -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM -my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient()) +my_collection = dbally.create_collection("collection_name", llm=LiteLLM()) my_collection.add(CandidateView, lambda: CandidateView(engine)) ``` diff --git a/docs/how-to/update_similarity_indexes.md b/docs/how-to/update_similarity_indexes.md index dc773e3b..c93de01a 100644 --- a/docs/how-to/update_similarity_indexes.md +++ b/docs/how-to/update_similarity_indexes.md @@ -61,9 +61,9 @@ If you have a [collection](../concepts/collections.md) and want to update Simila ```python from db_ally import create_collection -from db_ally.llm_client.openai_client import OpenAIClient +from db_ally.llms.litellm import LiteLLM -my_collection = create_collection("collection_name", llm_client=OpenAIClient()) +my_collection = create_collection("collection_name", llm=LiteLLM()) # ... add views to the collection diff --git a/docs/quickstart/index.md b/docs/quickstart/index.md index 1a08f663..405f651c 100644 --- a/docs/quickstart/index.md +++ b/docs/quickstart/index.md @@ -22,10 +22,10 @@ To install db-ally, execute the following command: pip install dbally ``` -Since we will be using OpenAI's GPT, you also need to install the `openai` extension: +Since we will be using OpenAI's GPT, you also need to install the `litellm` extension: ```bash -pip install dbally[openai] +pip install dbally[litellm] ``` ## Database Configuration @@ -104,9 +104,9 @@ By setting up these filters, you enable the LLM to fetch candidates while option To use OpenAI's GPT, configure db-ally and provide your OpenAI API key: ```python -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM -llm = OpenAIClient(model_name="gpt-3.5-turbo", api_key="...") +llm = LiteLLM(model_name="gpt-3.5-turbo", api_key="...") ``` Replace `...` with your OpenAI API key. Alternatively, you can set the `OPENAI_API_KEY` environment variable with your API key and omit the `api_key` parameter altogether. diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index 9fa8cd4a..cd9669d0 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -12,7 +12,7 @@ from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embedding_client.openai import OpenAiEmbeddingClient -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM engine = create_engine('sqlite:///candidates.db') @@ -73,7 +73,7 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem async def main(): await country_similarity.update() - llm = OpenAIClient(model_name="gpt-3.5-turbo") + llm = LiteLLM(model_name="gpt-3.5-turbo") collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) collection.add(CandidateView, lambda: CandidateView(engine)) diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index 1d474e8b..3732c9da 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -12,7 +12,7 @@ from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embedding_client.openai import OpenAiEmbeddingClient -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM engine = create_engine('sqlite:///candidates.db') @@ -122,7 +122,7 @@ def display_results(result: ExecutionResult): async def main(): await country_similarity.update() - llm = OpenAIClient(model_name="gpt-3.5-turbo") + llm = LiteLLM(model_name="gpt-3.5-turbo") collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(JobView, lambda: JobView(jobs_data)) diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index 77d9d87e..09c5924e 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -9,7 +9,7 @@ from dbally import decorators, SqlAlchemyBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM engine = create_engine('sqlite:///candidates.db') @@ -54,7 +54,7 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: return Candidate.country == country async def main(): - llm = OpenAIClient(model_name="gpt-3.5-turbo") + llm = LiteLLM(model_name="gpt-3.5-turbo") collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) collection.add(CandidateView, lambda: CandidateView(engine)) diff --git a/docs/tutorials/LangGraphXdbally.ipynb b/docs/tutorials/LangGraphXdbally.ipynb index f330c793..436dc541 100644 --- a/docs/tutorials/LangGraphXdbally.ipynb +++ b/docs/tutorials/LangGraphXdbally.ipynb @@ -34,7 +34,7 @@ }, "outputs": [], "source": [ - "!pip install -U dbally[openai] langgraph langchain langchain_openai langchain_experimental dbally[langsmith]" + "!pip install -U dbally[litellm,langsmith] langgraph langchain langchain_openai langchain_experimental" ] }, { @@ -203,9 +203,9 @@ "outputs": [], "source": [ "import dbally\n", - "from dbally.llm_client.openai_client import OpenAIClient\n", + "from dbally.llms.litellm import LiteLLM\n", "\n", - "recruitment_db = dbally.create_collection(\"recruitment\", llm_client=OpenAIClient())\n", + "recruitment_db = dbally.create_collection(\"recruitment\", llm=LiteLLM())\n", "recruitment_db.add(JobOfferView, lambda: JobOfferView(ENGINE))\n", "recruitment_db.add(CandidateView, lambda: CandidateView(ENGINE))" ] diff --git a/docs/tutorials/db-ally_tutorial.ipynb b/docs/tutorials/db-ally_tutorial.ipynb index 8f9789f9..cd7f9395 100644 --- a/docs/tutorials/db-ally_tutorial.ipynb +++ b/docs/tutorials/db-ally_tutorial.ipynb @@ -29,7 +29,7 @@ }, "outputs": [], "source": [ - "!pip install dbally[openai] a-world-of-countries" + "!pip install dbally[litellm] a-world-of-countries" ] }, { @@ -158,14 +158,14 @@ "outputs": [], "source": [ "from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler\n", - "from dbally.llm_client.openai_client import OpenAIClient\n", + "from dbally.llms.litellm import LiteLLM\n", "import dbally\n", "\n", "API_KEY = \"API-KEY-GOES-HERE\"\n", "\n", "\n", "async def ask_dbally(question: str):\n", - " llm = OpenAIClient(api_key=API_KEY)\n", + " llm = LiteLLM(api_key=API_KEY)\n", " recruitment_db = dbally.create_collection(\"recruitment\", llm, event_handlers=[CLIEventHandler()])\n", " recruitment_db.add(CandidateView, lambda: CandidateView(ENGINE))\n", "\n", @@ -484,7 +484,7 @@ "outputs": [], "source": [ "async def ask_dbally(question: str):\n", - " llm = OpenAIClient(api_key=API_KEY)\n", + " llm = LiteLLM(api_key=API_KEY)\n", " recruitment_db = dbally.create_collection(\"recruitment\", llm, event_handlers=[CLIEventHandler()])\n", " recruitment_db.add(CandidateView, lambda: CandidateView(ENGINE))\n", "\n", @@ -576,7 +576,7 @@ "\n", "\n", "async def ask_dbally(question: str):\n", - " llm = OpenAIClient(api_key=API_KEY)\n", + " llm = LiteLLM(api_key=API_KEY)\n", " recruitment_db = dbally.create_collection(\"recruitment\", llm, event_handlers=[CLIEventHandler()])\n", " recruitment_db.add(RecruitmentView, lambda: RecruitmentView(ENGINE))\n", "\n", @@ -669,7 +669,7 @@ "outputs": [], "source": [ "async def ask_dbally(question: str):\n", - " llm = OpenAIClient(api_key=API_KEY)\n", + " llm = LiteLLM(api_key=API_KEY)\n", " recruitment_db = dbally.create_collection(\"recruitment\", llm, event_handlers=[CLIEventHandler()])\n", " recruitment_db.add(CandidateView, lambda: CandidateView(ENGINE))\n", " recruitment_db.add(RecruitmentView, lambda: RecruitmentView(ENGINE))\n", diff --git a/docs/tutorials/dbally-assistants-api.ipynb b/docs/tutorials/dbally-assistants-api.ipynb index 9d2905d6..8fdc0162 100644 --- a/docs/tutorials/dbally-assistants-api.ipynb +++ b/docs/tutorials/dbally-assistants-api.ipynb @@ -34,7 +34,7 @@ }, "outputs": [], "source": [ - "!pip install dbally[openai] nest_asyncio" + "!pip install dbally[litellm] nest_asyncio" ] }, { @@ -349,9 +349,9 @@ "outputs": [], "source": [ "import dbally\n", - "from dbally.llm_client.openai_client import OpenAIClient\n", + "from dbally.llms.litellm import LiteLLM\n", "\n", - "llm = OpenAIClient(api_key=OPENAI_API_KEY)\n", + "llm = LiteLLM(api_key=OPENAI_API_KEY)\n", "recruitment_db = dbally.create_collection(\"recruitment\", llm)\n", "recruitment_db.add(JobOfferView, lambda: JobOfferView(ENGINE))\n", "recruitment_db.add(CandidateView, lambda: CandidateView(ENGINE))" diff --git a/examples/recruiting.py b/examples/recruiting.py index b5290344..fedb2ea0 100644 --- a/examples/recruiting.py +++ b/examples/recruiting.py @@ -8,7 +8,7 @@ import dbally from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.audit.event_tracker import EventTracker -from dbally.llm_client.openai_client import OpenAIClient +from dbally.llms.litellm import LiteLLM from dbally.prompts import PromptTemplate TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate( @@ -101,18 +101,18 @@ async def recruiting_example(db_description: str, benchmark: Benchmark = example recruitment_db = dbally.create_collection( "recruitment", - llm_client=OpenAIClient(), + llm=LiteLLM(), event_handlers=[CLIEventHandler()], ) recruitment_db.add(RecruitmentView, lambda: RecruitmentView(ENGINE)) event_tracker = EventTracker() - llm_client = OpenAIClient("gpt-4") + llm = LiteLLM("gpt-4") for question in benchmark.questions: await recruitment_db.ask(question.dbally_question, return_natural_response=True) gpt_question = question.gpt_question if question.gpt_question else question.dbally_question - gpt_response = await llm_client.text_generation( + gpt_response = await llm.generate_text( TEXT2SQL_PROMPT_TEMPLATE, {"schema": db_description, "question": gpt_question}, event_tracker=event_tracker ) diff --git a/requirements-dev.txt b/requirements-dev.txt index 6f92eed8..ae740034 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # Requirements as needed for development for this project. # --------------------------------------------------------- # Install current project --e.[openai,transformers,chromadb] +-e.[litellm,chromadb] # developer tools: pre-commit pytest>=6.2.5 diff --git a/setup.cfg b/setup.cfg index 8cac4258..62d3bef8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,19 +41,18 @@ install_requires = numpy>=1.24.0 [options.extras_require] -openai = - openai>=1.10.0 - tiktoken>=0.6.0 +litellm = + litellm>=1.37.9 faiss = faiss-cpu>=1.8.0 +chromadb = + chromadb>=0.4.24 +langsmith= + langsmith~=0.1.57 examples = pydantic~=2.6.0 pydantic_settings~=2.1.0 psycopg2-binary~=2.9.9 -langsmith= - langsmith~=0.1.57 -transformers= - transformers>=4.37.1 benchmark = asyncpg~=0.28.0 eval-type-backport~=0.1.3 @@ -64,8 +63,6 @@ benchmark = pydantic-core~=2.16.2 pydantic-settings~=2.0.3 psycopg2-binary~=2.9.9 -chromadb = - chromadb>=0.4.24 [options.packages.find] where = src diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index bb3b6486..a947eb97 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -8,9 +8,11 @@ from dbally.views.structured import BaseStructuredView from .__version__ import __version__ +from ._exceptions import DbAllyError from ._main import create_collection from ._types import NOT_GIVEN, NotGiven from .collection import Collection +from .llms.clients._exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError __all__ = [ "__version__", @@ -22,6 +24,11 @@ "BaseStructuredView", "DataFrameBaseView", "ExecutionResult", + "DbAllyError", + "LLMError", + "LLMConnectionError", + "LLMResponseError", + "LLMStatusError", "NotGiven", "NOT_GIVEN", ] diff --git a/src/dbally/_exceptions.py b/src/dbally/_exceptions.py new file mode 100644 index 00000000..6b095cd7 --- /dev/null +++ b/src/dbally/_exceptions.py @@ -0,0 +1,4 @@ +class DbAllyError(Exception): + """ + Base class for all exceptions raised by db-ally. + """ diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 316bec01..eeb2d836 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -2,7 +2,7 @@ from .audit.event_handlers.base import EventHandler from .collection import Collection -from .llm_client.base import LLMClient +from .llms import LLM from .nl_responder.nl_responder import NLResponder from .view_selection.base import ViewSelector from .view_selection.llm_view_selector import LLMViewSelector @@ -10,7 +10,7 @@ def create_collection( name: str, - llm_client: LLMClient, + llm: LLM, event_handlers: Optional[List[EventHandler]] = None, view_selector: Optional[ViewSelector] = None, nl_responder: Optional[NLResponder] = None, @@ -27,16 +27,15 @@ def create_collection( ```python from dbally import create_collection - from dbally.llm_client.openai_client import OpenAIClient + from dbally.llms.litellm import LiteLLM - collection = create_collection("my_collection", llm_client=OpenAIClient()) + collection = create_collection("my_collection", llm=LiteLLM()) ``` Args: name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ used to distinguish different db-ally runs. - llm_client: LLM client used by the collection to generate views and respond to natural language\ - queries. + llm: LLM used by the collection to generate responses for natural language queries. event_handlers: Event handlers used by the collection during query executions. Can be used to\ log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\ [LangSmithEventHandler](event_handlers/langsmith_handler.md). @@ -52,14 +51,14 @@ def create_collection( Raises: ValueError: if default LLM client is not configured """ - view_selector = view_selector or LLMViewSelector(llm_client=llm_client) - nl_responder = nl_responder or NLResponder(llm_client=llm_client) + view_selector = view_selector or LLMViewSelector(llm=llm) + nl_responder = nl_responder or NLResponder(llm=llm) event_handlers = event_handlers or [] return Collection( name, nl_responder=nl_responder, view_selector=view_selector, - llm_client=llm_client, + llm=llm, event_handlers=event_handlers, ) diff --git a/src/dbally/collection.py b/src/dbally/collection.py index ee2271ba..d31c106e 100644 --- a/src/dbally/collection.py +++ b/src/dbally/collection.py @@ -8,7 +8,8 @@ from dbally.audit.event_tracker import EventTracker from dbally.data_models.audit import RequestEnd, RequestStart from dbally.data_models.execution_result import ExecutionResult -from dbally.llm_client.base import LLMClient, LLMOptions +from dbally.llms.base import LLM +from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder import NLResponder from dbally.similarity.index import AbstractSimilarityIndex from dbally.utils.errors import NoViewFoundError @@ -47,7 +48,7 @@ def __init__( self, name: str, view_selector: ViewSelector, - llm_client: LLMClient, + llm: LLM, event_handlers: List[EventHandler], nl_responder: NLResponder, n_retries: int = 3, @@ -59,7 +60,7 @@ def __init__( view_selector: As you register more then one [View](views/index.md) within single collection,\ before generating the IQL query, a View that fits query the most is selected by the\ [ViewSelector](view_selection/index.md). - llm_client: LLM client used by the collection to generate views and respond to natural language queries. + llm: LLM used by the collection to generate views and respond to natural language queries. event_handlers: Event handlers used by the collection during query executions. Can be used\ to log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance\ as [LangSmithEventHandler](event_handlers/langsmith_handler.md). @@ -75,7 +76,7 @@ def __init__( self._view_selector = view_selector self._nl_responder = nl_responder self._event_handlers = event_handlers - self._llm_client = llm_client + self._llm = llm T = TypeVar("T", bound=BaseView) @@ -217,7 +218,7 @@ async def ask( start_time_view = time.monotonic() view_result = await view.ask( query=question, - llm_client=self._llm_client, + llm=self._llm, event_tracker=event_tracker, n_retries=self.n_retries, dry_run=dry_run, diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 5bd310a0..8633afc0 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -3,7 +3,8 @@ from dbally.audit.event_tracker import EventTracker from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template -from dbally.llm_client.base import LLMClient, LLMOptions +from dbally.llms.base import LLM +from dbally.llms.clients.base import LLMOptions from dbally.views.exposed_functions import ExposedFunction @@ -25,17 +26,17 @@ class IQLGenerator: def __init__( self, - llm_client: LLMClient, + llm: LLM, prompt_template: Optional[IQLPromptTemplate] = None, promptify_view: Optional[Callable] = None, ) -> None: """ Args: - llm_client: LLM client used to generate IQL + llm: LLM used to generate IQL prompt_template: If not provided by the users is set to `default_iql_template` promptify_view: Function formatting filters for prompt """ - self._llm_client = llm_client + self._llm = llm self._prompt_template = prompt_template or copy.deepcopy(default_iql_template) self._promptify_view = promptify_view or _promptify_filters @@ -64,7 +65,7 @@ async def generate_iql( template = conversation or self._prompt_template - llm_response = await self._llm_client.text_generation( + llm_response = await self._llm.generate_text( template=template, fmt={"filters": filters_for_prompt, "question": question}, event_tracker=event_tracker, diff --git a/src/dbally/llm_client/__init__.py b/src/dbally/llm_client/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/dbally/llm_client/base.py b/src/dbally/llm_client/base.py deleted file mode 100644 index 896e3eed..00000000 --- a/src/dbally/llm_client/base.py +++ /dev/null @@ -1,141 +0,0 @@ -# disable args docstring check as args are documented in OpenAI API docs -# pylint: disable=W9015,R0914 - -import abc -from abc import ABC -from dataclasses import asdict, dataclass -from functools import cached_property -from typing import Any, ClassVar, Dict, Generic, Optional, Type, TypeVar, Union - -from dbally.audit.event_tracker import EventTracker -from dbally.data_models.audit import LLMEvent -from dbally.prompts import ChatFormat, PromptBuilder, PromptTemplate - -from .._types import NotGiven - -LLMOptionsNotGiven = TypeVar("LLMOptionsNotGiven") -LLMClientOptions = TypeVar("LLMClientOptions") - - -@dataclass -class LLMOptions(ABC): - """ - Abstract dataclass that represents all available LLM call options. - """ - - _not_given: ClassVar[Optional[LLMOptionsNotGiven]] = None - - def __or__(self, other: "LLMOptions") -> "LLMOptions": - """ - Merges two LLMOptions, prioritizing non-NOT_GIVEN values from the 'other' object. - """ - self_dict = asdict(self) - other_dict = asdict(other) - - updated_dict = { - key: other_dict.get(key, self_dict[key]) - if not isinstance(other_dict.get(key), NotGiven) - else self_dict[key] - for key in self_dict - } - - return self.__class__(**updated_dict) - - def dict(self) -> Dict[str, Any]: - """ - Creates a dictionary representation of the LLMOptions instance. - If a value is None, it will be replaced with a provider-specific not-given sentinel. - - Returns: - A dictionary representation of the LLMOptions instance. - """ - options = asdict(self) - return { - key: self._not_given if value is None or isinstance(value, NotGiven) else value - for key, value in options.items() - } - - -class LLMClient(Generic[LLMClientOptions], ABC): - """ - Abstract client for interaction with LLM. - - It constructs a prompt using the `PromptBuilder` instance and generates text using the `self.call` method. - """ - - _options_cls: Type[LLMClientOptions] - - def __init__(self, model_name: str, default_options: Optional[LLMClientOptions] = None) -> None: - self.model_name = model_name - self.default_options = default_options or self._options_cls() - - def __init_subclass__(cls) -> None: - if not hasattr(cls, "_options_cls"): - raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute") - - @cached_property - def _prompt_builder(self) -> PromptBuilder: - """ - Prompt builder used to construct final prompts for the LLM. - """ - return PromptBuilder() - - async def text_generation( # pylint: disable=R0913 - self, - template: PromptTemplate, - fmt: dict, - *, - event_tracker: Optional[EventTracker] = None, - options: Optional[LLMClientOptions] = None, - ) -> str: - """ - For a given a PromptType and format dict creates a prompt and - returns the response from LLM. - - Args: - template: Prompt template in system/user/assistant openAI format. - fmt: Dictionary with formatting. - event_tracker: Event store used to audit the generation process. - options: options to use for the LLM client. - - Returns: - Text response from LLM. - """ - options = (self.default_options | options) if options else self.default_options - - prompt = self._prompt_builder.build(template, fmt) - - event = LLMEvent(prompt=prompt, type=type(template).__name__) - - event_tracker = event_tracker or EventTracker() - async with event_tracker.track_event(event) as span: - event.response = await self.call( - prompt=prompt, - response_format=template.response_format, - options=options, - event=event, - ) - span(event) - - return event.response - - @abc.abstractmethod - async def call( - self, - prompt: Union[str, ChatFormat], - response_format: Optional[Dict[str, str]], - options: LLMOptions, - event: LLMEvent, - ) -> str: - """ - Calls LLM API endpoint. - - Args: - prompt: prompt passed to the LLM. - response_format: Optional argument used in the OpenAI API - used to force a json output - options: Additional settings used by LLM. - event: an LLMEvent instance which fields should be filled during the method execution. - - Returns: - Response string from LLM. - """ diff --git a/src/dbally/llm_client/openai_client.py b/src/dbally/llm_client/openai_client.py deleted file mode 100644 index 4458698b..00000000 --- a/src/dbally/llm_client/openai_client.py +++ /dev/null @@ -1,98 +0,0 @@ -from dataclasses import dataclass -from typing import ClassVar, Dict, List, Optional, Union - -from openai import NOT_GIVEN as OPENAI_NOT_GIVEN -from openai import NotGiven as OpenAINotGiven - -from dbally.data_models.audit import LLMEvent -from dbally.llm_client.base import LLMClient -from dbally.prompts import ChatFormat - -from .._types import NOT_GIVEN, NotGiven -from .base import LLMOptions - - -@dataclass -class OpenAIOptions(LLMOptions): - """ - Dataclass that represents all available LLM call options for the OpenAI API. Each of them is - described in the [OpenAI API documentation](https://platform.openai.com/docs/api-reference/chat/create.) - """ - - _not_given: ClassVar[Optional[OpenAINotGiven]] = OPENAI_NOT_GIVEN - - frequency_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN - max_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN - n: Union[Optional[int], NotGiven] = NOT_GIVEN - presence_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN - seed: Union[Optional[int], NotGiven] = NOT_GIVEN - stop: Union[Optional[Union[str, List[str]]], NotGiven] = NOT_GIVEN - temperature: Union[Optional[float], NotGiven] = NOT_GIVEN - top_p: Union[Optional[float], NotGiven] = NOT_GIVEN - - -class OpenAIClient(LLMClient[OpenAIOptions]): - """ - `OpenAIClient` is a class designed to interact with OpenAI's language model (LLM) endpoints, - particularly for the GPT models. - - Args: - model_name: Name of the [OpenAI's model](https://platform.openai.com/docs/models) to be used,\ - default is "gpt-3.5-turbo". - api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used - default_options: Default options to be used in the LLM calls. - """ - - _options_cls = OpenAIOptions - - def __init__( - self, - model_name: str = "gpt-3.5-turbo", - api_key: Optional[str] = None, - default_options: Optional[OpenAIOptions] = None, - ) -> None: - try: - from openai import AsyncOpenAI # pylint: disable=import-outside-toplevel - except ImportError as exc: - raise ImportError("You need to install openai package to use GPT models") from exc - - super().__init__(model_name=model_name, default_options=default_options) - self._client = AsyncOpenAI(api_key=api_key) - - async def call( - self, - prompt: Union[str, ChatFormat], - response_format: Optional[Dict[str, str]], - options: OpenAIOptions, - event: LLMEvent, - ) -> str: - """ - Calls the OpenAI API endpoint. - - Args: - prompt: Prompt as an OpenAI client style list. - response_format: Optional argument used in the OpenAI API - used to force the json output - options: Additional settings used by the LLM. - event: container with the prompt, LLM response and call metrics. - - Returns: - Response string from LLM. - """ - - # only "turbo" models support response_format argument - # https://platform.openai.com/docs/api-reference/chat/create#chat-create-response_format - if "turbo" not in self.model_name: - response_format = None - - response = await self._client.chat.completions.create( - messages=prompt, - model=self.model_name, - response_format=response_format, - **options.dict(), # type: ignore - ) - - event.completion_tokens = response.usage.completion_tokens - event.prompt_tokens = response.usage.prompt_tokens - event.total_tokens = response.usage.total_tokens - - return response.choices[0].message.content # type: ignore diff --git a/src/dbally/llms/__init__.py b/src/dbally/llms/__init__.py new file mode 100644 index 00000000..111892eb --- /dev/null +++ b/src/dbally/llms/__init__.py @@ -0,0 +1,4 @@ +from .base import LLM +from .litellm import LiteLLM + +__all__ = ["LLM", "LiteLLM"] diff --git a/src/dbally/llms/base.py b/src/dbally/llms/base.py new file mode 100644 index 00000000..3d6c90e0 --- /dev/null +++ b/src/dbally/llms/base.py @@ -0,0 +1,104 @@ +from abc import ABC, abstractmethod +from functools import cached_property +from typing import Dict, Generic, Optional, Type + +from dbally.audit.event_tracker import EventTracker +from dbally.data_models.audit import LLMEvent +from dbally.llms.clients.base import LLMClient, LLMClientOptions, LLMOptions +from dbally.prompts.common_validation_utils import ChatFormat +from dbally.prompts.prompt_template import PromptTemplate + + +class LLM(Generic[LLMClientOptions], ABC): + """ + Abstract class for interaction with Large Language Model. + """ + + _options_cls: Type[LLMClientOptions] + + def __init__(self, model_name: str, default_options: Optional[LLMOptions] = None) -> None: + """ + Constructs a new LLM instance. + + Args: + model_name: Name of the model to be used. + default_options: Default options to be used. + + Raises: + TypeError: If the subclass is missing the '_options_cls' attribute. + """ + self.model_name = model_name + self.default_options = default_options or self._options_cls() + + def __init_subclass__(cls) -> None: + if not hasattr(cls, "_options_cls"): + raise TypeError(f"Class {cls.__name__} is missing the '_options_cls' attribute") + + @cached_property + @abstractmethod + def _client(self) -> LLMClient: + """ + Client for the LLM. + """ + + def _format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFormat: + """ + Applies formatting to the prompt template. + + Args: + template: Prompt template in system/user/assistant openAI format. + fmt: Dictionary with formatting. + + Returns: + Prompt in the format of the client. + """ + return [{**message, "content": message["content"].format(**fmt)} for message in template.chat] + + def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: + """ + Counts tokens in the messages. + + Args: + messages: Messages to count tokens for. + fmt: Arguments to be used with prompt. + + Returns: + Number of tokens in the messages. + """ + return sum(len(message["content"].format(**fmt)) for message in messages) + + async def generate_text( + self, + template: PromptTemplate, + fmt: Dict[str, str], + *, + event_tracker: Optional[EventTracker] = None, + options: Optional[LLMOptions] = None, + ) -> str: + """ + Prepares and sends a prompt to the LLM and returns the response. + + Args: + template: Prompt template in system/user/assistant openAI format. + fmt: Dictionary with formatting. + event_tracker: Event store used to audit the generation process. + options: Options to use for the LLM client. + + Returns: + Text response from LLM. + """ + options = (self.default_options | options) if options else self.default_options + prompt = self._format_prompt(template, fmt) + event = LLMEvent(prompt=prompt, type=type(template).__name__) + event_tracker = event_tracker or EventTracker() + + async with event_tracker.track_event(event) as span: + event.response = await self._client.call( + prompt=prompt, + response_format=template.response_format, + options=options, + event=event, + ) + span(event) + + return event.response diff --git a/src/dbally/llms/clients/__init__.py b/src/dbally/llms/clients/__init__.py new file mode 100644 index 00000000..3e0a2cbe --- /dev/null +++ b/src/dbally/llms/clients/__init__.py @@ -0,0 +1,9 @@ +from .base import LLMClient, LLMOptions +from .litellm import LiteLLMClient, LiteLLMOptions + +__all__ = [ + "LLMClient", + "LLMOptions", + "LiteLLMClient", + "LiteLLMOptions", +] diff --git a/src/dbally/llms/clients/_exceptions.py b/src/dbally/llms/clients/_exceptions.py new file mode 100644 index 00000000..ffe3fb04 --- /dev/null +++ b/src/dbally/llms/clients/_exceptions.py @@ -0,0 +1,39 @@ +from ..._exceptions import DbAllyError + + +class LLMError(DbAllyError): + """ + Base class for all exceptions raised by the LLMClient. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + +class LLMConnectionError(LLMError): + """ + Raised when there is an error connecting to the LLM API. + """ + + def __init__(self, message: str = "Connection error.") -> None: + super().__init__(message) + + +class LLMStatusError(LLMError): + """ + Raised when an API response has a status code of 4xx or 5xx. + """ + + def __init__(self, message: str, status_code: int) -> None: + super().__init__(message) + self.status_code = status_code + + +class LLMResponseError(LLMError): + """ + Raised when an API response has an invalid schema. + """ + + def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: + super().__init__(message) diff --git a/src/dbally/llms/clients/base.py b/src/dbally/llms/clients/base.py new file mode 100644 index 00000000..bc55f6ea --- /dev/null +++ b/src/dbally/llms/clients/base.py @@ -0,0 +1,86 @@ +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar + +from dbally.data_models.audit import LLMEvent +from dbally.prompts import ChatFormat + +from ..._types import NotGiven + +LLMOptionsNotGiven = TypeVar("LLMOptionsNotGiven") +LLMClientOptions = TypeVar("LLMClientOptions", bound="LLMOptions") + + +@dataclass +class LLMOptions(ABC): + """ + Abstract dataclass that represents all available LLM call options. + """ + + _not_given: ClassVar[Optional[LLMOptionsNotGiven]] = None + + def __or__(self, other: "LLMOptions") -> "LLMOptions": + """ + Merges two LLMOptions, prioritizing non-NOT_GIVEN values from the 'other' object. + """ + self_dict = asdict(self) + other_dict = asdict(other) + + updated_dict = { + key: other_dict.get(key, self_dict[key]) + if not isinstance(other_dict.get(key), NotGiven) + else self_dict[key] + for key in self_dict + } + + return self.__class__(**updated_dict) + + def dict(self) -> Dict[str, Any]: + """ + Creates a dictionary representation of the LLMOptions instance. + If a value is None, it will be replaced with a provider-specific not-given sentinel. + + Returns: + A dictionary representation of the LLMOptions instance. + """ + options = asdict(self) + return { + key: self._not_given if value is None or isinstance(value, NotGiven) else value + for key, value in options.items() + } + + +class LLMClient(Generic[LLMClientOptions], ABC): + """ + Abstract client for a direct communication with LLM. + """ + + def __init__(self, model_name: str) -> None: + """ + Constructs a new LLMClient instance. + + Args: + model_name: Name of the model to be used. + """ + self.model_name = model_name + + @abstractmethod + async def call( + self, + prompt: ChatFormat, + response_format: Optional[Dict[str, str]], + options: LLMClientOptions, + event: LLMEvent, + ) -> str: + """ + Calls LLM inference API. + + Args: + prompt: Prompt passed to the LLM. + response_format: Optional argument used in the OpenAI API - used to force a json output + options: Additional settings used by LLM. + event: LLMEvent instance which fields should be filled during the method execution. + + Returns: + Response string from LLM. + """ diff --git a/src/dbally/llms/clients/litellm.py b/src/dbally/llms/clients/litellm.py new file mode 100644 index 00000000..82752b1d --- /dev/null +++ b/src/dbally/llms/clients/litellm.py @@ -0,0 +1,107 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +import litellm +from openai import APIConnectionError, APIResponseValidationError, APIStatusError + +from dbally.data_models.audit import LLMEvent +from dbally.llms.clients.base import LLMClient, LLMOptions +from dbally.prompts import ChatFormat + +from ..._types import NOT_GIVEN, NotGiven +from ._exceptions import LLMConnectionError, LLMResponseError, LLMStatusError + + +@dataclass +class LiteLLMOptions(LLMOptions): + """ + Dataclass that represents all available LLM call options for the LiteLLM client. + Each of them is described in the [LiteLLM documentation](https://docs.litellm.ai/docs/completion/input). + """ + + frequency_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN + max_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN + n: Union[Optional[int], NotGiven] = NOT_GIVEN + presence_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN + seed: Union[Optional[int], NotGiven] = NOT_GIVEN + stop: Union[Optional[Union[str, List[str]]], NotGiven] = NOT_GIVEN + temperature: Union[Optional[float], NotGiven] = NOT_GIVEN + top_p: Union[Optional[float], NotGiven] = NOT_GIVEN + + +class LiteLLMClient(LLMClient[LiteLLMOptions]): + """ + Client for the LiteLLM that supports calls to 100+ LLMs APIs, including OpenAI, Anthropic, VertexAI, + Hugging Face and others. + """ + + _options_cls = LiteLLMOptions + + def __init__( + self, + model_name: str, + *, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + ) -> None: + """ + Constructs a new LiteLLMClient instance. + + Args: + model_name: Name of the model to use. + base_url: Base URL of the LLM API. + api_key: API key used to authenticate with the LLM API. + api_version: API version of the LLM API. + """ + super().__init__(model_name) + self.base_url = base_url + self.api_key = api_key + self.api_version = api_version + + async def call( + self, + prompt: ChatFormat, + response_format: Optional[Dict[str, str]], + options: LiteLLMOptions, + event: LLMEvent, + ) -> str: + """ + Calls the appropriate LLM endpoint with the given prompt and options. + + Args: + prompt: Prompt as an OpenAI client style list. + response_format: Optional argument used in the OpenAI API - used to force the json output + options: Additional settings used by the LLM. + event: Container with the prompt, LLM response and call metrics. + + Returns: + Response string from LLM. + + Raises: + LLMConnectionError: If there is a connection error with the LLM API. + LLMStatusError: If the LLM API returns an error status code. + LLMResponseError: If the LLM API response is invalid. + """ + try: + response = await litellm.acompletion( + messages=prompt, + model=self.model_name, + base_url=self.base_url, + api_key=self.api_key, + api_version=self.api_version, + response_format=response_format, + **options.dict(), # type: ignore + ) + except APIConnectionError as exc: + raise LLMConnectionError() from exc + except APIStatusError as exc: + raise LLMStatusError(exc.message, exc.status_code) from exc + except APIResponseValidationError as exc: + raise LLMResponseError() from exc + + event.completion_tokens = response.usage.completion_tokens + event.prompt_tokens = response.usage.prompt_tokens + event.total_tokens = response.usage.total_tokens + + return response.choices[0].message.content # type: ignore diff --git a/src/dbally/llms/litellm.py b/src/dbally/llms/litellm.py new file mode 100644 index 00000000..9b295214 --- /dev/null +++ b/src/dbally/llms/litellm.py @@ -0,0 +1,68 @@ +from functools import cached_property +from typing import Dict, Optional + +from litellm import token_counter + +from dbally.llms.base import LLM +from dbally.llms.clients.litellm import LiteLLMClient, LiteLLMOptions +from dbally.prompts import ChatFormat + + +class LiteLLM(LLM[LiteLLMOptions]): + """ + Class for interaction with any LLM supported by LiteLLM API. + """ + + _options_cls = LiteLLMOptions + + def __init__( + self, + model_name: str = "gpt-3.5-turbo", + default_options: Optional[LiteLLMOptions] = None, + *, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + ) -> None: + """ + Construct a new LiteLLM instance. + + Args: + model_name: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/providers) to be used, + default is "gpt-3.5-turbo". + default_options: Default options to be used. + base_url: Base URL of the LLM API. + api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, + for more information, follow the instructions for your specific vendor in the\ + [LiteLLM documentation](https://docs.litellm.ai/docs/providers). + api_version: API version to be used. If not specified, the default version will be used. + """ + super().__init__(model_name, default_options) + self.base_url = base_url + self.api_key = api_key + self.api_version = api_version + + @cached_property + def _client(self) -> LiteLLMClient: + """ + Client for the LLM. + """ + return LiteLLMClient( + model_name=self.model_name, + base_url=self.base_url, + api_key=self.api_key, + api_version=self.api_version, + ) + + def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: + """ + Count tokens in the messages using a specified model. + + Args: + messages: Messages to count tokens for. + fmt: Arguments to be used with prompt. + + Returns: + Number of tokens in the messages. + """ + return sum(token_counter(model=self.model_name, text=message["content"].format(**fmt)) for message in messages) diff --git a/src/dbally/nl_responder/nl_responder.py b/src/dbally/nl_responder/nl_responder.py index a0d093b4..02aaf647 100644 --- a/src/dbally/nl_responder/nl_responder.py +++ b/src/dbally/nl_responder/nl_responder.py @@ -5,13 +5,13 @@ from dbally.audit.event_tracker import EventTracker from dbally.data_models.execution_result import ViewExecutionResult -from dbally.llm_client.base import LLMClient, LLMOptions +from dbally.llms.base import LLM +from dbally.llms.clients.base import LLMOptions from dbally.nl_responder.nl_responder_prompt_template import NLResponderPromptTemplate, default_nl_responder_template from dbally.nl_responder.query_explainer_prompt_template import ( QueryExplainerPromptTemplate, default_query_explainer_template, ) -from dbally.nl_responder.token_counters import count_tokens_for_huggingface, count_tokens_for_openai class NLResponder: @@ -22,22 +22,21 @@ class NLResponder: def __init__( self, - llm_client: LLMClient, + llm: LLM, query_explainer_prompt_template: Optional[QueryExplainerPromptTemplate] = None, nl_responder_prompt_template: Optional[NLResponderPromptTemplate] = None, max_tokens_count: int = 4096, ) -> None: """ Args: - llm_client: LLM client used to generate natural language response - query_explainer_prompt_template: template for the prompt used to generate the iql explanation\ - if not set defaults to `default_query_explainer_template` - nl_responder_prompt_template: template for the prompt used to generate the NL response\ - if not set defaults to `nl_responder_prompt_template` + llm: LLM used to generate natural language response + query_explainer_prompt_template: template for the prompt used to generate the iql explanation + if not set defaults to `default_query_explainer_template` + nl_responder_prompt_template: template for the prompt used to generate the NL response + if not set defaults to `nl_responder_prompt_template` max_tokens_count: maximum number of tokens that can be used in the prompt """ - - self._llm_client = llm_client + self._llm = llm self._nl_responder_prompt_template = nl_responder_prompt_template or copy.deepcopy( default_nl_responder_template ) @@ -65,27 +64,17 @@ async def generate_response( Returns: Natural language response to the user question. """ - rows = _promptify_rows(result.results) - if "gpt" in self._llm_client.model_name: - tokens_count = count_tokens_for_openai( - messages=self._nl_responder_prompt_template.chat, - fmt={"rows": rows, "question": question}, - model=self._llm_client.model_name, - ) - - else: - tokens_count = count_tokens_for_huggingface( - messages=self._nl_responder_prompt_template.chat, - fmt={"rows": rows, "question": question}, - model=self._llm_client.model_name, - ) + tokens_count = self._llm.count_tokens( + messages=self._nl_responder_prompt_template.chat, + fmt={"rows": rows, "question": question}, + ) if tokens_count > self._max_tokens_count: context = result.context query = next((context.get(key) for key in self.QUERY_KEYS if context.get(key)), question) - llm_response = await self._llm_client.text_generation( + llm_response = await self._llm.generate_text( template=self._query_explainer_prompt_template, fmt={"question": question, "query": query, "number_of_results": len(result.results)}, event_tracker=event_tracker, @@ -94,7 +83,7 @@ async def generate_response( return llm_response - llm_response = await self._llm_client.text_generation( + llm_response = await self._llm.generate_text( template=self._nl_responder_prompt_template, fmt={"rows": _promptify_rows(result.results), "question": question}, event_tracker=event_tracker, diff --git a/src/dbally/nl_responder/token_counters.py b/src/dbally/nl_responder/token_counters.py deleted file mode 100644 index 2b242eba..00000000 --- a/src/dbally/nl_responder/token_counters.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Dict - -from dbally.prompts import ChatFormat - - -def count_tokens_for_openai(messages: ChatFormat, fmt: Dict[str, str], model: str) -> int: - """ - Counts the number of tokens in the messages for OpenAIs' models. - - Args: - messages: Messages to count tokens for. - fmt: Arguments to be used with prompt. - model: Model name. - - Returns: - Number of tokens in the messages. - - Raises: - ImportError: If tiktoken package is not installed. - """ - - try: - import tiktoken # pylint: disable=import-outside-toplevel - except ImportError as exc: - raise ImportError("You need to install tiktoken package to use GPT models") from exc - - encoding = tiktoken.encoding_for_model(model) - num_tokens = 0 - for message in messages: - num_tokens += 4 # every message follows "{role/name}\n{content}\n" - num_tokens += len(encoding.encode(message["content"].format(**fmt))) - - num_tokens += 2 # every reply starts with "assistant" - return num_tokens - - -def count_tokens_for_huggingface(messages: ChatFormat, fmt: Dict[str, str], model: str) -> int: - """ - Counts the number of tokens in the messages for models available on HuggingFace. - - Args: - messages: Messages to count tokens for. - fmt: Arguments to be used with prompt. - model: Model name. - - Returns: - Number of tokens in the messages. - - Raises: - ImportError: If transformers package is not installed. - """ - - try: - from transformers import AutoTokenizer # pylint: disable=import-outside-toplevel - except ImportError as exc: - raise ImportError("You need to install transformers package to use huggingface models' tokenizers.") from exc - - tokenizer = AutoTokenizer.from_pretrained(model) - - for message in messages: - message["content"] = message["content"].format(**fmt) - - return len(tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)) diff --git a/src/dbally/prompts/__init__.py b/src/dbally/prompts/__init__.py index 279ac384..38e20cc7 100644 --- a/src/dbally/prompts/__init__.py +++ b/src/dbally/prompts/__init__.py @@ -1,5 +1,4 @@ from .common_validation_utils import ChatFormat, PromptTemplateError, check_prompt_variables -from .prompt_builder import PromptBuilder from .prompt_template import PromptTemplate -__all__ = ["PromptBuilder", "PromptTemplate", "PromptTemplateError", "check_prompt_variables", "ChatFormat"] +__all__ = ["PromptTemplate", "PromptTemplateError", "check_prompt_variables", "ChatFormat"] diff --git a/src/dbally/prompts/common_validation_utils.py b/src/dbally/prompts/common_validation_utils.py index f62d72b1..124246aa 100644 --- a/src/dbally/prompts/common_validation_utils.py +++ b/src/dbally/prompts/common_validation_utils.py @@ -1,7 +1,7 @@ import re -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Set -ChatFormat = Tuple[Dict[str, str], ...] +ChatFormat = List[Dict[str, str]] class PromptTemplateError(Exception): diff --git a/src/dbally/prompts/prompt_builder.py b/src/dbally/prompts/prompt_builder.py deleted file mode 100644 index 6ab00852..00000000 --- a/src/dbally/prompts/prompt_builder.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import TYPE_CHECKING, Dict, Optional, Union - -from .common_validation_utils import ChatFormat -from .prompt_template import PromptTemplate - -if TYPE_CHECKING: - from transformers.tokenization_utils import PreTrainedTokenizer - - -class PromptBuilder: - """Class used to build prompts""" - - def __init__(self, model_name: Optional[str] = None) -> None: - """ - Args: - model_name: name of the tokenizer model to use. If provided, the tokenizer will convert the prompt to the - format expected by the model. The model_name should be a model available on huggingface.co/models. - - Raises: - OSError: If model_name is not found in huggingface.co/models - """ - self._tokenizer: Optional["PreTrainedTokenizer"] = None - - if model_name is not None: - try: - from transformers import AutoTokenizer # pylint: disable=import-outside-toplevel - except ImportError as exc: - raise ImportError("You need to install transformers package to use huggingface tokenizers") from exc - - self._tokenizer = AutoTokenizer.from_pretrained(model_name) - - def format_prompt(self, prompt_template: PromptTemplate, fmt: Dict[str, str]) -> ChatFormat: - """ - Format prompt using provided arguments - - Args: - prompt_template: this template will be modified in place - fmt: formatting dict - - Returns: - ChatFormat formatted prompt - """ - return tuple({**msg, "content": msg["content"].format(**fmt)} for msg in prompt_template.chat) - - def build(self, prompt_template: PromptTemplate, fmt: Dict[str, str]) -> Union[str, ChatFormat]: - """Build the prompt - - Args: - prompt_template: Prompt template in system/user/assistant openAI format. - fmt: Dictionary with formatting. - - Returns: - Either prompt as a string (if it was formatted for a hf model, model_name provided), or prompt as an - openAI client style list. - - Raises: - KeyError: If fmt does not fill all template arguments. - """ - - prompt = self.format_prompt(prompt_template, fmt) - if self._tokenizer is not None: - prompt = self._tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) - return prompt diff --git a/src/dbally/view_selection/base.py b/src/dbally/view_selection/base.py index 94495e39..e504c7a3 100644 --- a/src/dbally/view_selection/base.py +++ b/src/dbally/view_selection/base.py @@ -2,7 +2,7 @@ from typing import Dict, Optional from dbally.audit.event_tracker import EventTracker -from dbally.llm_client.base import LLMOptions +from dbally.llms.clients.base import LLMOptions class ViewSelector(abc.ABC): diff --git a/src/dbally/view_selection/llm_view_selector.py b/src/dbally/view_selection/llm_view_selector.py index 477f763f..2d501922 100644 --- a/src/dbally/view_selection/llm_view_selector.py +++ b/src/dbally/view_selection/llm_view_selector.py @@ -3,7 +3,8 @@ from dbally.audit.event_tracker import EventTracker from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate -from dbally.llm_client.base import LLMClient, LLMOptions +from dbally.llms.base import LLM +from dbally.llms.clients.base import LLMOptions from dbally.view_selection.base import ViewSelector from dbally.view_selection.view_selector_prompt_template import default_view_selector_template @@ -21,18 +22,18 @@ class LLMViewSelector(ViewSelector): def __init__( self, - llm_client: LLMClient, + llm: LLM, prompt_template: Optional[IQLPromptTemplate] = None, promptify_views: Optional[Callable[[Dict[str, str]], str]] = None, ) -> None: """ Args: - llm_client: LLM client used to generate IQL + llm: LLM used to generate IQL prompt_template: template for the prompt used for the view selection promptify_views: Function formatting filters for prompt. By default names and descriptions of\ all views are concatenated """ - self._llm_client = llm_client + self._llm = llm self._prompt_template = prompt_template or copy.deepcopy(default_view_selector_template) self._promptify_views = promptify_views or _promptify_views @@ -58,7 +59,7 @@ async def select_view( views_for_prompt = self._promptify_views(views) - llm_response = await self._llm_client.text_generation( + llm_response = await self._llm.generate_text( template=self._prompt_template, fmt={"views": views_for_prompt, "question": question}, event_tracker=event_tracker, diff --git a/src/dbally/view_selection/random_view_selector.py b/src/dbally/view_selection/random_view_selector.py index 8b2b1173..61dce39d 100644 --- a/src/dbally/view_selection/random_view_selector.py +++ b/src/dbally/view_selection/random_view_selector.py @@ -2,7 +2,7 @@ from typing import Dict, Optional from dbally.audit.event_tracker import EventTracker -from dbally.llm_client.base import LLMOptions +from dbally.llms.clients.base import LLMOptions from dbally.view_selection.base import ViewSelector diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py index a0846153..24a86822 100644 --- a/src/dbally/views/base.py +++ b/src/dbally/views/base.py @@ -3,7 +3,8 @@ from dbally.audit.event_tracker import EventTracker from dbally.data_models.execution_result import ViewExecutionResult -from dbally.llm_client.base import LLMClient, LLMOptions +from dbally.llms.base import LLM +from dbally.llms.clients.base import LLMOptions class BaseView(metaclass=abc.ABCMeta): @@ -16,7 +17,7 @@ class BaseView(metaclass=abc.ABCMeta): async def ask( self, query: str, - llm_client: LLMClient, + llm: LLM, event_tracker: EventTracker, n_retries: int = 3, dry_run: bool = False, @@ -27,11 +28,11 @@ async def ask( Args: query: The natural language query to execute. - llm_client: The LLM client used to execute the query. + llm: The LLM used to execute the query. event_tracker: The event tracker used to audit the query execution. n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. - llm_options: options to use for the LLM client. + llm_options: Options to use for the LLM. Returns: The result of the query. diff --git a/src/dbally/views/freeform/text2sql/_autodiscovery.py b/src/dbally/views/freeform/text2sql/_autodiscovery.py index b579876d..c649684f 100644 --- a/src/dbally/views/freeform/text2sql/_autodiscovery.py +++ b/src/dbally/views/freeform/text2sql/_autodiscovery.py @@ -4,7 +4,7 @@ from sqlalchemy.sql.ddl import CreateTable from typing_extensions import Self -from dbally.llm_client.base import LLMClient +from dbally.llms.base import LLM from dbally.prompts import PromptTemplate from ._config import Text2SQLConfig, Text2SQLTableConfig @@ -36,7 +36,7 @@ class _AutoDiscoveryBuilderBase: Builder class for configuring the auto-discovery of the database for text2sql freeform view. """ - _llm_client: Optional[LLMClient] + _llm: Optional[LLM] _blacklist: Optional[List[str]] _whitelist: Optional[List[str]] _description_extraction: _DescriptionExtractionStrategy @@ -49,10 +49,10 @@ def __init__( whitelist: Optional[List[str]] = None, description_extraction: Optional[_DescriptionExtractionStrategy] = None, similarity_enabled: bool = False, - llm_client: Optional[LLMClient] = None, + llm: Optional[LLM] = None, ) -> None: self._engine = engine - self._llm_client = llm_client + self._llm = llm self._blacklist = blacklist self._whitelist = whitelist @@ -115,7 +115,7 @@ async def discover(self) -> Text2SQLConfig: Text2SQLConfig: The configuration object for the text2sql freeform view. """ return await _Text2SQLAutoDiscovery( - llm_client=self._llm_client, + llm=self._llm, engine=self._engine, whitelist=self._whitelist, blacklist=self._blacklist, @@ -150,12 +150,12 @@ class AutoDiscoveryBuilder(_AutoDiscoveryBuilderBase): Builder class for configuring the auto-discovery of the database for text2sql freeform view. """ - def use_llm(self, llm_client: LLMClient) -> AutoDiscoveryBuilderWithLLM: + def use_llm(self, llm: LLM) -> AutoDiscoveryBuilderWithLLM: """ Set the LLM client to use for generating descriptions. Args: - llm_client: The LLM client to use for generating descriptions. + llm: The LLM client to use for generating descriptions. Returns: The builder instance. @@ -166,7 +166,7 @@ def use_llm(self, llm_client: LLMClient) -> AutoDiscoveryBuilderWithLLM: blacklist=self._blacklist, description_extraction=self._description_extraction, similarity_enabled=self._similarity_enabled, - llm_client=llm_client, + llm=llm, ) @@ -229,11 +229,11 @@ def __init__( engine: Engine, description_extraction: _DescriptionExtractionStrategy, whitelist: Optional[List[str]] = None, - llm_client: Optional[LLMClient] = None, + llm: Optional[LLM] = None, blacklist: Optional[List[str]] = None, similarity_enabled: bool = False, ) -> None: - self._llm_client = llm_client + self._llm = llm self._engine = engine self._whitelist = whitelist self._blacklist = blacklist @@ -285,13 +285,13 @@ async def discover(self) -> Text2SQLConfig: async def _suggest_similarity_indexes( self, connection: Connection, description: str, table: Table ) -> Dict[str, str]: - if self._llm_client is None: + if self._llm is None: raise ValueError("LLM client is required for suggesting similarity indexes.") similarity = {} for column_name, column in self._iterate_str_columns(table): example_values = self._get_column_example_values(connection, table, column) - similarity_type = await self._llm_client.text_generation( + similarity_type = await self._llm.generate_text( template=similarity_template, fmt={"table_summary": description, "column_name": column.name, "values": example_values}, ) @@ -300,10 +300,10 @@ async def _suggest_similarity_indexes( return similarity async def _generate_llm_summary(self, ddl: str, example_rows: List[dict]) -> str: - if self._llm_client is None: + if self._llm is None: raise ValueError("LLM client is required for generating descriptions.") - return await self._llm_client.text_generation( + return await self._llm.generate_text( template=discovery_template, fmt={"dialect": self._engine.dialect.name, "table_ddl": ddl, "example_rows": example_rows}, ) diff --git a/src/dbally/views/freeform/text2sql/_view.py b/src/dbally/views/freeform/text2sql/_view.py index 3efe0e66..a2610888 100644 --- a/src/dbally/views/freeform/text2sql/_view.py +++ b/src/dbally/views/freeform/text2sql/_view.py @@ -5,7 +5,8 @@ from dbally.audit.event_tracker import EventTracker from dbally.data_models.execution_result import ViewExecutionResult -from dbally.llm_client.base import LLMClient, LLMOptions +from dbally.llms.base import LLM +from dbally.llms.clients.base import LLMOptions from dbally.prompts import PromptTemplate from dbally.views.base import BaseView @@ -39,7 +40,7 @@ def __init__(self, engine: sqlalchemy.engine.Engine, config: Text2SQLConfig) -> async def ask( self, query: str, - llm_client: LLMClient, + llm: LLM, event_tracker: EventTracker, n_retries: int = 3, dry_run: bool = False, @@ -51,11 +52,11 @@ async def ask( Args: query: The natural language query to execute. - llm_client: The LLM client used to execute the query. + llm: The LLM used to execute the query. event_tracker: The event tracker used to audit the query execution. n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. - llm_options: options to use for the LLM client. + llm_options: Options to use for the LLM. Returns: The result of the query. @@ -75,7 +76,7 @@ async def ask( sql, conversation = await self._generate_sql( query=query, conversation=conversation, - llm_client=llm_client, + llm=llm, event_tracker=event_tracker, llm_options=llm_options, ) @@ -106,11 +107,11 @@ async def _generate_sql( self, query: str, conversation: PromptTemplate, - llm_client: LLMClient, + llm: LLM, event_tracker: EventTracker, llm_options: Optional[LLMOptions] = None, ) -> Tuple[str, PromptTemplate]: - response = await llm_client.text_generation( + response = await llm.generate_text( template=conversation, fmt={"tables": self._get_tables_context(), "dialect": self._engine.dialect.name, "question": query}, event_tracker=event_tracker, diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 4cead890..e6034a18 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -5,7 +5,8 @@ from dbally.data_models.execution_result import ViewExecutionResult from dbally.iql import IQLError, IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.llm_client.base import LLMClient, LLMOptions +from dbally.llms.base import LLM +from dbally.llms.clients.base import LLMOptions from dbally.views.exposed_functions import ExposedFunction from .base import BaseView @@ -17,22 +18,22 @@ class BaseStructuredView(BaseView): to be able to list all available filters, apply them and execute queries. """ - def get_iql_generator(self, llm_client: LLMClient) -> IQLGenerator: + def get_iql_generator(self, llm: LLM) -> IQLGenerator: """ Returns the IQL generator for the view. Args: - llm_client: LLM client used to generate the IQL queries + llm: LLM used to generate the IQL queries Returns: IQLGenerator: IQL generator for the view """ - return IQLGenerator(llm_client=llm_client) + return IQLGenerator(llm=llm) async def ask( self, query: str, - llm_client: LLMClient, + llm: LLM, event_tracker: EventTracker, n_retries: int = 3, dry_run: bool = False, @@ -44,16 +45,16 @@ async def ask( Args: query: The natural language query to execute. - llm_client: The LLM client used to execute the query. + llm: The LLM used to execute the query. event_tracker: The event tracker used to audit the query execution. n_retries: The number of retries to execute the query in case of errors. dry_run: If True, the query will not be used to fetch data from the datasource. - llm_options: options to use for the LLM client. + llm_options: Options to use for the LLM. Returns: The result of the query. """ - iql_generator = self.get_iql_generator(llm_client) + iql_generator = self.get_iql_generator(llm) filter_list = self.list_filters() iql_filters, conversation = await iql_generator.generate_iql( diff --git a/tests/integration/test_llm_options.py b/tests/integration/test_llm_options.py index 35fd1c8b..79dde194 100644 --- a/tests/integration/test_llm_options.py +++ b/tests/integration/test_llm_options.py @@ -3,15 +3,8 @@ import pytest from dbally import create_collection -from tests.unit.mocks import MockLLMClient, MockLLMOptions, MockViewBase - - -class MockView1(MockViewBase): - ... - - -class MockView2(MockViewBase): - ... +from tests.unit.mocks import MockLLM, MockLLMOptions +from tests.unit.test_collection import MockView1, MockView2 @pytest.mark.asyncio @@ -20,28 +13,26 @@ async def test_llm_options_propagation(): custom_options = MockLLMOptions(mock_property1=2) expected_options = MockLLMOptions(mock_property1=2, mock_property2="default mock") - llm_client = MockLLMClient(default_options=default_options) + llm = MockLLM(default_options=default_options) + llm._client.call = AsyncMock(return_value="MockView1") collection = create_collection( name="test_collection", - llm_client=llm_client, + llm=llm, ) - + collection.n_retries = 0 collection.add(MockView1) collection.add(MockView2) - collection.n_retries = 0 - collection._llm_client.call = AsyncMock(return_value="MockView1") - await collection.ask( question="Mock question", return_natural_response=True, llm_options=custom_options, ) - assert llm_client.call.call_count == 3 + assert llm._client.call.call_count == 3 - llm_client.call.assert_has_calls( + llm._client.call.assert_has_calls( [ call( prompt=ANY, diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 2e8ceafa..f7de3faf 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -5,14 +5,15 @@ """ from dataclasses import dataclass +from functools import cached_property from typing import List, Optional, Tuple, Union -from unittest.mock import create_autospec from dbally import NOT_GIVEN, NotGiven from dbally.iql import IQLQuery from dbally.iql_generator.iql_generator import IQLGenerator from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template -from dbally.llm_client.base import LLMClient, LLMOptions +from dbally.llms.base import LLM +from dbally.llms.clients.base import LLMClient, LLMOptions from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.structured import BaseStructuredView, ExposedFunction, ViewExecutionResult @@ -36,7 +37,7 @@ def execute(self, dry_run=False) -> ViewExecutionResult: class MockIQLGenerator(IQLGenerator): def __init__(self, iql: str) -> None: self.iql = iql - super().__init__(llm_client=create_autospec(LLMClient)) + super().__init__(llm=MockLLM()) async def generate_iql(self, *_, **__) -> Tuple[str, IQLPromptTemplate]: return self.iql, default_iql_template @@ -69,14 +70,19 @@ class MockLLMOptions(LLMOptions): class MockLLMClient(LLMClient[MockLLMOptions]): - _options_cls = MockLLMOptions - - def __init__( - self, - model_name: str = "gpt-4-mock", - default_options: Optional[MockLLMOptions] = None, - ) -> None: - super().__init__(model_name, default_options) + def __init__(self, model_name: str) -> None: + super().__init__(model_name) async def call(self, *_, **__) -> str: return "mock response" + + +class MockLLM(LLM[MockLLMOptions]): + _options_cls = MockLLMOptions + + def __init__(self, default_options: Optional[MockLLMOptions] = None) -> None: + super().__init__("mock-llm", default_options) + + @cached_property + def _client(self) -> MockLLMClient: + return MockLLMClient(model_name=self.model_name) diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 4ce99a20..edc43fda 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -13,7 +13,7 @@ from dbally.utils.errors import NoViewFoundError from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping from dbally.views.structured import BaseStructuredView -from tests.unit.mocks import MockIQLGenerator, MockLLMClient, MockSimilarityIndex, MockViewBase, MockViewSelector +from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector class MockView1(MockViewBase): @@ -129,7 +129,7 @@ def mock_collection() -> Collection: """ collection = create_collection( "foo", - llm_client=MockLLMClient(), + llm=MockLLM(), view_selector=MockViewSelector("MockView1"), nl_responder=AsyncMock(), ) @@ -270,9 +270,7 @@ class ViewWithMockGenerator(MockViewBase): def get_iql_generator(self, *_, **__): return iql_generator - collection = Collection( - "foo", view_selector=Mock(), llm_client=MockLLMClient(), nl_responder=Mock(), event_handlers=[] - ) + collection = Collection("foo", view_selector=Mock(), llm=MockLLM(), nl_responder=Mock(), event_handlers=[]) collection.add(ViewWithMockGenerator) return collection @@ -293,7 +291,7 @@ async def test_ask_feedback_loop(collection_feedback: Collection) -> None: mock_iql_query.side_effect = errors view = collection_feedback.get("ViewWithMockGenerator") assert isinstance(view, BaseStructuredView) - iql_generator = view.get_iql_generator(llm_client=MockLLMClient()) + iql_generator = view.get_iql_generator(llm=MockLLM()) await collection_feedback.ask("Mock question") @@ -320,7 +318,7 @@ async def test_ask_view_selection_single_view() -> None: collection = Collection( "foo", view_selector=MockViewSelector(""), - llm_client=MockLLMClient(), + llm=MockLLM(), nl_responder=AsyncMock(), event_handlers=[], ) @@ -339,7 +337,7 @@ async def test_ask_view_selection_multiple_views() -> None: collection = Collection( "foo", view_selector=MockViewSelector("MockViewWithResults"), - llm_client=MockLLMClient(), + llm=MockLLM(), nl_responder=AsyncMock(), event_handlers=[], ) @@ -360,7 +358,7 @@ async def test_ask_view_selection_no_views() -> None: collection = Collection( "foo", view_selector=MockViewSelector(""), - llm_client=MockLLMClient(), + llm=MockLLM(), nl_responder=AsyncMock(), event_handlers=[], ) diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index cdedeb06..af991c31 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -1,6 +1,6 @@ # mypy: disable-error-code="empty-body" -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock import pytest import sqlalchemy @@ -11,6 +11,7 @@ from dbally.iql_generator.iql_generator import IQLGenerator from dbally.iql_generator.iql_prompt_template import default_iql_template from dbally.views.methods_base import MethodsBaseView +from tests.unit.mocks import MockLLM class MockView(MethodsBaseView): @@ -33,26 +34,25 @@ def filter_by_name(self, city: str) -> sqlalchemy.ColumnElement: @pytest.fixture -def view(): - view = MockView() - return view +def view() -> MockView: + return MockView() @pytest.fixture -def llm_client(): - mock_client = Mock() - mock_client.text_generation = AsyncMock(return_value="LLM IQL mock answer") - return mock_client +def llm() -> MockLLM: + llm = MockLLM() + llm._client.call = AsyncMock(return_value="LLM IQL mock answer") + return llm @pytest.fixture -def event_tracker(): +def event_tracker() -> EventTracker: return EventTracker() @pytest.mark.asyncio -async def test_iql_generation(llm_client, event_tracker, view): - iql_generator = IQLGenerator(llm_client, default_iql_template) +async def test_iql_generation(llm: MockLLM, event_tracker: EventTracker, view: MockView) -> None: + iql_generator = IQLGenerator(llm, default_iql_template) filters_for_prompt = iql_generator._promptify_view(view.list_filters()) filters_in_prompt = set(filters_for_prompt.split("\n")) @@ -72,8 +72,8 @@ async def test_iql_generation(llm_client, event_tracker, view): assert response2 == ("LLM IQL mock answer", template_after_2nd_response) -def test_add_error_msg(llm_client): - iql_generator = IQLGenerator(llm_client, default_iql_template) +def test_add_error_msg(llm: MockLLM) -> None: + iql_generator = IQLGenerator(llm, default_iql_template) errors = [ValueError("Mock_error")] conversation = default_iql_template.add_assistant_message(content="Assistant") diff --git a/tests/unit/test_nl_responder.py b/tests/unit/test_nl_responder.py index 78a0040e..b4237d01 100644 --- a/tests/unit/test_nl_responder.py +++ b/tests/unit/test_nl_responder.py @@ -1,33 +1,32 @@ -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock import pytest from dbally.audit.event_tracker import EventTracker from dbally.data_models.execution_result import ViewExecutionResult from dbally.nl_responder.nl_responder import NLResponder +from tests.unit.mocks import MockLLM @pytest.fixture -def llm_client(): - mock_client = Mock() - mock_client.text_generation = AsyncMock(return_value="db-ally is the best") - mock_client.model_name = "gpt-4" - return mock_client +def llm() -> MockLLM: + llm = MockLLM() + llm._client.call = AsyncMock(return_value="db-ally is the best") + return llm @pytest.fixture -def event_tracker(): +def event_tracker() -> EventTracker: return EventTracker() @pytest.fixture -def answer(): +def answer() -> ViewExecutionResult: return ViewExecutionResult(results=[{"id": 1, "name": "Mock name"}], context={"sql": "Mock SQL"}) @pytest.mark.asyncio -async def test_nl_responder(llm_client, answer, event_tracker): - nl_responder = NLResponder(llm_client) - +async def test_nl_responder(llm: MockLLM, answer: ViewExecutionResult, event_tracker: EventTracker): + nl_responder = NLResponder(llm) response = await nl_responder.generate_response(answer, "Mock question", event_tracker) assert response == "db-ally is the best" diff --git a/tests/unit/test_prompt_builder.py b/tests/unit/test_prompt_builder.py index a4a2fb59..9c3d0695 100644 --- a/tests/unit/test_prompt_builder.py +++ b/tests/unit/test_prompt_builder.py @@ -1,7 +1,8 @@ import pytest from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate -from dbally.prompts import ChatFormat, PromptBuilder, PromptTemplate, PromptTemplateError +from dbally.prompts import ChatFormat, PromptTemplate, PromptTemplateError +from tests.unit.mocks import MockLLM @pytest.fixture() @@ -16,35 +17,24 @@ def simple_template(): @pytest.fixture() -def default_prompt_builder(): - builder = PromptBuilder() - return builder +def llm(): + return MockLLM() -@pytest.fixture() -def hf_prompt_builder(): - builder = PromptBuilder("HuggingFaceH4/zephyr-7b-beta") - return builder - - -def test_openai_client_prompt(default_prompt_builder, simple_template): - prompt = default_prompt_builder.build(simple_template, fmt={"question": "Example user question?"}) - assert prompt == ( +def test_default_llm_format_prompt(llm, simple_template): + prompt = llm._format_prompt( + template=simple_template, + fmt={"question": "Example user question?"}, + ) + assert prompt == [ {"content": "You are a helpful assistant.", "role": "system"}, {"content": "Example user question?", "role": "user"}, - ) - - -def test_text_prompt(hf_prompt_builder, simple_template): - prompt = hf_prompt_builder.build(simple_template, fmt={"question": "Example user question?"}) - assert ( - prompt == "<|system|>\nYou are a helpful assistant.\n<|user|>\nExample user question?\n<|assistant|>\n" - ) + ] -def test_missing_format_dict(default_prompt_builder, simple_template): +def test_missing_format_dict(llm, simple_template): with pytest.raises(KeyError): - _ = default_prompt_builder.build(simple_template, fmt={}) + _ = llm._format_prompt(simple_template, fmt={}) @pytest.mark.parametrize( @@ -73,10 +63,10 @@ def test_chat_order_validation(invalid_chat): _ = PromptTemplate(chat=invalid_chat) -def test_dynamic_few_shot(default_prompt_builder, simple_template): +def test_dynamic_few_shot(llm, simple_template): assert ( len( - default_prompt_builder.build( + llm._format_prompt( simple_template.add_assistant_message("assistant message").add_user_message("user message"), fmt={"question": "user question"}, ) diff --git a/tests/unit/test_view_selector.py b/tests/unit/test_view_selector.py index 09a7f1d3..a135b34e 100644 --- a/tests/unit/test_view_selector.py +++ b/tests/unit/test_view_selector.py @@ -1,39 +1,37 @@ # mypy: disable-error-code="empty-body" # pylint: disable=missing-return-doc from typing import Dict -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock import pytest import dbally from dbally.audit.event_tracker import EventTracker -from dbally.llm_client.base import LLMClient +from dbally.llms.base import LLM from dbally.view_selection.llm_view_selector import LLMViewSelector - -from .mocks import MockLLMClient -from .test_collection import MockView1, MockView2 +from tests.unit.mocks import MockLLM +from tests.unit.test_collection import MockView1, MockView2 @pytest.fixture -def llm_client() -> LLMClient: +def llm() -> LLM: """Return a mock LLM client.""" - client = Mock() - client.text_generation = AsyncMock(return_value="MockView1") - return client + llm = MockLLM() + llm._client.call = AsyncMock(return_value="MockView1") + return llm @pytest.fixture def views() -> Dict[str, str]: """Return a map of view names + view descriptions to be used in the test.""" - mock_collection = dbally.create_collection("mock_collection", llm_client=MockLLMClient()) + mock_collection = dbally.create_collection("mock_collection", llm=MockLLM()) mock_collection.add(MockView1) mock_collection.add(MockView2) return mock_collection.list() @pytest.mark.asyncio -async def test_view_selection(llm_client: LLMClient, views: Dict[str, str]): - view_selector = LLMViewSelector(llm_client) - +async def test_view_selection(llm: LLM, views: Dict[str, str]): + view_selector = LLMViewSelector(llm) view = await view_selector.select_view("Mock question?", views, event_tracker=EventTracker()) assert view == "MockView1" diff --git a/tests/unit/views/text2sql/test_autodiscovery.py b/tests/unit/views/text2sql/test_autodiscovery.py index bde682ec..3321c841 100644 --- a/tests/unit/views/text2sql/test_autodiscovery.py +++ b/tests/unit/views/text2sql/test_autodiscovery.py @@ -64,7 +64,7 @@ async def test_autodiscovery_whitelist(sample_db: Engine): async def test_autodiscovery_llm_descriptions(sample_db: Engine): mock_client = Mock() - mock_client.text_generation = AsyncMock(return_value="LLM mock answer") + mock_client.generate_text = AsyncMock(return_value="LLM mock answer") config = await ( configure_text2sql_auto_discovery(sample_db) diff --git a/tests/unit/views/text2sql/test_view.py b/tests/unit/views/text2sql/test_view.py index 18fcb34d..02dad2fe 100644 --- a/tests/unit/views/text2sql/test_view.py +++ b/tests/unit/views/text2sql/test_view.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, Mock +from unittest.mock import AsyncMock import pytest import sqlalchemy @@ -7,6 +7,7 @@ import dbally from dbally.views.freeform.text2sql import Text2SQLConfig, Text2SQLFreeformView from dbally.views.freeform.text2sql._config import Text2SQLTableConfig +from tests.unit.mocks import MockLLM @pytest.fixture @@ -31,8 +32,8 @@ def sample_db() -> Engine: async def test_text2sql_view(sample_db: Engine): - mock_llm = Mock() - mock_llm.text_generation = AsyncMock(return_value="SELECT * FROM customers WHERE city = 'New York'") + llm = MockLLM() + llm._client.call = AsyncMock(return_value="SELECT * FROM customers WHERE city = 'New York'") config = Text2SQLConfig( tables={ @@ -43,7 +44,7 @@ async def test_text2sql_view(sample_db: Engine): } ) - collection = dbally.create_collection(name="test_collection", llm_client=mock_llm) + collection = dbally.create_collection(name="test_collection", llm=llm) collection.add(Text2SQLFreeformView, lambda: Text2SQLFreeformView(sample_db, config)) response = await collection.ask("Show me customers from New York")