From ff55482fe5d274174867a090788ec5bcf4f2b552 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 20 Aug 2024 10:45:23 +0200 Subject: [PATCH] add aggregations to IQLGenerator --- src/dbally/iql_generator/iql_generator.py | 228 ++++++++++++++---- .../iql_generator/iql_prompt_template.py | 0 src/dbally/iql_generator/prompt.py | 104 +++++--- src/dbally/prompt/template.py | 4 +- src/dbally/views/structured.py | 94 ++------ tests/unit/mocks.py | 22 +- tests/unit/test_collection.py | 22 +- tests/unit/test_iql_format.py | 16 +- tests/unit/test_iql_generator.py | 132 +++++----- 9 files changed, 377 insertions(+), 245 deletions(-) delete mode 100644 src/dbally/iql_generator/iql_prompt_template.py diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py index 27347734..12bd5e6d 100644 --- a/src/dbally/iql_generator/iql_generator.py +++ b/src/dbally/iql_generator/iql_generator.py @@ -1,71 +1,174 @@ +from dataclasses import dataclass from typing import List, Optional from dbally.audit.event_tracker import EventTracker from dbally.iql import IQLError, IQLQuery from dbally.iql_generator.prompt import ( + AGGREGATION_DECISION_TEMPLATE, + AGGREGATION_GENERATION_TEMPLATE, FILTERING_DECISION_TEMPLATE, - IQL_GENERATION_TEMPLATE, - FilteringDecisionPromptFormat, + FILTERS_GENERATION_TEMPLATE, + DecisionPromptFormat, IQLGenerationPromptFormat, + UnsupportedQueryError, ) from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions from dbally.llms.clients.exceptions import LLMError from dbally.prompt.elements import FewShotExample from dbally.prompt.template import PromptTemplate +from dbally.views.exceptions import IQLGenerationError from dbally.views.exposed_functions import ExposedFunction -ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ - generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" - -class IQLGenerator: +@dataclass +class IQLGeneratorState: + """ + State of the IQL generator. """ - Class used to generate IQL from natural language question. - In db-ally, LLM uses IQL (Intermediate Query Language) to express complex queries in a simplified way. - The class used to generate IQL from natural language query is `IQLGenerator`. + filters: Optional[IQLQuery] = None + aggregation: Optional[IQLQuery] = None - IQL generation is done using the method `self.generate_iql`. - It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question. + +class IQLGenerator: + """ + Program that orchestrates all IQL operations for the given question. """ def __init__( self, - llm: LLM, - *, - decision_prompt: Optional[PromptTemplate[FilteringDecisionPromptFormat]] = None, - generation_prompt: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None, + filters_generation: Optional["IQLOperationGenerator"] = None, + aggregation_generation: Optional["IQLOperationGenerator"] = None, ) -> None: """ Constructs a new IQLGenerator instance. Args: - llm: LLM used to generate IQL. decision_prompt: Prompt template for filtering decision making. generation_prompt: Prompt template for IQL generation. """ - self._llm = llm - self._decision_prompt = decision_prompt or FILTERING_DECISION_TEMPLATE - self._generation_prompt = generation_prompt or IQL_GENERATION_TEMPLATE + self._filters_generation = filters_generation or IQLOperationGenerator( + FILTERING_DECISION_TEMPLATE, + FILTERS_GENERATION_TEMPLATE, + ) + self._aggregation_generation = aggregation_generation or IQLOperationGenerator( + AGGREGATION_DECISION_TEMPLATE, + AGGREGATION_GENERATION_TEMPLATE, + ) - async def generate( + # pylint: disable=too-many-arguments + async def __call__( self, + *, question: str, filters: List[ExposedFunction], - event_tracker: EventTracker, - examples: Optional[List[FewShotExample]] = None, + aggregations: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, + event_tracker: Optional[EventTracker] = None, llm_options: Optional[LLMOptions] = None, n_retries: int = 3, - ) -> Optional[IQLQuery]: + ) -> IQLGeneratorState: """ - Generates IQL in text form using LLM. + Generates IQL operations for the given question. Args: question: User question. filters: List of filters exposed by the view. + aggregations: List of aggregations exposed by the view. + examples: List of examples to be injected during filters and aggregation generation. + llm: LLM used to generate IQL. event_tracker: Event store used to audit the generation process. + llm_options: Options to use for the LLM client. + n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. + + Returns: + Generated IQL operations. + + Raises: + IQLGenerationError: If IQL generation fails. + """ + try: + filters = await self._filters_generation( + question=question, + methods=filters, + examples=examples, + llm=llm, + llm_options=llm_options, + event_tracker=event_tracker, + n_retries=n_retries, + ) + except (IQLError, UnsupportedQueryError) as exc: + raise IQLGenerationError( + view_name=self.__class__.__name__, + filters=exc.source if isinstance(exc, IQLError) else None, + aggregation=None, + ) from exc + + try: + aggregation = await self._aggregation_generation( + question=question, + methods=aggregations, + examples=examples, + llm=llm, + llm_options=llm_options, + event_tracker=event_tracker, + n_retries=n_retries, + ) + except (IQLError, UnsupportedQueryError) as exc: + raise IQLGenerationError( + view_name=self.__class__.__name__, + filters=str(filters) if filters else None, + aggregation=exc.source if isinstance(exc, IQLError) else None, + ) from exc + + return IQLGeneratorState( + filters=filters, + aggregation=aggregation, + ) + + +class IQLOperationGenerator: + """ + Program that generates IQL queries for the given question. + """ + + def __init__( + self, + assessor_prompt: PromptTemplate[DecisionPromptFormat], + generator_prompt: PromptTemplate[IQLGenerationPromptFormat], + ) -> None: + """ + Constructs a new IQLGenerator instance. + + Args: + assessor_prompt: Prompt template for filtering decision making. + generator_prompt: Prompt template for IQL generation. + """ + self.assessor = IQLQuestionAssessor(assessor_prompt) + self.generator = IQLQueryGenerator(generator_prompt) + + async def __call__( + self, + *, + question: str, + methods: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, + event_tracker: Optional[EventTracker] = None, + llm_options: Optional[LLMOptions] = None, + n_retries: int = 3, + ) -> Optional[IQLQuery]: + """ + Generates IQL query for the given question. + + Args: + llm: LLM used to generate IQL. + question: User question. + methods: List of methods exposed by the view. examples: List of examples to be injected into the conversation. + event_tracker: Event store used to audit the generation process. llm_options: Options to use for the LLM client. n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. @@ -77,38 +180,52 @@ async def generate( IQLError: If IQL parsing fails after all retries. UnsupportedQueryError: If the question is not supported by the view. """ - decision = await self._decide_on_generation( + decision = await self.assessor( question=question, - event_tracker=event_tracker, + llm=llm, llm_options=llm_options, + event_tracker=event_tracker, n_retries=n_retries, ) if not decision: return None - return await self._generate_iql( + return await self.generator( question=question, - filters=filters, - event_tracker=event_tracker, + methods=methods, examples=examples, + llm=llm, llm_options=llm_options, + event_tracker=event_tracker, n_retries=n_retries, ) - async def _decide_on_generation( + +class IQLQuestionAssessor: + """ + Program that assesses whether a question requires applying IQL operation or not. + """ + + def __init__(self, prompt: PromptTemplate[DecisionPromptFormat]) -> None: + self.prompt = prompt + + async def __call__( self, + *, question: str, - event_tracker: EventTracker, + llm: LLM, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, n_retries: int = 3, ) -> bool: """ - Decides whether the question requires filtering or not. + Decides whether the question requires generating IQL or not. Args: question: User question. - event_tracker: Event store used to audit the generation process. + llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. + event_tracker: Event store used to audit the generation process. n_retries: Number of retries to LLM API in case of errors. Returns: @@ -117,12 +234,14 @@ async def _decide_on_generation( Raises: LLMError: If LLM text generation fails after all retries. """ - prompt_format = FilteringDecisionPromptFormat(question=question) - formatted_prompt = self._decision_prompt.format_prompt(prompt_format) + prompt_format = DecisionPromptFormat( + question=question, + ) + formatted_prompt = self.prompt.format_prompt(prompt_format) for retry in range(n_retries + 1): try: - response = await self._llm.generate_text( + response = await llm.generate_text( prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, @@ -133,24 +252,39 @@ async def _decide_on_generation( if retry == n_retries: raise exc - async def _generate_iql( + +class IQLQueryGenerator: + """ + Program that generates IQL queries for the given question. + """ + + ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \ + generation of correct IQL is very important. Below you have errors generated by the system:\n{error}" + + def __init__(self, prompt: PromptTemplate[IQLGenerationPromptFormat]) -> None: + self.prompt = prompt + + async def __call__( self, + *, question: str, - filters: List[ExposedFunction], - event_tracker: Optional[EventTracker] = None, - examples: Optional[List[FewShotExample]] = None, + methods: List[ExposedFunction], + examples: List[FewShotExample], + llm: LLM, llm_options: Optional[LLMOptions] = None, + event_tracker: Optional[EventTracker] = None, n_retries: int = 3, ) -> IQLQuery: """ - Generates IQL in text form using LLM. + Generates IQL query for the given question. Args: question: User question. filters: List of filters exposed by the view. - event_tracker: Event store used to audit the generation process. examples: List of examples to be injected into the conversation. + llm: LLM used to generate IQL. llm_options: Options to use for the LLM client. + event_tracker: Event store used to audit the generation process. n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection. Returns: @@ -163,14 +297,14 @@ async def _generate_iql( """ prompt_format = IQLGenerationPromptFormat( question=question, - filters=filters, + methods=methods, examples=examples, ) - formatted_prompt = self._generation_prompt.format_prompt(prompt_format) + formatted_prompt = self.prompt.format_prompt(prompt_format) for retry in range(n_retries + 1): try: - response = await self._llm.generate_text( + response = await llm.generate_text( prompt=formatted_prompt, event_tracker=event_tracker, options=llm_options, @@ -180,7 +314,7 @@ async def _generate_iql( # TODO: Move IQL query parsing to prompt response parser return await IQLQuery.parse( source=iql, - allowed_functions=filters, + allowed_functions=methods, event_tracker=event_tracker, ) except LLMError as exc: @@ -190,4 +324,4 @@ async def _generate_iql( if retry == n_retries: raise exc formatted_prompt = formatted_prompt.add_assistant_message(response) - formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc)) + formatted_prompt = formatted_prompt.add_user_message(self.ERROR_MESSAGE.format(error=exc)) diff --git a/src/dbally/iql_generator/iql_prompt_template.py b/src/dbally/iql_generator/iql_prompt_template.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py index 4e5a45ec..395ce50f 100644 --- a/src/dbally/iql_generator/prompt.py +++ b/src/dbally/iql_generator/prompt.py @@ -1,6 +1,6 @@ # pylint: disable=C0301 -from typing import List +from typing import List, Optional from dbally.exceptions import DbAllyError from dbally.prompt.elements import FewShotExample @@ -52,7 +52,7 @@ def _decision_iql_response_parser(response: str) -> bool: return "true" in decision -class FilteringDecisionPromptFormat(PromptFormat): +class DecisionPromptFormat(PromptFormat): """ IQL prompt format, providing a question and filters to be used in the conversation. """ @@ -71,44 +71,96 @@ def __init__(self, *, question: str, examples: List[FewShotExample] = None) -> N class IQLGenerationPromptFormat(PromptFormat): """ - IQL prompt format, providing a question and filters to be used in the conversation. + IQL prompt format, providing a question and methods to be used in the conversation. """ def __init__( self, *, question: str, - filters: List[ExposedFunction], - examples: List[FewShotExample] = None, + methods: List[ExposedFunction], + examples: Optional[List[FewShotExample]] = None, ) -> None: """ Constructs a new IQLGenerationPromptFormat instance. Args: question: Question to be asked. - filters: List of filters exposed by the view. + methods: List of filters exposed by the view. examples: List of examples to be injected into the conversation. aggregations: List of aggregations exposed by the view. """ super().__init__(examples) self.question = question - self.filters = "\n".join([str(condition) for condition in filters]) if filters else [] + self.methods = "\n".join([str(condition) for condition in methods]) if methods else [] + +FILTERING_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( + [ + { + "role": "system", + "content": ( + "Given a question, determine whether the answer requires initial data filtering in order to compute it.\n" + "Initial data filtering is a process in which the result set is reduced to only include the rows " + "that meet certain criteria specified in the question.\n\n" + "---\n\n" + "Follow the following format.\n\n" + "Question: ${{question}}\n" + "Hint: ${{hint}}" + "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" + "Decision: indicates whether the answer to the question requires initial data filtering. " + "(Respond with True or False)\n\n" + ), + }, + { + "role": "user", + "content": ( + "Question: {question}\n" + "Hint: Look for words indicating data specific features.\n" + "Reasoning: Let's think step by step in order to " + ), + }, + ], + response_parser=_decision_iql_response_parser, +) -IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( +AGGREGATION_DECISION_TEMPLATE = PromptTemplate[DecisionPromptFormat]( + [ + { + "role": "system", + "content": ( + "Given a question, determine whether the answer requires computing the aggregation in order to compute it.\n" + "Aggregation is a process in which the result set is reduced to a single value.\n\n" + "---\n\n" + "Follow the following format.\n\n" + "Question: ${{question}}\n" + "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" + "Decision: indicates whether the answer to the question requires initial data filtering. " + "(Respond with True or False)\n\n" + ), + }, + { + "role": "user", + "content": ("Question: {question}\n" "Reasoning: Let's think step by step in order to "), + }, + ], + response_parser=_decision_iql_response_parser, +) + +FILTERS_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( [ { "role": "system", "content": ( "You have access to an API that lets you query a database:\n" - "\n{filters}\n" + "\n{methods}\n" "Suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n" "Remember! Don't give any comments, just the function calls.\n" "The output will look like this:\n" 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n' "DO NOT INCLUDE arguments names in your response. Only the values.\n" "You MUST use only these methods:\n" - "\n{filters}\n" + "\n{methods}\n" "It is VERY IMPORTANT not to use methods other than those listed above." """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" "This is CRUCIAL, otherwise the system will crash. " @@ -122,32 +174,28 @@ def __init__( response_parser=_validate_iql_response, ) - -FILTERING_DECISION_TEMPLATE = PromptTemplate[FilteringDecisionPromptFormat]( +AGGREGATION_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat]( [ { "role": "system", "content": ( - "Given a question, determine whether the answer requires initial data filtering in order to compute it.\n" - "Initial data filtering is a process in which the result set is reduced to only include the rows " - "that meet certain criteria specified in the question.\n\n" - "---\n\n" - "Follow the following format.\n\n" - "Question: ${{question}}\n" - "Hint: ${{hint}}" - "Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n" - "Decision: indicates whether the answer to the question requires initial data filtering. " - "(Respond with True or False)\n\n" + "You have access to an API that lets you query a database supporting a SINGLE aggregation.\n" + "When prompted for an aggregation, use the following methods: \n" + "{methods}" + "DO NOT INCLUDE arguments names in your response. Only the values.\n" + "You MUST use only these methods:\n" + "\n{methods}\n" + "It is VERY IMPORTANT not to use methods other than those listed above." + """If you DON'T KNOW HOW TO ANSWER DON'T SAY anything other than `UNSUPPORTED QUERY`""" + "This is CRUCIAL to put `UNSUPPORTED QUERY` text only, otherwise the system will crash. " + "Structure output to resemble the following pattern:\n" + 'aggregation1("arg1", arg2)\n' ), }, { "role": "user", - "content": ( - "Question: {question}\n" - "Hint: Look for words indicating data specific features.\n" - "Reasoning: Let's think step by step in order to " - ), + "content": "{question}", }, ], - response_parser=_decision_iql_response_parser, + response_parser=_validate_iql_response, ) diff --git a/src/dbally/prompt/template.py b/src/dbally/prompt/template.py index 124a3e1c..b4ef650d 100644 --- a/src/dbally/prompt/template.py +++ b/src/dbally/prompt/template.py @@ -1,6 +1,6 @@ import copy import re -from typing import Callable, Dict, Generic, List, TypeVar +from typing import Callable, Dict, Generic, List, Optional, TypeVar from typing_extensions import Self @@ -55,7 +55,7 @@ class PromptFormat: Generic format for prompts allowing to inject few shot examples into the conversation. """ - def __init__(self, examples: List[FewShotExample] = None) -> None: + def __init__(self, examples: Optional[List[FewShotExample]] = None) -> None: """ Constructs a new PromptFormat instance. diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py index c3ac91e0..d6444826 100644 --- a/src/dbally/views/structured.py +++ b/src/dbally/views/structured.py @@ -4,17 +4,12 @@ from dbally.audit.event_tracker import EventTracker from dbally.collection.results import ViewExecutionResult -from dbally.exceptions import UnsupportedAggregationError from dbally.iql import IQLQuery -from dbally.iql._exceptions import IQLError from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.prompt import UnsupportedQueryError from dbally.llms.base import LLM from dbally.llms.clients.base import LLMOptions -from dbally.views.exceptions import IQLGenerationError from dbally.views.exposed_functions import ExposedFunction -from ..prompt.aggregation import AggregationFormatter from ..similarity import AbstractSimilarityIndex from .base import BaseView, IndexLocation @@ -32,29 +27,14 @@ def __init__(self, data: DataT) -> None: super().__init__() self.data = data - def get_iql_generator(self, llm: LLM) -> IQLGenerator: + def get_iql_generator(self) -> IQLGenerator: """ Returns the IQL generator for the view. - Args: - llm: LLM used to generate the IQL queries. - Returns: IQL generator for the view. """ - return IQLGenerator(llm=llm) - - def get_agg_formatter(self, llm: LLM) -> AggregationFormatter: - """ - Returns the AggregtionFormatter for the view. - - Args: - llm: LLM used to generate the queries. - - Returns: - AggregtionFormatter for the view. - """ - return AggregationFormatter(llm=llm) + return IQLGenerator() async def ask( self, @@ -84,65 +64,33 @@ async def ask( LLMError: If LLM text generation API fails. IQLGenerationError: If the IQL generation fails. """ - iql_generator = self.get_iql_generator(llm) - agg_formatter = self.get_agg_formatter(llm) filters = self.list_filters() examples = self.list_few_shots() aggregations = self.list_aggregations() - try: - iql = await iql_generator.generate( - question=query, - filters=filters, - examples=examples, - event_tracker=event_tracker, - llm_options=llm_options, - n_retries=n_retries, - ) - except UnsupportedQueryError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=None, - aggregation=None, - ) from exc - except IQLError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=exc.source, - aggregation=None, - ) from exc - - if iql: - await self.apply_filters(iql) - - try: - agg_node = await agg_formatter.format_to_query_object( - question=query, - aggregations=aggregations, - event_tracker=event_tracker, - llm_options=llm_options, - ) - except UnsupportedAggregationError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=str(iql) if iql else None, - aggregation=None, - ) from exc - except IQLError as exc: - raise IQLGenerationError( - view_name=self.__class__.__name__, - filters=str(iql) if iql else None, - aggregation=exc.source, - ) from exc - - await self.apply_aggregation(agg_node) + iql_generator = self.get_iql_generator() + iql = await iql_generator( + question=query, + filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, + event_tracker=event_tracker, + llm_options=llm_options, + n_retries=n_retries, + ) + + if iql.filters: + await self.apply_filters(iql.filters) + + if iql.aggregation: + await self.apply_aggregation(iql.aggregation) result = self.execute(dry_run=dry_run) result.context["iql"] = { - "filters": str(iql) if iql else None, - "aggregation": str(agg_node), + "filters": str(iql.filters) if iql.filters else None, + "aggregation": str(iql.aggregation) if iql.aggregation else None, } - return result @abc.abstractmethod diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 992fd03d..cc79d76c 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -10,10 +10,9 @@ from dbally import NOT_GIVEN, NotGiven from dbally.iql import IQLQuery -from dbally.iql_generator.iql_generator import IQLGenerator +from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState from dbally.llms.base import LLM from dbally.llms.clients.base import LLMClient, LLMOptions -from dbally.prompt.aggregation import AggregationFormatter from dbally.similarity.index import AbstractSimilarityIndex from dbally.view_selection.base import ViewSelector from dbally.views.structured import BaseStructuredView, ExposedFunction, ViewExecutionResult @@ -44,21 +43,12 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult: class MockIQLGenerator(IQLGenerator): - def __init__(self, iql: IQLQuery) -> None: - self.iql = iql - super().__init__(llm=MockLLM()) + def __init__(self, state: IQLGeneratorState) -> None: + self.state = state + super().__init__() - async def generate(self, *_, **__) -> IQLQuery: - return self.iql - - -class MockAggregationFormatter(AggregationFormatter): - def __init__(self, iql_query: IQLQuery) -> None: - self.iql_query = iql_query - super().__init__(llm=MockLLM()) - - async def format_to_query_object(self, *_, **__) -> IQLQuery: - return self.iql_query + async def __call__(self, *_, **__) -> IQLQuery: + return self.state class MockViewSelector(ViewSelector): diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index a077286d..50aee918 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -12,15 +12,9 @@ from dbally.collection.results import ViewExecutionResult from dbally.iql import IQLQuery from dbally.iql.syntax import FunctionCall +from dbally.iql_generator.iql_generator import IQLGeneratorState from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping -from tests.unit.mocks import ( - MockAggregationFormatter, - MockIQLGenerator, - MockLLM, - MockSimilarityIndex, - MockViewBase, - MockViewSelector, -) +from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector class MockView1(MockViewBase): @@ -66,15 +60,17 @@ def execute(self, dry_run=False) -> ViewExecutionResult: def list_filters(self) -> List[ExposedFunction]: return [ExposedFunction("test_filter", "", [])] - def get_iql_generator(self, *_, **__) -> MockIQLGenerator: - return MockIQLGenerator(IQLQuery(FunctionCall("test_filter", []), "test_filter()")) + def get_iql_generator(self) -> MockIQLGenerator: + return MockIQLGenerator( + IQLGeneratorState( + filters=IQLQuery(FunctionCall("test_filter", []), "test_filter()"), + aggregation=IQLQuery(FunctionCall("test_aggregation", []), "test_aggregation()"), + ), + ) def list_aggregations(self) -> List[ExposedFunction]: return [ExposedFunction("test_aggregation", "", [])] - def get_agg_formatter(self, *_, **__) -> MockAggregationFormatter: - return MockAggregationFormatter(IQLQuery(FunctionCall("test_aggregation", []), "test_aggregation()")) - @pytest.fixture(name="similarity_classes") def mock_similarity_classes() -> ( diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py index b798e533..a2bf23c4 100644 --- a/tests/unit/test_iql_format.py +++ b/tests/unit/test_iql_format.py @@ -1,14 +1,14 @@ -from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat +from dbally.iql_generator.prompt import FILTERS_GENERATION_TEMPLATE, IQLGenerationPromptFormat from dbally.prompt.elements import FewShotExample async def test_iql_prompt_format_default() -> None: prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=[], ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) assert formatted_prompt.chat == [ { @@ -35,10 +35,10 @@ async def test_iql_prompt_format_few_shots_injected() -> None: examples = [FewShotExample("q1", "a1")] prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=examples, ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) assert formatted_prompt.chat == [ { @@ -67,12 +67,12 @@ async def test_iql_input_format_few_shot_examples_repeat_no_example_duplicates() examples = [FewShotExample("q1", "a1")] prompt_format = IQLGenerationPromptFormat( question="", - filters=[], + methods=[], examples=examples, ) - formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format) + formatted_prompt = FILTERS_GENERATION_TEMPLATE.format_prompt(prompt_format) - assert len(formatted_prompt.chat) == len(IQL_GENERATION_TEMPLATE.chat) + (len(examples) * 2) + assert len(formatted_prompt.chat) == len(FILTERS_GENERATION_TEMPLATE.chat) + (len(examples) * 2) assert formatted_prompt.chat[1]["role"] == "user" assert formatted_prompt.chat[1]["content"] == examples[0].question assert formatted_prompt.chat[2]["role"] == "assistant" diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py index b95fe585..b8d81013 100644 --- a/tests/unit/test_iql_generator.py +++ b/tests/unit/test_iql_generator.py @@ -8,13 +8,8 @@ from dbally import decorators from dbally.audit.event_tracker import EventTracker from dbally.iql import IQLError, IQLQuery -from dbally.iql_generator.iql_generator import IQLGenerator -from dbally.iql_generator.prompt import ( - FILTERING_DECISION_TEMPLATE, - IQL_GENERATION_TEMPLATE, - FilteringDecisionPromptFormat, - IQLGenerationPromptFormat, -) +from dbally.iql_generator.iql_generator import IQLGenerator, IQLGeneratorState +from dbally.views.exceptions import IQLGenerationError from dbally.views.methods_base import MethodsBaseView from tests.unit.mocks import MockLLM @@ -62,71 +57,71 @@ def event_tracker() -> EventTracker: @pytest.fixture -def iql_generator(llm: MockLLM) -> IQLGenerator: - return IQLGenerator(llm) +def iql_generator() -> IQLGenerator: + return IQLGenerator() @pytest.mark.asyncio -async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventTracker, view: MockView) -> None: +async def test_iql_generation( + iql_generator: IQLGenerator, + llm: MockLLM, + event_tracker: EventTracker, + view: MockView, +) -> None: filters = view.list_filters() - - decision_format = FilteringDecisionPromptFormat( - question="Mock_question", - ) - generation_format = IQLGenerationPromptFormat( - question="Mock_question", - filters=filters, - ) - - decision_prompt = FILTERING_DECISION_TEMPLATE.format_prompt(decision_format) - generation_prompt = IQL_GENERATION_TEMPLATE.format_prompt(generation_format) + aggregations = view.list_aggregations() + examples = view.list_few_shots() llm_responses = [ "decision: true", "filter_by_id(1)", + "decision: true", + "aggregate_by_id()", + ] + iql_parser_responses = [ + "filter_by_id(1)", + "aggregate_by_id()", ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse: - iql = await iql_generator.generate( + + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=iql_parser_responses)) as mock_parse: + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, ) - assert iql == "filter_by_id(1)" - iql_generator._llm.generate_text.assert_has_calls( + assert iql == IQLGeneratorState(filters="filter_by_id(1)", aggregation="aggregate_by_id()") + assert llm.generate_text.call_count == 4 + mock_parse.assert_has_calls( [ call( - prompt=decision_prompt, + source="filter_by_id(1)", + allowed_functions=filters, event_tracker=event_tracker, - options=None, ), call( - prompt=generation_prompt, + source="aggregate_by_id()", + allowed_functions=aggregations, event_tracker=event_tracker, - options=None, ), ] ) - mock_parse.assert_called_once_with( - source="filter_by_id(1)", - allowed_functions=filters, - event_tracker=event_tracker, - ) @pytest.mark.asyncio async def test_iql_generation_error_escalation_after_max_retires( iql_generator: IQLGenerator, + llm: MockLLM, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - responses = [ - IQLError("err1", "src1"), - IQLError("err2", "src2"), - IQLError("err3", "src3"), - IQLError("err4", "src4"), - ] + aggregations = view.list_aggregations() + examples = view.list_few_shots() + llm_responses = [ "decision: true", "filter_by_id(1)", @@ -134,53 +129,74 @@ async def test_iql_generation_error_escalation_after_max_retires( "filter_by_id(1)", "filter_by_id(1)", ] + iql_parser_responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + IQLError("err4", "src4"), + ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)), pytest.raises(IQLError): - iql = await iql_generator.generate( + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=iql_parser_responses)), pytest.raises( + IQLGenerationError + ): + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, n_retries=3, ) assert iql is None - assert iql_generator._llm.generate_text.call_count == 4 - for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[1:], start=1): + assert llm.generate_text.call_count == 4 + for i, arg in enumerate(llm.generate_text.call_args_list[1:], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"] @pytest.mark.asyncio async def test_iql_generation_response_after_max_retries( iql_generator: IQLGenerator, + llm: MockLLM, event_tracker: EventTracker, view: MockView, ) -> None: filters = view.list_filters() - responses = [ - IQLError("err1", "src1"), - IQLError("err2", "src2"), - IQLError("err3", "src3"), - "filter_by_id(1)", - ] + aggregations = view.list_aggregations() + examples = view.list_few_shots() + llm_responses = [ "decision: true", "filter_by_id(1)", "filter_by_id(1)", "filter_by_id(1)", "filter_by_id(1)", + "decision: true", + "aggregate_by_id()", + ] + iql_parser_responses = [ + IQLError("err1", "src1"), + IQLError("err2", "src2"), + IQLError("err3", "src3"), + "filter_by_id(1)", + "aggregate_by_id()", ] - iql_generator._llm.generate_text = AsyncMock(side_effect=llm_responses) - with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=responses)): - iql = await iql_generator.generate( + llm.generate_text = AsyncMock(side_effect=llm_responses) + with patch("dbally.iql.IQLQuery.parse", AsyncMock(side_effect=iql_parser_responses)): + iql = await iql_generator( question="Mock_question", filters=filters, + aggregations=aggregations, + examples=examples, + llm=llm, event_tracker=event_tracker, n_retries=3, ) - assert iql == "filter_by_id(1)" - assert iql_generator._llm.generate_text.call_count == 5 - for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[2:], start=1): + assert iql == IQLGeneratorState(filters="filter_by_id(1)", aggregation="aggregate_by_id()") + assert llm.generate_text.call_count == 7 + for i, arg in enumerate(llm.generate_text.call_args_list[2:5], start=1): assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"]