diff --git a/README.md b/README.md index 0bbc2843..b44e63b0 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,26 @@ -#

🦮 db-ally

+
-

- Efficient, consistent and secure library for querying structured data with natural language + + + dbally logo + + +
+
+ +

+ Efficient, consistent and secure library for querying structured data with natural language

---- +[![PyPI - License](https://img.shields.io/pypi/l/dbally)](https://pypi.org/project/dbally) +[![PyPI - Version](https://img.shields.io/pypi/v/dbally)](https://pypi.org/project/dbally) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/dbally)](https://pypi.org/project/dbally) -* **Documentation:** [db-ally.deepsense.ai](https://db-ally.deepsense.ai/) -* **Source code:** [github.com/deepsense-ai/db-ally](https://github.com/deepsense-ai/db-ally) +
--- - -**db-ally** is an LLM-powered library for creating natural language interfaces to data sources. While it occupies a similar space to the text-to-SQL solutions, its goals and methods are different. db-ally allows developers to outline specific use cases for the LLM to handle, detailing the desired data format and the possible operations to fetch this data. +db-ally is an LLM-powered library for creating natural language interfaces to data sources. While it occupies a similar space to the text-to-SQL solutions, its goals and methods are different. db-ally allows developers to outline specific use cases for the LLM to handle, detailing the desired data format and the possible operations to fetch this data. db-ally effectively shields the complexity of the underlying data source from the model, presenting only the essential information needed for solving the specific use cases. Instead of generating arbitrary SQL, the model is asked to generate responses in a simplified query language. @@ -25,7 +33,7 @@ The benefits of db-ally can be described in terms of its four main characteristi ## Quickstart -In db-ally, developers define their use cases by implementing [**views**](https://db-ally.deepsense.ai/concepts/views) and **filters**. A list of possible filters is presented to the LLM in terms of [**IQL**](https://db-ally.deepsense.ai/concepts/iql) (Intermediate Query Language). Views are grouped and registered within a [**collection**](https://db-ally.deepsense.ai/concepts/views), which then serves as an entry point for asking questions in natural language. +In db-ally, developers define their use cases by implementing [**views**](https://db-ally.deepsense.ai/concepts/views), **filters** and **aggregations**. A list of possible filters and aggregations is presented to the LLM in terms of [**IQL**](https://db-ally.deepsense.ai/concepts/iql) (Intermediate Query Language). Views are grouped and registered within a [**collection**](https://db-ally.deepsense.ai/concepts/views), which then serves as an entry point for asking questions in natural language. This is a basic implementation of a db-ally view for an example HR application, which retrieves candidates from an SQL database: @@ -52,8 +60,10 @@ class CandidateView(SqlAlchemyBaseView): """ return Candidate.country == country -engine = create_engine('sqlite:///examples/recruiting/data/candidates.db') + llm = LiteLLM(model_name="gpt-3.5-turbo") +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") + my_collection = create_collection("collection_name", llm) my_collection.add(CandidateView, lambda: CandidateView(engine)) diff --git a/benchmarks/sql/bench/pipelines/base.py b/benchmarks/sql/bench/pipelines/base.py index 38bcb304..dc8d83ea 100644 --- a/benchmarks/sql/bench/pipelines/base.py +++ b/benchmarks/sql/bench/pipelines/base.py @@ -1,8 +1,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union +from dbally.iql._exceptions import IQLError +from dbally.iql._query import IQLQuery +from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM +from dbally.llms.clients.exceptions import LLMError from dbally.llms.litellm import LiteLLM from dbally.llms.local import LocalLLM @@ -16,6 +20,25 @@ class IQL: source: Optional[str] = None unsupported: bool = False valid: bool = True + generated: bool = True + + @classmethod + def from_query(cls, query: Optional[Union[IQLQuery, Exception]]) -> "IQL": + """ + Creates an IQL object from the query. + + Args: + query: The IQL query or exception. + + Returns: + The IQL object. + """ + return cls( + source=query.source if isinstance(query, (IQLQuery, IQLError)) else None, + unsupported=isinstance(query, UnsupportedQueryError), + valid=not isinstance(query, IQLError), + generated=not isinstance(query, LLMError), + ) @dataclass @@ -47,6 +70,7 @@ class EvaluationResult: """ db_id: str + question_id: str question: str reference: ExecutionResult prediction: ExecutionResult diff --git a/benchmarks/sql/bench/pipelines/collection.py b/benchmarks/sql/bench/pipelines/collection.py index dfc127cf..19831b0d 100644 --- a/benchmarks/sql/bench/pipelines/collection.py +++ b/benchmarks/sql/bench/pipelines/collection.py @@ -5,10 +5,8 @@ import dbally from dbally.collection.collection import Collection from dbally.collection.exceptions import NoViewFoundError -from dbally.iql._exceptions import IQLError -from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.view_selection.llm_view_selector import LLMViewSelector -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from ..views import VIEWS_REGISTRY from .base import IQL, EvaluationPipeline, EvaluationResult, ExecutionResult, IQLResult @@ -74,44 +72,23 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return_natural_response=False, ) except NoViewFoundError: - prediction = ExecutionResult( - view_name=None, - iql=None, - sql=None, - ) - except IQLGenerationError as exc: + prediction = ExecutionResult() + except ViewExecutionError as exc: prediction = ExecutionResult( view_name=exc.view_name, iql=IQLResult( - filters=IQL( - source=exc.filters, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), - aggregation=IQL( - source=exc.aggregation, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), + filters=IQL.from_query(exc.iql.filters), + aggregation=IQL.from_query(exc.iql.aggregation), ), - sql=None, ) else: prediction = ExecutionResult( view_name=result.view_name, iql=IQLResult( - filters=IQL( - source=result.context.get("iql"), - unsupported=False, - valid=True, - ), - aggregation=IQL( - source=None, - unsupported=False, - valid=True, - ), + filters=IQL(source=result.context["iql"]["filters"]), + aggregation=IQL(source=result.context["iql"]["aggregation"]), ), - sql=result.context.get("sql"), + sql=result.context["sql"], ) reference = ExecutionResult( @@ -134,6 +111,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, diff --git a/benchmarks/sql/bench/pipelines/view.py b/benchmarks/sql/bench/pipelines/view.py index d4ae8515..be9d8263 100644 --- a/benchmarks/sql/bench/pipelines/view.py +++ b/benchmarks/sql/bench/pipelines/view.py @@ -5,9 +5,7 @@ from sqlalchemy import create_engine -from dbally.iql._exceptions import IQLError -from dbally.iql_generator.prompt import UnsupportedQueryError -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from dbally.views.freeform.text2sql.view import BaseText2SQLView from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @@ -94,37 +92,20 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: dry_run=True, n_retries=0, ) - except IQLGenerationError as exc: + except ViewExecutionError as exc: prediction = ExecutionResult( view_name=data["view_name"], iql=IQLResult( - filters=IQL( - source=exc.filters, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.filters and not exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), - aggregation=IQL( - source=exc.aggregation, - unsupported=isinstance(exc.__cause__, UnsupportedQueryError), - valid=not (exc.aggregation and isinstance(exc.__cause__, IQLError)), - ), + filters=IQL.from_query(exc.iql.filters), + aggregation=IQL.from_query(exc.iql.aggregation), ), - sql=None, ) else: prediction = ExecutionResult( view_name=data["view_name"], iql=IQLResult( - filters=IQL( - source=result.context["iql"], - unsupported=False, - valid=True, - ), - aggregation=IQL( - source=None, - unsupported=False, - valid=True, - ), + filters=IQL(source=result.context["iql"]["filters"]), + aggregation=IQL(source=result.context["iql"]["aggregation"]), ), sql=result.context["sql"], ) @@ -135,12 +116,10 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: filters=IQL( source=data["iql_filters"], unsupported=data["iql_filters_unsupported"], - valid=True, ), aggregation=IQL( source=data["iql_aggregation"], unsupported=data["iql_aggregation_unsupported"], - valid=True, ), context=data["iql_context"], ), @@ -149,6 +128,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, @@ -209,6 +189,7 @@ async def __call__(self, data: Dict[str, Any]) -> EvaluationResult: return EvaluationResult( db_id=data["db_id"], + question_id=data["question_id"], question=data["question"], reference=reference, prediction=prediction, diff --git a/benchmarks/sql/bench/views/structured/superhero.py b/benchmarks/sql/bench/views/structured/superhero.py index 8a6bc38a..305369bb 100644 --- a/benchmarks/sql/bench/views/structured/superhero.py +++ b/benchmarks/sql/bench/views/structured/superhero.py @@ -552,7 +552,6 @@ class SuperheroView( SqlAlchemyBaseView, SuperheroAggregationMixin, SuperheroFilterMixin, - SuperheroColourAggregationMixin, SuperheroColourFilterMixin, AlignmentAggregationMixin, AlignmentFilterMixin, diff --git a/docs/about/roadmap.md b/docs/about/roadmap.md index a6f5312e..600f6cf2 100644 --- a/docs/about/roadmap.md +++ b/docs/about/roadmap.md @@ -9,7 +9,7 @@ Below you can find a list of planned features and integrations. ## Planned Features -- [ ] **Support analytical queries**: support for exposing operations beyond filtering. +- [x] **Support analytical queries**: support for exposing operations beyond filtering. - [x] **Few-shot prompting configuration**: allow users to configure the few-shot prompting in View definition to improve IQL generation accuracy. - [ ] **Request contextualization**: allow to provide extra context for db-ally runs, such as user asking the question. diff --git a/docs/assets/banner-dark.svg b/docs/assets/banner-dark.svg new file mode 100644 index 00000000..3b580fe3 --- /dev/null +++ b/docs/assets/banner-dark.svg @@ -0,0 +1,266 @@ + + + + + + diff --git a/docs/assets/banner-light.svg b/docs/assets/banner-light.svg new file mode 100644 index 00000000..e0ef1d4d --- /dev/null +++ b/docs/assets/banner-light.svg @@ -0,0 +1,266 @@ + + + + + + diff --git a/docs/assets/favicon.ico b/docs/assets/favicon.ico new file mode 100644 index 00000000..7683ac51 Binary files /dev/null and b/docs/assets/favicon.ico differ diff --git a/docs/assets/guide_dog_lg.png b/docs/assets/guide_dog_lg.png deleted file mode 100644 index dee16c22..00000000 Binary files a/docs/assets/guide_dog_lg.png and /dev/null differ diff --git a/docs/assets/guide_dog_sm.png b/docs/assets/guide_dog_sm.png deleted file mode 100644 index 85f91ee5..00000000 Binary files a/docs/assets/guide_dog_sm.png and /dev/null differ diff --git a/docs/assets/logo.svg b/docs/assets/logo.svg new file mode 100644 index 00000000..5d4381d0 --- /dev/null +++ b/docs/assets/logo.svg @@ -0,0 +1,229 @@ + + + + diff --git a/docs/concepts/iql.md b/docs/concepts/iql.md index c4496aca..5f79c5ec 100644 --- a/docs/concepts/iql.md +++ b/docs/concepts/iql.md @@ -1,12 +1,45 @@ # Concept: IQL -Intermediate Query Language (IQL) is a simple language that serves as an abstraction layer between natural language and data source-specific query syntax, such as SQL. With db-ally's [structured views](./structured_views.md), LLM utilizes IQL to express complex queries in a simplified way. +Intermediate Query Language (IQL) is a simple language that serves as an abstraction layer between natural language and data source-specific query syntax, such as SQL. With db-ally's [structured views](structured_views.md), LLM utilizes IQL to express complex queries in a simplified way. IQL allows developers to model operations such as filtering and aggregation on the underlying data. + +## Filtering For instance, an LLM might generate an IQL query like this when asked "Find me French candidates suitable for a senior data scientist position": +```python +from_country("France") AND senior_data_scientist_position() ``` -from_country('France') AND senior_data_scientist_position() + +The capabilities made available to the AI model via IQL differ between projects. Developers control these by defining special [views](structured_views.md). db-ally automatically exposes special methods defined in structured views, known as "filters", via IQL. For instance, the expression above suggests that the specific project contains a view that includes the `from_country` and `senior_data_scientist_position` methods (and possibly others that the LLM did not choose to use for this particular question). Additionally, the LLM can use boolean operators (`AND`, `OR`, `NOT`) to combine individual filters into more complex expressions. + +## Aggregation + +Similar to filtering, developers can define special methods in [structured views](structured_views.md) that perform aggregation. These methods are also exposed to the LLM via IQL. For example, an LLM might generate the following IQL query when asked "What's the average salary for each country?": + +```python +average_salary_by_country() ``` -The capabilities made available to the AI model via IQL differ between projects. Developers control these by defining special [Views](structured_views.md). db-ally automatically exposes special methods defined in structured views, known as "filters", via IQL. For instance, the expression above suggests that the specific project contains a view that includes the `from_country` and `senior_data_scientist_position` methods (and possibly others that the LLM did not choose to use for this particular question). Additionally, the LLM can use Boolean operators (`and`,`or`, `not`) to combine individual filters into more complex expressions. +The `average_salary_by_country` groups candidates by country and calculates the average salary for each group. + +The aggregation IQL call has access to the raw query, so it can perform even more complex aggregations. Like grouping different columns, or applying a custom functions. We can ask db-ally to generate candidates raport with the following IQL query: + +```python +candidate_report() +``` + +In this case, the `candidate_report` method is defined in a structured view, and it performs a series of aggregations and calculations to produce a report with the average salary, number of candiates, and other metrics, by country. + +## Operation chaining + +Some queries require filtering and aggregation. For example, to calculate the average salary for a data scientist in the US, we first need to filter the data to include only US candidates who are senior specialists, and then calculate the average salary. In this case, db-ally will first generate an IQL query to filter the data, and then another IQL query to calculate the average salary. + +```python +from_country("USA") AND senior_data_scientist_position() +``` + +```python +average_salary() +``` +In this case, db-ally will execute queries sequentially to build a single query plan to execute on the data source. diff --git a/docs/concepts/similarity_indexes.md b/docs/concepts/similarity_indexes.md index 1c0ede7e..70966c61 100644 --- a/docs/concepts/similarity_indexes.md +++ b/docs/concepts/similarity_indexes.md @@ -11,4 +11,4 @@ A similarity index consists of two main components: The concept of "similarity" is deliberately broad, as it varies depending on the store's implementation to best fit the use case. Db-ally not only supports custom store implementations but also includes various built-in implementations, from simple case-insensitive text matches to more complex embedding-based semantic similarities. -See the [Quickstart Part 2: Semantic Similarity](../quickstart/quickstart2.md) for an example of using a similarity index with a semantic similarity store. +See the [Quickstart Part 2: Semantic Similarity](../quickstart/semantic-similarity.md) for an example of using a similarity index with a semantic similarity store. diff --git a/docs/concepts/structured_views.md b/docs/concepts/structured_views.md index db048c8f..aab4970f 100644 --- a/docs/concepts/structured_views.md +++ b/docs/concepts/structured_views.md @@ -7,7 +7,7 @@ Structured views are a type of [view](../concepts/views.md), which provide a way Given different natural language queries, a db-ally view will produce different responses while maintaining a consistent data structure. This consistency offers a reliable interface for integration - the code consuming responses from a particular structured view knows what data structure to expect and can utilize this knowledge when displaying or processing the data. This feature of db-ally makes it stand out in terms of reliability and stability compared to standard text-to-SQL approaches. -Each structured view can contain one or more “filters”, which the LLM may decide to choose and apply to the extracted data so that it meets the criteria specified in the natural language query. Given such a query, LLM chooses which filters to use, provides arguments to the filters, and connects the filters with Boolean operators. The LLM expresses these filter combinations using a special language called [IQL](iql.md), in which the defined view filters provide a layer of abstraction between the LLM and the raw syntax used to query the data source (e.g., SQL). +Each structured view can contain one or more **filters** or **aggregations**, which the LLM may decide to choose and apply to the extracted data so that it meets the criteria specified in the natural language query. Given such a query, LLM chooses which filters to use, provides arguments to the filters, and connects the filters with boolean operators. For aggregations, the LLM selects an appropriate aggregation method and applies it to the data. The LLM expresses these filter combinations and aggregation using a special language called [IQL](iql.md), in which the defined view filters and aggregations provide a layer of abstraction between the LLM and the raw syntax used to query the data source (e.g., SQL). !!! example For instance, this is a simple [view that uses SQLAlchemy](../how-to/views/sql.md) to select data from specific columns in a SQL database. It contains a single filter, that the LLM may optionally use to control which table rows to fetch: @@ -18,14 +18,14 @@ Each structured view can contain one or more “filters”, which the LLM may de A view for retrieving candidates from the database. """ - def get_select(self): + def get_select(self) -> Select: """ Defines which columns to select """ return sqlalchemy.select(Candidate.id, Candidate.name, Candidate.country) @decorators.view_filter() - def from_country(self, country: str): + def from_country(self, country: str) -> ColumnElement: """ Filter candidates from a specific country. """ diff --git a/docs/how-to/update_similarity_indexes.md b/docs/how-to/update_similarity_indexes.md index 763bc1a7..75aaf322 100644 --- a/docs/how-to/update_similarity_indexes.md +++ b/docs/how-to/update_similarity_indexes.md @@ -2,7 +2,7 @@ The Similarity Index is a feature provided by db-ally that takes user input and maps it to the closest matching value in the data source using a chosen similarity metric. This feature is handy when the user input does not exactly match the data source, such as when the user asks to "list all employees in the IT department," while the database categorizes this group as the "computer department." To learn more about Similarity Indexes, refer to the [Concept: Similarity Indexes](../concepts/similarity_indexes.md) page. -While Similarity Indexes can be used directly, they are usually used with [Views](../concepts/views.md), annotating arguments to filter methods. This technique lets db-ally automatically match user-provided arguments to the most similar value in the data source. You can see an example of using similarity indexes with views on the [Quickstart Part 2: Semantic Similarity](../quickstart/quickstart2.md) page. +While Similarity Indexes can be used directly, they are usually used with [Views](../concepts/views.md), annotating arguments to filter methods. This technique lets db-ally automatically match user-provided arguments to the most similar value in the data source. You can see an example of using similarity indexes with views on the [Quickstart Part 2: Semantic Similarity](../quickstart/semantic-similarity.md) page. Similarity Indexes are designed to index all possible values (e.g., on disk or in a different data store). Consequently, when the data source undergoes changes, the Similarity Index must update to reflect these alterations. This guide will explain how to update Similarity Indexes in your code. diff --git a/docs/how-to/use_custom_similarity_fetcher.md b/docs/how-to/use_custom_similarity_fetcher.md index b79b6554..bcdbf0dd 100644 --- a/docs/how-to/use_custom_similarity_fetcher.md +++ b/docs/how-to/use_custom_similarity_fetcher.md @@ -55,7 +55,7 @@ In this example, we used the FaissStore, which utilizes the `faiss` library for ## Using the Similarity Index -You can use the index with a custom fetcher [the same way](../quickstart/quickstart2.md) as you would with a built-in fetcher. The similarity index will map user input to the closest matching value from your data source, allowing you to deliver more precise responses to user queries. Remember to frequently update the similarity index with new values from your data source to maintain its relevance. You can accomplish this by calling the `update` method on the similarity index. +You can use the index with a custom fetcher [the same way](../quickstart/semantic-similarity.md) as you would with a built-in fetcher. The similarity index will map user input to the closest matching value from your data source, allowing you to deliver more precise responses to user queries. Remember to frequently update the similarity index with new values from your data source to maintain its relevance. You can accomplish this by calling the `update` method on the similarity index. ```python await breeds_similarity.update() @@ -72,4 +72,4 @@ print(await breeds_similarity.similar("bagle")) This will return the most similar dog breed to "bagle" based on the data retrieved from the dog.ceo API - in this case, "beagle". -In general, instead of directly calling the similarity index, you would usually use it to annotate arguments to views, as demonstrated in the [Quickstart guide](../quickstart/quickstart2.md). \ No newline at end of file +In general, instead of directly calling the similarity index, you would usually use it to annotate arguments to views, as demonstrated in the [Quickstart guide](../quickstart/semantic-similarity.md). \ No newline at end of file diff --git a/docs/how-to/use_custom_similarity_store.md b/docs/how-to/use_custom_similarity_store.md index d98b0219..7af70cb1 100644 --- a/docs/how-to/use_custom_similarity_store.md +++ b/docs/how-to/use_custom_similarity_store.md @@ -56,11 +56,11 @@ country_similarity = SimilarityIndex( ) ``` -In this example, we used the sample `DogBreedsFetcher` fetcher detailed in the [custom fetcher guide](./use_custom_similarity_fetcher.md) and the `PickleStore` to store the values in a Python pickle file. You can use a different fetcher depending on your needs, for example [the Sqlalchemy one described in the Quickstart guide](../quickstart/quickstart2.md)). +In this example, we used the sample `DogBreedsFetcher` fetcher detailed in the [custom fetcher guide](./use_custom_similarity_fetcher.md) and the `PickleStore` to store the values in a Python pickle file. You can use a different fetcher depending on your needs, for example [the Sqlalchemy one described in the Quickstart guide](../quickstart/semantic-similarity.md)). ## Using the Similarity Index -You can use an index with a custom store [the same way](../quickstart/quickstart2.md) you would use one with a built-in store. The similarity index will map user input to the closest matching value from your data source, enabling you to deliver more accurate responses. It's important to regularly update the similarity index with new values from your data source to keep it current. Do this by invoking the `update` method on the similarity index. +You can use an index with a custom store [the same way](../quickstart/semantic-similarity.md) you would use one with a built-in store. The similarity index will map user input to the closest matching value from your data source, enabling you to deliver more accurate responses. It's important to regularly update the similarity index with new values from your data source to keep it current. Do this by invoking the `update` method on the similarity index. ```python await country_similarity.update() @@ -77,4 +77,4 @@ print(await country_similarity.similar("bagle")) This will return the closest matching dog breed to "bagle" - in this case, "beagle". -Typically, instead of directly invoking the similarity index, you would employ it to annotate arguments to views, as demonstrated in the [Quickstart guide](../quickstart/quickstart2.md). \ No newline at end of file +Typically, instead of directly invoking the similarity index, you would employ it to annotate arguments to views, as demonstrated in the [Quickstart guide](../quickstart/semantic-similarity.md). \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 1b5c6460..64471909 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,6 +3,119 @@ hide: - navigation --- +# db-ally docs + + + +
+ ![dbally logo](https://raw.githubusercontent.com/deepsense-ai/db-ally/main/docs/assets/banner-light.svg#only-light){ width="30%" } + ![dbally logo](https://raw.githubusercontent.com/deepsense-ai/db-ally/main/docs/assets/banner-dark.svg#only-dark){ width="30%" } +
+ +

+ Efficient, consistent and secure library for querying structured data with natural language +

+ +
+ + + PyPI - License + + + + PyPI - Version + + + + PyPI - Python Version + + +
+ --- ---8<-- "README.md" +db-ally is an LLM-powered library for creating natural language interfaces to data sources. While it occupies a similar space to the text-to-SQL solutions, its goals and methods are different. db-ally allows developers to outline specific use cases for the LLM to handle, detailing the desired data format and the possible operations to fetch this data. + +db-ally effectively shields the complexity of the underlying data source from the model, presenting only the essential information needed for solving the specific use cases. Instead of generating arbitrary SQL, the model is asked to generate responses in a simplified query language. + +The benefits of db-ally can be described in terms of its four main characteristics: + +* **Consistency**: db-ally ensures predictable output formats and confines operations to those predefined by developers, making it particularly well-suited for applications with precise requirements on their behavior or data format +* **Security**: db-ally prevents direct database access and arbitrary SQL execution, bolstering system safety +* **Efficiency**: db-ally hides most of the underlying database complexity, enabling the LLM to concentrate on essential aspects and improving performance +* **Portability**: db-ally introduces an abstraction layer between the model and the data, ensuring easy integration with various database technologies and other data sources. + +## Quickstart + +In db-ally, developers define their use cases by implementing [**views**](https://db-ally.deepsense.ai/concepts/views), **filters** and **aggregations**. A list of possible filters and aggregations is presented to the LLM in terms of [**IQL**](https://db-ally.deepsense.ai/concepts/iql) (Intermediate Query Language). Views are grouped and registered within a [**collection**](https://db-ally.deepsense.ai/concepts/views), which then serves as an entry point for asking questions in natural language. + +This is a basic implementation of a db-ally view for an example HR application, which retrieves candidates from an SQL database: + +```python +from dbally import decorators, SqlAlchemyBaseView, create_collection +from dbally.llms.litellm import LiteLLM +from sqlalchemy import create_engine + +class CandidateView(SqlAlchemyBaseView): + """ + A view for retrieving candidates from the database. + """ + + def get_select(self): + """ + Defines which columns to select. + """ + return sqlalchemy.select(Candidate.id, Candidate.name, Candidate.country) + + @decorators.view_filter() + def from_country(self, country: str): + """ + Filter candidates from a specific country. + """ + return Candidate.country == country + + +llm = LiteLLM(model_name="gpt-3.5-turbo") +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") + +my_collection = create_collection("collection_name", llm) +my_collection.add(CandidateView, lambda: CandidateView(engine)) + +my_collection.ask("Find candidates from United States") +``` + +For a concrete step-by-step example on how to use db-ally, go to [Quickstart](https://db-ally.deepsense.ai/quickstart/) guide. For a more learning-oriented experience, check our db-ally [Tutorial](https://db-ally.deepsense.ai/tutorials/). + +## Motivation + +db-ally was originally developed at [deepsense.ai](https://deepsense.ai). In our work on various projects, we frequently encountered the need to retrieve data from data sources, typically databases, in response to natural language queries. + +The standard approach to this issue involves using the text-to-SQL technique. While this method is powerful, it is also complex and challenging to control. Often, the results were unsatisfactory because the Language Model lacked the necessary context to understand the specific requirements of the application and the business logic behind the data. + +This led us to experiment with a more structured approach. In this method, the developer defines the specific use cases for the Language Model to handle, detailing the desired data format and the possible operations to retrieve this data. This approach proved to be more efficient, predictable, and easier to manage, making it simpler to integrate with the rest of the system. + +Eventually, we decided to create a library that would allow us to use this approach in a more systematic way, and we made it open-source for the community. + +## Installation + +To install db-ally, execute the following command: + +```bash +pip install dbally +``` + +Additionally, you can install one of our extensions to use specific features. + +* `dbally[litellm]`: Use [100+ LLMs](https://docs.litellm.ai/docs/providers) +* `dbally[faiss]`: Use [Faiss](https://github.com/facebookresearch/faiss) indexes for similarity search +* `dbally[langsmith]`: Use [LangSmith](https://www.langchain.com/langsmith) for query tracking + +```bash +pip install dbally[litellm,faiss,langsmith] +``` + +## License + +db-ally is released under MIT license. diff --git a/docs/quickstart/aggregations.md b/docs/quickstart/aggregations.md new file mode 100644 index 00000000..951543fb --- /dev/null +++ b/docs/quickstart/aggregations.md @@ -0,0 +1,93 @@ +# Quickstart: Aggregations + +This guide is a continuation of the [Intro](./intro.md) guide. It assumes that you have already set up the views and the collection. If not, please refer to the complete Part 1 code on [GitHub](https://github.com/deepsense-ai/db-ally/blob/main/examples/intro.py){:target="_blank"}. + +In this guide, we will add aggregations to our view to calculate general metrics about the candidates. + +## View Definition + +To add aggregations to our [structured view](../concepts/structured_views.md), we'll define new methods. These methods will allow the LLM model to perform calculations and summarize data across multiple rows. Let's add three aggregation methods to our `CandidateView`: + +```python +class CandidateView(SqlAlchemyBaseView): + """ + A view for retrieving candidates from the database. + """ + + def get_select(self) -> sqlalchemy.Select: + """ + Creates the initial SqlAlchemy select object, which will be used to build the query. + """ + return sqlalchemy.select(Candidate) + + @decorators.view_aggregation() + def average_years_of_experience(self) -> sqlalchemy.Select: + """ + Calculates the average years of experience of candidates. + """ + return self.select.with_only_columns( + sqlalchemy.func.avg(Candidate.years_of_experience).label("average_years_of_experience") + ) + + @decorators.view_aggregation() + def positions_per_country(self) -> sqlalchemy.Select: + """ + Returns the number of candidates per position per country. + """ + return ( + self.select.with_only_columns( + sqlalchemy.func.count(Candidate.position).label("number_of_positions"), + Candidate.position, + Candidate.country, + ) + .group_by(Candidate.position, Candidate.country) + .order_by(sqlalchemy.desc("number_of_positions")) + ) + + @decorators.view_aggregation() + def candidates_per_country(self) -> sqlalchemy.Select: + """ + Returns the number of candidates per country. + """ + return ( + self.select.with_only_columns( + sqlalchemy.func.count(Candidate.id).label("number_of_candidates"), + Candidate.country, + ) + .group_by(Candidate.country) + ) +``` + +By setting up these aggregations, you enable the LLM to calculate metrics about the average years of experience, the number of candidates per position per country, and the top universities based on the number of candidates. + +## Query Execution + +Having already defined and registered the view with the collection, we can now execute the query: + +```python +result = await collection.ask("What is the average years of experience of candidates?") +print(result.results) +``` + +This will return the average years of experience of candidates. + +
+ The expected output +``` +The generated SQL query is: SELECT avg(candidates.years_of_experience) AS average_years_of_experience +FROM candidates + +Number of rows: 1 +{'average_years_of_experience': 4.98} +``` +
+ +Feel free to try other questions like: "What's the distribution of candidates across different positions and countries?" or "How many candidates are from China?". + +## Full Example + +Access the full example on [GitHub](https://github.com/deepsense-ai/db-ally/blob/main/examples/aggregations.py){:target="_blank"}. + +## Next Steps + +Explore [Quickstart Part 3: Semantic Similarity](./semantic-similarity.md) to expand on the example and learn about using semantic similarity. diff --git a/docs/quickstart/index.md b/docs/quickstart/index.md index 764f351b..cd856274 100644 --- a/docs/quickstart/index.md +++ b/docs/quickstart/index.md @@ -52,7 +52,7 @@ Candidate = Base.classes.candidates ## View Definition -To use db-ally, define the views you want to use. A [structured view](../concepts/structured_views.md) is a class that specifies what to select from the database and includes methods that the AI model can use to filter rows. These methods are known as "filters". +To use db-ally, define the views you want to use. A [structured view](../concepts/structured_views.md) is a class that specifies what to select from the database and includes methods that the AI model can use to filter rows. These methods are known as **filters**. ```python from dbally import decorators, SqlAlchemyBaseView @@ -97,7 +97,7 @@ class CandidateView(SqlAlchemyBaseView): By setting up these filters, you enable the LLM to fetch candidates while optionally applying filters based on experience, country, and eligibility for a senior data scientist position. !!! note - The `from_country` filter defined above supports only exact matches, which is not always ideal. Thankfully, db-ally comes with a solution for this problem - Similarity Indexes, which can be used to find the most similar value from the ones available. Refer to [Quickstart Part 2: Semantic Similarity](./quickstart2.md) for an example of using semantic similarity when filtering candidates by country. + The `from_country` filter defined above supports only exact matches, which is not always ideal. Thankfully, db-ally comes with a solution for this problem - Similarity Indexes, which can be used to find the most similar value from the ones available. Refer to [Quickstart Part 2: Semantic Similarity](./semantic-similarity.md) for an example of using semantic similarity when filtering candidates by country. ## OpenAI Access Configuration @@ -170,8 +170,8 @@ Retrieved 1 candidates: ## Full Example -Access the full example here: [quickstart_code.py](quickstart_code.py) +Access the full example on [GitHub](https://github.com/deepsense-ai/db-ally/blob/main/examples/intro.py){:target="_blank"}. ## Next Steps -Explore [Quickstart Part 2: Semantic Similarity](./quickstart2.md) to expand on the example and learn about using semantic similarity. \ No newline at end of file +Explore [Quickstart Part 2: Semantic Similarity](./semantic-similarity.md) to expand on the example and learn about using semantic similarity. diff --git a/docs/quickstart/quickstart3.md b/docs/quickstart/multiple-views.md similarity index 90% rename from docs/quickstart/quickstart3.md rename to docs/quickstart/multiple-views.md index 8a2789d2..33cdc2d1 100644 --- a/docs/quickstart/quickstart3.md +++ b/docs/quickstart/multiple-views.md @@ -1,6 +1,6 @@ # Quickstart: Multiple Views -This guide continues from [Semantic Similarity](./quickstart2.md) guide. It assumes that you have already set up the views and the collection. If not, please refer to the complete Part 2 code here: [quickstart2_code.py](quickstart2_code.py). +This guide continues from [Semantic Similarity](./semantic-similarity.md) guide. It assumes that you have already set up the views and the collection. If not, please refer to the complete Part 3 code on [GitHub](https://github.com/deepsense-ai/db-ally/blob/main/examples/semantic_similarity.py){:target="_blank"}. The guide illustrates how to use multiple views to handle queries requiring different types of data. `CandidateView` and `JobView` are used as examples. @@ -28,6 +28,7 @@ jobs_data = pd.DataFrame.from_records([ {"title": "Machine Learning Engineer", "company": "Company C", "location": "Berlin", "salary": 90000}, {"title": "Data Scientist", "company": "Company D", "location": "London", "salary": 110000}, {"title": "Data Scientist", "company": "Company E", "location": "Warsaw", "salary": 80000}, + {"title": "Data Scientist", "company": "Company F", "location": "Warsaw", "salary": 100000}, ]) ``` @@ -124,7 +125,7 @@ Julia Nowak - Adobe XD;Sketch;Figma Anna Kowalska - AWS;Azure;Google Cloud ``` -That wraps it up! You can find the full example code here: [quickstart3_code.py](quickstart3_code.py). +That wraps it up! You can find the full example code on [GitHub](https://github.com/deepsense-ai/db-ally/blob/main/examples/multiple-views.py){:target="_blank"}. ## Next Steps Visit the [Tutorial](../tutorials.md) for a more comprehensive guide on how to use db-ally. \ No newline at end of file diff --git a/docs/quickstart/quickstart2.md b/docs/quickstart/semantic-similarity.md similarity index 91% rename from docs/quickstart/quickstart2.md rename to docs/quickstart/semantic-similarity.md index d3bf968a..0baf0446 100644 --- a/docs/quickstart/quickstart2.md +++ b/docs/quickstart/semantic-similarity.md @@ -1,6 +1,6 @@ # Quickstart: Semantic Similarity -This guide is a continuation of the [Intro](./index.md) guide. It assumes that you have already set up the views and the collection. If not, please refer to the complete Part 1 code here: [quickstart_code.py](quickstart_code.py). +This guide is a continuation of the [Aggregations](./aggregations.md) guide. It assumes that you have already set up the views and the collection. If not, please refer to the complete Part 2 code on [GitHub](https://github.com/deepsense-ai/db-ally/blob/main/examples/aggregations.py){:target="_blank"}. This guide will demonstrate how to use semantic similarity to handle queries in which the filter values are similar to those in the database, without requiring an exact match. We will use filtering by country as an example. @@ -146,8 +146,8 @@ Retrieved 1 candidates: That's it! You can apply similar techniques to any other filter that takes a string value. -To see the full example, you can find the code here: [quickstart2_code.py](quickstart2_code.py). +To see the full example, you can find the code on [GitHub](https://github.com/deepsense-ai/db-ally/blob/main/examples/semantic_similarity.py){:target="_blank"}. ## Next Steps -Explore [Quickstart Part 3: Multiple Views](./quickstart3.md) to learn how to run queries with multiple views and display the results based on the view that was used to fetch the data. +Explore [Quickstart Part 4: Multiple Views](./multiple-views.md) to learn how to run queries with multiple views and display the results. diff --git a/docs/reference/similarity/index.md b/docs/reference/similarity/index.md index aca1f119..3f0ae26f 100644 --- a/docs/reference/similarity/index.md +++ b/docs/reference/similarity/index.md @@ -10,6 +10,6 @@ Explore [Similarity Stores](./similarity_store/index.md) and [Similarity Fetcher * [How-To: Use Similarity Indexes with Data from Custom Sources](../../how-to/use_custom_similarity_fetcher.md) * [How-To: Store Similarity Index in a Custom Store](../../how-to/use_custom_similarity_store.md) * [How-To: Update Similarity Indexes](../../how-to/update_similarity_indexes.md) - * [Quickstart: Semantic Similarity](../../quickstart/quickstart2.md) + * [Quickstart: Semantic Similarity](../../quickstart/semantic-similarity.md) ::: dbally.similarity.SimilarityIndex diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css index 698837d0..01232bf2 100644 --- a/docs/stylesheets/extra.css +++ b/docs/stylesheets/extra.css @@ -2,12 +2,6 @@ --md-primary-fg-color: #00b0e0; } -.md-header__button.md-logo { - margin: 0; - padding: 0; +.md-header__title { + margin-left: 0.5rem !important; } - -.md-header__button.md-logo img, .md-header__button.md-logo svg { - height: 1.8rem; - width: 1.8rem; -} \ No newline at end of file diff --git a/examples/aggregations.py b/examples/aggregations.py new file mode 100644 index 00000000..14d71127 --- /dev/null +++ b/examples/aggregations.py @@ -0,0 +1,107 @@ +# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring, duplicate-code + +import asyncio + +import sqlalchemy +from sqlalchemy import create_engine +from sqlalchemy.ext.automap import automap_base + +import dbally +from dbally import SqlAlchemyBaseView, decorators +from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler +from dbally.llms.litellm import LiteLLM + +engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") + +Base = automap_base() +Base.prepare(autoload_with=engine) + +Candidate = Base.classes.candidates + + +class CandidateView(SqlAlchemyBaseView): + """ + A view for retrieving candidates from the database. + """ + + def get_select(self) -> sqlalchemy.Select: + """ + Creates the initial SqlAlchemy select object, which will be used to build the query. + """ + return sqlalchemy.select(Candidate) + + @decorators.view_filter() + def at_least_experience(self, years: int) -> sqlalchemy.ColumnElement: + """ + Filters candidates with at least `years` of experience. + """ + return Candidate.years_of_experience >= years + + @decorators.view_filter() + def senior_data_scientist_position(self) -> sqlalchemy.ColumnElement: + """ + Filters candidates that can be considered for a senior data scientist position. + """ + return sqlalchemy.and_( + Candidate.position.in_(["Data Scientist", "Machine Learning Engineer", "Data Engineer"]), + Candidate.years_of_experience >= 3, + ) + + @decorators.view_filter() + def from_country(self, country: str) -> sqlalchemy.ColumnElement: + """ + Filters candidates from a specific country. + """ + return Candidate.country == country + + @decorators.view_aggregation() + def average_years_of_experience(self) -> sqlalchemy.Select: + """ + Calculates the average years of experience of candidates. + """ + return self.select.with_only_columns( + sqlalchemy.func.avg(Candidate.years_of_experience).label("average_years_of_experience") + ) + + @decorators.view_aggregation() + def positions_per_country(self) -> sqlalchemy.Select: + """ + Returns the number of candidates per position per country. + """ + return ( + self.select.with_only_columns( + sqlalchemy.func.count(Candidate.position).label("number_of_positions"), + Candidate.position, + Candidate.country, + ) + .group_by(Candidate.position, Candidate.country) + .order_by(sqlalchemy.desc("number_of_positions")) + ) + + @decorators.view_aggregation() + def candidates_per_country(self) -> sqlalchemy.Select: + """ + Returns the number of candidates per country. + """ + return self.select.with_only_columns( + sqlalchemy.func.count(Candidate.id).label("number_of_candidates"), + Candidate.country, + ).group_by(Candidate.country) + + +async def main() -> None: + llm = LiteLLM(model_name="gpt-3.5-turbo") + dbally.event_handlers = [CLIEventHandler()] + + collection = dbally.create_collection("recruitment", llm) + collection.add(CandidateView, lambda: CandidateView(engine)) + + result = await collection.ask("What is the average years of experience of candidates?") + + print(f"The generated SQL query is: {result.context.get('sql')}") + for row in result.results: + print(row) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/quickstart/quickstart_code.py b/examples/intro.py similarity index 96% rename from docs/quickstart/quickstart_code.py rename to examples/intro.py index ef73cad0..7bce4e46 100644 --- a/docs/quickstart/quickstart_code.py +++ b/examples/intro.py @@ -1,16 +1,16 @@ -# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring -import dbally +# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring, duplicate-code + import asyncio import sqlalchemy from sqlalchemy import create_engine from sqlalchemy.ext.automap import automap_base -from dbally import decorators, SqlAlchemyBaseView +import dbally +from dbally import SqlAlchemyBaseView, decorators from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.llms.litellm import LiteLLM - engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") Base = automap_base() diff --git a/docs/quickstart/quickstart3_code.py b/examples/multiple_views.py similarity index 61% rename from docs/quickstart/quickstart3_code.py rename to examples/multiple_views.py index 0ad8b1a7..a8b9423d 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/examples/multiple_views.py @@ -1,19 +1,21 @@ -# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring -import os +# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring, duplicate-code + import asyncio -from typing_extensions import Annotated +import os +import pandas as pd import sqlalchemy from sqlalchemy import create_engine from sqlalchemy.ext.automap import automap_base -import pandas as pd +from typing_extensions import Annotated import dbally -from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult +from dbally import DataFrameBaseView, ExecutionResult, SqlAlchemyBaseView, decorators from dbally.audit import CLIEventHandler -from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.llms.litellm import LiteLLM +from dbally.similarity import FaissStore, SimilarityIndex, SimpleSqlAlchemyFetcher +from dbally.views.pandas_base import Aggregation, AggregationGroup engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") @@ -75,6 +77,45 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem """ return Candidate.country == country + @decorators.view_aggregation() + def average_years_of_experience(self) -> sqlalchemy.Select: + """ + Calculates the average years of experience of candidates. + """ + return self.select.with_only_columns( + sqlalchemy.func.avg(Candidate.years_of_experience).label("average_years_of_experience") + ) + + @decorators.view_aggregation() + def positions_per_country(self) -> sqlalchemy.Select: + """ + Returns the number of candidates per position per country. + """ + return ( + self.select.with_only_columns( + sqlalchemy.func.count(Candidate.position).label("number_of_candidates"), + Candidate.position, + Candidate.country, + ) + .group_by(Candidate.position, Candidate.country) + .order_by(sqlalchemy.desc("number_of_candidates")) + ) + + @decorators.view_aggregation() + def top_universities(self, limit: int) -> sqlalchemy.Select: + """ + Returns the top universities by the number of candidates. + """ + return ( + self.select.with_only_columns( + sqlalchemy.func.count(Candidate.id).label("number_of_candidates"), + Candidate.university, + ) + .group_by(Candidate.university) + .order_by(sqlalchemy.desc("number_of_candidates")) + .limit(limit) + ) + jobs_data = pd.DataFrame.from_records( [ @@ -83,6 +124,7 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem {"title": "Machine Learning Engineer", "company": "Company C", "location": "Berlin", "salary": 90000}, {"title": "Data Scientist", "company": "Company D", "location": "London", "salary": 110000}, {"title": "Data Scientist", "company": "Company E", "location": "Warsaw", "salary": 80000}, + {"title": "Data Scientist", "company": "Company F", "location": "Warsaw", "salary": 100000}, ] ) @@ -113,6 +155,46 @@ def from_company(self, company: str) -> pd.Series: """ return self.df.company == company + @decorators.view_aggregation() + def average_salary(self) -> AggregationGroup: + """ + Calculates the average salary of job offers. + """ + return AggregationGroup( + aggregations=[ + Aggregation(column="salary", function="mean"), + ], + ) + + @decorators.view_aggregation() + def average_salary_per_location(self) -> AggregationGroup: + """ + Calculates the average salary of job offers per location and title. + """ + return AggregationGroup( + aggregations=[ + Aggregation(column="salary", function="mean"), + ], + groupbys=[ + "location", + "title", + ], + ) + + @decorators.view_aggregation() + def count_per_title(self) -> AggregationGroup: + """ + Counts the number of job offers per title. + """ + return AggregationGroup( + aggregations=[ + Aggregation(column="title", function="count"), + ], + groupbys=[ + "title", + ], + ) + def display_results(result: ExecutionResult): if result.view_name == "CandidateView": diff --git a/docs/quickstart/quickstart2_code.py b/examples/semantic_similarity.py similarity index 65% rename from docs/quickstart/quickstart2_code.py rename to examples/semantic_similarity.py index d1504cd8..098f167a 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/examples/semantic_similarity.py @@ -1,19 +1,20 @@ -# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring -import os +# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring, duplicate-code + import asyncio -from typing_extensions import Annotated +import os -from dotenv import load_dotenv import sqlalchemy +from dotenv import load_dotenv from sqlalchemy import create_engine from sqlalchemy.ext.automap import automap_base +from typing_extensions import Annotated import dbally -from dbally import decorators, SqlAlchemyBaseView +from dbally import SqlAlchemyBaseView, decorators from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler -from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.llms.litellm import LiteLLM +from dbally.similarity import FaissStore, SimilarityIndex, SimpleSqlAlchemyFetcher load_dotenv() engine = create_engine("sqlite:///examples/recruiting/data/candidates.db") @@ -75,6 +76,45 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem """ return Candidate.country == country + @decorators.view_aggregation() + def average_years_of_experience(self) -> sqlalchemy.Select: + """ + Calculates the average years of experience of candidates. + """ + return self.select.with_only_columns( + sqlalchemy.func.avg(Candidate.years_of_experience).label("average_years_of_experience") + ) + + @decorators.view_aggregation() + def positions_per_country(self) -> sqlalchemy.Select: + """ + Returns the number of candidates per position per country. + """ + return ( + self.select.with_only_columns( + sqlalchemy.func.count(Candidate.position).label("number_of_candidates"), + Candidate.position, + Candidate.country, + ) + .group_by(Candidate.position, Candidate.country) + .order_by(sqlalchemy.desc("number_of_candidates")) + ) + + @decorators.view_aggregation() + def top_universities(self, limit: int) -> sqlalchemy.Select: + """ + Returns the top universities by the number of candidates. + """ + return ( + self.select.with_only_columns( + sqlalchemy.func.count(Candidate.id).label("number_of_candidates"), + Candidate.university, + ) + .group_by(Candidate.university) + .order_by(sqlalchemy.desc("number_of_candidates")) + .limit(limit) + ) + async def main(): dbally.event_handlers = [CLIEventHandler()] diff --git a/mkdocs.yml b/mkdocs.yml index 0129b1ff..f92b8932 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,12 +1,16 @@ -site_name: db-ally +site_name: db-ally docs +site_description: Efficient, consistent and secure library for querying structured data with natural language +site_url: https://db-ally.deepsense.ai +repo_name: deepsense-ai/db-ally repo_url: https://github.com/deepsense-ai/db-ally +copyright: Copyright © 2024 deepsense.ai nav: - - Home: index.md + - db-ally docs: index.md - Quickstart: - quickstart/index.md - - quickstart/quickstart2.md - - quickstart/quickstart3.md - - Tutorials: tutorials.md + - quickstart/aggregations.md + - quickstart/semantic-similarity.md + - quickstart/multiple-views.md - Concepts: - concepts/views.md - concepts/structured_views.md @@ -37,6 +41,7 @@ nav: - how-to/trace_runs_with_otel.md - how-to/create_custom_event_handler.md - how-to/openai_assistants_integration.md + - Tutorials: tutorials.md - API Reference: - reference/index.md - reference/collection.md @@ -86,26 +91,31 @@ nav: - about/contact.md theme: name: material - logo: assets/guide_dog_lg.png - favicon: assets/guide_dog_sm.png + logo: assets/logo.svg + favicon: assets/favicon.ico icon: repo: fontawesome/brands/github palette: - - media: "(prefers-color-scheme: light)" + - media: "(prefers-color-scheme)" + toggle: + icon: material/lightbulb-auto + name: Switch to light mode + - media: '(prefers-color-scheme: light)' scheme: default primary: primary toggle: - icon: material/brightness-7 + icon: material/lightbulb name: Switch to dark mode - - media: "(prefers-color-scheme: dark)" + - media: '(prefers-color-scheme: dark)' scheme: slate primary: primary toggle: - icon: material/brightness-4 - name: Switch to light mode + icon: material/lightbulb-outline + name: Switch to system preference features: - navigation.footer - navigation.tabs + - navigation.tabs.sticky - navigation.top - content.code.annotate - content.code.copy @@ -125,6 +135,7 @@ markdown_extensions: - pymdownx.snippets - pymdownx.inlinehilite - attr_list + - md_in_html - pymdownx.details - def_list - pymdownx.tasklist: @@ -161,3 +172,29 @@ extra: analytics: provider: google property: G-FBBJRN0H0G + feedback: + title: Was this page helpful? + ratings: + - icon: material/emoticon-happy-outline + name: This page was helpful + data: 1 + note: >- + Thanks for your feedback! + - icon: material/emoticon-sad-outline + name: This page could be improved + data: 0 + note: >- + Thanks for your feedback! + social: + - icon: fontawesome/brands/github + link: https://github.com/deepsense-ai + - icon: fontawesome/brands/x-twitter + link: https://x.com/deepsense_ai + - icon: fontawesome/brands/linkedin + link: https://linkedin.com/company/deepsense-ai + - icon: fontawesome/brands/youtube + link: https://youtube.com/@deepsenseai + - icon: fontawesome/brands/medium + link: https://medium.com/deepsense-ai + - icon: fontawesome/solid/globe + link: https://deepsense.ai diff --git a/src/dbally/iql/__init__.py b/src/dbally/iql/__init__.py index 0df0a766..20bde9eb 100644 --- a/src/dbally/iql/__init__.py +++ b/src/dbally/iql/__init__.py @@ -1,5 +1,13 @@ from . import syntax from ._exceptions import IQLArgumentParsingError, IQLError, IQLUnsupportedSyntaxError -from ._query import IQLQuery +from ._query import IQLAggregationQuery, IQLFiltersQuery, IQLQuery -__all__ = ["IQLQuery", "syntax", "IQLError", "IQLArgumentParsingError", "IQLUnsupportedSyntaxError"] +__all__ = [ + "IQLQuery", + "IQLFiltersQuery", + "IQLAggregationQuery", + "syntax", + "IQLError", + "IQLArgumentParsingError", + "IQLUnsupportedSyntaxError", +] diff --git a/src/dbally/iql/_processor.py b/src/dbally/iql/_processor.py index f1adf64c..1bd72bcc 100644 --- a/src/dbally/iql/_processor.py +++ b/src/dbally/iql/_processor.py @@ -1,5 +1,6 @@ import ast -from typing import TYPE_CHECKING, Any, List, Optional, Union +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker from dbally.iql import syntax @@ -19,10 +20,12 @@ if TYPE_CHECKING: from dbally.views.structured import ExposedFunction +RootT = TypeVar("RootT", bound=syntax.Node) -class IQLProcessor: + +class IQLProcessor(Generic[RootT], ABC): """ - Parses IQL string to tree structure. + Base class for IQL processors. """ def __init__( @@ -32,9 +35,9 @@ def __init__( self.allowed_functions = {func.name: func for func in allowed_functions} self._event_tracker = event_tracker or EventTracker() - async def process(self) -> syntax.Node: + async def process(self) -> RootT: """ - Process IQL string to root IQL.Node. + Process IQL string to IQL root node. Returns: IQL node which is root of the tree representing IQL query. @@ -60,25 +63,17 @@ async def process(self) -> syntax.Node: return await self._parse_node(ast_tree.body[0].value) - async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node: - if isinstance(node, ast.BoolOp): - return await self._parse_bool_op(node) - if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): - return syntax.Not(await self._parse_node(node.operand)) - if isinstance(node, ast.Call): - return await self._parse_call(node) - - raise IQLUnsupportedSyntaxError(node, self.source) + @abstractmethod + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> RootT: + """ + Parses AST node to IQL node. - async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp: - if isinstance(node.op, ast.Not): - return syntax.Not(await self._parse_node(node.values[0])) - if isinstance(node.op, ast.And): - return syntax.And([await self._parse_node(x) for x in node.values]) - if isinstance(node.op, ast.Or): - return syntax.Or([await self._parse_node(x) for x in node.values]) + Args: + node: AST node to parse. - raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp") + Returns: + IQL node. + """ async def _parse_call(self, node: ast.Call) -> syntax.FunctionCall: func = node.func @@ -153,3 +148,41 @@ def _to_lower_except_in_quotes(text: str, keywords: List[str]) -> str: converted_text = converted_text[: len(converted_text) - len(keyword)] + keyword.lower() return converted_text + + +class IQLFiltersProcessor(IQLProcessor[syntax.Node]): + """ + IQL processor for filters. + """ + + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.Node: + if isinstance(node, ast.BoolOp): + return await self._parse_bool_op(node) + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): + return syntax.Not(await self._parse_node(node.operand)) + if isinstance(node, ast.Call): + return await self._parse_call(node) + + raise IQLUnsupportedSyntaxError(node, self.source) + + async def _parse_bool_op(self, node: ast.BoolOp) -> syntax.BoolOp: + if isinstance(node.op, ast.Not): + return syntax.Not(await self._parse_node(node.values[0])) + if isinstance(node.op, ast.And): + return syntax.And([await self._parse_node(x) for x in node.values]) + if isinstance(node.op, ast.Or): + return syntax.Or([await self._parse_node(x) for x in node.values]) + + raise IQLUnsupportedSyntaxError(node, self.source, context="BoolOp") + + +class IQLAggregationProcessor(IQLProcessor[syntax.FunctionCall]): + """ + IQL processor for aggregation. + """ + + async def _parse_node(self, node: Union[ast.expr, ast.Expr]) -> syntax.FunctionCall: + if isinstance(node, ast.Call): + return await self._parse_call(node) + + raise IQLUnsupportedSyntaxError(node, self.source) diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py index dd831a91..57b3b4ed 100644 --- a/src/dbally/iql/_query.py +++ b/src/dbally/iql/_query.py @@ -1,26 +1,29 @@ -from typing import TYPE_CHECKING, List, Optional +from abc import ABC +from typing import TYPE_CHECKING, Generic, List, Optional, Type from ..audit.event_tracker import EventTracker from . import syntax -from ._processor import IQLProcessor +from ._processor import IQLAggregationProcessor, IQLFiltersProcessor, IQLProcessor, RootT if TYPE_CHECKING: from dbally.views.structured import ExposedFunction -class IQLQuery: +class IQLQuery(Generic[RootT], ABC): """ IQLQuery container. It stores IQL as a syntax tree defined in `IQL` class. """ - root: syntax.Node + root: RootT + source: str + _processor: Type[IQLProcessor[RootT]] - def __init__(self, root: syntax.Node, source: str) -> None: + def __init__(self, root: RootT, source: str) -> None: self.root = root - self._source = source + self.source = source def __str__(self) -> str: - return self._source + return self.source @classmethod async def parse( @@ -28,7 +31,7 @@ async def parse( source: str, allowed_functions: List["ExposedFunction"], event_tracker: Optional[EventTracker] = None, - ) -> "IQLQuery": + ) -> "IQLQuery[RootT]": """ Parse IQL string to IQLQuery object. @@ -43,5 +46,21 @@ async def parse( Raises: IQLError: If parsing fails. """ - root = await IQLProcessor(source, allowed_functions, event_tracker=event_tracker).process() + root = await cls._processor(source, allowed_functions, event_tracker=event_tracker).process() return cls(root=root, source=source) + + +class IQLFiltersQuery(IQLQuery[syntax.Node]): + """ + IQL filters query container. + """ + + _processor: Type[IQLFiltersProcessor] = IQLFiltersProcessor + + +class IQLAggregationQuery(IQLQuery[syntax.FunctionCall]): + """ + IQL aggregation query container. + """ + + _processor: Type[IQLAggregationProcessor] = IQLAggregationProcessor diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 27347734..4ea65340 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,11 +1,16 @@ -from typing import List, Optional +import asyncio +from dataclasses import dataclass +from typing import Generic, List, Optional, TypeVar, Union from dbally.audit.event_tracker import EventTracker from dbally.iql import IQLError, IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.prompt import ( + AGGREGATION_DECISION_TEMPLATE, + AGGREGATION_GENERATION_TEMPLATE, FILTERING_DECISION_TEMPLATE, - IQL_GENERATION_TEMPLATE, - FilteringDecisionPromptFormat, + FILTERS_GENERATION_TEMPLATE, + DecisionPromptFormat, IQLGenerationPromptFormat, ) from dbally.llms.base import LLM @@ -15,57 +20,151 @@ from dbally.prompt.template import PromptTemplate from dbally.views.exposed_functions import ExposedFunction -ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ - generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" +IQLQueryT = TypeVar("IQLQueryT", bound=IQLQuery) -class IQLGenerator: +@dataclass +class IQLGeneratorState: + """ + State of the IQL generator. """ - Class used to generate IQL from natural language question. - In db-ally, LLM uses IQL (Intermediate Query Language) to express complex queries in a simplified way. - The class used to generate IQL from natural language query is `IQLGenerator`. + filters: Optional[Union[IQLFiltersQuery, Exception]] = None + aggregation: Optional[Union[IQLAggregationQuery, Exception]] = None + + @property + def failed(self) -> bool: + """ + Checks if the generation failed. - IQL generation is done using the method `self.generate_iql`. - It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question. + Returns: + True if the generation failed, False otherwise. + """ + return isinstance(self.filters, Exception) or isinstance(self.aggregation, Exception) + + +class IQLGenerator: + """ + Orchestrates all IQL operations for the given question. """ def __init__( self, - llm: LLM, - *, - decision_prompt: Optional[PromptTemplate[FilteringDecisionPromptFormat]] = None, - generation_prompt: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None, + filters_generation: Optional["IQLOperationGenerator"] = None, + aggregation_generation: Optional["IQLOperationGenerator"] = None, ) -> None: """ Constructs a new IQLGenerator instance. Args: - llm: LLM used to generate IQL. decision_prompt: Prompt template for filtering decision making. generation_prompt: Prompt template for IQL generation. """ - self._llm = llm - self._decision_prompt = decision_prompt or FILTERING_DECISION_TEMPLATE - self._generation_prompt = generation_prompt or IQL_GENERATION_TEMPLATE + self._filters_generation = filters_generation or IQLOperationGenerator[IQLFiltersQuery]( + FILTERING_DECISION_TEMPLATE, + FILTERS_GENERATION_TEMPLATE, + ) + self._aggregation_generation = aggregation_generation or IQLOperationGenerator[IQLAggregationQuery]( + AGGREGATION_DECISION_TEMPLATE, + AGGREGATION_GENERATION_TEMPLATE, + ) - async def generate( + # pylint: disable=too-many-arguments + async def __call__( self, + *, question: str, filters: List[ExposedFunction], - event_tracker: EventTracker, - examples: Optional[List[FewShotExample]] = None, + aggregations: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, + event_tracker: Optional[EventTracker] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, - ) -> Optional[IQLQuery]: + ) -> IQLGeneratorState: """ - Generates IQL in text form using LLM. + Generates IQL operations for the given question. Args: question: User question. filters: List of filters exposed by the view. + aggregations: List of aggregations exposed by the view. + examples: List of examples to be injected during filters and aggregation generation. + llm: LLM used to generate IQL. event_tracker: Event store used to audit the generation process. + llm_options: Options to use for the LLM client. + n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. + + Returns: + Generated IQL operations. + """ + filters, aggregation = await asyncio.gather( + self._filters_generation( + question=question, + methods=filters, + examples=examples, + llm=llm, + llm_options=llm_options, + event_tracker=event_tracker, + n_retries=n_retries, + ), + self._aggregation_generation( + question=question, + methods=aggregations, + examples=examples, + llm=llm, + llm_options=llm_options, + event_tracker=event_tracker, + n_retries=n_retries, + ), + return_exceptions=True, + ) + return IQLGeneratorState( + filters=filters, + aggregation=aggregation, + ) + + +class IQLOperationGenerator(Generic[IQLQueryT]): + """ + Generates IQL queries for the given question. + """ + + def __init__( + self, + assessor_prompt: PromptTemplate[DecisionPromptFormat], + generator_prompt: PromptTemplate[IQLGenerationPromptFormat], + ) -> None: + """ + Constructs a new IQLGenerator instance. + + Args: + assessor_prompt: Prompt template for filtering decision making. + generator_prompt: Prompt template for IQL generation. + """ + self.assessor = IQLQuestionAssessor(assessor_prompt) + self.generator = IQLQueryGenerator[IQLQueryT](generator_prompt) + + async def __call__( + self, + *, + question: str, + methods: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, + event_tracker: Optional[EventTracker] = None, + llm_options: Optional[LLMOptions] = None, + n_retries: int = 3, + ) -> Optional[IQLQueryT]: + """ + Generates IQL query for the given question. + + Args: + llm: LLM used to generate IQL. + question: User question. + methods: List of methods exposed by the view. examples: List of examples to be injected into the conversation. + event_tracker: Event store used to audit the generation process. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. @@ -77,38 +176,52 @@ async def generate( IQLError: If IQL parsing fails after all retries. UnsupportedQueryError: If the question is not supported by the view. """ - decision = await self._decide_on_generation( + decision = await self.assessor( question=question, - event_tracker=event_tracker, + llm=llm, llm_options=llm_options, + event_tracker=event_tracker, n_retries=n_retries, ) if not decision: return None - return await self._generate_iql( + return await self.generator( question=question, - filters=filters, - event_tracker=event_tracker, + methods=methods, examples=examples, + llm=llm, llm_options=llm_options, + event_tracker=event_tracker, n_retries=n_retries, ) - async def _decide_on_generation( + +class IQLQuestionAssessor: + """ + Assesses whether a question requires applying IQL operation or not. + """ + + def __init__(self, prompt: PromptTemplate[DecisionPromptFormat]) -> None: + self.prompt = prompt + + async def __call__( self, + *, question: str, - event_tracker: EventTracker, + llm: LLM, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, n_retries: int = 3, ) -> bool: """ - Decides whether the question requires filtering or not. + Decides whether the question requires generating IQL or not. Args: question: User question. - event_tracker: Event store used to audit the generation process. + llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. + event_tracker: Event store used to audit the generation process. n_retries: Number of retries to LLM API in case of errors. Returns: @@ -117,12 +230,14 @@ async def _decide_on_generation( Raises: LLMError: If LLM text generation fails after all retries. """ - prompt_format = FilteringDecisionPromptFormat(question=question) - formatted_prompt = self._decision_prompt.format_prompt(prompt_format) + prompt_format = DecisionPromptFormat( + question=question, + ) + formatted_prompt = self.prompt.format_prompt(prompt_format) for retry in range(n_retries + 1): try: - response = await self._llm.generate_text( + response = await llm.generate_text( prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, @@ -133,24 +248,39 @@ async def _decide_on_generation( if retry == n_retries: raise exc - async def _generate_iql( + +class IQLQueryGenerator(Generic[IQLQueryT]): + """ + Generates IQL queries for the given question. + """ + + ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ + generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" + + def __init__(self, prompt: PromptTemplate[IQLGenerationPromptFormat]) -> None: + self.prompt = prompt + + async def __call__( self, + *, question: str, - filters: List[ExposedFunction], - event_tracker: Optional[EventTracker] = None, - examples: Optional[List[FewShotExample]] = None, + methods: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, n_retries: int = 3, - ) -> IQLQuery: + ) -> IQLQueryT: """ - Generates IQL in text form using LLM. + Generates IQL query for the given question. Args: question: User question. filters: List of filters exposed by the view. - event_tracker: Event store used to audit the generation process. examples: List of examples to be injected into the conversation. + llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. + event_tracker: Event store used to audit the generation process. n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. Returns: @@ -163,24 +293,22 @@ async def _generate_iql( """ prompt_format = IQLGenerationPromptFormat( question=question, - filters=filters, + methods=methods, examples=examples, ) - formatted_prompt = self._generation_prompt.format_prompt(prompt_format) + formatted_prompt = self.prompt.format_prompt(prompt_format) for retry in range(n_retries + 1): try: - response = await self._llm.generate_text( + response = await llm.generate_text( prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, ) # TODO: Move response parsing to llm generate_text method - iql = formatted_prompt.response_parser(response) - # TODO: Move IQL query parsing to prompt response parser - return await IQLQuery.parse( - source=iql, - allowed_functions=filters, + return await formatted_prompt.response_parser( + response=response, + allowed_functions=methods, event_tracker=event_tracker, ) except LLMError as exc: @@ -190,4 +318,4 @@ async def _generate_iql( if retry == n_retries: raise exc formatted_prompt = formatted_prompt.add_assistant_message(response) - formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) + formatted_prompt = formatted_prompt.add_user_message(self.ERROR_MESSAGE.format(error=exc)) diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 4e5a45ec..f2c29d62 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -1,8 +1,10 @@ # pylint: disable=C0301 -from typing import List +from typing import List, Optional +from dbally.audit.event_tracker import EventTracker from dbally.exceptions import DbAllyError +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.prompt.elements import FewShotExample from dbally.prompt.template import PromptFormat, PromptTemplate from dbally.views.exposed_functions import ExposedFunction @@ -15,26 +17,65 @@ class UnsupportedQueryError(DbAllyError): """ -def _validate_iql_response(llm_response: str) -> str: +async def _iql_filters_parser( + response: str, + allowed_functions: List[ExposedFunction], + event_tracker: Optional[EventTracker] = None, +) -> IQLFiltersQuery: """ - Validates LLM response to IQL + Parses the response from the LLM to IQL. Args: - llm_response: LLM response + response: LLM response. + allowed_functions: List of functions that can be used in the IQL. + event_tracker: Event tracker to be used for auditing. Returns: - A string containing IQL for filters. + IQL query for filters. Raises: - UnsuppotedQueryError: When IQL generator is unable to construct a query - with given filters. + UnsuppotedQueryError: When IQL generator is unable to construct a query with given filters. """ - if "unsupported query" in llm_response.lower(): + if "unsupported query" in response.lower(): raise UnsupportedQueryError - return llm_response + return await IQLFiltersQuery.parse( + source=response, + allowed_functions=allowed_functions, + event_tracker=event_tracker, + ) -def _decision_iql_response_parser(response: str) -> bool: + +async def _iql_aggregation_parser( + response: str, + allowed_functions: List[ExposedFunction], + event_tracker: Optional[EventTracker] = None, +) -> IQLAggregationQuery: + """ + Parses the response from the LLM to IQL. + + Args: + response: LLM response. + allowed_functions: List of functions that can be used in the IQL. + event_tracker: Event tracker to be used for auditing. + + Returns: + IQL query for aggregations. + + Raises: + UnsuppotedQueryError: When IQL generator is unable to construct a query with given aggregations. + """ + if "unsupported query" in response.lower(): + raise UnsupportedQueryError + + return await IQLAggregationQuery.parse( + source=response, + allowed_functions=allowed_functions, + event_tracker=event_tracker, + ) + + +def _decision_parser(response: str) -> bool: """ Parses the response from the decision prompt. @@ -52,7 +93,7 @@ def _decision_iql_response_parser(response: str) -> bool: return "true" in decision -class FilteringDecisionPromptFormat(PromptFormat): +class DecisionPromptFormat(PromptFormat): """ IQL prompt format, providing a question and filters to be used in the conversation. """ @@ -71,44 +112,96 @@ def __init__(self, *, question: str, examples: List[FewShotExample] = None) -> N class IQLGenerationPromptFormat(PromptFormat): """ - IQL prompt format, providing a question and filters to be used in the conversation. + IQL prompt format, providing a question and methods to be used in the conversation. """ def __init__( self, *, question: str, - filters: List[ExposedFunction], - examples: List[FewShotExample] = None, + methods: List[ExposedFunction], + examples: Optional[List[FewShotExample]] = None, ) -> None: """ Constructs a new IQLGenerationPromptFormat instance. Args: question: Question to be asked. - filters: List of filters exposed by the view. + methods: List of methods exposed by the view. examples: List of examples to be injected into the conversation. aggregations: List of aggregations exposed by the view. """ super().__init__(examples) self.question = question - self.filters = "\n".join([str(condition) for condition in filters]) if filters else [] + self.methods = "\n".join(str(method) for method in methods) -IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( +FILTERING_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( + [ + { + "role": "system", + "content": ( + "Given a question, determine whether the answer requires data filtering in order to compute it.\n" + "Data filtering is a process in which the result set is filtered based on the specific features " + "stated in the question. Such a question can be easily identified by using words that refer to " + "specific feature values (rather than feature names).\n" + "Look for words indicating specific values that the answer should contain. \n\n" + "---\n\n" + "Follow the following format.\n\n" + "Question: ${{question}}\n" + "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" + "Decision: indicates whether the answer to the question requires data filtering. " + "(Respond with True or False)\n\n" + ), + }, + { + "role": "user", + "content": ("Question: {question}\n" "Reasoning: Let's think step by step in order to "), + }, + ], + response_parser=_decision_parser, +) + +AGGREGATION_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( + [ + { + "role": "system", + "content": ( + "Given a question, determine whether the answer requires data aggregation in order to compute it.\n" + "Data aggregation is a process in which we calculate a single values for a group of rows in the " + "result set.\n" + "Most common aggregation functions are counting, averaging, summing, but other types of aggregation " + "are possible.\n\n" + "---\n\n" + "Follow the following format.\n\n" + "Question: ${{question}}\n" + "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" + "Decision: indicates whether the answer to the question requires initial data filtering. " + "(Respond with True or False)\n\n" + ), + }, + { + "role": "user", + "content": "Question: {question}\n" "Reasoning: Let's think step by step in order to ", + }, + ], + response_parser=_decision_parser, +) + +FILTERS_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( [ { "role": "system", "content": ( "You have access to an API that lets you query a database:\n" - "\n{filters}\n" + "\n{methods}\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n{filters}\n" + "\n{methods}\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. " @@ -119,35 +212,31 @@ def __init__( "content": "{question}", }, ], - response_parser=_validate_iql_response, + response_parser=_iql_filters_parser, ) - -FILTERING_DECISION_TEMPLATE = PromptTemplate[FilteringDecisionPromptFormat]( +AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( [ { "role": "system", "content": ( - "Given a question, determine whether the answer requires initial data filtering in order to compute it.\n" - "Initial data filtering is a process in which the result set is reduced to only include the rows " - "that meet certain criteria specified in the question.\n\n" - "---\n\n" - "Follow the following format.\n\n" - "Question: ${{question}}\n" - "Hint: ${{hint}}" - "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" - "Decision: indicates whether the answer to the question requires initial data filtering. " - "(Respond with True or False)\n\n" + "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" + "When prompted for an aggregation, use the following methods: \n" + "{methods}" + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n{methods}\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" + "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " + "Structure output to resemble the following pattern:\n" + 'aggregation1("arg1", arg2)\n' ), }, { "role": "user", - "content": ( - "Question: {question}\n" - "Hint: Look for words indicating data specific features.\n" - "Reasoning: Let's think step by step in order to " - ), + "content": "{question}", }, ], - response_parser=_decision_iql_response_parser, + response_parser=_iql_aggregation_parser, ) diff --git a/src/dbally/prompt/template.py b/src/dbally/prompt/template.py index 124a3e1c..b4ef650d 100644 --- a/src/dbally/prompt/template.py +++ b/src/dbally/prompt/template.py @@ -1,6 +1,6 @@ import copy import re -from typing import Callable, Dict, Generic, List, TypeVar +from typing import Callable, Dict, Generic, List, Optional, TypeVar from typing_extensions import Self @@ -55,7 +55,7 @@ class PromptFormat: Generic format for prompts allowing to inject few shot examples into the conversation. """ - def __init__(self, examples: List[FewShotExample] = None) -> None: + def __init__(self, examples: Optional[List[FewShotExample]] = None) -> None: """ Constructs a new PromptFormat instance. diff --git a/src/dbally/views/exceptions.py b/src/dbally/views/exceptions.py index 277064a4..15770e9a 100644 --- a/src/dbally/views/exceptions.py +++ b/src/dbally/views/exceptions.py @@ -1,26 +1,22 @@ -from typing import Optional - from dbally.exceptions import DbAllyError +from dbally.iql_generator.iql_generator import IQLGeneratorState -class IQLGenerationError(DbAllyError): +class ViewExecutionError(DbAllyError): """ - Exception for when an error occurs while generating IQL for a view. + Exception for when an error occurs while executing a view. """ def __init__( self, view_name: str, - filters: Optional[str] = None, - aggregation: Optional[str] = None, + iql: IQLGeneratorState, ) -> None: """ Args: view_name: Name of the view that caused the error. - filters: Filters generated by the view. - aggregation: Aggregation generated by the view. + iql: View IQL generator state. """ - super().__init__(f"Error while generating IQL for view {view_name}") + super().__init__(f"Error while executing view {view_name}") self.view_name = view_name - self.filters = filters - self.aggregation = aggregation + self.iql = iql diff --git a/src/dbally/views/methods_base.py b/src/dbally/views/methods_base.py index 977a2fa1..8bf93363 100644 --- a/src/dbally/views/methods_base.py +++ b/src/dbally/views/methods_base.py @@ -1,15 +1,15 @@ import inspect import textwrap from abc import ABC -from typing import Any, Callable, Generic, List, Tuple +from typing import Any, Callable, List, Tuple from dbally.iql import syntax from dbally.views import decorators from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from dbally.views.structured import BaseStructuredView, DataT +from dbally.views.structured import BaseStructuredView -class MethodsBaseView(Generic[DataT], BaseStructuredView, ABC): +class MethodsBaseView(BaseStructuredView, ABC): """ Base class for views that use view methods to expose filters. """ @@ -67,14 +67,14 @@ def list_aggregations(self) -> List[ExposedFunction]: def _method_with_args_from_call( self, func: syntax.FunctionCall, method_decorator: Callable - ) -> Tuple[Callable, list]: + ) -> Tuple[Callable, List]: """ Converts a IQL FunctionCall node to a method object and its arguments. Args: func: IQL FunctionCall node method_decorator: The decorator that the method should have - (currently allows discrimination between filters and aggregations) + (currently allows discrimination between filters and aggregations) Returns: Tuple with the method object and its arguments @@ -94,6 +94,21 @@ def _method_with_args_from_call( return method, func.arguments + async def _call_method(self, method: Callable, args: List) -> Any: + """ + Calls the method with the given arguments. If the method is a coroutine, it will be awaited. + + Args: + method: The method to call. + args: The arguments to pass to the method. + + Returns: + The result of the method call. + """ + if inspect.iscoroutinefunction(method): + return await method(*args) + return method(*args) + async def call_filter_method(self, func: syntax.FunctionCall) -> Any: """ Converts a IQL FunctonCall filter to a method call. If the method is a coroutine, it will be awaited. @@ -105,12 +120,9 @@ async def call_filter_method(self, func: syntax.FunctionCall) -> Any: The result of the method call """ method, args = self._method_with_args_from_call(func, decorators.view_filter) + return await self._call_method(method, args) - if inspect.iscoroutinefunction(method): - return await method(*args) - return method(*args) - - async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataT: + async def call_aggregation_method(self, func: syntax.FunctionCall) -> Any: """ Converts a IQL FunctonCall aggregation to a method call. If the method is a coroutine, it will be awaited. @@ -121,7 +133,4 @@ async def call_aggregation_method(self, func: syntax.FunctionCall) -> DataT: The result of the method call """ method, args = self._method_with_args_from_call(func, decorators.view_aggregation) - - if inspect.iscoroutinefunction(method): - return await method(*args) - return method(*args) + return await self._call_method(method, args) diff --git a/src/dbally/views/pandas_base.py b/src/dbally/views/pandas_base.py index 5f7bc8ce..e4da84c4 100644 --- a/src/dbally/views/pandas_base.py +++ b/src/dbally/views/pandas_base.py @@ -1,15 +1,37 @@ import asyncio +from dataclasses import dataclass from functools import reduce -from typing import Optional +from typing import List, Optional, Union import pandas as pd from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery, syntax +from dbally.iql import syntax +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.methods_base import MethodsBaseView -class DataFrameBaseView(MethodsBaseView[pd.DataFrame]): +@dataclass(frozen=True) +class Aggregation: + """ + Represents an aggregation to be applied to a Pandas DataFrame. + """ + + column: str + function: str + + +@dataclass(frozen=True) +class AggregationGroup: + """ + Represents an aggregations and groupbys to be applied to a Pandas DataFrame. + """ + + aggregations: Optional[List[Aggregation]] = None + groupbys: Optional[Union[str, List[str]]] = None + + +class DataFrameBaseView(MethodsBaseView): """ Base class for views that use Pandas DataFrames to store and filter data. @@ -24,35 +46,30 @@ def __init__(self, df: pd.DataFrame) -> None: Args: df: Pandas DataFrame with the data to be filtered. """ - super().__init__(df) - - # The mask to be applied to the dataframe to filter the data + super().__init__() + self.df = df self._filter_mask: Optional[pd.Series] = None + self._aggregation_group: AggregationGroup = AggregationGroup() - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: filters: IQLQuery object representing the filters to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self._filter_mask = await self.build_filter_node(filters.root) - self.data = self.data.loc[self._filter_mask] + self._filter_mask = await self._build_filter_node(filters.root) - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ Applies the aggregation of choice to the view. Args: aggregation: IQLQuery object representing the aggregation to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self.data = await self.call_aggregation_method(aggregation.root) + self._aggregation_group = await self.call_aggregation_method(aggregation.root) - async def build_filter_node(self, node: syntax.Node) -> pd.Series: + async def _build_filter_node(self, node: syntax.Node) -> pd.Series: """ Converts a filter node from the IQLQuery to a Pandas Series representing a boolean mask to be applied to the dataframe. @@ -69,13 +86,13 @@ async def build_filter_node(self, node: syntax.Node) -> pd.Series: if isinstance(node, syntax.FunctionCall): return await self.call_filter_method(node) if isinstance(node, syntax.And): # logical AND - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children]) return reduce(lambda x, y: x & y, children) if isinstance(node, syntax.Or): # logical OR - children = await asyncio.gather(*[self.build_filter_node(child) for child in node.children]) + children = await asyncio.gather(*[self._build_filter_node(child) for child in node.children]) return reduce(lambda x, y: x | y, children) if isinstance(node, syntax.Not): - child = await self.build_filter_node(node.child) + child = await self._build_filter_node(node.child) return ~child raise ValueError(f"Unsupported grammar: {node}") @@ -90,11 +107,30 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: Returns: ExecutionResult object with the results and the context information with the binary mask. """ - results = pd.DataFrame.empty if dry_run else self.data + results = pd.DataFrame() + + if not dry_run: + results = self.df + if self._filter_mask is not None: + results = results.loc[self._filter_mask] + + if self._aggregation_group.groupbys is not None: + results = results.groupby(self._aggregation_group.groupbys) + + if self._aggregation_group.aggregations is not None: + results = results.agg( + **{ + f"{agg.column}_{agg.function}": (agg.column, agg.function) + for agg in self._aggregation_group.aggregations + } + ) + results = results.reset_index() return ViewExecutionResult( results=results.to_dict(orient="records"), context={ "filter_mask": self._filter_mask, + "groupbys": self._aggregation_group.groupbys, + "aggregations": self._aggregation_group.aggregations, }, ) diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index 4863aa6f..3a7c7981 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -4,11 +4,12 @@ import sqlalchemy from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery, syntax +from dbally.iql import syntax +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.methods_base import MethodsBaseView -class SqlAlchemyBaseView(MethodsBaseView[sqlalchemy.Select]): +class SqlAlchemyBaseView(MethodsBaseView): """ Base class for views that use SQLAlchemy to generate SQL queries. """ @@ -20,39 +21,36 @@ def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None: Args: sqlalchemy_engine: SQLAlchemy engine to use for executing the queries. """ - super().__init__(self.get_select()) + super().__init__() + self.select = self.get_select() self._sqlalchemy_engine = sqlalchemy_engine @abc.abstractmethod def get_select(self) -> sqlalchemy.Select: """ - Creates initial SELECT statement for the view. - - Returns: - SQLAlchemy Select object for the view. + Creates the initial + [SqlAlchemy select object + ](https://docs.sqlalchemy.org/en/20/core/selectable.html#sqlalchemy.sql.expression.Select) + which will be used to build the query. """ - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: filters: IQLQuery object representing the filters to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self.data = self.data.where(await self._build_filter_node(filters.root)) + self.select = self.select.where(await self._build_filter_node(filters.root)) - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ Applies the chosen aggregation to the view. Args: aggregation: IQLQuery object representing the aggregation to apply. """ - # data is defined in the parent class - # pylint: disable=attribute-defined-outside-init - self.data = await self.call_aggregation_method(aggregation.root) + self.select = await self.call_aggregation_method(aggregation.root) async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement: """ @@ -95,11 +93,11 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored. """ results = [] - sql = str(self.data.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) + sql = str(self.select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True})) if not dry_run: with self._sqlalchemy_engine.connect() as connection: - rows = connection.execute(self.data).fetchall() + rows = connection.execute(self.select).fetchall() # The underscore is used by sqlalchemy to avoid conflicts with column names # pylint: disable=protected-access results = [dict(row._mapping) for row in rows] diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index c3ac91e0..019bafea 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,57 +4,33 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.exceptions import UnsupportedAggregationError -from dbally.iql import IQLQuery -from dbally.iql._exceptions import IQLError +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.views.exceptions import IQLGenerationError +from dbally.views.exceptions import ViewExecutionError from dbally.views.exposed_functions import ExposedFunction -from ..prompt.aggregation import AggregationFormatter from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation DataT = TypeVar("DataT", bound=Any) -# TODO(Python 3.9+): Make BaseStructuredView a generic class class BaseStructuredView(BaseView): """ Base class for all structured [Views](../../concepts/views.md). All classes implementing this interface has\ to be able to list all available filters, apply them and execute queries. """ - def __init__(self, data: DataT) -> None: - super().__init__() - self.data = data - - def get_iql_generator(self, llm: LLM) -> IQLGenerator: + def get_iql_generator(self) -> IQLGenerator: """ Returns the IQL generator for the view. - Args: - llm: LLM used to generate the IQL queries. - Returns: IQL generator for the view. """ - return IQLGenerator(llm=llm) - - def get_agg_formatter(self, llm: LLM) -> AggregationFormatter: - """ - Returns the AggregtionFormatter for the view. - - Args: - llm: LLM used to generate the queries. - - Returns: - AggregtionFormatter for the view. - """ - return AggregationFormatter(llm=llm) + return IQLGenerator() async def ask( self, @@ -81,68 +57,41 @@ async def ask( The result of the query. Raises: - LLMError: If LLM text generation API fails. - IQLGenerationError: If the IQL generation fails. + ViewExecutionError: When an error occurs while executing the view. """ - iql_generator = self.get_iql_generator(llm) - agg_formatter = self.get_agg_formatter(llm) filters = self.list_filters() examples = self.list_few_shots() aggregations = self.list_aggregations() - try: - iql = await iql_generator.generate( - question=query, - filters=filters, - examples=examples, - event_tracker=event_tracker, - llm_options=llm_options, - n_retries=n_retries, - ) - except UnsupportedQueryError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=None, - aggregation=None, - ) from exc - except IQLError as exc: - raise IQLGenerationError( + iql_generator = self.get_iql_generator() + iql = await iql_generator( + question=query, + filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, + event_tracker=event_tracker, + llm_options=llm_options, + n_retries=n_retries, + ) + + if iql.failed: + raise ViewExecutionError( view_name=self.__class__.__name__, - filters=exc.source, - aggregation=None, - ) from exc - - if iql: - await self.apply_filters(iql) - - try: - agg_node = await agg_formatter.format_to_query_object( - question=query, - aggregations=aggregations, - event_tracker=event_tracker, - llm_options=llm_options, + iql=iql, ) - except UnsupportedAggregationError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=str(iql) if iql else None, - aggregation=None, - ) from exc - except IQLError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=str(iql) if iql else None, - aggregation=exc.source, - ) from exc - await self.apply_aggregation(agg_node) + if iql.filters: + await self.apply_filters(iql.filters) + + if iql.aggregation: + await self.apply_aggregation(iql.aggregation) result = self.execute(dry_run=dry_run) result.context["iql"] = { - "filters": str(iql) if iql else None, - "aggregation": str(agg_node), + "filters": str(iql.filters) if iql.filters else None, + "aggregation": str(iql.aggregation) if iql.aggregation else None, } - return result @abc.abstractmethod @@ -164,21 +113,21 @@ def list_aggregations(self) -> List[ExposedFunction]: """ @abc.abstractmethod - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: """ Applies the chosen filters to the view. Args: - filters: [IQLQuery](../../concepts/iql.md) object representing the filters to apply. + filters: IQLQuery object representing the filters to apply. """ @abc.abstractmethod - async def apply_aggregation(self, aggregation: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: """ Applies the chosen aggregation to the view. Args: - aggregation: [IQLQuery](../../concepts/iql.md) object representing the filters to apply. + aggregation: IQLQuery object representing the aggregation to apply. """ @abc.abstractmethod diff --git a/tests/unit/iql/test_iql_parser.py b/tests/unit/iql/test_iql_parser.py index ae5d2269..bed83d0a 100644 --- a/tests/unit/iql/test_iql_parser.py +++ b/tests/unit/iql/test_iql_parser.py @@ -3,7 +3,7 @@ import pytest -from dbally.iql import IQLArgumentParsingError, IQLQuery, IQLUnsupportedSyntaxError, syntax +from dbally.iql import IQLArgumentParsingError, IQLUnsupportedSyntaxError, syntax from dbally.iql._exceptions import ( IQLArgumentValidationError, IQLFunctionNotExists, @@ -14,11 +14,12 @@ IQLSyntaxError, ) from dbally.iql._processor import IQLProcessor +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -async def test_iql_parser(): - parsed = await IQLQuery.parse( +async def test_iql_filter_parser(): + parsed = await IQLFiltersQuery.parse( "not (filter_by_name(['John', 'Anne']) and filter_by_city('cracow') and filter_by_company('deepsense.ai'))", allowed_functions=[ ExposedFunction( @@ -51,9 +52,9 @@ async def test_iql_parser(): assert company_filter.arguments[0] == "deepsense.ai" -async def test_iql_parser_arg_error(): +async def test_iql_filter_parser_arg_error(): with pytest.raises(IQLArgumentParsingError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_city('Cracow') and filter_by_name(lambda x: x + 1)", allowed_functions=[ ExposedFunction( @@ -76,9 +77,9 @@ async def test_iql_parser_arg_error(): assert exc_info.match(re.escape("Not a valid IQL argument: lambda x: x + 1")) -async def test_iql_parser_syntax_error(): +async def test_iql_filter_parser_syntax_error(): with pytest.raises(IQLSyntaxError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age(", allowed_functions=[ ExposedFunction( @@ -94,9 +95,9 @@ async def test_iql_parser_syntax_error(): assert exc_info.match(re.escape("Syntax error in: filter_by_age(")) -async def test_iql_parser_multiple_expression_error(): +async def test_iql_filter_parser_multiple_expression_error(): with pytest.raises(IQLMultipleStatementsError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age\nfilter_by_age", allowed_functions=[ ExposedFunction( @@ -112,9 +113,9 @@ async def test_iql_parser_multiple_expression_error(): assert exc_info.match(re.escape("Multiple statements in IQL are not supported")) -async def test_iql_parser_empty_expression_error(): +async def test_iql_filter_parser_empty_expression_error(): with pytest.raises(IQLNoStatementError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "", allowed_functions=[ ExposedFunction( @@ -130,9 +131,9 @@ async def test_iql_parser_empty_expression_error(): assert exc_info.match(re.escape("Empty IQL")) -async def test_iql_parser_no_expression_error(): +async def test_iql_filter_parser_no_expression_error(): with pytest.raises(IQLNoExpressionError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "import filter_by_age", allowed_functions=[ ExposedFunction( @@ -148,9 +149,9 @@ async def test_iql_parser_no_expression_error(): assert exc_info.match(re.escape("No expression found in IQL: import filter_by_age")) -async def test_iql_parser_unsupported_syntax_error(): +async def test_iql_filter_parser_unsupported_syntax_error(): with pytest.raises(IQLUnsupportedSyntaxError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age() >= 30", allowed_functions=[ ExposedFunction( @@ -166,9 +167,9 @@ async def test_iql_parser_unsupported_syntax_error(): assert exc_info.match(re.escape("Compare syntax is not supported in IQL: filter_by_age() >= 30")) -async def test_iql_parser_method_not_exists(): +async def test_iql_filter_parser_method_not_exists(): with pytest.raises(IQLFunctionNotExists) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_how_old_somebody_is(40)", allowed_functions=[ ExposedFunction( @@ -184,9 +185,9 @@ async def test_iql_parser_method_not_exists(): assert exc_info.match(re.escape("Function filter_by_how_old_somebody_is not exists: filter_by_how_old_somebody_is")) -async def test_iql_parser_incorrect_number_of_arguments_fail(): +async def test_iql_filter_parser_incorrect_number_of_arguments_fail(): with pytest.raises(IQLIncorrectNumberArgumentsError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age('too old', 40)", allowed_functions=[ ExposedFunction( @@ -204,9 +205,9 @@ async def test_iql_parser_incorrect_number_of_arguments_fail(): ) -async def test_iql_parser_argument_validation_fail(): +async def test_iql_filter_parser_argument_validation_fail(): with pytest.raises(IQLArgumentValidationError) as exc_info: - await IQLQuery.parse( + await IQLFiltersQuery.parse( "filter_by_age('too old')", allowed_functions=[ ExposedFunction( @@ -222,6 +223,189 @@ async def test_iql_parser_argument_validation_fail(): assert exc_info.match(re.escape("'too old' is not of type int: 'too old'")) +async def test_iql_aggregation_parser(): + parsed = await IQLAggregationQuery.parse( + "mean_age_by_city('Paris')", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert isinstance(parsed.root, syntax.FunctionCall) + assert parsed.root.name == "mean_age_by_city" + assert parsed.root.arguments == ["Paris"] + + +async def test_iql_aggregation_parser_arg_error(): + with pytest.raises(IQLArgumentParsingError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city(lambda x: x + 1)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Not a valid IQL argument: lambda x: x + 1")) + + +async def test_iql_aggregation_parser_syntax_error(): + with pytest.raises(IQLSyntaxError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city(", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + assert exc_info.match(re.escape("Syntax error in: mean_age_by_city(")) + + +async def test_iql_aggregation_parser_multiple_expression_error(): + with pytest.raises(IQLMultipleStatementsError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city\nmean_age_by_city", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Multiple statements in IQL are not supported")) + + +async def test_iql_aggregation_parser_empty_expression_error(): + with pytest.raises(IQLNoStatementError) as exc_info: + await IQLAggregationQuery.parse( + "", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Empty IQL")) + + +async def test_iql_aggregation_parser_no_expression_error(): + with pytest.raises(IQLNoExpressionError) as exc_info: + await IQLAggregationQuery.parse( + "import mean_age_by_city", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("No expression found in IQL: import mean_age_by_city")) + + +@pytest.mark.parametrize( + "iql, info", + [ + ("mean_age_by_city() >= 30", "Compare syntax is not supported in IQL: mean_age_by_city() >= 30"), + ( + "mean_age_by_city('Paris') and mean_age_by_city('London')", + "BoolOp syntax is not supported in IQL: mean_age_by_city('Paris') and mean_age_by_city('London')", + ), + ( + "mean_age_by_city('Paris') or mean_age_by_city('London')", + "BoolOp syntax is not supported in IQL: mean_age_by_city('Paris') or mean_age_by_city('London')", + ), + ("not mean_age_by_city('Paris')", "UnaryOp syntax is not supported in IQL: not mean_age_by_city('Paris')"), + ], +) +async def test_iql_aggregation_parser_unsupported_syntax_error(iql, info): + with pytest.raises(IQLUnsupportedSyntaxError) as exc_info: + await IQLAggregationQuery.parse( + iql, + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + assert exc_info.match(re.escape(info)) + + +async def test_iql_aggregation_parser_method_not_exists(): + with pytest.raises(IQLFunctionNotExists) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_town()", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match(re.escape("Function mean_age_by_town not exists: mean_age_by_town")) + + +async def test_iql_aggregation_parser_incorrect_number_of_arguments_fail(): + with pytest.raises(IQLIncorrectNumberArgumentsError) as exc_info: + await IQLAggregationQuery.parse( + "mean_age_by_city('too old')", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[], + ), + ], + ) + + assert exc_info.match( + re.escape("The method mean_age_by_city has incorrect number of arguments: mean_age_by_city('too old')") + ) + + +async def test_iql_aggregation_parser_argument_validation_fail(): + with pytest.raises(IQLArgumentValidationError): + await IQLAggregationQuery.parse( + "mean_age_by_city(12)", + allowed_functions=[ + ExposedFunction( + name="mean_age_by_city", + description="", + parameters=[ + MethodParamWithTyping(name="city", type=str), + ], + ), + ], + ) + + def test_keywords_lowercase(): rv = IQLProcessor._to_lower_except_in_quotes( """NOT filter1(230) AND (NOT filter_2("NOT ADMIN") AND filter_('IS NOT ADMIN')) OR NOT filter_4()""", diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 992fd03d..69174389 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -9,11 +9,10 @@ from typing import List, Optional, Union from dbally import NOT_GIVEN, NotGiven -from dbally.iql import IQLQuery -from dbally.iql_generator.iql_generator import IQLGenerator +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery +from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState from dbally.llms.base import LLM from dbally.llms.clients.base import LLMClient, LLMOptions -from dbally.prompt.aggregation import AggregationFormatter from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.structured import BaseStructuredView, ExposedFunction, ViewExecutionResult @@ -24,19 +23,16 @@ class MockViewBase(BaseStructuredView): Mock view base class """ - def __init__(self) -> None: - super().__init__([]) - def list_filters(self) -> List[ExposedFunction]: return [] - async def apply_filters(self, filters: IQLQuery) -> None: - ... - def list_aggregations(self) -> List[ExposedFunction]: return [] - async def apply_aggregation(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: @@ -44,21 +40,12 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: class MockIQLGenerator(IQLGenerator): - def __init__(self, iql: IQLQuery) -> None: - self.iql = iql - super().__init__(llm=MockLLM()) - - async def generate(self, *_, **__) -> IQLQuery: - return self.iql - - -class MockAggregationFormatter(AggregationFormatter): - def __init__(self, iql_query: IQLQuery) -> None: - self.iql_query = iql_query - super().__init__(llm=MockLLM()) + def __init__(self, state: IQLGeneratorState) -> None: + self.state = state + super().__init__() - async def format_to_query_object(self, *_, **__) -> IQLQuery: - return self.iql_query + async def __call__(self, *_, **__) -> IQLGeneratorState: + return self.state class MockViewSelector(ViewSelector): diff --git a/tests/unit/similarity/sample_module/submodule.py b/tests/unit/similarity/sample_module/submodule.py index 42e05c0a..ab4b6c7e 100644 --- a/tests/unit/similarity/sample_module/submodule.py +++ b/tests/unit/similarity/sample_module/submodule.py @@ -3,7 +3,7 @@ from typing_extensions import Annotated from dbally import MethodsBaseView, decorators -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.structured import ViewExecutionResult from tests.unit.mocks import MockSimilarityIndex @@ -20,7 +20,10 @@ def method_foo(self, idx: Annotated[str, index_foo]) -> str: def method_bar(self, city: Annotated[str, index_foo], year: Annotated[int, index_bar]) -> str: return f"hello {city} in {year}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: @@ -39,7 +42,10 @@ def method_qux(self, city: str, year: int) -> str: """ return f"hello {city} in {year}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: + ... + + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index a077286d..1d675d84 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -10,17 +10,11 @@ from dbally.collection import Collection from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.iql.syntax import FunctionCall +from dbally.iql_generator.iql_generator import IQLGeneratorState from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from tests.unit.mocks import ( - MockAggregationFormatter, - MockIQLGenerator, - MockLLM, - MockSimilarityIndex, - MockViewBase, - MockViewSelector, -) +from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector class MockView1(MockViewBase): @@ -66,15 +60,17 @@ def execute(self, dry_run=False) -> ViewExecutionResult: def list_filters(self) -> List[ExposedFunction]: return [ExposedFunction("test_filter", "", [])] - def get_iql_generator(self, *_, **__) -> MockIQLGenerator: - return MockIQLGenerator(IQLQuery(FunctionCall("test_filter", []), "test_filter()")) + def get_iql_generator(self) -> MockIQLGenerator: + return MockIQLGenerator( + IQLGeneratorState( + filters=IQLFiltersQuery(FunctionCall("test_filter", []), "test_filter()"), + aggregation=IQLAggregationQuery(FunctionCall("test_aggregation", []), "test_aggregation()"), + ), + ) def list_aggregations(self) -> List[ExposedFunction]: return [ExposedFunction("test_aggregation", "", [])] - def get_agg_formatter(self, *_, **__) -> MockAggregationFormatter: - return MockAggregationFormatter(IQLQuery(FunctionCall("test_aggregation", []), "test_aggregation()")) - @pytest.fixture(name="similarity_classes") def mock_similarity_classes() -> ( diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index b798e533..3a21a1fe 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -1,27 +1,27 @@ -from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat +from dbally.iql_generator.prompt import FILTERS_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.prompt.elements import FewShotExample async def test_iql_prompt_format_default() -> None: prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=[], ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) assert formatted_prompt.chat == [ { "role": "system", "content": "You have access to an API that lets you query a database:\n" - "\n[]\n" + "\n\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n[]\n" + "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. ", @@ -35,23 +35,23 @@ async def test_iql_prompt_format_few_shots_injected() -> None: examples = [FewShotExample("q1", "a1")] prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=examples, ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) assert formatted_prompt.chat == [ { "role": "system", "content": "You have access to an API that lets you query a database:\n" - "\n[]\n" + "\n\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n[]\n" + "\n\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. ", @@ -67,12 +67,12 @@ async def test_iql_input_format_few_shot_examples_repeat_no_example_duplicates() examples = [FewShotExample("q1", "a1")] prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=examples, ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) - assert len(formatted_prompt.chat) == len(IQL_GENERATION_TEMPLATE.chat) + (len(examples) * 2) + assert len(formatted_prompt.chat) == len(FILTERS_GENERATION_TEMPLATE.chat) + (len(examples) * 2) assert formatted_prompt.chat[1]["role"] == "user" assert formatted_prompt.chat[1]["content"] == examples[0].question assert formatted_prompt.chat[2]["role"] == "assistant" diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index b95fe585..0defc8e1 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -1,35 +1,23 @@ # mypy: disable-error-code="empty-body" -from unittest.mock import AsyncMock, call, patch +from unittest.mock import AsyncMock, patch import pytest import sqlalchemy from dbally import decorators from dbally.audit.event_tracker import EventTracker -from dbally.iql import IQLError, IQLQuery -from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.prompt import ( - FILTERING_DECISION_TEMPLATE, - IQL_GENERATION_TEMPLATE, - FilteringDecisionPromptFormat, - IQLGenerationPromptFormat, -) +from dbally.iql import IQLAggregationQuery, IQLError, IQLFiltersQuery +from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState from dbally.views.methods_base import MethodsBaseView from tests.unit.mocks import MockLLM class MockView(MethodsBaseView): - def __init__(self) -> None: - super().__init__(None) - - def get_select(self) -> sqlalchemy.Select: - ... - - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: ... - async def apply_aggregation(self, filters: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False): @@ -62,125 +50,177 @@ def event_tracker() -> EventTracker: @pytest.fixture -def iql_generator(llm: MockLLM) -> IQLGenerator: - return IQLGenerator(llm) +def iql_generator() -> IQLGenerator: + return IQLGenerator() @pytest.mark.asyncio -async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventTracker, view: MockView) -> None: +async def test_iql_generation( + iql_generator: IQLGenerator, + llm: MockLLM, + event_tracker: EventTracker, + view: MockView, +) -> None: filters = view.list_filters() - - decision_format = FilteringDecisionPromptFormat( - question="Mock_question", - ) - generation_format = IQLGenerationPromptFormat( - question="Mock_question", - filters=filters, - ) - - decision_prompt = FILTERING_DECISION_TEMPLATE.format_prompt(decision_format) - generation_prompt = IQL_GENERATION_TEMPLATE.format_prompt(generation_format) + aggregations = view.list_aggregations() + examples = view.list_few_shots() llm_responses = [ "decision: true", "filter_by_id(1)", + "decision: true", + "aggregate_by_id()", ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse: - iql = await iql_generator.generate( + iql_filter_parser_response = "filter_by_id(1)" + iql_aggregation_parser_response = "aggregate_by_id()" + + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch( + "dbally.iql.IQLFiltersQuery.parse", AsyncMock(return_value=iql_filter_parser_response) + ) as mock_filters_parse, patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(return_value=iql_aggregation_parser_response) + ) as mock_aggregation_parse: + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, ) - assert iql == "filter_by_id(1)" - iql_generator._llm.generate_text.assert_has_calls( - [ - call( - prompt=decision_prompt, - event_tracker=event_tracker, - options=None, - ), - call( - prompt=generation_prompt, - event_tracker=event_tracker, - options=None, - ), - ] + assert iql == IQLGeneratorState( + filters=iql_filter_parser_response, + aggregation=iql_aggregation_parser_response, ) - mock_parse.assert_called_once_with( - source="filter_by_id(1)", + assert llm.generate_text.call_count == 4 + mock_filters_parse.assert_called_once_with( + source=llm_responses[1], allowed_functions=filters, event_tracker=event_tracker, ) + mock_aggregation_parse.assert_called_once_with( + source=llm_responses[3], + allowed_functions=aggregations, + event_tracker=event_tracker, + ) @pytest.mark.asyncio async def test_iql_generation_error_escalation_after_max_retires( iql_generator: IQLGenerator, + llm: MockLLM, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - responses = [ + aggregations = view.list_aggregations() + examples = view.list_few_shots() + + llm_responses = [ + "decision: true", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "decision: true", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + ] + iql_filter_parser_responses = [ IQLError("err1", "src1"), IQLError("err2", "src2"), IQLError("err3", "src3"), IQLError("err4", "src4"), ] - llm_responses = [ - "decision: true", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", + iql_aggregation_parser_responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + IQLError("err4", "src4"), ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)), pytest.raises(IQLError): - iql = await iql_generator.generate( + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLFiltersQuery.parse", AsyncMock(side_effect=iql_filter_parser_responses)), patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(side_effect=iql_aggregation_parser_responses) + ): + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, n_retries=3, ) - assert iql is None - assert iql_generator._llm.generate_text.call_count == 4 - for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[1:], start=1): + assert iql == IQLGeneratorState( + filters=iql_filter_parser_responses[-1], + aggregation=iql_aggregation_parser_responses[-1], + ) + assert llm.generate_text.call_count == 10 + for i, arg in enumerate(llm.generate_text.call_args_list[2:5], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] + for i, arg in enumerate(llm.generate_text.call_args_list[7:10], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] @pytest.mark.asyncio async def test_iql_generation_response_after_max_retries( iql_generator: IQLGenerator, + llm: MockLLM, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - responses = [ + aggregations = view.list_aggregations() + examples = view.list_few_shots() + + llm_responses = [ + "decision: true", + "wrong_filter", + "wrong_filter", + "wrong_filter", + "filter_by_id(1)", + "decision: true", + "wrong_aggregation", + "wrong_aggregation", + "wrong_aggregation", + "aggregate_by_id()", + ] + iql_filter_parser_responses = [ IQLError("err1", "src1"), IQLError("err2", "src2"), IQLError("err3", "src3"), "filter_by_id(1)", ] - llm_responses = [ - "decision: true", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", - "filter_by_id(1)", + iql_aggregation_parser_responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + "aggregate_by_id()", ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)): - iql = await iql_generator.generate( + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLFiltersQuery.parse", AsyncMock(side_effect=iql_filter_parser_responses)), patch( + "dbally.iql.IQLAggregationQuery.parse", AsyncMock(side_effect=iql_aggregation_parser_responses) + ): + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, n_retries=3, ) - - assert iql == "filter_by_id(1)" - assert iql_generator._llm.generate_text.call_count == 5 - for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[2:], start=1): + assert iql == IQLGeneratorState( + filters=iql_filter_parser_responses[-1], + aggregation=iql_aggregation_parser_responses[-1], + ) + assert llm.generate_text.call_count == len(llm_responses) + for i, arg in enumerate(llm.generate_text.call_args_list[2:5], start=1): + assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] + for i, arg in enumerate(llm.generate_text.call_args_list[7:10], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 8d90ffc3..57c0b68a 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -4,7 +4,7 @@ from typing import List, Literal, Tuple from dbally.collection.results import ViewExecutionResult -from dbally.iql import IQLQuery +from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.exposed_functions import MethodParamWithTyping from dbally.views.methods_base import MethodsBaseView @@ -15,9 +15,6 @@ class MockMethodsBase(MethodsBaseView): Mock class for testing the MethodsBaseView """ - def __init__(self) -> None: - super().__init__(None) - @view_filter() def method_foo(self, idx: int) -> None: """ @@ -35,13 +32,13 @@ def method_baz(self) -> None: """ @view_aggregation() - def method_qux(self, ages: List[int], names: List[str]) -> None: + def method_qux(self, ages: List[int], names: List[str]) -> str: return f"hello {ages} and {names}" - async def apply_filters(self, filters: IQLQuery) -> None: + async def apply_filters(self, filters: IQLFiltersQuery) -> None: ... - async def apply_aggregation(self, filters: IQLQuery) -> None: + async def apply_aggregation(self, aggregation: IQLAggregationQuery) -> None: ... def execute(self, dry_run: bool = False) -> ViewExecutionResult: diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index 52a8f405..46e89750 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -1,10 +1,12 @@ # pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name + import pandas as pd -from dbally.iql import IQLQuery +from dbally.iql import IQLFiltersQuery +from dbally.iql._query import IQLAggregationQuery from dbally.views.decorators import view_aggregation, view_filter -from dbally.views.pandas_base import DataFrameBaseView +from dbally.views.pandas_base import Aggregation, AggregationGroup, DataFrameBaseView MOCK_DATA = [ {"name": "Alice", "city": "London", "year": 2020, "age": 30}, @@ -54,8 +56,21 @@ def filter_name(self, name: str) -> pd.Series: return self.data["name"] == name @view_aggregation() - def mean_age_by_city(self) -> pd.DataFrame: - return self.data.groupby(["city"]).agg({"age": "mean"}).reset_index() + def mean_age_by_city(self) -> AggregationGroup: + return AggregationGroup( + aggregations=[ + Aggregation(column="age", function="mean"), + ], + groupbys="city", + ) + + @view_aggregation() + def count_records(self) -> AggregationGroup: + return AggregationGroup( + aggregations=[ + Aggregation(column="name", function="count"), + ], + ) async def test_filter_or() -> None: @@ -63,7 +78,7 @@ async def test_filter_or() -> None: Test that the filtering the DataFrame with logical OR works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'filter_city("Berlin") or filter_city("London")', allowed_functions=mock_view.list_filters(), ) @@ -71,6 +86,8 @@ async def test_filter_or() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_BERLIN_OR_LONDON assert result.context["filter_mask"].tolist() == [True, False, True, False, True] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None async def test_filter_and() -> None: @@ -78,7 +95,7 @@ async def test_filter_and() -> None: Test that the filtering the DataFrame with logical AND works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'filter_city("Paris") and filter_year(2020)', allowed_functions=mock_view.list_filters(), ) @@ -86,6 +103,8 @@ async def test_filter_and() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_PARIS_2020 assert result.context["filter_mask"].tolist() == [False, True, False, False, False] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None async def test_filter_not() -> None: @@ -93,7 +112,7 @@ async def test_filter_not() -> None: Test that the filtering the DataFrame with logical NOT works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'not (filter_city("Paris") and filter_year(2020))', allowed_functions=mock_view.list_filters(), ) @@ -101,25 +120,48 @@ async def test_filter_not() -> None: result = mock_view.execute() assert result.results == MOCK_DATA_NOT_PARIS_2020 assert result.context["filter_mask"].tolist() == [True, False, True, True, True] + assert result.context["groupbys"] is None + assert result.context["aggregations"] is None -async def test_aggregtion() -> None: +async def test_aggregation() -> None: """ Test that DataFrame aggregation works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( + "count_records()", + allowed_functions=mock_view.list_aggregations(), + ) + await mock_view.apply_aggregation(query) + result = mock_view.execute() + assert result.results == [ + {"index": "name_count", "name": 5}, + ] + assert result.context["filter_mask"] is None + assert result.context["groupbys"] is None + assert result.context["aggregations"] == [Aggregation(column="name", function="count")] + + +async def test_aggregtion_with_groupby() -> None: + """ + Test that DataFrame aggregation with groupby works correctly + """ + mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) + query = await IQLAggregationQuery.parse( "mean_age_by_city()", allowed_functions=mock_view.list_aggregations(), ) await mock_view.apply_aggregation(query) result = mock_view.execute() assert result.results == [ - {"city": "Berlin", "age": 45.0}, - {"city": "London", "age": 32.5}, - {"city": "Paris", "age": 32.5}, + {"city": "Berlin", "age_mean": 45.0}, + {"city": "London", "age_mean": 32.5}, + {"city": "Paris", "age_mean": 32.5}, ] assert result.context["filter_mask"] is None + assert result.context["groupbys"] == "city" + assert result.context["aggregations"] == [Aggregation(column="age", function="mean")] async def test_filters_and_aggregtion() -> None: @@ -127,16 +169,18 @@ async def test_filters_and_aggregtion() -> None: Test that DataFrame filtering and aggregation works correctly """ mock_view = MockDataFrameView(pd.DataFrame.from_records(MOCK_DATA)) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( "filter_city('Paris')", allowed_functions=mock_view.list_filters(), ) await mock_view.apply_filters(query) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( "mean_age_by_city()", allowed_functions=mock_view.list_aggregations(), ) await mock_view.apply_aggregation(query) result = mock_view.execute() - assert result.results == [{"city": "Paris", "age": 32.5}] + assert result.results == [{"city": "Paris", "age_mean": 32.5}] assert result.context["filter_mask"].tolist() == [False, True, False, True, False] + assert result.context["groupbys"] == "city" + assert result.context["aggregations"] == [Aggregation(column="age", function="mean")] diff --git a/tests/unit/views/test_sqlalchemy_base.py b/tests/unit/views/test_sqlalchemy_base.py index 435c8f8e..571e6a70 100644 --- a/tests/unit/views/test_sqlalchemy_base.py +++ b/tests/unit/views/test_sqlalchemy_base.py @@ -4,7 +4,8 @@ import sqlalchemy -from dbally.iql import IQLQuery +from dbally.iql import IQLFiltersQuery +from dbally.iql._query import IQLAggregationQuery from dbally.views.decorators import view_aggregation, view_filter from dbally.views.sqlalchemy_base import SqlAlchemyBaseView @@ -33,7 +34,7 @@ def method_baz(self) -> sqlalchemy.Select: """ Some documentation string """ - return self.data.add_columns(sqlalchemy.literal("baz")).group_by(sqlalchemy.literal("baz")) + return self.select.add_columns(sqlalchemy.literal("baz")).group_by(sqlalchemy.literal("baz")) def normalize_whitespace(s: str) -> str: @@ -50,7 +51,7 @@ async def test_filter_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'method_foo(1) and method_bar("London", 2020)', allowed_functions=mock_view.list_filters(), ) @@ -66,7 +67,7 @@ async def test_aggregation_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( "method_baz()", allowed_functions=mock_view.list_aggregations(), ) @@ -82,12 +83,12 @@ async def test_filter_and_aggregation_sql_generation() -> None: mock_connection = sqlalchemy.create_mock_engine("postgresql://", executor=None) mock_view = MockSqlAlchemyView(mock_connection.engine) - query = await IQLQuery.parse( + query = await IQLFiltersQuery.parse( 'method_foo(1) and method_bar("London", 2020)', allowed_functions=mock_view.list_filters() + mock_view.list_aggregations(), ) await mock_view.apply_filters(query) - query = await IQLQuery.parse( + query = await IQLAggregationQuery.parse( "method_baz()", allowed_functions=mock_view.list_aggregations(), )