diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 98bb2d6a..a9656457 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -87,7 +87,7 @@ async def generate_iql( # TODO: Move IQL query parsing to prompt response parser return await IQLQuery.parse( source=iql, - allowed_functions=filters + aggregations, + allowed_functions=filters or [] + aggregations or [], event_tracker=event_tracker, ) except IQLError as exc: diff --git a/src/dbally/views/sqlalchemy_base.py b/src/dbally/views/sqlalchemy_base.py index ecb74536..e2d62984 100644 --- a/src/dbally/views/sqlalchemy_base.py +++ b/src/dbally/views/sqlalchemy_base.py @@ -35,7 +35,7 @@ def get_subquery(self) -> sqlalchemy.Subquery: Returns: The sqlalchemy.Subquery object based on private _select attribute. """ - return self._select.subquery("aggregation") + return self._select.subquery("subquery") async def apply_filters(self, filters: IQLQuery) -> None: """ diff --git a/tests/integration/test_llm_options.py b/tests/integration/test_llm_options.py index fb8cfba4..411ffe14 100644 --- a/tests/integration/test_llm_options.py +++ b/tests/integration/test_llm_options.py @@ -30,7 +30,7 @@ async def test_llm_options_propagation(): llm_options=custom_options, ) - assert llm.client.call.call_count == 3 + assert llm.client.call.call_count == 4 llm.client.call.assert_has_calls( [ diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index 8f583c4c..b0bae183 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -14,14 +14,14 @@ async def test_iql_prompt_format_default() -> None: { "role": "system", "content": "You have access to API that lets you query a database:\n" - "\n\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "\n[]\n" + "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n\n" + "\n[]\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", @@ -44,14 +44,14 @@ async def test_iql_prompt_format_few_shots_injected() -> None: { "role": "system", "content": "You have access to API that lets you query a database:\n" - "\n\n" - "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" + "\n[]\n" + "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n\n" + "\n[]\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """ "This is CRUCIAL, otherwise the system will crash. ", diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index ce3f593d..4a79c394 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -21,6 +21,9 @@ def get_select(self) -> sqlalchemy.Select: async def apply_filters(self, filters: IQLQuery) -> None: ... + async def apply_aggregation(self, filters: IQLQuery) -> None: + ... + def execute(self, dry_run: bool = False): ... diff --git a/tests/unit/views/test_methods_base.py b/tests/unit/views/test_methods_base.py index 58959a64..841af1a7 100644 --- a/tests/unit/views/test_methods_base.py +++ b/tests/unit/views/test_methods_base.py @@ -28,6 +28,9 @@ def method_bar(self, cities: List[str], year: Literal["2023", "2024"], pairs: Li async def apply_filters(self, filters: IQLQuery) -> None: ... + async def apply_aggregation(self, filters: IQLQuery) -> None: + ... + def execute(self, dry_run: bool = False) -> ViewExecutionResult: return ViewExecutionResult(results=[], context={})