Skip to content

Commit

Permalink
Merge pull request #24 from parkervg/mypy-fixes
Browse files Browse the repository at this point in the history
`mypy` fixes, using `singledispatch` to route model generation behavior
  • Loading branch information
parkervg authored Jun 17, 2024
2 parents 76fa0c1 + f1a6c8f commit 4f3447c
Show file tree
Hide file tree
Showing 44 changed files with 587 additions and 556 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
.chainlit/
.files/
.zed/
pyrightconfig.json
proxies.py
site/
tests/data/
Expand Down
10 changes: 5 additions & 5 deletions blendsql/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class IngredientType(str, Enum, metaclass=StrInMeta):

@dataclass
class IngredientKwarg:
QUESTION = "question"
CONTEXT = "context"
VALUES = "values"
OPTIONS = "options"
MODEL = "model"
QUESTION: str = "question"
CONTEXT: str = "context"
VALUES: str = "values"
OPTIONS: str = "options"
MODEL: str = "model"
42 changes: 4 additions & 38 deletions blendsql/_program.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from __future__ import annotations
from typing import Tuple
from typing import Tuple, Type
import inspect
from outlines.models import LogitsGenerator
import ast
import textwrap
import logging
from colorama import Fore
from abc import abstractmethod

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from .models import Model
from ._logger import logger


class Program:
Expand All @@ -36,7 +33,7 @@ def __call__(self, model: Model, context: pd.DataFrame) -> Tuple[str, str]:
prompt = f"Summarize the following table. {context.to_string()}"
# Below we follow the outlines pattern for unconstrained text generation
# https://github.com/outlines-dev/outlines
generator = outlines.generate.text(model.logits_generator)
generator = outlines.generate.text(model.model_obj)
response: str = generator(prompt)
# Finally, return (response, prompt) tuple
# Returning the prompt here allows the underlying BlendSQL classes to track token usage
Expand Down Expand Up @@ -64,38 +61,7 @@ def __call__(self, model: Model, *args, **kwargs) -> Tuple[str, str]:
...


def return_ollama_response(
logits_generator: LogitsGenerator, prompt, **kwargs
) -> Tuple[str, str]:
"""Helper function to work with Ollama models,
since they're not recognized in the Outlines ecosystem.
"""
from ollama import Options

options = Options(**kwargs)
if options.get("temperature") is None:
options["temperature"] = 0.0
stream = logger.level <= logging.DEBUG
response = logits_generator(
messages=[{"role": "user", "content": prompt}],
options=options,
stream=stream,
)
if stream:
chunked_res = []
for chunk in response:
chunked_res.append(chunk["message"]["content"])
print(
Fore.CYAN + chunk["message"]["content"] + Fore.RESET,
end="",
flush=True,
)
print("\n")
return ("".join(chunked_res), prompt)
return (response["message"]["content"], prompt)


