Skip to content

Commit

Permalink
Merge pull request #7 from parkervg/vqa-ingredients
Browse files Browse the repository at this point in the history
VQA Ingredient, rendering notebooks in `examples`, `LLM` -> `Model`
  • Loading branch information
parkervg authored Feb 29, 2024
2 parents 8e2cfe1 + 202ec52 commit 0365d38
Show file tree
Hide file tree
Showing 41 changed files with 1,630 additions and 469 deletions.
1 change: 1 addition & 0 deletions .github/workflows/publish-documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ jobs:
- run: pip install mkdocstrings
- run: pip install mkdocs-section-index
- run: pip install mkdocstrings-python
- run: pip install mkdocs-jupyter
- name: Build documentation
run: mkdocs gh-deploy --force
38 changes: 19 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ pip install blendsql
```python
from blendsql import blend, LLMQA, LLMMap
from blendsql.db import SQLite
from blendsql.llms import OpenaiLLM
from blendsql.models import OpenaiLLM

blendsql = """
SELECT merchant FROM transactions WHERE
Expand All @@ -303,11 +303,11 @@ SELECT merchant FROM transactions WHERE
"""
# Make our smoothie - the executed BlendSQL script
smoothie = blend(
query=blendsql,
blender=OpenaiLLM("gpt-3.5-turbo-0613"),
ingredients={LLMMap, LLMQA},
db=SQLite(db_path="transactions.db"),
verbose=True
query=blendsql,
blender=OpenaiLLM("gpt-3.5-turbo-0613"),
ingredients={LLMMap, LLMQA},
db=SQLite(db_path="transactions.db"),
verbose=True
)

```
Expand Down Expand Up @@ -348,7 +348,7 @@ The `blend()` function is used to execute a BlendSQL query against a database an
```python
from blendsql import blend, LLMMap, LLMQA, LLMJoin
from blendsql.db import SQLite
from blendsql.llms import OpenaiLLM
from blendsql.models import OpenaiLLM
blendsql = """
SELECT * FROM w
Expand All @@ -362,18 +362,18 @@ WHERE city = {{
"""
db = SQLite(db_path)
smoothie = blend(
query=blendsql,
db=db,
ingredients={LLMMap, LLMQA, LLMJoin},
blender=AzureOpenaiLLM("gpt-4"),
# Optional args below
infer_gen_constraints=True,
silence_db_exec_errors=False,
verbose=True,
blender_args={
"few_shot": True,
"temperature": 0.01
}
query=blendsql,
db=db,
ingredients={LLMMap, LLMQA, LLMJoin},
blender=AzureOpenaiLLM("gpt-4"),
# Optional args below
infer_gen_constraints=True,
silence_db_exec_errors=False,
verbose=True,
blender_args={
"few_shot": True,
"temperature": 0.01
}
)
```
Expand Down
1 change: 0 additions & 1 deletion blendsql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@


