Skip to content

Commit

Permalink
refactor(iql): add iql gen exception (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst authored Jul 31, 2024
1 parent 0a22942 commit 90b0e66
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 11 deletions.
10 changes: 8 additions & 2 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.llms.clients.exceptions import LLMError
from dbally.prompt.elements import FewShotExample
from dbally.prompt.template import PromptTemplate
from dbally.views.exposed_functions import ExposedFunction
Expand Down Expand Up @@ -52,13 +53,15 @@ async def generate_iql(
event_tracker: Event store used to audit the generation process.
examples: List of examples to be injected into the conversation.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors.
n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection.
Returns:
Generated IQL query.
Raises:
IQLError: If IQL generation fails after all retries.
LLMError: If LLM text generation fails after all retries.
IQLError: If IQL parsing fails after all retries.
UnsupportedQueryError: If the question is not supported by the view.
"""
prompt_format = IQLGenerationPromptFormat(
question=question,
Expand All @@ -82,6 +85,9 @@ async def generate_iql(
allowed_functions=filters,
event_tracker=event_tracker,
)
except LLMError as exc:
if retry == n_retries:
raise exc
except IQLError as exc:
if retry == n_retries:
raise exc
Expand Down
3 changes: 3 additions & 0 deletions src/dbally/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ async def generate_text(
Returns:
Text response from LLM.
Raises:
LLMError: If LLM text generation fails.
"""
options = (self.default_options | options) if options else self.default_options
event = LLMEvent(prompt=prompt.chat, type=type(prompt).__name__)
Expand Down
3 changes: 3 additions & 0 deletions src/dbally/nl_responder/nl_responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ async def generate_response(
Returns:
Natural language response to the user question.
Raises:
LLMError: If LLM text generation fails.
"""
prompt_format = NLResponsePromptFormat(
question=question,
Expand Down
3 changes: 3 additions & 0 deletions src/dbally/view_selection/llm_view_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ async def select_view(
Returns:
The most relevant view name.
Raises:
LLMError: If LLM text generation fails.
"""
prompt_format = ViewSelectionPromptFormat(question=question, views=views)
formatted_prompt = self._prompt_template.format_prompt(prompt_format)
Expand Down
26 changes: 26 additions & 0 deletions src/dbally/views/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional

from dbally.exceptions import DbAllyError


class IQLGenerationError(DbAllyError):
"""
Exception for when an error occurs while generating IQL for a view.
"""

def __init__(
self,
view_name: str,
filters: Optional[str] = None,
aggregation: Optional[str] = None,
) -> None:
"""
Args:
view_name: Name of the view that caused the error.
filters: Filters generated by the view.
aggregation: Aggregation generated by the view.
"""
super().__init__(f"Error while generating IQL for view {view_name}")
self.view_name = view_name
self.filters = filters
self.aggregation = aggregation
36 changes: 27 additions & 9 deletions src/dbally/views/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
from dbally.audit.event_tracker import EventTracker
from dbally.collection.results import ViewExecutionResult
from dbally.iql import IQLQuery
from dbally.iql._exceptions import IQLError
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.views.exceptions import IQLGenerationError
from dbally.views.exposed_functions import ExposedFunction

from ..similarity import AbstractSimilarityIndex
Expand Down Expand Up @@ -57,21 +60,36 @@ async def ask(
The result of the query.
Raises:
IQLError: If the generated IQL query is not valid.
LLMError: If LLM text generation API fails.
IQLGenerationError: If the IQL generation fails.
"""
iql_generator = self.get_iql_generator(llm)

filters = self.list_filters()
examples = self.list_few_shots()

iql = await iql_generator.generate_iql(
question=query,
filters=filters,
examples=examples,
event_tracker=event_tracker,
llm_options=llm_options,
n_retries=n_retries,
)
try:
iql = await iql_generator.generate_iql(
question=query,
filters=filters,
examples=examples,
event_tracker=event_tracker,
llm_options=llm_options,
n_retries=n_retries,
)
except UnsupportedQueryError as exc:
raise IQLGenerationError(
view_name=self.__class__.__name__,
filters=None,
aggregation=None,
) from exc
except IQLError as exc:
raise IQLGenerationError(
view_name=self.__class__.__name__,
filters=exc.source,
aggregation=None,
) from exc

await self.apply_filters(iql)

result = self.execute(dry_run=dry_run)
Expand Down

0 comments on commit 90b0e66

Please sign in to comment.