def program_to_str(program: Program):
def program_to_str(program: Type[Program]) -> str:
"""Create a string representation of a program.
It is slightly tricky, since in addition to getting the code content, we need to
1) identify all global variables referenced within a function, and then
Expand Down
6 changes: 3 additions & 3 deletions blendsql/_smoothie.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import List, Collection
from typing import List, Iterable, Type
import pandas as pd

from .ingredients import Ingredient
Expand All @@ -20,8 +20,8 @@ class SmoothieMeta:
num_values_passed: int # Number of values passed to a Map/Join/QA ingredient
prompt_tokens: int
completion_tokens: int
prompts: List[str] # Log of prompts submitted to model
ingredients: Collection[Ingredient]
prompts: List[dict] # Log of prompts submitted to model
ingredients: Iterable[Type[Ingredient]]
query: str
db_url: str
contains_ingredient: bool = True
Expand Down
91 changes: 70 additions & 21 deletions blendsql/_sqlglot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
import sqlglot
from sqlglot import exp, Schema
from sqlglot.optimizer.scope import build_scope
from typing import Generator, List, Set, Tuple, Union, Callable, Type, Optional
from typing import (
Generator,
List,
Set,
Tuple,
Union,
Callable,
Type,
Optional,
Dict,
Any,
Literal,
)
from ast import literal_eval
from sqlglot.optimizer.scope import find_all_in_scope, find_in_scope
from attr import attrs, attrib
Expand Down Expand Up @@ -278,6 +290,8 @@ def extract_multi_table_predicates(
def get_first_child(node):
"""
Helper function to get first child of a node.
The default argument to `walk()` is bfs=True,
meaning we do breadth-first search.
"""
gen = node.walk()
_ = next(gen)
Expand Down Expand Up @@ -411,9 +425,11 @@ def get_scope_nodes(
@attrs
class QueryContextManager:
"""Handles manipulation of underlying SQL query.
We need to maintain two representations here:
- The underlying sqlglot exp.Expression
- The string representation of the query
We need to maintain two synced representations here:
1) The underlying sqlglot exp.Expression node
2) The string representation of the query
"""

node: exp.Expression = attrib(default=None)
Expand All @@ -425,6 +441,7 @@ def parse(self, query, schema: Optional[Union[dict, Schema]] = None):
self.node = _parse_one(query, schema=schema)

def to_string(self):
# Only call `recover_blendsql` if we need to
if hash(self.node) != hash(self._last_to_string_node):
self._query = recover_blendsql(self.node.sql(dialect=FTS5SQLite))
self.last_to_string_node = self.node
Expand All @@ -447,8 +464,6 @@ class SubqueryContextManager:
root: sqlglot.optimizer.scope.Scope = attrib(init=False)

def __attrs_post_init__(self):
if self.alias_to_subquery is None:
self.alias_to_subquery = {}
self.alias_to_tablename = {}
self.tablename_to_alias = {}
# https://github.com/tobymao/sqlglot/blob/v20.9.0/posts/ast_primer.md#scope
Expand All @@ -474,8 +489,18 @@ def abstracted_table_selects(self) -> Generator[Tuple[str, str], None, None]:
abstracted_queries: Generator with (tablename, exp.Select, alias_to_table). The exp.Select is the abstracted query.
Examples:
>>> {{Model('is this an italian restaurant?', 'transactions::merchant', endpoint_name='gpt-4')}} = 1 AND child_category = 'Restaurants & Dining'
```python
scm = SubqueryContextManager(
node=_parse_one(
"SELECT * FROM transactions WHERE {{Model('is this an italian restaurant?', 'transactions::merchant')}} = TRUE AND child_category = 'Restaurants & Dining'"
)
)
scm.abstracted_table_selects()
```
Returns:
```text
('transactions', 'SELECT * FROM transactions WHERE TRUE AND child_category = \'Restaurants & Dining\'')
```
"""
# TODO: don't really know how to optimize with 'CASE' queries right now
if self.node.find(exp.Case):
Expand Down Expand Up @@ -573,13 +598,18 @@ def _table_star_queries(
table_star_queries: Generator with (tablename, exp.Select). The exp.Select is the table_star query
Examples:
>>> SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)', Name
>>> FROM account_history
>>> LEFT JOIN constituents ON account_history.Symbol = constituents.Symbol
>>> WHERE constituents.Sector = 'Information Technology'
>>> AND lower(Action) like "%dividend%"
```sql
SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)', Name
FROM account_history
LEFT JOIN constituents ON account_history.Symbol = constituents.Symbol
WHERE constituents.Sector = 'Information Technology'
AND lower(Action) like "%dividend%"
```
Returns (after getting str representation of `exp.Select`):
```text
('account_history', 'SELECT * FROM account_history WHERE lower(Action) like "%dividend%')
('constituents', 'SELECT * FROM constituents WHERE sector = \'Information Technology\'')
```
"""
# Use `scope` to get all unique tablenodes in ast
tablenodes = set(
Expand All @@ -600,8 +630,10 @@ def _table_star_queries(
alias = subquery_node.args["alias"]
if alias is None:
# Try to get from parent
if "alias" in subquery_node.parent.args:
alias = subquery_node.parent.args["alias"]
parent_node = subquery_node.parent
if parent_node is not None:
if "alias" in parent_node.args:
alias = parent_node.args["alias"]
if alias is not None:
if not any(x.name == alias.name for x in tablenodes):
tablenodes.add(exp.Table(this=exp.Identifier(this=alias.name)))
Expand Down Expand Up @@ -672,18 +704,35 @@ def infer_gen_constraints(self, start: int, end: int) -> dict:
downstream Model generations.
For example:
>>> SELECT * FROM w WHERE {{LLMMap('Is this true?', 'w::colname')}}
```sql
SELECT * FROM w WHERE {{LLMMap('Is this true?', 'w::colname')}}
```
We can infer given the structure above that we expect `LLMMap` to return a boolean.
This function identifies that.
Arguments:
indices: The string indices pointing to the span within the overall BlendSQL query
containing our ingredient in question.
Returns:
dict, with keys:
output_type: numeric | string | boolean
pattern: regular expression pattern to use in constrained decoding with Model
- output_type
- 'boolean' | 'integer' | 'float' | 'string'
- pattern: regular expression pattern lambda to use in constrained decoding with Model
- See `create_pattern` for more info on these pattern lambdas
"""

def create_pattern(output_type: str) -> Callable[[int], str]:
def create_pattern(
output_type: Literal["boolean", "integer", "float"]
) -> Callable[[int], str]:
"""Helper function to create a pattern lambda.
These pattern lambdas take an integer (num_repeats) and return
a regex pattern which is restricted to repeat exclusively num_repeats times.
"""
if output_type == "boolean":
base_pattern = f"((t|f|{DEFAULT_NAN_ANS}){DEFAULT_ANS_SEP})"
elif output_type == "integer":
Expand All @@ -696,7 +745,7 @@ def create_pattern(output_type: str) -> Callable[[int], str]:
raise ValueError(f"Unknown output_type {output_type}")
return lambda num_repeats: base_pattern + "{" + str(num_repeats) + "}"

