Skip to content

Commit

Permalink
Merge pull request #10 from parkervg/sqlalchemy
Browse files Browse the repository at this point in the history
SQLalchemy Refactor - PostgreSQL support, `CREATE TEMP TABLE` logic, Guidance LlamaCpp Model class
  • Loading branch information
parkervg authored May 13, 2024
2 parents 8090cd4 + 061c872 commit 99ff6bb
Show file tree
Hide file tree
Showing 96 changed files with 7,472 additions and 3,544 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
.chainlit/
.files/
proxies.py
site/
tests/data/
Expand Down Expand Up @@ -104,3 +106,4 @@ target/

# kernprof line profiling
*.lprof
/lark-constrained-parsing/
356 changes: 225 additions & 131 deletions README.md

Large diffs are not rendered by default.

78 changes: 78 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import chainlit as cl
from dotenv import load_dotenv
import textwrap

import json
from chainlit import make_async
from blendsql import blend, LLMQA, LLMJoin, LLMMap
from blendsql.models import AzureOpenaiLLM
from blendsql.db import SQLite
from research.prompts.parser_program import ParserProgram
from research.utils.database import to_serialized

load_dotenv(".env")

DB_PATH = "./research/db/hybridqa/2004_United_States_Grand_Prix_0.db"
db = SQLite(DB_PATH, check_same_thread=False)
few_shot_prompt = open("./research/prompts/hybridqa/few_shot.txt").read()
ingredients_prompt = open("./research/prompts/hybridqa/ingredients.txt").read()
serialized_db = to_serialized(db, num_rows=3)


def fewshot_parse(model, **input_program_args):
# Dedent str args
for k, v in input_program_args.items():
if isinstance(v, str):
input_program_args[k] = textwrap.dedent(v)
res = model.predict(program=ParserProgram, **input_program_args)
return textwrap.dedent(res["result"])


@cl.on_message # this function will be called every time a user inputs a message in the UI
async def main(message: cl.Message):
"""
This function is called every time a user inputs a message in the UI.
It sends back an intermediate response from the tool, followed by the final answer.
Args:
message: The user's message.
Returns:
None.
"""
parser_model = AzureOpenaiLLM("gpt-4")
blender_model = AzureOpenaiLLM("gpt-4")

async with cl.Step(
name="Fewshot Parse to BlendSQL", language="sql", type="llm"
) as parser_step:
parser_step.input = message.content
blendsql_query = await make_async(fewshot_parse)(
model=parser_model,
ingredients_prompt=ingredients_prompt,
few_shot_prompt=few_shot_prompt,
serialized_db=serialized_db,
question=message.content,
)

parser_step.output = blendsql_query

async with cl.Step(
name="Execute BlendSQL Script", language="json", type="llm"
) as blender_step:
blender_step.input = blendsql_query
res = await make_async(blend)(
query=blendsql_query,
db=db,
ingredients={LLMMap, LLMQA, LLMJoin},
blender=blender_model,
infer_gen_constraints=True,
verbose=False,
)
blender_step.output = json.dumps(res.meta.prompts, indent=4)

# Send the final answer.
if not res.df.empty:
await cl.Message(content="\n".join([str(i) for i in res.df.values[0]])).send()
else:
await cl.Message(content="Empty response.").send()
3 changes: 2 additions & 1 deletion blendsql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@


from .ingredients.builtin import LLMMap, LLMQA, LLMJoin, DT, LLMValidate
from .blendsql import blend
from .blend import blend
from .nl_to_blendsql import nl_to_blendsql
14 changes: 2 additions & 12 deletions blendsql/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,9 @@ def __contains__(cls, item):
return item in cls.__members__.values()


OPENAI_COMPLETE_LLM = ["text-davinci-003"]
OPENAI_CHAT_LLM = [
"gpt-4",
"gpt-4-32k",
"gpt-4-32k-0613",
"gpt-35-turbo",
"gpt-35-turbo-0613",
"gpt-35-turbo-16k-0613",
"gpt-35-turbo-instruct-0914",
]

DEFAULT_ANS_SEP = ";"
DEFAULT_NAN_ANS = "-"
VALUE_BATCH_SIZE = 10
VALUE_BATCH_SIZE = 5


