Skip to content

Commit

Permalink
add aggregations to IQLGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 20, 2024
1 parent d21f4e1 commit ff55482
Show file tree
Hide file tree
Showing 9 changed files with 377 additions and 245 deletions.
228 changes: 181 additions & 47 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,174 @@
from dataclasses import dataclass
from typing import List, Optional

from dbally.audit.event_tracker import EventTracker
from dbally.iql import IQLError, IQLQuery
from dbally.iql_generator.prompt import (
AGGREGATION_DECISION_TEMPLATE,
AGGREGATION_GENERATION_TEMPLATE,
FILTERING_DECISION_TEMPLATE,
IQL_GENERATION_TEMPLATE,
FilteringDecisionPromptFormat,
FILTERS_GENERATION_TEMPLATE,
DecisionPromptFormat,
IQLGenerationPromptFormat,
UnsupportedQueryError,
)
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.llms.clients.exceptions import LLMError
from dbally.prompt.elements import FewShotExample
from dbally.prompt.template import PromptTemplate
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exposed_functions import ExposedFunction

ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \
generation of correct IQL is very important. Below you have errors generated by the system:\n{error}"


class IQLGenerator:
@dataclass
class IQLGeneratorState:
"""
State of the IQL generator.
"""
Class used to generate IQL from natural language question.

In db-ally, LLM uses IQL (Intermediate Query Language) to express complex queries in a simplified way.
The class used to generate IQL from natural language query is `IQLGenerator`.
filters: Optional[IQLQuery] = None
aggregation: Optional[IQLQuery] = None

IQL generation is done using the method `self.generate_iql`.
It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question.

class IQLGenerator:
"""
Program that orchestrates all IQL operations for the given question.
"""

def __init__(
self,
llm: LLM,
*,
decision_prompt: Optional[PromptTemplate[FilteringDecisionPromptFormat]] = None,
generation_prompt: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None,
filters_generation: Optional["IQLOperationGenerator"] = None,
aggregation_generation: Optional["IQLOperationGenerator"] = None,
) -> None:
"""
Constructs a new IQLGenerator instance.
Args:
llm: LLM used to generate IQL.
decision_prompt: Prompt template for filtering decision making.
generation_prompt: Prompt template for IQL generation.
"""
self._llm = llm
self._decision_prompt = decision_prompt or FILTERING_DECISION_TEMPLATE
self._generation_prompt = generation_prompt or IQL_GENERATION_TEMPLATE
self._filters_generation = filters_generation or IQLOperationGenerator(
FILTERING_DECISION_TEMPLATE,
FILTERS_GENERATION_TEMPLATE,
)
self._aggregation_generation = aggregation_generation or IQLOperationGenerator(
AGGREGATION_DECISION_TEMPLATE,
AGGREGATION_GENERATION_TEMPLATE,
)

async def generate(
# pylint: disable=too-many-arguments
async def __call__(
self,
*,
question: str,
filters: List[ExposedFunction],
event_tracker: EventTracker,
examples: Optional[List[FewShotExample]] = None,
aggregations: List[ExposedFunction],
examples: List[FewShotExample],
llm: LLM,
event_tracker: Optional[EventTracker] = None,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
) -> Optional[IQLQuery]:
) -> IQLGeneratorState:
"""
Generates IQL in text form using LLM.
Generates IQL operations for the given question.
Args:
question: User question.
filters: List of filters exposed by the view.
aggregations: List of aggregations exposed by the view.
examples: List of examples to be injected during filters and aggregation generation.
llm: LLM used to generate IQL.
event_tracker: Event store used to audit the generation process.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection.
Returns:
Generated IQL operations.
Raises:
IQLGenerationError: If IQL generation fails.
"""
try:
filters = await self._filters_generation(
question=question,
methods=filters,
examples=examples,
llm=llm,
llm_options=llm_options,
event_tracker=event_tracker,
n_retries=n_retries,
)
except (IQLError, UnsupportedQueryError) as exc:
raise IQLGenerationError(
view_name=self.__class__.__name__,
filters=exc.source if isinstance(exc, IQLError) else None,
aggregation=None,
) from exc

try:
aggregation = await self._aggregation_generation(
question=question,
methods=aggregations,
examples=examples,
llm=llm,
llm_options=llm_options,
event_tracker=event_tracker,
n_retries=n_retries,
)
except (IQLError, UnsupportedQueryError) as exc:
raise IQLGenerationError(
view_name=self.__class__.__name__,
filters=str(filters) if filters else None,
aggregation=exc.source if isinstance(exc, IQLError) else None,
) from exc

return IQLGeneratorState(
filters=filters,
aggregation=aggregation,
)


class IQLOperationGenerator:
"""
Program that generates IQL queries for the given question.
"""

def __init__(
self,
assessor_prompt: PromptTemplate[DecisionPromptFormat],
generator_prompt: PromptTemplate[IQLGenerationPromptFormat],
) -> None:
"""
Constructs a new IQLGenerator instance.
Args:
assessor_prompt: Prompt template for filtering decision making.
generator_prompt: Prompt template for IQL generation.
"""
self.assessor = IQLQuestionAssessor(assessor_prompt)
self.generator = IQLQueryGenerator(generator_prompt)