added_kwargs = {}
added_kwargs: Dict[str, Any] = {}
ingredient_node = _parse_one(self.sql()[start:end])
child = None
for child, _, _ in self.node.walk():
Expand All @@ -710,7 +759,7 @@ def create_pattern(output_type: str) -> Callable[[int], str]:
# Example: CAST({{LLMMap('jump distance', 'w::notes')}} AS FLOAT)
while isinstance(start_node, exp.Func) and start_node is not None:
start_node = start_node.parent
output_type = None
output_type: Literal["boolean", "integer", "float"] = None
predicate_literals: List[str] = []
if start_node is not None:
predicate_literals = get_predicate_literals(start_node)
Expand All @@ -736,7 +785,7 @@ def create_pattern(output_type: str) -> Callable[[int], str]:
elif isinstance(
ingredient_node_in_context.parent, (exp.Order, exp.Ordered, exp.AggFunc)
):
output_type = "float" # Use 'float' as default numeric pattern
output_type = "float" # Use 'float' as default numeric pattern, since it's more expressive than 'integer'
if output_type is not None:
added_kwargs["output_type"] = output_type
added_kwargs["pattern"] = create_pattern(output_type)
Expand Down
38 changes: 19 additions & 19 deletions blendsql/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
Generator,
Optional,
Callable,
Collection,
Type,
)
from collections.abc import Collection, Iterable
from sqlite3 import OperationalError
from attr import attrs, attrib
from functools import partial
Expand Down Expand Up @@ -64,23 +64,23 @@ class Kitchen(list):
db: Database = attrib()
session_uuid: str = attrib()

added_ingredient_names: set = attrib(init=False)
name_to_ingredient: Dict[str, Ingredient] = attrib(init=False)

def __attrs_post_init__(self):
self.added_ingredient_names = set()
self.name_to_ingredient = {}

def names(self):
return [i.name for i in self]

def get_from_name(self, name: str):
for f in self:
if f.name == name.upper():
return f
raise InvalidBlendSQL(
f"Ingredient '{name}' called, but not found in passed `ingredient` arg!"
)
try:
return self.name_to_ingredient[name.upper()]
except KeyError:
raise InvalidBlendSQL(
f"Ingredient '{name}' called, but not found in passed `ingredient` arg!"
) from None

def extend(self, ingredients: Collection[Type[Ingredient]]) -> None:
def extend(self, ingredients: Iterable[Type[Ingredient]]) -> None:
""" "Initializes ingredients class with base attributes, for use in later operations."""
try:
if not all(issubclass(x, Ingredient) for x in ingredients):
Expand All @@ -94,17 +94,18 @@ def extend(self, ingredients: Collection[Type[Ingredient]]) -> None:
for ingredient in ingredients:
name = ingredient.__name__.upper()
assert (
name not in self.added_ingredient_names
name not in self.name_to_ingredient
), f"Duplicate ingredient names passed! These are case insensitive, be careful.\n{name}"
ingredient = ingredient(
# Initialize the ingredient, going from `Type[Ingredient]` to `Ingredient`
initialied_ingredient: Ingredient = ingredient(
name=name,
# Add db and session_uuid as default kwargs
# This way, ingredients are able to interact with data
db=self.db,
session_uuid=self.session_uuid,
)
self.added_ingredient_names.add(name)
self.append(ingredient)
self.name_to_ingredient[name] = initialied_ingredient
self.append(initialied_ingredient)


def autowrap_query(
Expand Down Expand Up @@ -499,7 +500,7 @@ def _blend(
subquery_str
), # Need to do this so we don't track parents into construct_abstracted_selects
prev_subquery_has_ingredient=prev_subquery_has_ingredient,
alias_to_subquery={table_alias_name: subquery} if in_cte else None,
alias_to_subquery={table_alias_name: subquery} if in_cte else {},
tables_in_ingredients=tables_in_ingredients,
)
for tablename, abstracted_query in scm.abstracted_table_selects():
Expand Down Expand Up @@ -705,10 +706,9 @@ def _blend(
temp_join_tablename,
) = function_out
# Special case for when we have more than 1 ingredient in `JOIN` node left at this point
num_ingredients_in_join = (
len(list(query_context.node.find(exp.Join).find_all(exp.Struct)))
// 2
)
join_node = query_context.node.find(exp.Join)
assert join_node is not None
num_ingredients_in_join = len(list(join_node.find_all(exp.Struct))) // 2
if num_ingredients_in_join > 1:
# Case where we have
# `SELECT * FROM w0 JOIN w0 ON {{B()}} > 1 AND {{A()}} WHERE TRUE`
Expand Down
4 changes: 2 additions & 2 deletions blendsql/db/_database.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Generator, Union, List, Collection, Optional
from typing_extensions import Callable
from typing import Generator, Union, List, Callable, Optional
from collections.abc import Collection
import pandas as pd
from attr import attrib
from sqlalchemy.engine import URL
Expand Down
Loading

0 comments on commit 4f3447c

Please sign in to comment.