class IngredientType(str, Enum, metaclass=StrInMeta):
Expand All @@ -36,5 +25,6 @@ class IngredientType(str, Enum, metaclass=StrInMeta):
class IngredientKwarg:
QUESTION = "question"
CONTEXT = "context"
VALUES = "values"
OPTIONS = "options"
MODEL = "model"
34 changes: 22 additions & 12 deletions blendsql/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,48 @@
https://github.com/guidance-ai/guidance
"""
from typing import Optional
from guidance.models import Model, Chat
from guidance.models import Model as GuidanceModel
from guidance.models import Chat as GuidanceChatModel
from guidance import user, system, assistant
from contextlib import nullcontext
import inspect
import ast
import textwrap


def get_contexts(model: Model):
def get_contexts(model: GuidanceModel):
usercontext = nullcontext()
systemcontext = nullcontext()
assistantcontext = nullcontext()
if isinstance(model, Chat):
if isinstance(model, GuidanceChatModel):
usercontext = user()
systemcontext = system()
assistantcontext = assistant()
return (usercontext, systemcontext, assistantcontext)


class Program:
"""
TODO: how to add streaming?
if isinstance(res, ModelStream):
# Fetch actual model by iterating
# https://github.com/guidance-ai/guidance/blob/main/tests/models/test_model.py#L85
for idx, event in enumerate(res):
print(Fore.LIGHTCYAN_EX + str(event) + Fore.RESET)
res = event
"""

def __new__(
self,
model: Model,
model: GuidanceModel,
gen_kwargs: Optional[dict] = None,
few_shot: bool = True,
**kwargs,
):
self.model = model
self.gen_kwargs = gen_kwargs if gen_kwargs is not None else {}
self.gen_kwargs = gen_kwargs
self.gen_kwargs = {} if gen_kwargs is None else gen_kwargs
self.few_shot = few_shot
assert isinstance(
self.model, Model
), f"GuidanceProgram needs a guidance.models.Model object!\nGot {type(self.model)}"
(
self.usercontext,
self.systemcontext,
Expand Down Expand Up @@ -66,18 +75,19 @@ def program_to_str(program: Program):
Some helpful refs:
- https://github.com/universe-proton/universe-topology/issues/15
"""
call_content = textwrap.dedent(inspect.getsource(program.__call__))
source_func = program.__call__
call_content = textwrap.dedent(inspect.getsource(source_func))
root = ast.parse(call_content)
root_names = {node.id for node in ast.walk(root) if isinstance(node, ast.Name)}
co_varnames = set(program.__call__.__code__.co_varnames)
co_varnames = set(source_func.__code__.co_varnames)
names_to_resolve = sorted(root_names.difference(co_varnames))
resolved_names = ""
if len(names_to_resolve) > 0:
globals_as_dict = dict(inspect.getmembers(program.__call__))["__globals__"]
globals_as_dict = dict(inspect.getmembers(source_func))["__globals__"]
for name in names_to_resolve:
if name in globals_as_dict:
val = globals_as_dict[name]
# Ignore functions
# Ignore functions - we really only want scalars here
if not callable(val):
resolved_names += f"{val}\n"
return f"{call_content}{resolved_names}"
5 changes: 2 additions & 3 deletions blendsql/_smoothie.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Any, Collection
from typing import List, Collection
import pandas as pd

from .ingredients import Ingredient
Expand All @@ -11,15 +11,14 @@

@dataclass
class SmoothieMeta:
process_time_seconds: float
num_values_passed: int # Number of values passed to a Map/Join/QA ingredient
num_prompt_tokens: int # Number of prompt tokens (counting user and assistant, i.e. input/output)
prompts: List[str] # Log of prompts submitted to model
example_map_outputs: List[Any] # outputs from a Map ingredient, for debugging
ingredients: Collection[Ingredient]
query: str
db_path: str
contains_ingredient: bool = True
process_time_seconds: float = None


@dataclass
Expand Down
7 changes: 5 additions & 2 deletions blendsql/_sqlglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ def prune_empty_where(node) -> Union[exp.Expression, None]:
# Don't check *all* predicates here
# Since 'WHERE a = TRUE' is valid and should be kept
elif isinstance(node, exp.In):
if isinstance(node.args["query"], exp.Boolean):
return None
if "query" in node.args:
if isinstance(node.args["query"], exp.Boolean):
return None
return node


Expand Down Expand Up @@ -323,7 +324,9 @@ def get_reversed_subqueries(node):
"""Iterates through all subqueries (either parentheses or select).
Reverses all EXCEPT for CTEs, which should remain in order.
"""
# First, fetch all common table expressions
r = [i for i in node.find_all(SUBQUERY_EXP + (exp.Paren,)) if is_in_cte(i)]
# Then, add (reversed) other subqueries
return (
r
+ [i for i in node.find_all(SUBQUERY_EXP + (exp.Paren,)) if not is_in_cte(i)][
Expand Down
Loading

0 comments on commit 99ff6bb

Please sign in to comment.