from .ingredients.builtin import LLMMap, LLMQA, LLMJoin, DT, LLMValidate
from .ingredients.builtin import llm
from .blendsql import blend
1 change: 1 addition & 0 deletions blendsql/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ class IngredientKwarg:
QUESTION = "question"
CONTEXT = "context"
OPTIONS = "options"
MODEL = "model"
2 changes: 0 additions & 2 deletions blendsql/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@ class Program:
def __new__(
self,
model: Model,
question: Optional[str] = None,
gen_kwargs: Optional[dict] = None,
few_shot: bool = True,
**kwargs,
):
self.model = model
self.question = question
self.gen_kwargs = gen_kwargs if gen_kwargs is not None else {}
self.few_shot = few_shot
assert isinstance(
Expand Down
53 changes: 48 additions & 5 deletions blendsql/_sqlglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def get_predicate_literals(node) -> List[str]:
(We treat booleans as literals here, which might be a misuse of terminology.)
Examples:
>>> get_predicate_literals(_parse_one("{{LLM('year', 'w::year')}} IN ('2010', '2011', '2012')"))
>>> get_predicate_literals(_parse_one("{{Model('year', 'w::year')}} IN ('2010', '2011', '2012')"))
['2010', '2011', '2012']
"""
literals = set()
Expand Down Expand Up @@ -378,9 +378,10 @@ def all_terminals_are_true(node) -> bool:
class SubqueryContextManager:
node: exp.Select = attrib()
prev_subquery_has_ingredient: bool = attrib()
tables_in_ingredients: set = attrib()

# Keep a running log of what aliases we've initialized so far, per subquery
alias_to_subquery: dict = attrib(default=None)

alias_to_tablename: dict = attrib(init=False)
tablename_to_alias: dict = attrib(init=False)
root: sqlglot.optimizer.scope.Scope = attrib(init=False)
Expand Down Expand Up @@ -413,12 +414,54 @@ def abstracted_table_selects(self) -> Generator[Tuple[str, exp.Select], None, No
abstracted_queries: Generator with (tablename, exp.Select, alias_to_table). The exp.Select is the abstracted query.
Examples:
>>> {{LLM('is this an italian restaurant?', 'transactions::merchant', endpoint_name='gpt-4')}} = 1 AND child_category = 'Restaurants & Dining'
>>> {{Model('is this an italian restaurant?', 'transactions::merchant', endpoint_name='gpt-4')}} = 1 AND child_category = 'Restaurants & Dining'
('transactions', 'SELECT * FROM transactions WHERE TRUE AND child_category = \'Restaurants & Dining\'')
"""
# TODO: don't really know how to optimize with 'CASE' queries right now
if self.node.find(exp.Case):
return
# Special condition: If...
# 1) We have a `JOIN` clause
# 2) We *only* have an ingredient in the top-level `SELECT` clause
# 3) Our ingredients *only* call a single table
# ... then we should execute entire rest of SQL first and assign to temporary session table.
# Example: """SELECT w.title, w."designer ( s )", {{LLMMap('How many animals are in this image?', 'images::title')}}
# FROM images JOIN w ON w.title = images.title
# WHERE "designer ( s )" = 'georgia gerber'"""
join_exp = self.node.find(exp.Join)
if join_exp is not None:
select_exps = list(self.node.find_all(exp.Select))
if len(select_exps) == 1:
# Check if the only `STRUCT` nodes are found in select
all_struct_exps = list(self.node.find_all(exp.Struct))
if len(all_struct_exps) > 0:
num_select_struct_exps = sum(
[
len(list(n.find_all(exp.Struct)))
for n in select_exps[0].find_all(exp.Alias)
]
)
if num_select_struct_exps == len(all_struct_exps):
if len(self.tables_in_ingredients) == 1:
tablename = next(iter(self.tables_in_ingredients))
join_tablename = set(
[i.name for i in self.node.find_all(exp.Table)]
).difference({tablename})
if len(join_tablename) == 1:
join_tablename = next(iter(join_tablename))
base_select_str = f'SELECT "{tablename}".* FROM "{tablename}", {join_tablename} WHERE '
table_conditions_str = self.get_table_predicates_str(
tablename=tablename,
disambiguate_multi_tables=False,
)
abstracted_query = _parse_one(
base_select_str + table_conditions_str
)
abstracted_query_str = recover_blendsql(
abstracted_query.sql(dialect=FTS5SQLite)
)
yield (tablename, abstracted_query_str)
return
for tablename, table_star_query in self._table_star_queries():
# If this table_star_query doesn't have an ingredient at the top-level, we can safely ignore
if len(list(self._get_scope_nodes(exp.Struct, restrict_scope=True))) == 0:
Expand Down Expand Up @@ -572,7 +615,7 @@ def get_table_predicates_str(

def infer_gen_constraints(self, start: int, end: int) -> dict:
"""Given syntax of BlendSQL query, infers a regex pattern (if possible) to guide
downstream LLM generations.
downstream Model generations.
For example:
>>> SELECT * FROM w WHERE {{LLMMap('Is this true?', 'w::colname')}}
Expand All @@ -583,7 +626,7 @@ def infer_gen_constraints(self, start: int, end: int) -> dict:
Returns:
dict, with keys:
output_type: numeric | string | boolean
pattern: regular expression pattern to use in constrained decoding with LLM
pattern: regular expression pattern to use in constrained decoding with Model
"""
added_kwargs = {}
ingredient_node = _parse_one(self.sql()[start:end])
Expand Down
23 changes: 14 additions & 9 deletions blendsql/blendsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
IngredientType,
IngredientKwarg,
)
from .llms._llm import LLM
from .models._model import Model


@attrs
Expand Down Expand Up @@ -148,15 +148,15 @@ def autowrap_query(


def preprocess_blendsql(
query: str, blender_args: dict, blender: LLM
query: str, blender_args: dict, blender: Model
) -> Tuple[str, dict, set]:
"""Parses BlendSQL string with our pyparsing grammar and returns objects
required for interpretation and execution.
Args:
query: The BlendSQL query to preprocess
blender_args: Arguments used as default in ingredient calls with LLMs
blender: LLM object, which we attach to each parsed_dict
blender: Model object, which we attach to each parsed_dict
Returns:
Tuple, containing:
Expand Down Expand Up @@ -250,7 +250,7 @@ def preprocess_blendsql(
+ Fore.RESET
)
kwargs_dict[k] = v
kwargs_dict["llm"] = blender
kwargs_dict[IngredientKwarg.MODEL] = blender
context_arg = kwargs_dict.get(
IngredientKwarg.CONTEXT,
parsed_results_dict["args"][1]
Expand Down Expand Up @@ -289,7 +289,7 @@ def set_subquery_to_alias(
aliasname: str,
query: exp.Expression,
db: SQLite,
blender: LLM,
blender: Model,
ingredient_alias_to_parsed_dict: Dict[str, dict],
**kwargs,
) -> exp.Expression:
Expand Down Expand Up @@ -339,7 +339,12 @@ def get_sorted_grammar_matches(
- `Function` object that is matched
"""
ooo = [IngredientType.MAP, IngredientType.QA, IngredientType.JOIN]
ooo = [
IngredientType.STRING,
IngredientType.MAP,
IngredientType.QA,
IngredientType.JOIN,
]
parse_results = [i for i in grammar.scanString(q)]
while len(parse_results) > 0:
curr_ingredient_target = ooo.pop(0)
Expand Down Expand Up @@ -379,11 +384,10 @@ def disambiguate_and_submit_blend(
return blend(query=query, **kwargs)


# @profile
def blend(
query: str,
db: SQLite,
blender: Optional[LLM] = None,
blender: Optional[Model] = None,
ingredients: Optional[Collection[Ingredient]] = None,
verbose: bool = False,
blender_args: Optional[Dict[str, str]] = None,
Expand All @@ -403,7 +407,7 @@ def blend(
db: Database connector object
ingredients: List of ingredient objects, to use in interpreting BlendSQL query
verbose: Boolean defining whether to run in logging.debug mode
blender: Optionally override whatever llm argument we pass to LLM ingredient.
blender: Optionally override whatever llm argument we pass to Model ingredient.
Useful for research applications, where we don't (necessarily) want the parser to choose endpoints.
infer_gen_constraints: Optionally infer the output format of an `IngredientMap` call, given the predicate context
For example, in `{{LLMMap('convert to date', 'w::listing date')}} <= '1960-12-31'`
Expand Down Expand Up @@ -529,6 +533,7 @@ def blend(
), # Need to do this so we don't track parents into construct_abstracted_selects
prev_subquery_has_ingredient=prev_subquery_has_ingredient,
alias_to_subquery={table_alias_name: subquery} if in_cte else None,
tables_in_ingredients=tables_in_ingredients,
)
for tablename, abstracted_query in scm.abstracted_table_selects():
# If this table isn't being used in any ingredient calls, there's no
Expand Down
25 changes: 12 additions & 13 deletions blendsql/db/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@
from functools import lru_cache
import pandas as pd
from attr import attrib, attrs
from colorama import Fore, init
from colorama import Fore

from .utils import single_quote_escape, double_quote_escape


init(autoreset=True)


@attrs(auto_detect=True)
class SQLite:
"""
Expand Down Expand Up @@ -96,8 +93,16 @@ def to_serialized(
num_rows: int = 0,
table_description: str = None,
) -> str:
"""Generates a string representation of a database, via `CREATE` statements.
This can then be passed to a LLM as context.
Args:
ignore_tables: Name of tables to ignore in serialization. Default is just 'documents'.
num_rows: How many rows per table to include in serialization
table_description: Optional table description to add at top
"""
if ignore_tables is None:
ignore_tables = set()
ignore_tables = {"documents"}
serialized_db = (
[]
if table_description is None
Expand All @@ -108,19 +113,13 @@ def to_serialized(
continue
serialized_db.append(create_clause)
serialized_db.append("\n")
if num_rows > 0 and tablename != "docs":
if num_rows > 0:
get_rows_query = (
f'SELECT * FROM "{double_quote_escape(tablename)}" LIMIT {num_rows}'
)
serialized_db.append("\n/*")
serialized_db.append(f"\n{num_rows} example rows:")
serialized_db.append(f"\n{get_rows_query}")
elif tablename == "docs":
get_rows_query = (
f'SELECT DISTINCT title FROM "{double_quote_escape(tablename)}"'
)
serialized_db.append("\n/*")
serialized_db.append(f"\n{get_rows_query}")
else:
continue
rows = self.execute_query(get_rows_query)
Expand Down Expand Up @@ -150,7 +149,7 @@ def execute_query(
return df
except Exception as e:
if silence_errors:
print(Fore.RED + "Something went wrong!")
print(Fore.RED + "Something went wrong!" + Fore.RESET)
print(e)
if return_error:
return (pd.DataFrame(), str(e))
Expand Down
1 change: 1 addition & 0 deletions blendsql/ingredients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
JoinIngredient,
StringIngredient,
QAIngredient,
IngredientException,
)

from .builtin import LLMQA, LLMJoin, LLMMap, LLMValidate, DT
1 change: 1 addition & 0 deletions blendsql/ingredients/builtin/dt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from fiscalyear import FiscalDateTime, FiscalQuarter
from dateutil.relativedelta import relativedelta
from datetime import datetime, timedelta
import dateparser


def initialize_date_map(today_dt: datetime):
Expand Down
6 changes: 3 additions & 3 deletions blendsql/ingredients/builtin/llm/join/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from guidance import gen
from textwrap import dedent

from blendsql.llms._llm import LLM
from blendsql.models._model import Model
from blendsql._program import Program
from blendsql import _constants as CONST
from blendsql.ingredients.ingredient import JoinIngredient
Expand Down Expand Up @@ -89,11 +89,11 @@ def run(
self,
left_values: List[str],
right_values: List[str],
llm: LLM,
model: Model,
join_criteria: str = "Join to same topics.",
**kwargs,
) -> dict:
res = llm.predict(
res = model.predict(
program=JoinProgram,
sep=CONST.DEFAULT_ANS_SEP,
left_values=left_values,
Expand Down
Loading

0 comments on commit 0365d38

Please sign in to comment.