Skip to content

Commit

Permalink
Merge branch 'main' into mp/prompt-tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Aug 9, 2024
2 parents c1a503c + 11a7b21 commit 7449409
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 45 deletions.
12 changes: 6 additions & 6 deletions src/dbally/iql/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,19 @@ def __init__(self, source: str) -> None:
super().__init__(message, source)


class IQLEmptyExpressionError(IQLError):
"""Raised when IQL expression is empty."""
class IQLNoStatementError(IQLError):
"""Raised when IQL does not have any statement."""

def __init__(self, source: str) -> None:
message = "Empty IQL expression"
message = "Empty IQL"
super().__init__(message, source)


class IQLMultipleExpressionsError(IQLError):
"""Raised when IQL contains multiple expressions."""
class IQLMultipleStatementsError(IQLError):
"""Raised when IQL contains multiple statements."""

def __init__(self, nodes: List[ast.stmt], source: str) -> None:
message = "Multiple expressions or statements in IQL are not supported"
message = "Multiple statements in IQL are not supported"
super().__init__(message, source)
self.nodes = nodes

Expand Down
8 changes: 4 additions & 4 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from dbally.iql._exceptions import (
IQLArgumentParsingError,
IQLArgumentValidationError,
IQLEmptyExpressionError,
IQLFunctionNotExists,
IQLIncorrectNumberArgumentsError,
IQLMultipleExpressionsError,
IQLMultipleStatementsError,
IQLNoExpressionError,
IQLNoStatementError,
IQLSyntaxError,
IQLUnsupportedSyntaxError,
)
Expand Down Expand Up @@ -50,10 +50,10 @@ async def process(self) -> syntax.Node:
raise IQLSyntaxError(self.source) from exc

if not ast_tree.body:
raise IQLEmptyExpressionError(self.source)
raise IQLNoStatementError(self.source)

if len(ast_tree.body) > 1:
raise IQLMultipleExpressionsError(ast_tree.body, self.source)
raise IQLMultipleStatementsError(ast_tree.body, self.source)