async def __call__(
self,
*,
question: str,
methods: List[ExposedFunction],
examples: List[FewShotExample],
llm: LLM,
event_tracker: Optional[EventTracker] = None,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
) -> Optional[IQLQuery]:
"""
Generates IQL query for the given question.
Args:
llm: LLM used to generate IQL.
question: User question.
methods: List of methods exposed by the view.
examples: List of examples to be injected into the conversation.
event_tracker: Event store used to audit the generation process.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection.
Expand All @@ -77,38 +180,52 @@ async def generate(
IQLError: If IQL parsing fails after all retries.
UnsupportedQueryError: If the question is not supported by the view.
"""
decision = await self._decide_on_generation(
decision = await self.assessor(
question=question,
event_tracker=event_tracker,
llm=llm,
llm_options=llm_options,
event_tracker=event_tracker,
n_retries=n_retries,
)
if not decision:
return None

return await self._generate_iql(
return await self.generator(
question=question,
filters=filters,
event_tracker=event_tracker,
methods=methods,
examples=examples,
llm=llm,
llm_options=llm_options,
event_tracker=event_tracker,
n_retries=n_retries,
)

async def _decide_on_generation(

class IQLQuestionAssessor:
"""
Program that assesses whether a question requires applying IQL operation or not.
"""

def __init__(self, prompt: PromptTemplate[DecisionPromptFormat]) -> None:
self.prompt = prompt

async def __call__(
self,
*,
question: str,
event_tracker: EventTracker,
llm: LLM,
llm_options: Optional[LLMOptions] = None,
event_tracker: Optional[EventTracker] = None,
n_retries: int = 3,
) -> bool:
"""
Decides whether the question requires filtering or not.
Decides whether the question requires generating IQL or not.
Args:
question: User question.
event_tracker: Event store used to audit the generation process.
llm: LLM used to generate IQL.
llm_options: Options to use for the LLM client.
event_tracker: Event store used to audit the generation process.
n_retries: Number of retries to LLM API in case of errors.
Returns:
Expand All @@ -117,12 +234,14 @@ async def _decide_on_generation(
Raises:
LLMError: If LLM text generation fails after all retries.
"""
prompt_format = FilteringDecisionPromptFormat(question=question)
formatted_prompt = self._decision_prompt.format_prompt(prompt_format)
prompt_format = DecisionPromptFormat(
question=question,
)
formatted_prompt = self.prompt.format_prompt(prompt_format)

for retry in range(n_retries + 1):
try:
response = await self._llm.generate_text(
response = await llm.generate_text(
prompt=formatted_prompt,
event_tracker=event_tracker,
options=llm_options,
Expand All @@ -133,24 +252,39 @@ async def _decide_on_generation(
if retry == n_retries:
raise exc

async def _generate_iql(

class IQLQueryGenerator:
"""
Program that generates IQL queries for the given question.
"""

ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \
generation of correct IQL is very important. Below you have errors generated by the system:\n{error}"

def __init__(self, prompt: PromptTemplate[IQLGenerationPromptFormat]) -> None:
self.prompt = prompt

async def __call__(
self,
*,
question: str,
filters: List[ExposedFunction],
event_tracker: Optional[EventTracker] = None,
examples: Optional[List[FewShotExample]] = None,
methods: List[ExposedFunction],
examples: List[FewShotExample],
llm: LLM,
llm_options: Optional[LLMOptions] = None,
event_tracker: Optional[EventTracker] = None,
n_retries: int = 3,
) -> IQLQuery:
"""
Generates IQL in text form using LLM.
Generates IQL query for the given question.
Args:
question: User question.
filters: List of filters exposed by the view.
event_tracker: Event store used to audit the generation process.
examples: List of examples to be injected into the conversation.
llm: LLM used to generate IQL.
llm_options: Options to use for the LLM client.
event_tracker: Event store used to audit the generation process.
n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection.
Returns:
Expand All @@ -163,14 +297,14 @@ async def _generate_iql(
"""
prompt_format = IQLGenerationPromptFormat(
question=question,
filters=filters,
methods=methods,
examples=examples,
)
formatted_prompt = self._generation_prompt.format_prompt(prompt_format)
formatted_prompt = self.prompt.format_prompt(prompt_format)

for retry in range(n_retries + 1):
try:
response = await self._llm.generate_text(
response = await llm.generate_text(
prompt=formatted_prompt,
event_tracker=event_tracker,
options=llm_options,
Expand All @@ -180,7 +314,7 @@ async def _generate_iql(
# TODO: Move IQL query parsing to prompt response parser
return await IQLQuery.parse(
source=iql,
allowed_functions=filters,
allowed_functions=methods,
event_tracker=event_tracker,
)
except LLMError as exc:
Expand All @@ -190,4 +324,4 @@ async def _generate_iql(
if retry == n_retries:
raise exc
formatted_prompt = formatted_prompt.add_assistant_message(response)
formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc))
formatted_prompt = formatted_prompt.add_user_message(self.ERROR_MESSAGE.format(error=exc))
Empty file.
Loading

0 comments on commit ff55482

Please sign in to comment.