diff --git a/src/dbally/exceptions.py b/src/dbally/exceptions.py index 62faac37..6b095cd7 100644 --- a/src/dbally/exceptions.py +++ b/src/dbally/exceptions.py @@ -2,10 +2,3 @@ class DbAllyError(Exception): """ Base class for all exceptions raised by db-ally. """ - - -class UnsupportedAggregationError(DbAllyError): - """ - Error raised when AggregationFormatter is unable to construct a query - with given aggregation. - """ diff --git a/src/dbally/iql_generator/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/dbally/prompt/aggregation.py b/src/dbally/prompt/aggregation.py deleted file mode 100644 index 8dedd95c..00000000 --- a/src/dbally/prompt/aggregation.py +++ /dev/null @@ -1,122 +0,0 @@ -from typing import List, Optional - -from dbally.audit import EventTracker -from dbally.exceptions import UnsupportedAggregationError -from dbally.iql import IQLQuery -from dbally.llms.base import LLM -from dbally.llms.clients import LLMOptions -from dbally.prompt.template import PromptFormat, PromptTemplate -from dbally.views.exposed_functions import ExposedFunction - - -def _validate_agg_response(llm_response: str) -> str: - """ - Validates LLM response to IQL - - Args: - llm_response: LLM response - - Returns: - A string containing aggregations. - - Raises: - UnsupportedAggregationError: When IQL generator is unable to construct a query - with given aggregation. - """ - if "unsupported query" in llm_response.lower(): - raise UnsupportedAggregationError - return llm_response - - -class AggregationPromptFormat(PromptFormat): - """ - Aggregation prompt format, providing a question and aggregation to be used in the conversation. - """ - - def __init__( - self, - question: str, - aggregations: List[ExposedFunction] = None, - ) -> None: - super().__init__() - self.question = question - self.aggregations = "\n".join([str(aggregation) for aggregation in aggregations]) if aggregations else [] - - -class AggregationFormatter: - """ - Class used to manage choice and formatting of aggregation based on natural language question. - """ - - def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[AggregationPromptFormat]] = None) -> None: - """ - Constructs a new AggregationFormatter instance. - - Args: - llm: LLM used to generate IQL - prompt_template: If not provided by the users is set to `AGGREGATION_GENERATION_TEMPLATE` - """ - self._llm = llm - self._prompt_template = prompt_template or AGGREGATION_GENERATION_TEMPLATE - - async def format_to_query_object( - self, - question: str, - event_tracker: EventTracker, - aggregations: List[ExposedFunction] = None, - llm_options: Optional[LLMOptions] = None, - ) -> IQLQuery: - """ - Generates IQL in text form using LLM. - - Args: - question: User question. - event_tracker: Event store used to audit the generation process. - aggregations: List of aggregations exposed by the view. - llm_options: Options to use for the LLM client. - - Returns: - Generated aggregation query. - """ - prompt_format = AggregationPromptFormat( - question=question, - aggregations=aggregations, - ) - - formatted_prompt = self._prompt_template.format_prompt(prompt_format) - - response = await self._llm.generate_text( - prompt=formatted_prompt, - event_tracker=event_tracker, - options=llm_options, - ) - # TODO: Move response parsing to llm generate_text method - agg = formatted_prompt.response_parser(response) - # TODO: Move IQL query parsing to prompt response parser - return await IQLQuery.parse( - source=agg, - allowed_functions=aggregations or [], - event_tracker=event_tracker, - ) - - -AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[AggregationPromptFormat]( - [ - { - "role": "system", - "content": "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" - "When prompted for an aggregation, use the following methods: \n" - "{aggregations}" - "DO NOT INCLUDE arguments names in your response. Only the values.\n" - "You MUST use only these methods:\n" - "\n{aggregations}\n" - "It is VERY IMPORTANT not to use methods other than those listed above." - """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" - "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " - "Structure output to resemble the following pattern:\n" - 'aggregation1("arg1", arg2)\n', - }, - {"role": "user", "content": "{question}"}, - ], - response_parser=_validate_agg_response, -) diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index 019bafea..2e5cff85 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -1,6 +1,6 @@ import abc from collections import defaultdict -from typing import Any, Dict, List, Optional, TypeVar +from typing import Dict, List, Optional from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult @@ -14,8 +14,6 @@ from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation -DataT = TypeVar("DataT", bound=Any) - class BaseStructuredView(BaseView): """ diff --git a/tests/unit/views/test_pandas_base.py b/tests/unit/views/test_pandas_base.py index 46e89750..029fe30f 100644 --- a/tests/unit/views/test_pandas_base.py +++ b/tests/unit/views/test_pandas_base.py @@ -41,19 +41,19 @@ class MockDataFrameView(DataFrameBaseView): @view_filter() def filter_city(self, city: str) -> pd.Series: - return self.data["city"] == city + return self.df["city"] == city @view_filter() def filter_year(self, year: int) -> pd.Series: - return self.data["year"] == year + return self.df["year"] == year @view_filter() def filter_age(self, age: int) -> pd.Series: - return self.data["age"] == age + return self.df["age"] == age @view_filter() def filter_name(self, name: str) -> pd.Series: - return self.data["name"] == name + return self.df["name"] == name @view_aggregation() def mean_age_by_city(self) -> AggregationGroup: