Skip to content

Commit

Permalink
Merge pull request #4 from parkervg/options-accept-query
Browse files Browse the repository at this point in the history
`options` Accepts Query
  • Loading branch information
parkervg authored Feb 22, 2024
2 parents 71105de + 88ed9a6 commit 8f9c403
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 54 deletions.
4 changes: 3 additions & 1 deletion blendsql/_constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import Enum, EnumMeta, auto
from dataclasses import dataclass

HF_REPO_ID = "parkervg/blendsql-test-dbs"

Expand Down Expand Up @@ -33,7 +34,8 @@ class IngredientType(str, Enum, metaclass=StrInMeta):
JOIN = auto()


class IngredientKwarg(str, Enum):
@dataclass
class IngredientKwarg:
QUESTION = "question"
CONTEXT = "context"
OPTIONS = "options"
75 changes: 39 additions & 36 deletions blendsql/blendsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,16 @@ def preprocess_blendsql(query: str) -> Tuple[str, dict]:
for arg_type in {"args", "kwargs"}:
for idx in range(len(parsed_results_dict[arg_type])):
curr_arg = parsed_results_dict[arg_type][idx]
curr_arg = curr_arg[-1] if arg_type == "kwargs" else curr_arg
if not isinstance(curr_arg, str):
continue
parsed_results_dict[arg_type][idx] = re.sub(
formatted_curr_arg = re.sub(
r"(^\()(.*)(\)$)", r"\2", curr_arg
).strip()
if arg_type == "args":
parsed_results_dict[arg_type][idx] = formatted_curr_arg
else:
parsed_results_dict[arg_type][idx][-1] = formatted_curr_arg
# Below we track the 'raw' representation, in case we need to pass into
# a recursive BlendSQL call later
ingredient_alias_to_parsed_dict[
Expand Down Expand Up @@ -507,47 +512,40 @@ def blend(
+ Fore.RESET
)
kwargs_dict[k] = v
# Handle llm, make sure we initialize it if it's a string
# llm = kwargs_dict.get("llm", None)
# llm_type = kwargs_dict.get("llm_type", None)
# if llm is not None:
# if not isinstance(llm, LLM):
# kwargs_dict["llm"] = initialize_llm(llm_type, llm)
# else:
# kwargs_dict["llm"] = initialize_llm(
# DEFAULT_LLM_TYPE, DEFAULT_LLM_NAME
# )
kwargs_dict["llm"] = blender

kwargs_dict["llm"] = blender
if _function.ingredient_type == IngredientType.MAP:
# Latter is the winner.
# So if we already define something in kwargs_dict,
# It's not overriden here
kwargs_dict = (
scm.infer_map_constraints(
start=start,
end=end,
if infer_map_constraints:
# Latter is the winner.
# So if we already define something in kwargs_dict,
# It's not overriden here
kwargs_dict = (
scm.infer_map_constraints(
start=start,
end=end,
)
| kwargs_dict
)
| kwargs_dict
if infer_map_constraints
else kwargs_dict
)

if table_to_title is not None:
kwargs_dict["table_to_title"] = table_to_title

if _function.ingredient_type == IngredientType.QA:
# Optionally, recursively call blend() again to get subtable
context_arg = kwargs_dict.get(
IngredientKwarg.CONTEXT,
parse_results_dict["args"][1]
if len(parse_results_dict["args"]) > 1
else parse_results_dict["args"][0]
if len(parse_results_dict["args"]) > 0
# Optionally, recursively call blend() again to get subtable
# This applies to `context` and `options`
for i, unpack_kwarg in enumerate(
[IngredientKwarg.CONTEXT, IngredientKwarg.OPTIONS]
):
unpack_value = kwargs_dict.get(
unpack_kwarg,
parse_results_dict["args"][i + 1]
if len(parse_results_dict["args"]) > i + 1
else parse_results_dict["args"][i]
if len(parse_results_dict["args"]) > i
else "",
)
if context_arg.upper().startswith(("SELECT", "WITH")):
if unpack_value.upper().startswith(("SELECT", "WITH")):
_smoothie = blend(
query=context_arg,
query=unpack_value,
db=db,
blender=blender,
ingredients=ingredients,
Expand All @@ -561,9 +559,14 @@ def blend(
)
_prev_passed_values = _smoothie.meta.num_values_passed
subtable = _smoothie.df
kwargs_dict[IngredientKwarg.CONTEXT] = subtable
# Below, we can remove the optional `context` arg we passed in args
parse_results_dict["args"] = parse_results_dict["args"][:1]
if unpack_kwarg == IngredientKwarg.OPTIONS:
# Here, we need to format as a flat set
kwargs_dict[unpack_kwarg] = list(subtable.values.flat)
else:
kwargs_dict[unpack_kwarg] = subtable
# Below, we can remove the optional `context` arg we passed in args
parse_results_dict["args"] = parse_results_dict["args"][:1]

# Execute our ingredient function
function_out = _function(
*parse_results_dict["args"],
Expand Down
4 changes: 2 additions & 2 deletions blendsql/ingredients/builtin/llm/qa/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Union, Optional, List
from typing import Dict, Union, Optional, Set

import pandas as pd
from blendsql.llms._llm import LLM
Expand All @@ -12,7 +12,7 @@ def run(
self,
question: str,
llm: LLM,
options: Optional[List[str]] = None,
options: Optional[Set[str]] = None,
context: Optional[pd.DataFrame] = None,
value_limit: Optional[int] = None,
table_to_title: Optional[Dict[str, str]] = None,
Expand Down
22 changes: 12 additions & 10 deletions blendsql/ingredients/ingredient.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,18 @@ def __call__(
raise IngredientException("Empty subtable passed to QAIngredient!")
unpacked_options = options
if options is not None:
try:
tablename, colname = utils.get_tablename_colname(options)
tablename = kwargs.get("aliases_to_tablenames").get(
tablename, tablename
)
unpacked_options = self.db.execute_query(
f'SELECT DISTINCT "{colname}" FROM "{tablename}"'
)[colname].tolist()
except ValueError:
unpacked_options = options.split(";")
if not isinstance(options, list):
try:
tablename, colname = utils.get_tablename_colname(options)
tablename = kwargs.get("aliases_to_tablenames").get(
tablename, tablename
)
unpacked_options = self.db.execute_query(
f'SELECT DISTINCT "{colname}" FROM "{tablename}"'
)[colname].tolist()
except ValueError:
unpacked_options = options.split(";")
unpacked_options = set(unpacked_options)
self.num_values_passed += len(subtable) if subtable is not None else 0
kwargs[IngredientKwarg.OPTIONS] = unpacked_options
kwargs[IngredientKwarg.CONTEXT] = subtable
Expand Down
38 changes: 37 additions & 1 deletion tests/test_multi_table_blendsql.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest
from blendsql import blend
from blendsql.db import SQLiteDBConnector
from blendsql.utils import fetch_from_hub
from tests.utils import (
fetch_from_hub,
assert_equality,
starts_with,
get_length,
Expand Down Expand Up @@ -300,6 +300,42 @@ def test_cte_qa_multi_exec(db, ingredients):
# assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item()


def test_cte_qa_named_multi_exec(db, ingredients):
blendsql = """
{{
get_table_size(
question='Table size?',
context=(
WITH a AS (
SELECT * FROM (SELECT DISTINCT * FROM portfolio) as w
WHERE {{starts_with('F', 'w::Symbol')}} = TRUE
) SELECT * FROM a WHERE LENGTH(a.Symbol) > 2
)
)
}}
"""
sql = """
WITH a AS (
SELECT * FROM (SELECT DISTINCT * FROM portfolio) as w
WHERE w.Symbol LIKE 'F%'
) SELECT COUNT(*) FROM a WHERE LENGTH(a.Symbol) > 2
"""
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
)
sql_df = db.execute_query(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"])
# Make sure we only pass what's necessary to our ingredient
# passed_to_ingredient = db.execute_query(
# """
# SELECT COUNT(DISTINCT Symbol) FROM portfolio WHERE LENGTH(Symbol) > 3 AND Quantity > 200
# """
# )
# assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item()


# def test_subquery_alias_with_join_multi_exec(db, ingredients):
# blendsql = """
# SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) as w
Expand Down
31 changes: 29 additions & 2 deletions tests/test_single_table_blendsql.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import pytest
from blendsql import blend
from blendsql.db import SQLiteDBConnector
from blendsql.utils import fetch_from_hub
from tests.utils import (
fetch_from_hub,
assert_equality,
starts_with,
get_length,
select_first_sorted,
get_table_size,
select_first_option,
)


Expand All @@ -18,7 +19,13 @@ def db() -> SQLiteDBConnector:

@pytest.fixture
def ingredients() -> set:
return {starts_with, get_length, select_first_sorted, get_table_size}
return {
starts_with,
get_length,
select_first_sorted,
get_table_size,
select_first_option,
}


def test_simple_exec(db, ingredients):
Expand Down Expand Up @@ -457,5 +464,25 @@ def test_exists_isolated_qa_call(db, ingredients):
assert smoothie.meta.num_values_passed == passed_to_ingredient.values[0].item()


def test_query_options_arg(db, ingredients):
# commit 5ffa26d
blendsql = """
{{
select_first_option(
'I hope this test works',
(SELECT * FROM transactions),
options=(SELECT DISTINCT merchant FROM transactions WHERE merchant = 'Paypal')
)
}}
"""
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
)
assert len(smoothie.df) == 1
assert smoothie.df.values.flat[0] == "Paypal"


if __name__ == "__main__":
pytest.main()
14 changes: 12 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd
from typing import Iterable, Any, List, Union
from blendsql.ingredients import MapIngredient, QAIngredient, JoinIngredient
from blendsql.db.utils import single_quote_escape


class starts_with(MapIngredient):
Expand All @@ -24,7 +25,7 @@ def run(self, question: str, values: List[str], **kwargs) -> Iterable[int]:


class select_first_sorted(QAIngredient):
def run(self, question: str, options: List[str], **kwargs) -> Iterable[Any]:
def run(self, question: str, options: set, **kwargs) -> Iterable[Any]:
"""Simple test function, equivalent to the following in SQL:
`ORDER BY {colname} LIMIT 1`
"""
Expand All @@ -42,12 +43,21 @@ def run(

class get_table_size(QAIngredient):
def run(
self, question: str, context: pd.DataFrame, options: str = None, **kwargs
self, question: str, context: pd.DataFrame, options: set = None, **kwargs
) -> Union[str, int, float]:
"""Returns the length of the context subtable passed to it."""
return len(context)


class select_first_option(QAIngredient):
def run(
self, question: str, context: pd.DataFrame, options: set = None, **kwargs
) -> Union[str, int, float]:
"""Returns the first item in the (ordered) options set"""
assert options is not None
return f"'{single_quote_escape(sorted(list(options))[0])}'"


class do_join(JoinIngredient):
"""A very silly, overcomplicated way to do a traditional SQL join.
But useful for testing.
Expand Down

0 comments on commit 8f9c403

Please sign in to comment.