if not isinstance(ast_tree.body[0], ast.Expr):
raise IQLNoExpressionError(ast_tree.body[0], self.source)
Expand Down
110 changes: 104 additions & 6 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from dbally.audit.event_tracker import EventTracker
from dbally.iql import IQLError, IQLQuery
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
from dbally.iql_generator.prompt import (
FILTERING_DECISION_TEMPLATE,
IQL_GENERATION_TEMPLATE,
FilteringDecisionPromptFormat,
IQLGenerationPromptFormat,
)
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.llms.clients.exceptions import LLMError
Expand All @@ -25,17 +30,110 @@ class IQLGenerator:
It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question.
"""

def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None) -> None:
def __init__(
self,
llm: LLM,
*,
decision_prompt: Optional[PromptTemplate[FilteringDecisionPromptFormat]] = None,
generation_prompt: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None,
) -> None:
"""
Constructs a new IQLGenerator instance.
Args:
llm: LLM used to generate IQL
llm: LLM used to generate IQL.
decision_prompt: Prompt template for filtering decision making.
generation_prompt: Prompt template for IQL generation.
"""
self._llm = llm
self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE
self._decision_prompt = decision_prompt or FILTERING_DECISION_TEMPLATE
self._generation_prompt = generation_prompt or IQL_GENERATION_TEMPLATE

async def generate(
self,
question: str,
filters: List[ExposedFunction],
event_tracker: EventTracker,
examples: Optional[List[FewShotExample]] = None,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
) -> Optional[IQLQuery]:
"""
Generates IQL in text form using LLM.
Args:
question: User question.
filters: List of filters exposed by the view.
event_tracker: Event store used to audit the generation process.
examples: List of examples to be injected into the conversation.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to regenerate IQL in case of errors in parsing or LLM connection.
Returns:
Generated IQL query or None if the decision is not to continue.
Raises:
LLMError: If LLM text generation fails after all retries.
IQLError: If IQL parsing fails after all retries.
UnsupportedQueryError: If the question is not supported by the view.
"""
decision = await self._decide_on_generation(
question=question,
event_tracker=event_tracker,
llm_options=llm_options,
n_retries=n_retries,
)
if not decision:
return None

return await self._generate_iql(
question=question,
filters=filters,
event_tracker=event_tracker,
examples=examples,
llm_options=llm_options,
n_retries=n_retries,
)

async def _decide_on_generation(
self,
question: str,
event_tracker: EventTracker,
llm_options: Optional[LLMOptions] = None,
n_retries: int = 3,
) -> bool:
"""
Decides whether the question requires filtering or not.
Args:
question: User question.
event_tracker: Event store used to audit the generation process.
llm_options: Options to use for the LLM client.
n_retries: Number of retries to LLM API in case of errors.
Returns:
Decision whether to generate IQL or not.
Raises:
LLMError: If LLM text generation fails after all retries.
"""
prompt_format = FilteringDecisionPromptFormat(question=question)
formatted_prompt = self._decision_prompt.format_prompt(prompt_format)

for retry in range(n_retries + 1):
try:
response = await self._llm.generate_text(
prompt=formatted_prompt,
event_tracker=event_tracker,
options=llm_options,
)
# TODO: Move response parsing to llm generate_text method
return formatted_prompt.response_parser(response)
except LLMError as exc:
if retry == n_retries:
raise exc

async def generate_iql(
async def _generate_iql(
self,
question: str,
filters: List[ExposedFunction],
Expand Down Expand Up @@ -68,7 +166,7 @@ async def generate_iql(
filters=filters,
examples=examples,
)
formatted_prompt = self._prompt_template.format_prompt(prompt_format)
formatted_prompt = self._generation_prompt.format_prompt(prompt_format)

for retry in range(n_retries + 1):
try:
Expand Down
65 changes: 65 additions & 0 deletions src/dbally/iql_generator/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,41 @@ def _validate_iql_response(llm_response: str) -> str:
return llm_response


def _decision_iql_response_parser(response: str) -> bool:
"""
Parses the response from the decision prompt.
Args:
response: Response from the LLM.
Returns:
True if the response is positive, False otherwise.
"""
response = response.lower()
if "decision:" not in response:
return False

_, decision = response.split("decision:", 1)
return "true" in decision


class FilteringDecisionPromptFormat(PromptFormat):
"""
IQL prompt format, providing a question and filters to be used in the conversation.
"""

def __init__(self, *, question: str, examples: List[FewShotExample] = None) -> None:
"""
Constructs a new IQLGenerationPromptFormat instance.
Args:
question: Question to be asked.
examples: List of examples to be injected into the conversation.
"""
super().__init__(examples)
self.question = question


class IQLGenerationPromptFormat(PromptFormat):
"""
IQL prompt format, providing a question and filters to be used in the conversation.
Expand Down Expand Up @@ -85,3 +120,33 @@ def __init__(
],
response_parser=_validate_iql_response,
)


FILTERING_DECISION_TEMPLATE = PromptTemplate[FilteringDecisionPromptFormat](
[
{
"role": "system",
"content": (
"Given a question, determine whether the answer requires initial data filtering in order to compute it.\n"
"Initial data filtering is a process in which the result set is reduced to only include the rows "
"that meet certain criteria specified in the question.\n\n"
"---\n\n"
"Follow the following format.\n\n"
"Question: ${{question}}\n"
"Hint: ${{hint}}"
"Reasoning: Let's think step by step in order to ${{produce the decision}}. We...\n"
"Decision: indicates whether the answer to the question requires initial data filtering. "
"(Respond with True or False)\n\n"
),
},
{
"role": "user",
"content": (
"Question: {question}\n"
"Hint: Look for words indicating data specific features.\n"
"Reasoning: Let's think step by step in order to "
),
},
],
response_parser=_decision_iql_response_parser,
)
3 changes: 2 additions & 1 deletion src/dbally/views/pandas_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from functools import reduce
from typing import Optional

import pandas as pd

Expand All @@ -25,7 +26,7 @@ def __init__(self, df: pd.DataFrame) -> None:
self.df = df

# The mask to be applied to the dataframe to filter the data
self._filter_mask: pd.Series = None
self._filter_mask: Optional[pd.Series] = None

