Skip to content

Commit

Permalink
Merge pull request #13 from parkervg/grammar-prompting
Browse files Browse the repository at this point in the history
Syncing grammar-prompting into main
  • Loading branch information
parkervg authored May 18, 2024
2 parents 457f900 + d3a6af6 commit 71f2dec
Show file tree
Hide file tree
Showing 27 changed files with 502 additions and 318 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,16 @@ print(smoothie.meta.prompts)
}
```
### Acknowledgements
Special thanks to those below for inspiring this project. Definitely recommend checking out the linked work below, and citing when applicable!

- The authors of [Binding Language Models in Symbolic Languages](https://arxiv.org/abs/2210.02875)
- This paper was the primary inspiration for BlendSQL.
- The authors of [EHRXQA: A Multi-Modal Question Answering Dataset for Electronic Health Records with Chest X-ray Images](https://arxiv.org/pdf/2310.18652)
- As far as I can tell, the first publication to propose unifying model calls within SQL
- Served as the inspiration for the [vqa-ingredient.ipynb](./examples/vqa-ingredient.ipynb) example
- The authors of [Grammar Prompting for Domain-Specific Language Generation with Large Language Models](https://arxiv.org/abs/2305.19234)

# Documentation


Expand Down
10 changes: 5 additions & 5 deletions blendsql/_constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from enum import Enum, EnumMeta, auto
from enum import Enum, EnumMeta
from dataclasses import dataclass

HF_REPO_ID = "parkervg/blendsql-test-dbs"
Expand All @@ -15,10 +15,10 @@ def __contains__(cls, item):


class IngredientType(str, Enum, metaclass=StrInMeta):
MAP = auto()
STRING = auto()
QA = auto()
JOIN = auto()
MAP = "MAP"
STRING = "STRING"
QA = "QA"
JOIN = "JOIN"


@dataclass
Expand Down
6 changes: 6 additions & 0 deletions blendsql/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class InvalidBlendSQL(ValueError):
pass


class IngredientException(ValueError):
pass
31 changes: 22 additions & 9 deletions blendsql/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Optional,
Callable,
Collection,
Union,
)
from sqlite3 import OperationalError
import sqlglot.expressions
Expand All @@ -31,8 +32,9 @@
recover_blendsql,
get_tablename_colname,
)
from ._exceptions import InvalidBlendSQL
from .db import Database
from .db.utils import double_quote_escape, single_quote_escape
from .db.utils import double_quote_escape, select_all_from_table_query
from ._sqlglot import (
MODIFIERS,
get_first_child,
Expand Down Expand Up @@ -428,7 +430,7 @@ def _blend(

# Preliminary check - we can't have anything that modifies database state
if _query.find(MODIFIERS):
raise ValueError("BlendSQL query cannot have `DELETE` clause!")
raise InvalidBlendSQL("BlendSQL query cannot have `DELETE` clause!")

# If there's no `SELECT` and just a QAIngredient, wrap it in a `SELECT CASE` query
if _query.find(exp.Select) is None:
Expand All @@ -442,7 +444,7 @@ def _blend(
# If we don't have any ingredient calls, execute as normal SQL
if len(ingredients) == 0 or len(ingredient_alias_to_parsed_dict) == 0:
return Smoothie(
df=db.execute_query(query),
df=db.execute_to_df(query),
meta=SmoothieMeta(
num_values_passed=0,
num_prompt_tokens=0,
Expand Down Expand Up @@ -541,7 +543,7 @@ def _blend(
)
try:
db.to_temp_table(
df=db.execute_query(abstracted_query),
df=db.execute_to_df(abstracted_query),
tablename=_get_temp_subquery_table(tablename),
)
except OperationalError as e:
Expand Down Expand Up @@ -572,6 +574,7 @@ def _blend(
if prev_subquery_has_ingredient:
scm.set_node(scm.node.transform(maybe_set_subqueries_to_true))

lazy_limit: Union[int, None] = scm.get_lazy_limit()
# After above processing of AST, sync back to string repr
subquery_str = scm.sql()
# Now, 1) Find all ingredients to execute (e.g. '{{f(a, b, c)}}')
Expand Down Expand Up @@ -741,14 +744,14 @@ def _blend(
# On their left join merge command: https://github.com/HKUNLP/Binder/blob/9eede69186ef3f621d2a50572e1696bc418c0e77/nsql/database.py#L196
# We create a new temp table to avoid a potentially self-destructive operation
base_tablename = tablename
_base_table: pd.DataFrame = db.execute_query(
f'SELECT * FROM "{double_quote_escape(base_tablename)}";'
_base_table: pd.DataFrame = db.execute_to_df(
select_all_from_table_query(base_tablename)
)
base_table = _base_table
if db.has_temp_table(_get_temp_session_table(tablename)):
base_tablename = _get_temp_session_table(tablename)
base_table: pd.DataFrame = db.execute_query(
f"SELECT * FROM '{single_quote_escape(base_tablename)}';",
base_table: pd.DataFrame = db.execute_to_df(
select_all_from_table_query(base_tablename)
)
previously_added_columns = base_table.columns.difference(
_base_table.columns
Expand Down Expand Up @@ -804,7 +807,7 @@ def _blend(
)
logging.debug("")

df = db.execute_query(query)
df = db.execute_to_df(query)

return Smoothie(
df=df,
Expand Down Expand Up @@ -880,6 +883,16 @@ def blend(
schema_qualify=schema_qualify,
)
except Exception as error:
if not isinstance(error, (InvalidBlendSQL, IngredientException)):
from .grammars.minEarley.parser import EarleyParser
from .grammars.utils import load_cfg_parser

# Parse with CFG and try to get helpful recommendations
parser: EarleyParser = load_cfg_parser(ingredients)
try:
parser.parse(query)
except Exception as parser_error:
raise parser_error
raise error
finally:
# In the case of a recursive `_blend()` call,
Expand Down
56 changes: 30 additions & 26 deletions blendsql/blend_cli.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import os
import argparse
import importlib
import json

from blendsql import blend
from blendsql.db import SQLite
from blendsql.db.utils import truncate_df_content
from blendsql.utils import tabulate
from blendsql.models import LlamaCppLLM
from blendsql.ingredients.builtin import LLMQA, LLMMap, LLMJoin, DT
from blendsql.models import OpenaiLLM, TransformersLLM, AzureOpenaiLLM, LlamaCppLLM
from blendsql.ingredients.builtin import LLMQA, LLMMap, LLMJoin

_has_readline = importlib.util.find_spec("readline") is not None

MODEL_TYPE_TO_CLASS = {
"openai": OpenaiLLM,
"azure_openai": AzureOpenaiLLM,
"llama_cpp": LlamaCppLLM,
"transformers": TransformersLLM,
}


def print_msg_box(msg, indent=1, width=None, title=None):
"""Print message-box with optional title."""
Expand All @@ -37,8 +44,21 @@ def main():

_ = readline
parser = argparse.ArgumentParser()
parser.add_argument("db_url", nargs="?")
parser.add_argument("secrets_path", nargs="?", default="./secrets.json")
parser.add_argument("db_url", nargs="?", help="Database URL,")
parser.add_argument(
"model_type",
nargs="?",
default="openai",
choices=list(MODEL_TYPE_TO_CLASS.keys()),
help="Model type, for the Blender to use in executing the BlendSQL query.",
)
parser.add_argument(
"model_name_or_path",
nargs="?",
default="gpt-3.5-turbo",
help="Model identifier to pass to the selected model_type class.",
)
parser.add_argument("-v", action="store_true", help="Flag to run in verbose mode.")
args = parser.parse_args()

db = SQLite(db_url=args.db_url)
Expand All @@ -61,30 +81,14 @@ def main():
smoothie = blend(
query=text,
db=db,
ingredients={LLMQA, LLMMap, LLMJoin, DT},
blender=LlamaCppLLM(
"./lark-constrained-parsing/tinyllama-1.1b-chat-v1.0.Q2_K.gguf"
ingredients={LLMQA, LLMMap, LLMJoin},
blender=MODEL_TYPE_TO_CLASS.get(args.model_type)(
args.model_name_or_path
),
infer_gen_constraints=True,
verbose=True,
verbose=args.v,
)
print()
print(tabulate(smoothie.df.iloc[:10]))
print()
print(json.dumps(smoothie.meta.prompts, indent=4))
print(tabulate(truncate_df_content(smoothie.df, 50)))
except Exception as error:
print(error)


"""
SELECT "common name" AS 'State Flower' FROM w
WHERE state = {{
LLMQA(
'Which is the smallest state by area?',
(SELECT title, content FROM documents WHERE documents MATCH 'smallest OR state OR area' LIMIT 3),
options='w::state'
)
}}
SELECT Symbol, Description, Quantity FROM portfolio WHERE {{LLMMap('Do they manufacture cell phones?', 'portfolio::Description')}} = TRUE AND portfolio.Symbol in (SELECT Symbol FROM constituents WHERE Sector = 'Information Technology')
"""
15 changes: 13 additions & 2 deletions blendsql/db/_database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generator, List, Dict, Collection
from typing import Generator, List, Dict, Collection, Type, Optional
from typing import Iterable
import pandas as pd
from colorama import Fore
Expand Down Expand Up @@ -86,7 +86,7 @@ def to_temp_table(self, df: pd.DataFrame, tablename: str):
self.con.execute(text(create_table_stmt))
df.to_sql(name=tablename, con=self.con, if_exists="append", index=False)

def execute_query(self, query: str, params: dict = None) -> pd.DataFrame:
def execute_to_df(self, query: str, params: dict = None) -> pd.DataFrame:
"""
Execute the given query and return results as dataframe.
Expand All @@ -106,3 +106,14 @@ def execute_query(self, query: str, params: dict = None) -> pd.DataFrame:
```
"""
return pd.read_sql(text(query), self.con, params=params)

def execute_to_list(
self, query: str, to_type: Optional[Type] = lambda x: x
) -> list:
"""A lower-level execute method that doesn't use the pandas processing logic.
Returns results as a tuple.
"""
res = []
for row in self.con.execute(text(query)).fetchall():
res.append(to_type(row[0]))
return res
2 changes: 1 addition & 1 deletion blendsql/db/_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, db_path: str):
def has_temp_table(self, tablename: str) -> bool:
return (
tablename
in self.execute_query(
in self.execute_to_df(
"SELECT * FROM information_schema.tables WHERE table_schema LIKE 'pg_temp_%'"
)["table_name"].unique()
)
Expand Down
4 changes: 2 additions & 2 deletions blendsql/db/_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, db_url: str):
def has_temp_table(self, tablename: str) -> bool:
return (
tablename
in self.execute_query(
in self.execute_to_df(
"SELECT name FROM sqlite_temp_master WHERE type='table';"
)["name"].unique()
)
Expand All @@ -39,7 +39,7 @@ def get_sqlglot_schema(self) -> dict:
schema = {}
for tablename in self.tables():
schema[f'"{double_quote_escape(tablename)}"'] = {}
for _, row in self.execute_query(
for _, row in self.execute_to_df(
f"""
SELECT name, type FROM pragma_table_info(:t)
""",
Expand Down
6 changes: 5 additions & 1 deletion blendsql/db/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ def escape(s):
return single_quote_escape(double_quote_escape(s))


def select_all_from_table_query(tablename: str) -> str:
return f'SELECT * FROM "{double_quote_escape(tablename)}";'


def truncate_df_content(df: pd.DataFrame, truncation_limit: int) -> pd.DataFrame:
# Truncate long strings
return df.applymap(
return df.map(
lambda x: f"{str(x)[:truncation_limit]}..."
if isinstance(x, str) and len(str(x)) > truncation_limit
else x
Expand Down
12 changes: 6 additions & 6 deletions blendsql/grammars/_cfg_grammar.lark
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,22 @@ JOIN_TYPE: "INNER"i | "FULL"i ["OUTER"i] | "LEFT"i["OUTER"i] | "RIGHT"i ["OUTER"
| "CASE"i (when_then)+ "ELSE"i expression_math "END"i -> case_expression
| "CAST"i "(" expression_math "AS"i TYPENAME ")" -> as_type
| "CAST"i "(" literal "AS"i TYPENAME ")" -> literal_cast
| AGGREGATION_FUNCTIONS expression_math ")" [window_form] -> sql_aggregation
| AGGREGATE_FUNCTIONS expression_math ")" [window_form] -> sql_aggregation
| SCALAR_FUNCTIONS [(expression_math ",")*] expression_math ")" -> sql_scalar
| blendsql_aggregation_expr -> blendsql_aggregation
| "RANK"i "(" ")" window_form -> rank_expression
| "DENSE_RANK"i "(" ")" window_form -> dense_rank_expression
| "|" "|" expression_math

BLENDSQL_AGGREGATION: ("LLMQA("i | "LLMVerify("i)
BLENDSQL_JOIN: ("LLMJoin("i)
BLENDSQL_AGGREGATE_FUNCTIONS: $blendsql_aggregate_functions
BLENDSQL_JOIN_FUNCTIONS: $blendsql_join_functions
left_on_arg: "left_on" "=" string
right_on_arg: "right_on" "=" string
blendsql_arg: (name "=" literal | literal | "(" start ")")

blendsql_expression_math: blendsql_arg ("," blendsql_arg)*
blendsql_aggregation_expr: "{{" BLENDSQL_AGGREGATION blendsql_expression_math ")" "}}"
blendsql_join_expr: "{{" BLENDSQL_JOIN (left_on_arg "," right_on_arg|right_on_arg "," left_on_arg) ")" "}}"
blendsql_aggregation_expr: "{{" BLENDSQL_AGGREGATE_FUNCTIONS blendsql_expression_math ")" "}}"
blendsql_join_expr: "{{" BLENDSQL_JOIN_FUNCTIONS (left_on_arg "," right_on_arg|right_on_arg "," left_on_arg) ")" "}}"

window_form: "OVER"i "(" ["PARTITION"i "BY"i (expression_math ",")* expression_math] ["ORDER"i "BY"i (order ",")* order [ row_range_clause ] ] ")"

Expand Down Expand Up @@ -125,7 +125,7 @@ TYPENAME: "object"i
| "string"i

// https://www.sqlite.org/lang_expr.html#*funcinexpr
AGGREGATION_FUNCTIONS: ("sum("i | "avg("i | "min("i | "max("i | "count("i ["distinct"i] )
AGGREGATE_FUNCTIONS: ("sum("i | "avg("i | "min("i | "max("i | "count("i ["distinct"i] )
SCALAR_FUNCTIONS: ("trim("i | "coalesce("i | "abs("i)

alias: string -> alias_string
Expand Down
Loading

0 comments on commit 71f2dec

Please sign in to comment.