Skip to content

Commit

Permalink
Extended unit tests to cover aggregation handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
patryk-wyzgowski committed Jul 12, 2024
1 parent 44d9b6d commit 67b34fd
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def generate_iql(
# TODO: Move IQL query parsing to prompt response parser
return await IQLQuery.parse(
source=iql,
allowed_functions=filters + aggregations,
allowed_functions=filters or [] + aggregations or [],
event_tracker=event_tracker,
)
except IQLError as exc:
Expand Down
2 changes: 1 addition & 1 deletion src/dbally/views/sqlalchemy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_subquery(self) -> sqlalchemy.Subquery:
Returns:
The sqlalchemy.Subquery object based on private _select attribute.
"""
return self._select.subquery("aggregation")
return self._select.subquery("subquery")

async def apply_filters(self, filters: IQLQuery) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_llm_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def test_llm_options_propagation():
llm_options=custom_options,
)

assert llm.client.call.call_count == 3
assert llm.client.call.call_count == 4

llm.client.call.assert_has_calls(
[
Expand Down
12 changes: 6 additions & 6 deletions tests/unit/test_iql_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ async def test_iql_prompt_format_default() -> None:
{
"role": "system",
"content": "You have access to API that lets you query a database:\n"
"\n\n"
"Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
"\n[]\n"
"Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
"Remember! Don't give any comments, just the function calls.\n"
"The output will look like this:\n"
'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n'
"DO NOT INCLUDE arguments names in your response. Only the values.\n"
"You MUST use only these methods:\n"
"\n\n"
"\n[]\n"
"It is VERY IMPORTANT not to use methods other than those listed above."
"""If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """
"This is CRUCIAL, otherwise the system will crash. ",
Expand All @@ -44,14 +44,14 @@ async def test_iql_prompt_format_few_shots_injected() -> None:
{
"role": "system",
"content": "You have access to API that lets you query a database:\n"
"\n\n"
"Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
"\n[]\n"
"Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
"Remember! Don't give any comments, just the function calls.\n"
"The output will look like this:\n"
'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n'
"DO NOT INCLUDE arguments names in your response. Only the values.\n"
"You MUST use only these methods:\n"
"\n\n"
"\n[]\n"
"It is VERY IMPORTANT not to use methods other than those listed above."
"""If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """
"This is CRUCIAL, otherwise the system will crash. ",
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def get_select(self) -> sqlalchemy.Select:
async def apply_filters(self, filters: IQLQuery) -> None:
...

async def apply_aggregation(self, filters: IQLQuery) -> None:
...

def execute(self, dry_run: bool = False):
...

Expand Down
3 changes: 3 additions & 0 deletions tests/unit/views/test_methods_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def method_bar(self, cities: List[str], year: Literal["2023", "2024"], pairs: Li
async def apply_filters(self, filters: IQLQuery) -> None:
...

async def apply_aggregation(self, filters: IQLQuery) -> None:
...

def execute(self, dry_run: bool = False) -> ViewExecutionResult:
return ViewExecutionResult(results=[], context={})

Expand Down

0 comments on commit 67b34fd

Please sign in to comment.