async def apply_filters(self, filters: IQLQuery) -> None:
"""
Expand Down
13 changes: 9 additions & 4 deletions src/dbally/views/sqlalchemy_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
import asyncio
from typing import Optional

import sqlalchemy

Expand All @@ -13,10 +14,11 @@ class SqlAlchemyBaseView(MethodsBaseView):
Base class for views that use SQLAlchemy to generate SQL queries.
"""

def __init__(self, sqlalchemy_engine: sqlalchemy.engine.Engine) -> None:
def __init__(self, sqlalchemy_engine: sqlalchemy.Engine) -> None:
super().__init__()
self._select = self.get_select()
self._sqlalchemy_engine = sqlalchemy_engine
self._select = self.get_select()
self._where_clause: Optional[sqlalchemy.ColumnElement] = None

@abc.abstractmethod
def get_select(self) -> sqlalchemy.Select:
Expand All @@ -34,7 +36,7 @@ async def apply_filters(self, filters: IQLQuery) -> None:
Args:
filters: IQLQuery object representing the filters to apply
"""
self._select = self._select.where(await self._build_filter_node(filters.root))
self._where_clause = await self._build_filter_node(filters.root)

async def _build_filter_node(self, node: syntax.Node) -> sqlalchemy.ColumnElement:
"""
Expand Down Expand Up @@ -75,8 +77,11 @@ def execute(self, dry_run: bool = False) -> ViewExecutionResult:
Results of the query where `results` will be a list of dictionaries representing retrieved rows or an empty\
list if `dry_run` is set to `True`. Inside the `context` field the generated sql will be stored.
"""

results = []

if self._where_clause is not None:
self._select = self._select.where(self._where_clause)

sql = str(self._select.compile(bind=self._sqlalchemy_engine, compile_kwargs={"literal_binds": True}))

if not dry_run:
Expand Down
7 changes: 4 additions & 3 deletions src/dbally/views/structured.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def ask(
examples = self.list_few_shots()

try:
iql = await iql_generator.generate_iql(
iql = await iql_generator.generate(
question=query,
filters=filters,
examples=examples,
Expand All @@ -90,10 +90,11 @@ async def ask(
aggregation=None,
) from exc

await self.apply_filters(iql)
if iql:
await self.apply_filters(iql)

result = self.execute(dry_run=dry_run)
result.context["iql"] = f"{iql}"
result.context["iql"] = str(iql) if iql else None

return result

Expand Down
12 changes: 6 additions & 6 deletions tests/unit/iql/test_iql_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from dbally.iql import IQLArgumentParsingError, IQLQuery, IQLUnsupportedSyntaxError, syntax
from dbally.iql._exceptions import (
IQLArgumentValidationError,
IQLEmptyExpressionError,
IQLFunctionNotExists,
IQLIncorrectNumberArgumentsError,
IQLMultipleExpressionsError,
IQLMultipleStatementsError,
IQLNoExpressionError,
IQLNoStatementError,
IQLSyntaxError,
)
from dbally.iql._processor import IQLProcessor
Expand Down Expand Up @@ -95,7 +95,7 @@ async def test_iql_parser_syntax_error():


async def test_iql_parser_multiple_expression_error():
with pytest.raises(IQLMultipleExpressionsError) as exc_info:
with pytest.raises(IQLMultipleStatementsError) as exc_info:
await IQLQuery.parse(
"filter_by_age\nfilter_by_age",
allowed_functions=[
Expand All @@ -109,11 +109,11 @@ async def test_iql_parser_multiple_expression_error():
],
)

assert exc_info.match(re.escape("Multiple expressions or statements in IQL are not supported"))
assert exc_info.match(re.escape("Multiple statements in IQL are not supported"))


async def test_iql_parser_empty_expression_error():
with pytest.raises(IQLEmptyExpressionError) as exc_info:
with pytest.raises(IQLNoStatementError) as exc_info:
await IQLQuery.parse(
"",
allowed_functions=[
Expand All @@ -127,7 +127,7 @@ async def test_iql_parser_empty_expression_error():
],
)

assert exc_info.match(re.escape("Empty IQL expression"))
assert exc_info.match(re.escape("Empty IQL"))


async def test_iql_parser_no_expression_error():
Expand Down
Loading

0 comments on commit 7449409

Please sign in to comment.