From 2c72a46306c43b5f4d1034140c5a63d81856d788 Mon Sep 17 00:00:00 2001 From: parkervg Date: Fri, 14 Jun 2024 13:22:45 -0400 Subject: [PATCH 01/10] Lots of mypy fixes --- .zed/settings.json | 0 blendsql/_constants.py | 10 +- blendsql/_program.py | 6 +- blendsql/_smoothie.py | 6 +- blendsql/_sqlglot.py | 23 ++- blendsql/blend.py | 38 ++--- blendsql/db/_database.py | 4 +- blendsql/db/_duckdb.py | 5 +- blendsql/db/_sqlalchemy.py | 7 +- blendsql/grammars/utils.py | 3 +- blendsql/ingredients/builtin/join/main.py | 2 +- blendsql/ingredients/builtin/map/main.py | 4 +- blendsql/ingredients/builtin/qa/main.py | 2 +- blendsql/ingredients/ingredient.py | 38 +++-- blendsql/models/_model.py | 27 +-- blendsql/models/local/_llama_cpp.py | 7 +- blendsql/models/local/_transformers.py | 8 +- blendsql/models/remote/_ollama.py | 2 +- blendsql/models/remote/_openai.py | 18 +- blendsql/nl_to_blendsql/nl_to_blendsql.py | 3 +- blendsql/prompts/_prompts.py | 4 +- blendsql/utils.py | 2 +- examples/benchmarks/with_blendsql.py | 53 ------ examples/benchmarks/without_blendsql.py | 157 ------------------ .../without_blendsql_with_guidance.py | 93 ----------- pyrightconfig.json | 4 + run-debug.py | 124 ++++++++++++++ run-nl-to-blendsql.py | 30 ++++ run-tests.py | 45 +++++ 29 files changed, 321 insertions(+), 404 deletions(-) create mode 100644 .zed/settings.json delete mode 100644 examples/benchmarks/with_blendsql.py delete mode 100644 examples/benchmarks/without_blendsql.py delete mode 100644 examples/benchmarks/without_blendsql_with_guidance.py create mode 100644 pyrightconfig.json create mode 100644 run-debug.py create mode 100644 run-nl-to-blendsql.py create mode 100644 run-tests.py diff --git a/.zed/settings.json b/.zed/settings.json new file mode 100644 index 00000000..e69de29b diff --git a/blendsql/_constants.py b/blendsql/_constants.py index 529816f0..01c455da 100644 --- a/blendsql/_constants.py +++ b/blendsql/_constants.py @@ -23,8 +23,8 @@ class IngredientType(str, Enum, metaclass=StrInMeta): @dataclass class IngredientKwarg: - QUESTION = "question" - CONTEXT = "context" - VALUES = "values" - OPTIONS = "options" - MODEL = "model" + QUESTION: str = "question" + CONTEXT: str = "context" + VALUES: str = "values" + OPTIONS: str = "options" + MODEL: str = "model" diff --git a/blendsql/_program.py b/blendsql/_program.py index 8fb8a35b..88f06de2 100644 --- a/blendsql/_program.py +++ b/blendsql/_program.py @@ -1,12 +1,12 @@ from __future__ import annotations from typing import Tuple import inspect -from outlines.models import LogitsGenerator import ast import textwrap import logging from colorama import Fore from abc import abstractmethod +from functools import partial from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -65,7 +65,7 @@ def __call__(self, model: Model, *args, **kwargs) -> Tuple[str, str]: def return_ollama_response( - logits_generator: LogitsGenerator, prompt, **kwargs + logits_generator: partial, prompt, **kwargs ) -> Tuple[str, str]: """Helper function to work with Ollama models, since they're not recognized in the Outlines ecosystem. @@ -95,7 +95,7 @@ def return_ollama_response( return (response["message"]["content"], prompt) -def program_to_str(program: Program): +def program_to_str(program: Program) -> str: """Create a string representation of a program. It is slightly tricky, since in addition to getting the code content, we need to 1) identify all global variables referenced within a function, and then diff --git a/blendsql/_smoothie.py b/blendsql/_smoothie.py index a94bd471..c311a97b 100644 --- a/blendsql/_smoothie.py +++ b/blendsql/_smoothie.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Collection +from typing import List, Iterable, Type import pandas as pd from .ingredients import Ingredient @@ -20,8 +20,8 @@ class SmoothieMeta: num_values_passed: int # Number of values passed to a Map/Join/QA ingredient prompt_tokens: int completion_tokens: int - prompts: List[str] # Log of prompts submitted to model - ingredients: Collection[Ingredient] + prompts: List[dict] # Log of prompts submitted to model + ingredients: Iterable[Type[Ingredient]] query: str db_url: str contains_ingredient: bool = True diff --git a/blendsql/_sqlglot.py b/blendsql/_sqlglot.py index 38cd92cd..a13c94f8 100644 --- a/blendsql/_sqlglot.py +++ b/blendsql/_sqlglot.py @@ -1,7 +1,18 @@ import sqlglot from sqlglot import exp, Schema from sqlglot.optimizer.scope import build_scope -from typing import Generator, List, Set, Tuple, Union, Callable, Type, Optional +from typing import ( + Generator, + List, + Set, + Tuple, + Union, + Callable, + Type, + Optional, + Dict, + Any, +) from ast import literal_eval from sqlglot.optimizer.scope import find_all_in_scope, find_in_scope from attr import attrs, attrib @@ -447,8 +458,6 @@ class SubqueryContextManager: root: sqlglot.optimizer.scope.Scope = attrib(init=False) def __attrs_post_init__(self): - if self.alias_to_subquery is None: - self.alias_to_subquery = {} self.alias_to_tablename = {} self.tablename_to_alias = {} # https://github.com/tobymao/sqlglot/blob/v20.9.0/posts/ast_primer.md#scope @@ -600,8 +609,10 @@ def _table_star_queries( alias = subquery_node.args["alias"] if alias is None: # Try to get from parent - if "alias" in subquery_node.parent.args: - alias = subquery_node.parent.args["alias"] + parent_node = subquery_node.parent + if parent_node is not None: + if "alias" in parent_node.args: + alias = parent_node.args["alias"] if alias is not None: if not any(x.name == alias.name for x in tablenodes): tablenodes.add(exp.Table(this=exp.Identifier(this=alias.name))) @@ -696,7 +707,7 @@ def create_pattern(output_type: str) -> Callable[[int], str]: raise ValueError(f"Unknown output_type {output_type}") return lambda num_repeats: base_pattern + "{" + str(num_repeats) + "}" - added_kwargs = {} + added_kwargs: Dict[str, Any] = {} ingredient_node = _parse_one(self.sql()[start:end]) child = None for child, _, _ in self.node.walk(): diff --git a/blendsql/blend.py b/blendsql/blend.py index d8cf9ae2..f626cc13 100644 --- a/blendsql/blend.py +++ b/blendsql/blend.py @@ -12,9 +12,9 @@ Generator, Optional, Callable, - Collection, Type, ) +from collections.abc import Collection, Iterable from sqlite3 import OperationalError from attr import attrs, attrib from functools import partial @@ -64,23 +64,23 @@ class Kitchen(list): db: Database = attrib() session_uuid: str = attrib() - added_ingredient_names: set = attrib(init=False) + name_to_ingredient: Dict[str, Ingredient] = attrib(init=False) def __attrs_post_init__(self): - self.added_ingredient_names = set() + self.name_to_ingredient = {} def names(self): return [i.name for i in self] def get_from_name(self, name: str): - for f in self: - if f.name == name.upper(): - return f - raise InvalidBlendSQL( - f"Ingredient '{name}' called, but not found in passed `ingredient` arg!" - ) + try: + return self.name_to_ingredient[name.upper()] + except KeyError: + raise InvalidBlendSQL( + f"Ingredient '{name}' called, but not found in passed `ingredient` arg!" + ) from None - def extend(self, ingredients: Collection[Type[Ingredient]]) -> None: + def extend(self, ingredients: Iterable[Type[Ingredient]]) -> None: """ "Initializes ingredients class with base attributes, for use in later operations.""" try: if not all(issubclass(x, Ingredient) for x in ingredients): @@ -94,17 +94,18 @@ def extend(self, ingredients: Collection[Type[Ingredient]]) -> None: for ingredient in ingredients: name = ingredient.__name__.upper() assert ( - name not in self.added_ingredient_names + name not in self.name_to_ingredient ), f"Duplicate ingredient names passed! These are case insensitive, be careful.\n{name}" - ingredient = ingredient( + # Initialize the ingredient, going from `Type[Ingredient]` to `Ingredient` + initialied_ingredient: Ingredient = ingredient( name=name, # Add db and session_uuid as default kwargs # This way, ingredients are able to interact with data db=self.db, session_uuid=self.session_uuid, ) - self.added_ingredient_names.add(name) - self.append(ingredient) + self.name_to_ingredient[name] = initialied_ingredient + self.append(initialied_ingredient) def autowrap_query( @@ -499,7 +500,7 @@ def _blend( subquery_str ), # Need to do this so we don't track parents into construct_abstracted_selects prev_subquery_has_ingredient=prev_subquery_has_ingredient, - alias_to_subquery={table_alias_name: subquery} if in_cte else None, + alias_to_subquery={table_alias_name: subquery} if in_cte else {}, tables_in_ingredients=tables_in_ingredients, ) for tablename, abstracted_query in scm.abstracted_table_selects(): @@ -705,10 +706,9 @@ def _blend( temp_join_tablename, ) = function_out # Special case for when we have more than 1 ingredient in `JOIN` node left at this point - num_ingredients_in_join = ( - len(list(query_context.node.find(exp.Join).find_all(exp.Struct))) - // 2 - ) + join_node = query_context.node.find(exp.Join) + assert join_node is not None + num_ingredients_in_join = len(list(join_node.find_all(exp.Struct))) // 2 if num_ingredients_in_join > 1: # Case where we have # `SELECT * FROM w0 JOIN w0 ON {{B()}} > 1 AND {{A()}} WHERE TRUE` diff --git a/blendsql/db/_database.py b/blendsql/db/_database.py index 618e33f2..a5f8a85b 100644 --- a/blendsql/db/_database.py +++ b/blendsql/db/_database.py @@ -1,5 +1,5 @@ -from typing import Generator, Union, List, Collection, Optional -from typing_extensions import Callable +from typing import Generator, Union, List, Callable, Optional +from collections.abc import Collection import pandas as pd from attr import attrib from sqlalchemy.engine import URL diff --git a/blendsql/db/_duckdb.py b/blendsql/db/_duckdb.py index 46baff1b..2aed5118 100644 --- a/blendsql/db/_duckdb.py +++ b/blendsql/db/_duckdb.py @@ -1,5 +1,6 @@ import importlib.util -from typing import Dict, Optional, List, Collection, Type, Generator, Set, Union +from typing import Dict, Optional, List, Generator, Set, Union, Callable +from collections.abc import Collection import pandas as pd from colorama import Fore from attr import attrs, attrib @@ -167,7 +168,7 @@ def execute_to_df(self, query: str, params: Optional[dict] = None) -> pd.DataFra return self.con.sql(query).df() def execute_to_list( - self, query: str, to_type: Optional[Type] = lambda x: x + self, query: str, to_type: Optional[Callable] = lambda x: x ) -> list: res = [] for row in self.con.execute(query).fetchall(): diff --git a/blendsql/db/_sqlalchemy.py b/blendsql/db/_sqlalchemy.py index f481cf55..e47f131d 100644 --- a/blendsql/db/_sqlalchemy.py +++ b/blendsql/db/_sqlalchemy.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Collection, Type, Optional, Union +from typing import Generator, List, Callable, Optional, Union +from collections.abc import Collection import pandas as pd from colorama import Fore import re @@ -149,9 +150,7 @@ def execute_to_df(self, query: str, params: Optional[dict] = None) -> pd.DataFra """ return pd.read_sql(text(query), self.con, params=params) - def execute_to_list( - self, query: str, to_type: Optional[Type] = lambda x: x - ) -> list: + def execute_to_list(self, query: str, to_type: Callable = lambda x: x) -> list: """A lower-level execute method that doesn't use the pandas processing logic. Returns results as a tuple. """ diff --git a/blendsql/grammars/utils.py b/blendsql/grammars/utils.py index baf03242..65db5c6d 100644 --- a/blendsql/grammars/utils.py +++ b/blendsql/grammars/utils.py @@ -1,5 +1,6 @@ from pathlib import Path -from typing import Optional, Collection, List, Dict, Type +from typing import Optional, List, Dict, Type +from collections.abc import Collection from string import Template from colorama import Fore diff --git a/blendsql/ingredients/builtin/join/main.py b/blendsql/ingredients/builtin/join/main.py index 9b5d5e77..916d25f2 100644 --- a/blendsql/ingredients/builtin/join/main.py +++ b/blendsql/ingredients/builtin/join/main.py @@ -117,7 +117,7 @@ def __call__( if isinstance(model, OllamaLLM): # Handle call to ollama return return_ollama_response( - logits_generator=model.logits_generator, + logits_generator=model.logits_generator, # type: ignore prompt=prompt, max_tokens=max_tokens, temperature=0.0, diff --git a/blendsql/ingredients/builtin/map/main.py b/blendsql/ingredients/builtin/map/main.py index 696f5df0..3ce0a09d 100644 --- a/blendsql/ingredients/builtin/map/main.py +++ b/blendsql/ingredients/builtin/map/main.py @@ -123,7 +123,7 @@ def __call__( if isinstance(model, OllamaLLM): # Handle call to ollama return return_ollama_response( - logits_generator=model.logits_generator, + logits_generator=model.logits_generator, # type: ignore prompt=prompt, max_tokens=max_tokens, temperature=0.0, @@ -180,7 +180,7 @@ def run( table_title = table_to_title[tablename] split_results: List[Union[str, None]] = [] # Only use tqdm if we're in debug mode - context_manager = ( + context_manager: Iterable = ( tqdm( range(0, len(values), CONST.MAP_BATCH_SIZE), total=len(values) // CONST.MAP_BATCH_SIZE, diff --git a/blendsql/ingredients/builtin/qa/main.py b/blendsql/ingredients/builtin/qa/main.py index 2326455b..0542b1de 100644 --- a/blendsql/ingredients/builtin/qa/main.py +++ b/blendsql/ingredients/builtin/qa/main.py @@ -70,7 +70,7 @@ def __call__( if isinstance(model, OllamaLLM): # Handle call to ollama return return_ollama_response( - logits_generator=model.logits_generator, + logits_generator=model.logits_generator, # type: ignore prompt=prompt, max_tokens=max_tokens, temperature=0.0, diff --git a/blendsql/ingredients/ingredient.py b/blendsql/ingredients/ingredient.py index 47bb3a2b..f7fe9103 100644 --- a/blendsql/ingredients/ingredient.py +++ b/blendsql/ingredients/ingredient.py @@ -6,17 +6,15 @@ from skrub import Joiner from typing import ( Any, - Iterable, Union, Dict, Tuple, - Type, Callable, Set, - Collection, Optional, List, ) +from collections.abc import Collection, Iterable import uuid from colorama import Fore from typeguard import check_type @@ -54,7 +52,7 @@ class Ingredient: session_uuid: str = attrib() ingredient_type: str = attrib(init=False) - allowed_output_types: Tuple[Type] = attrib(init=False) + allowed_output_types: Tuple[Any] = attrib(init=False) num_values_passed: int = 0 def __repr__(self): @@ -67,6 +65,10 @@ def __str__(self): def run(self, *args, **kwargs) -> Any: ... + @abstractmethod + def __call__(self, *args, **kwargs) -> Any: + ... + def _run(self, *args, **kwargs): return check_type(self.run(*args, **kwargs), self.allowed_output_types) @@ -88,7 +90,7 @@ class MapIngredient(Ingredient): to each of the given values, creating a new column.""" ingredient_type: str = IngredientType.MAP.value - allowed_output_types: Tuple[Type] = (Iterable[Any],) + allowed_output_types: Tuple[Any] = (Iterable[Any],) def unpack_default_kwargs(self, **kwargs): return unpack_default_kwargs(**kwargs) @@ -96,10 +98,10 @@ def unpack_default_kwargs(self, **kwargs): def __call__(self, question: str, context: str, *args, **kwargs) -> tuple: """Returns tuple with format (arg, tablename, colname, new_table)""" # Unpack kwargs - aliases_to_tablenames: Dict[str, str] = kwargs.get("aliases_to_tablenames") - get_temp_subquery_table: Callable = kwargs.get("get_temp_subquery_table") - get_temp_session_table: Callable = kwargs.get("get_temp_session_table") - prev_subquery_map_columns: Set[str] = kwargs.get("prev_subquery_map_columns") + aliases_to_tablenames: Dict[str, str] = kwargs["aliases_to_tablenames"] + get_temp_subquery_table: Callable = kwargs["get_temp_subquery_table"] + get_temp_session_table: Callable = kwargs["get_temp_session_table"] + prev_subquery_map_columns: Set[str] = kwargs["prev_subquery_map_columns"] tablename, colname = utils.get_tablename_colname(context) tablename = aliases_to_tablenames.get(tablename, tablename) @@ -200,7 +202,7 @@ class JoinIngredient(Ingredient): use_skrub_joiner: bool = attrib(default=True) ingredient_type: str = IngredientType.JOIN.value - allowed_output_types: Tuple[Type] = (dict,) + allowed_output_types: Tuple[Any] = (dict,) @classmethod def from_args(cls, use_skrub_joiner: bool = True): @@ -215,16 +217,16 @@ def __call__( **kwargs, ) -> tuple: # Unpack kwargs - aliases_to_tablenames: Dict[str, str] = kwargs.get("aliases_to_tablenames") - get_temp_subquery_table: Callable = kwargs.get("get_temp_subquery_table") - get_temp_session_table: Callable = kwargs.get("get_temp_session_table") + aliases_to_tablenames: Dict[str, str] = kwargs["aliases_to_tablenames"] + get_temp_subquery_table: Callable = kwargs["get_temp_subquery_table"] + get_temp_session_table: Callable = kwargs["get_temp_session_table"] # Depending on the size of the underlying data, it may be optimal to swap # the order of 'left_on' and 'right_on' columns during processing swapped = False values = [] original_lr_identifiers = [] modified_lr_identifiers = [] - mapping = {} + mapping: Dict[str, str] = {} for on_arg in [left_on, right_on]: tablename, colname = utils.get_tablename_colname(on_arg) tablename = aliases_to_tablenames.get(tablename, tablename) @@ -344,7 +346,7 @@ def run(self, *args, **kwargs) -> dict: @attrs class QAIngredient(Ingredient): ingredient_type: str = IngredientType.QA.value - allowed_output_types: Tuple[Type] = (Union[str, int, float],) + allowed_output_types: Tuple[Any] = (Union[str, int, float],) def __call__( self, @@ -355,7 +357,7 @@ def __call__( **kwargs, ) -> Tuple[Union[str, int, float], Optional[exp.Expression]]: # Unpack kwargs - aliases_to_tablenames: Dict[str, str] = kwargs.get("aliases_to_tablenames") + aliases_to_tablenames: Dict[str, str] = kwargs["aliases_to_tablenames"] subtable: Union[pd.DataFrame, None] = None if context is not None: @@ -371,7 +373,7 @@ def __call__( f'SELECT "{colname}" FROM "{tablename}"' ) elif isinstance(context, pd.DataFrame): - subtable = context + subtable: pd.DataFrame = context else: raise ValueError( f"Unknown type for `identifier` arg in QAIngredient: {type(context)}" @@ -417,7 +419,7 @@ class StringIngredient(Ingredient): """Outputs a string to be placed directly into the SQL query.""" ingredient_type: str = IngredientType.STRING.value - allowed_output_types: Tuple[Type] = (str,) + allowed_output_types: Tuple[Any] = (str,) def unpack_default_kwargs(self, **kwargs): return unpack_default_kwargs(**kwargs) diff --git a/blendsql/models/_model.py b/blendsql/models/_model.py index 77ff64e5..50582c30 100644 --- a/blendsql/models/_model.py +++ b/blendsql/models/_model.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Type +from typing import Any, List, Optional, Type, Dict import pandas as pd from attr import attrib, attrs from pathlib import Path @@ -45,12 +45,12 @@ class Model: tokenizer: Any = attrib(default=None) requires_config: bool = attrib(default=False) refresh_interval_min: Optional[int] = attrib(default=None) - load_model_kwargs: Optional[dict] = attrib(default={}) + load_model_kwargs: dict = attrib(default=None) env: str = attrib(default=".") caching: bool = attrib(default=True) logits_generator: LogitsGenerator = attrib(init=False) - prompts: list = attrib(init=False) + prompts: List[dict] = attrib(init=False) prompt_tokens: int = attrib(init=False) completion_tokens: int = attrib(init=False) num_calls: int = attrib(init=False) @@ -63,7 +63,9 @@ def __attrs_post_init__(self): Path(platformdirs.user_cache_dir("blendsql")) / f"{self.model_name_or_path}.diskcache" ) - self.prompts: List[str] = [] + if self.load_model_kwargs is None: + self.load_model_kwargs = {} + self.prompts: List[dict] = [] self.prompt_tokens = 0 self.completion_tokens = 0 self.num_calls = 0 @@ -107,13 +109,12 @@ def predict(self, program: Type[Program], **kwargs) -> str: """ if self.caching: # First, check our cache - key = self._create_key(program, **kwargs) + key: str = self._create_key(program, **kwargs) if key in self.cache: logger.debug(Fore.MAGENTA + "Using cache..." + Fore.RESET) - self.prompts.insert( - -1, self.format_prompt(self.cache.get(key), **kwargs) - ) - return self.cache.get(key) + response: str = self.cache.get(key) # type: ignore + self.prompts.insert(-1, self.format_prompt(response, **kwargs)) + return response # Modify fields used for tracking Model usage response: str prompt: str @@ -124,10 +125,10 @@ def predict(self, program: Type[Program], **kwargs) -> str: self.prompt_tokens += len(self.tokenizer.encode(prompt)) self.completion_tokens += len(self.tokenizer.encode(response)) if self.caching: - self.cache[key] = response + self.cache[key] = response # type: ignore return response - def _create_key(self, program: Program, **kwargs) -> str: + def _create_key(self, program: Type[Program], **kwargs) -> str: """Generates a hash to use in diskcache Cache. This way, we don't need to send our prompts to the same Model if our context of Model + program + kwargs is the same. @@ -157,11 +158,11 @@ def _create_key(self, program: Program, **kwargs) -> str: @cached_property def logits_generator(self) -> LogitsGenerator: """Allows for lazy loading of underlying model.""" - return self._load_model(**self.load_model_kwargs) + return self._load_model() @staticmethod def format_prompt(response: str, **kwargs) -> dict: - d = {"answer": response} + d: Dict[str, Any] = {"answer": response} if IngredientKwarg.QUESTION in kwargs: d[IngredientKwarg.QUESTION] = kwargs.get(IngredientKwarg.QUESTION) if IngredientKwarg.CONTEXT in kwargs: diff --git a/blendsql/models/local/_llama_cpp.py b/blendsql/models/local/_llama_cpp.py index 98b1f6a4..372e21ea 100644 --- a/blendsql/models/local/_llama_cpp.py +++ b/blendsql/models/local/_llama_cpp.py @@ -1,5 +1,6 @@ import importlib.util -from outlines.models import llamacpp, LogitsGenerator +from outlines.models import LogitsGenerator +from outlines.models.llamacpp import llamacpp from .._model import LocalModel from typing import Optional @@ -55,10 +56,10 @@ def __init__( caching=caching, ) - def _load_model(self, filename: str, **kwargs) -> LogitsGenerator: + def _load_model(self, filename: str) -> LogitsGenerator: return llamacpp( self.model_name_or_path, filename=filename, tokenizer=self._llama_tokenizer, - **kwargs + **self.load_model_kwargs ) diff --git a/blendsql/models/local/_transformers.py b/blendsql/models/local/_transformers.py index 74d80ef9..ddb6a1da 100644 --- a/blendsql/models/local/_transformers.py +++ b/blendsql/models/local/_transformers.py @@ -1,8 +1,11 @@ import importlib.util -from outlines.models import transformers, LogitsGenerator +from outlines.models import LogitsGenerator +from outlines.models.transformers import transformers from .._model import LocalModel +DEFAULT_KWARGS = {"do_sample": True, "temperature": 0.0, "top_p": 1.0} + _has_transformers = importlib.util.find_spec("transformers") is not None _has_torch = importlib.util.find_spec("torch") is not None @@ -42,6 +45,7 @@ def __init__(self, model_name_or_path: str, caching: bool = True, **kwargs): model_name_or_path=model_name_or_path, requires_config=False, tokenizer=transformers.AutoTokenizer.from_pretrained(model_name_or_path), + load_model_kwargs=DEFAULT_KWARGS | kwargs, caching=caching, **kwargs ) @@ -50,5 +54,5 @@ def _load_model(self) -> LogitsGenerator: # https://huggingface.co/blog/how-to-generate return transformers( self.model_name_or_path, - model_kwargs={"do_sample": True, "temperature": 0.0, "top_p": 1.0}, + model_kwargs=self.load_model_kwargs, ) diff --git a/blendsql/models/remote/_ollama.py b/blendsql/models/remote/_ollama.py index 60283438..42b26005 100644 --- a/blendsql/models/remote/_ollama.py +++ b/blendsql/models/remote/_ollama.py @@ -52,7 +52,7 @@ def __init__( **kwargs ) - def _load_model(self, **kwargs) -> partial: + def _load_model(self) -> partial: import ollama return partial( diff --git a/blendsql/models/remote/_openai.py b/blendsql/models/remote/_openai.py index a08a5985..20dc53eb 100644 --- a/blendsql/models/remote/_openai.py +++ b/blendsql/models/remote/_openai.py @@ -1,7 +1,7 @@ import os import importlib.util -from outlines.models import openai, azure_openai, LogitsGenerator -from outlines.models.openai import OpenAIConfig +from outlines.models import LogitsGenerator +from outlines.models.openai import openai, azure_openai, OpenAIConfig from .._model import RemoteModel from typing import Optional @@ -104,14 +104,13 @@ def __init__( **kwargs ) - def _load_model(self, config: OpenAIConfig, **kwargs) -> LogitsGenerator: + def _load_model(self, config: OpenAIConfig) -> LogitsGenerator: return azure_openai( self.model_name_or_path, config=config, azure_endpoint=os.getenv("OPENAI_API_BASE"), api_key=os.getenv("OPENAI_API_KEY"), - **kwargs - ) + ) # type: ignore def _setup(self, **kwargs) -> None: openai_setup() @@ -172,13 +171,10 @@ def __init__( **kwargs ) - def _load_model(self, config: OpenAIConfig, **kwargs) -> LogitsGenerator: + def _load_model(self, config: OpenAIConfig) -> LogitsGenerator: return openai( - self.model_name_or_path, - config=config, - api_key=os.getenv("OPENAI_API_KEY"), - **kwargs - ) + self.model_name_or_path, config=config, api_key=os.getenv("OPENAI_API_KEY") + ) # type: ignore def _setup(self, **kwargs) -> None: openai_setup() diff --git a/blendsql/nl_to_blendsql/nl_to_blendsql.py b/blendsql/nl_to_blendsql/nl_to_blendsql.py index 2a2d2acd..1bf4c75b 100644 --- a/blendsql/nl_to_blendsql/nl_to_blendsql.py +++ b/blendsql/nl_to_blendsql/nl_to_blendsql.py @@ -1,4 +1,5 @@ -from typing import Collection, Tuple, Set, Optional, Union, Type +from typing import Tuple, Set, Optional, Union, Type +from collections.abc import Collection from textwrap import dedent import outlines from colorama import Fore diff --git a/blendsql/prompts/_prompts.py b/blendsql/prompts/_prompts.py index eab965ba..9732b630 100644 --- a/blendsql/prompts/_prompts.py +++ b/blendsql/prompts/_prompts.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from pathlib import Path from attr import attrs, attrib -from typing import List, Collection, Type, Set +from typing import List, Iterable, Type, Set from ..ingredients import Ingredient from ..grammars._peg_grammar import grammar as peg_grammar @@ -67,7 +67,7 @@ def is_valid_query(self, query: str, ingredient_names: Set[str]) -> bool: stack.append(arg) return True - def filter(self, ingredients: Collection[Type[Ingredient]]) -> "Examples": + def filter(self, ingredients: Iterable[Type[Ingredient]]) -> "Examples": """Retrieve only those prompts which do not include any ingredient not specified in `ingredients`.""" ingredient_names: Set[str] = { ingredient.__name__.upper() for ingredient in ingredients diff --git a/blendsql/utils.py b/blendsql/utils.py index 58931b27..2f24fa4e 100644 --- a/blendsql/utils.py +++ b/blendsql/utils.py @@ -71,7 +71,7 @@ def recover_blendsql(select_sql: str): def get_temp_subquery_table( - session_uuid: str, subquery_idx: str, tablename: str + session_uuid: str, subquery_idx: int, tablename: str ) -> str: """Generates temporary tablename for a subquery""" return f"{session_uuid}_{tablename}_{subquery_idx}" diff --git a/examples/benchmarks/with_blendsql.py b/examples/benchmarks/with_blendsql.py deleted file mode 100644 index 4f714e7b..00000000 --- a/examples/benchmarks/with_blendsql.py +++ /dev/null @@ -1,53 +0,0 @@ -import time -import statistics - -from blendsql import blend, LLMMap -from blendsql.db import SQLite -from blendsql.models import AzureOpenaiLLM -from blendsql.utils import fetch_from_hub -from constants import model - -if __name__ == "__main__": - """ - PYTHONPATH=$PWD:$PYTHONPATH kernprof -lv examples/benchmarks/with_blendsql.py - """ - times = [] - iterations = 1 - db = SQLite(fetch_from_hub("multi_table.db")) - blender = AzureOpenaiLLM(model) - print(f"Using {model}...") - for _ in range(iterations): - # Uncomment if we want to clear cache first - # guidance.llms.OpenAI.cache.clear() - start = time.time() - blendsql = """ - SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)' - FROM account_history - WHERE Symbol IN - ( - SELECT Symbol FROM constituents - WHERE sector = 'Information Technology' - AND {{ - LLMMap( - 'does this company manufacture cell phones?', - 'constituents::Name' - ) - }} = TRUE - ) - AND lower(Action) like "%dividend%" - """ - smoothie = blend( - query=blendsql, - db=db, - ingredients={LLMMap}, - blender=blender, - blender_args={"few_shot": False}, - verbose=False, - ) - runtime = time.time() - start - print(f"Completed with_blendsql benchmark in {runtime} seconds") - print(f"Passed {smoothie.meta.num_values_passed} total values to LLM") - times.append(runtime) - print( - f"For {iterations} iterations, average runtime is {statistics.mean(times)} with stdev {statistics.stdev(times) if len(times) > 1 else 0}" - ) diff --git a/examples/benchmarks/without_blendsql.py b/examples/benchmarks/without_blendsql.py deleted file mode 100644 index f9fdbc15..00000000 --- a/examples/benchmarks/without_blendsql.py +++ /dev/null @@ -1,157 +0,0 @@ -import time -import os -import pandas as pd -import sqlite3 -from openai import AzureOpenAI -import statistics -import logging -from colorama import Fore -from typing import List, Union -from dotenv import load_dotenv - -from blendsql._constants import VALUE_BATCH_SIZE -from blendsql.utils import fetch_from_hub -from constants import model - - -def construct_messages_payload(prompt: Union[str, None], question: str) -> List: - messages = [] - # Add system prompt - # messages.append({"role": "system", "content": "" if prompt is None else prompt}) - messages.append({"role": "user", "content": prompt}) - return messages - - -map_prompt = """Answer the question row-by-row, in order. -Values can be either '1' (True) or '0' (False). -The answer should be a list separated by ';', and have {answer_length} items in total. - -Question: {question} - -{values} -""" - - -def openai_setup() -> None: - """Setup helper for AzureOpenAI and OpenAI models.""" - if all( - x is not None - for x in { - os.getenv("TENANT_ID"), - os.getenv("CLIENT_ID"), - os.getenv("CLIENT_SECRET"), - } - ): - try: - from azure.identity import ClientSecretCredential - except ImportError: - raise ValueError( - "Found ['TENANT_ID', 'CLIENT_ID', 'CLIENT_SECRET'] in .env file, using Azure OpenAI\nIn order to use Azure OpenAI, run `pip install azure-identity`!" - ) from None - credential = ClientSecretCredential( - tenant_id=os.environ["TENANT_ID"], - client_id=os.environ["CLIENT_ID"], - client_secret=os.environ["CLIENT_SECRET"], - disable_instance_discovery=True, - ) - access_token = credential.get_token( - os.environ["TOKEN_SCOPE"], - tenant_id=os.environ["TENANT_ID"], - ) - os.environ["AZURE_OPENAI_API_KEY"] = access_token.token - elif os.getenv("AZURE_OPENAI_API_KEY") is not None: - pass - else: - raise ValueError( - "Error authenticating with OpenAI\n Without explicit `OPENAI_API_KEY`, you need to provide ['TENANT_ID', 'CLIENT_ID', 'CLIENT_SECRET']" - ) from None - - -if __name__ == "__main__": - load_dotenv() - openai_setup() - con = sqlite3.connect(fetch_from_hub("multi_table.db")) - iterations = 1 - times = [] - client = AzureOpenAI() - print(f"Using {model}...") - for _ in range(iterations): - start = time.time() - # Select initial query results - sql = """ - SELECT * FROM constituents WHERE sector = 'Information Technology' - """ - question = "does this company manufacture cell phones?" - target_column = "Name" - df = pd.read_sql(sql, con) - - # Make our calls to the Model - values = df[target_column].unique().tolist() - # values_dict = [{"value": value, "idx": idx} for idx, value in enumerate(values)] - split_results = [] - # Pass in batches - batch_size = VALUE_BATCH_SIZE - for i in range(0, len(values), batch_size): - prompt = map_prompt.format( - answer_length=len(values[i : i + batch_size]), - question=question, - values="\n".join(values[i : i + batch_size]), - ) - # print(prompt) - payload = construct_messages_payload(prompt=prompt, question="") - res = ( - client.chat.completions.create( - model=model, - # Can be one of {"gpt-4", "gpt-4-32k", "gpt-35-turbo", "text-davinci-003"}, or others in Azure - messages=payload, - ) - .choices[0] - .message.content - ) - _r = [i.strip() for i in res.strip(";").split(";")] - expected_len = len(values[i : i + batch_size]) - if len(_r) != expected_len: - logging.debug( - Fore.YELLOW - + f"Mismatch between length of values and answers!\nvalues:{expected_len}, answers:{len(_r)}" - + Fore.RESET - ) - logging.debug(_r) - # Cut off, in case we over-predicted - _r = _r[:expected_len] - # Add, in case we under-predicted - while len(_r) < expected_len: - _r.append(None) - split_results.extend(_r) - values_passed = len(split_results) - df_as_dict = {target_column: [], question: []} - for idx, value in enumerate(values): - df_as_dict[target_column].append(value) - df_as_dict[question].append( - split_results[idx] if len(split_results) - 1 >= idx else None - ) - subtable = pd.DataFrame(df_as_dict) - # Add new_table to original table - new_table = df.merge(subtable, how="left", on=target_column) - new_table.to_sql("modified_constituents", con, if_exists="replace", index=False) - # Now, new table has original columns + column with the name of the question we answered - sql = """ - SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)' - FROM account_history - WHERE Symbol IN - ( - SELECT Symbol FROM modified_constituents - WHERE sector = 'Information Technology' - AND "does this company manufacture cell phones?" = 1 - ) - AND lower(Action) like "%dividend%" - """ - answer = pd.read_sql(sql, con) - con.execute(f"DROP TABLE 'modified_constituents'") - runtime = time.time() - start - print(f"Completed without_blendsql benchmark in {runtime} seconds") - print(f"Passed {values_passed} total values to LLM") - times.append(runtime) - print( - f"For {iterations} iterations, average runtime is {statistics.mean(times)} with stdev {statistics.stdev(times)}" - ) diff --git a/examples/benchmarks/without_blendsql_with_guidance.py b/examples/benchmarks/without_blendsql_with_guidance.py deleted file mode 100644 index c28ba2e1..00000000 --- a/examples/benchmarks/without_blendsql_with_guidance.py +++ /dev/null @@ -1,93 +0,0 @@ -import statistics -import time -import pandas as pd -import sqlite3 -import logging -from colorama import Fore - -from blendsql.models import AzureOpenaiLLM -from blendsql._constants import VALUE_BATCH_SIZE -from blendsql._programs import MapProgram -from blendsql.utils import fetch_from_hub -from constants import model - -if __name__ == "__main__": - start = time.time() - con = sqlite3.connect(fetch_from_hub("multi_table.db")) - iterations = 1 - times = [] - print(f"Using {model}...") - for _ in range(iterations): - # Select initial query results - sql = """ - SELECT * FROM constituents WHERE sector = 'Information Technology' - """ - question = "does this company manufacture cell phones?" - target_column = "Name" - df = pd.read_sql(sql, con) - - # Make our calls to the Model - blender = AzureOpenaiLLM("gpt-4") - values = df[target_column].unique().tolist() - split_results = [] - # Pass in batches - batch_size = VALUE_BATCH_SIZE - for i in range(0, len(values), batch_size): - res = MapProgram( - model=blender.model, - question=question, - sep=";", - values=values[i : i + batch_size], - example_outputs=None, - few_shot=False, - ) - _r = [i.strip() for i in res["result"].strip(";").split(";")] - expected_len = len(values[i : i + batch_size]) - if len(_r) != expected_len: - logging.debug( - Fore.YELLOW - + f"Mismatch between length of values and answers!\nvalues:{expected_len}, answers:{len(_r)}" - + Fore.RESET - ) - logging.debug(_r) - # Cut off, in case we over-predicted - _r = _r[:expected_len] - # Add, in case we under-predicted - while len(_r) < expected_len: - _r.append(None) - split_results.extend(_r) - values_passed = len(split_results) - df_as_dict = {target_column: [], question: []} - for idx, value in enumerate(values): - df_as_dict[target_column].append(value) - df_as_dict[question].append( - split_results[idx] if len(split_results) - 1 >= idx else None - ) - subtable = pd.DataFrame(df_as_dict) - - # Add new_table to original table - new_table = df.merge(subtable, how="left", on=target_column) - new_table.to_sql("modified_constituents", con, if_exists="replace", index=False) - # Now, new table has original columns + column with the name of the question we answered - sql = """ - SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)' - FROM account_history - WHERE Symbol IN - ( - SELECT Symbol FROM modified_constituents - WHERE sector = 'Information Technology' - AND "does this company manufacture cell phones?" = 1 - ) - AND lower(Action) like "%dividend%" - """ - answer = pd.read_sql(sql, con) - con.execute(f"DROP TABLE 'modified_constituents'") - runtime = time.time() - start - print( - f"Completed without_blendsql_with_guidance benchmark in {runtime} seconds" - ) - print(f"Passed {values_passed} total values to LLM") - times.append(runtime) - print( - f"For {iterations} iterations, average runtime is {statistics.mean(times)} with stdev {statistics.stdev(times) if len(times) > 1 else 0}" - ) diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 00000000..3c226a6c --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,4 @@ +{ + "venvPath": "/Users/parkerglenn/opt/miniconda3/envs/", + "venv": "/Users/parkerglenn/opt/miniconda3/envs/blendsql" +} diff --git a/run-debug.py b/run-debug.py new file mode 100644 index 00000000..0f625322 --- /dev/null +++ b/run-debug.py @@ -0,0 +1,124 @@ +from blendsql import blend, LLMJoin, LLMMap, LLMQA +from blendsql.db import SQLite +from blendsql.utils import fetch_from_hub + +# db = DuckDB.from_pandas( +# pd.DataFrame( +# { +# "name": ["John", "Parker"], +# "age": [12, 26] +# }, +# ) +# ) +# DuckDB.from_sqlite(fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db")) +# print() +# +# db.to_temp_table(df=pd.DataFrame( +# { +# "class": ["Boxing 101"], +# "num_enrolled": [23] +# } +# ), tablename="classes" +# ) +# print() + + +TEST_QUERIES = [ + """ + SELECT DISTINCT venue FROM w + WHERE city = 'sydney' AND {{ + LLMMap( + 'More than 30 total points?', + 'w::score' + ) + }} = TRUE + """, + # """ + # SELECT * FROM w + # WHERE city = {{ + # LLMQA( + # 'Which city is located 120 miles west of Sydney?', + # (SELECT * FROM documents), + # options='w::city' + # ) + # }} + # """, + # """ + # SELECT date, rival, score, documents.content AS "Team Description" FROM w + # JOIN {{ + # LLMJoin( + # left_on='documents::title', + # right_on='w::rival' + # ) + # }} + # """, + # """ + # {{ + # LLMQA( + # 'What is this table about?', + # (SELECT * FROM w;) + # ) + # }} + # """ +] + +# TEST_QUERIES = [ +# """ +# SELECT title, player FROM w JOIN {{ +# LLMJoin( +# left_on='documents::title', +# right_on='w::player' +# ) +# }} WHERE {{ +# LLMMap( +# 'How many years with the franchise?', +# 'w::career with the franchise' +# ) +# }} > 5 +# """ +# ] +if __name__ == "__main__": + """ + Without cached LLM response (10 runs): + before: 3.16 + after: 1.91 + With cached LLM response (100 runs): + before: 0.0175 + after: 0.0166 + With cached LLM response (30 runs): + with fuzzy join: 0.431 + without fuzzy join: 0.073 + Without cached LLM response (30 runs): + with fuzzy join: 0.286 + without fuzzy join: 318.85 + + DuckDB, with iterating over temp_tables (10 runs): + 0.5055 + Recreating database on reset: + 0.538 - 0.7 + + """ + db = SQLite( + fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db") + ) + ingredients = {LLMQA, LLMMap, LLMJoin} + # db = SQLite(fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db")) + from blendsql.models import OpenaiLLM + + # model = OpenaiLLM("gpt-3.5-turbo", caching=False) + times = [] + for _i in range(1): + for q in TEST_QUERIES: + # Make our smoothie - the executed BlendSQL script + smoothie = blend( + query=q, + db=db, + blender=OpenaiLLM("gpt-3.5-turbo", caching=False), + # blender=TransformersLLM("microsoft/Phi-3-mini-128k-instruct"), + # blender=OllamaLLM("phi3", caching=False), + verbose=True, + ingredients={LLMJoin.from_args(use_skrub_joiner=False), LLMMap, LLMQA}, + ) + times.append(smoothie.meta.process_time_seconds) + # print(smoothie.df) + print(f"Average time across {len(times)} runs: {sum(times) / len(times)}") diff --git a/run-nl-to-blendsql.py b/run-nl-to-blendsql.py new file mode 100644 index 00000000..93ede724 --- /dev/null +++ b/run-nl-to-blendsql.py @@ -0,0 +1,30 @@ +from blendsql.models import OllamaLLM, OpenaiLLM +from blendsql.db import SQLite +from blendsql.utils import fetch_from_hub +from blendsql.nl_to_blendsql import nl_to_blendsql, NLtoBlendSQLArgs, FewShot +from blendsql import LLMMap, LLMQA, blend + +if __name__ == "__main__": + model = OllamaLLM("phi3") + db = SQLite( + fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db") + ) + prediction = nl_to_blendsql( + question="Show me all info about the game played 120 miles west of Sydney", + db=db, + model=model, + ingredients={LLMQA, LLMMap}, + correction_model=OpenaiLLM("gpt-3.5-turbo"), + few_shot_examples=FewShot.hybridqa, + args=NLtoBlendSQLArgs( + use_tables=["w"], include_db_content_tables=["w"], use_bridge_encoder=True + ), + verbose=True, + ) + smoothie = blend( + query=prediction, + blender=OpenaiLLM("gpt-3.5-turbo"), + ingredients={LLMMap, LLMQA}, + db=db, + ) + print(smoothie.df) diff --git a/run-tests.py b/run-tests.py new file mode 100644 index 00000000..85f254fe --- /dev/null +++ b/run-tests.py @@ -0,0 +1,45 @@ +from blendsql import blend +from blendsql.db import DuckDB, SQLite +from blendsql.utils import fetch_from_hub +from tests.utils import ( + starts_with, + get_length, + select_first_sorted, + get_table_size, + select_first_option, + do_join, +) + +if __name__ == "__main__": + query = """ + {{ + select_first_option( + 'I hope this test works', + (SELECT * FROM transactions), + options=(SELECT DISTINCT merchant FROM transactions WHERE merchant = 'Paypal') + ) + }} + """ + ingredients = { + starts_with, + get_length, + select_first_sorted, + get_table_size, + select_first_option, + do_join, + } + sqlite_db = SQLite(fetch_from_hub("single_table.db")) + db = DuckDB.from_sqlite(fetch_from_hub("single_table.db")) + + smoothie = blend(query=query, db=db, ingredients=ingredients, verbose=True) + # sql = """ + # SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today''s Gain/Loss Percent" > 0.05) as w + # JOIN geographic ON w.Symbol = geographic.Symbol + # WHERE w.Symbol LIKE 'F%' + # AND w."Percent of Account" < 0.2 + # """ + # sql_df = db.execute_to_df(sql) + from tests.utils import assert_equality + + assert_equality(smoothie=smoothie, sql_df=sql_df, args=["Z"]) + print() From f8f92b9cb144eca6dda8a2fb421972791e462545 Mon Sep 17 00:00:00 2001 From: parkervg Date: Mon, 17 Jun 2024 12:02:15 -0700 Subject: [PATCH 02/10] LogitsGenerator -> ModelObj https://github.com/outlines-dev/outlines/pull/970 --- blendsql/ingredients/builtin/join/main.py | 6 +++--- blendsql/ingredients/builtin/map/main.py | 8 +++----- blendsql/ingredients/builtin/qa/main.py | 6 +++--- blendsql/ingredients/builtin/validate/main.py | 4 ++-- blendsql/models/__init__.py | 2 +- blendsql/models/_model.py | 16 +++++++--------- blendsql/models/local/_transformers.py | 5 ++--- blendsql/models/remote/_openai.py | 7 +++---- blendsql/nl_to_blendsql/nl_to_blendsql.py | 8 ++++---- ...aching-blendsql-via-in-context-learning.ipynb | 2 +- docs/reference/examples/vqa-ingredient.ipynb | 6 +++--- ...aching-blendsql-via-in-context-learning.ipynb | 2 +- examples/vqa-ingredient.ipynb | 6 +++--- tests/models/test_ollama.py | 4 ++-- 14 files changed, 38 insertions(+), 44 deletions(-) diff --git a/blendsql/ingredients/builtin/join/main.py b/blendsql/ingredients/builtin/join/main.py index 916d25f2..364c3c01 100644 --- a/blendsql/ingredients/builtin/join/main.py +++ b/blendsql/ingredients/builtin/join/main.py @@ -111,18 +111,18 @@ def __call__( if isinstance(model, LocalModel): generator = outlines.generate.regex( - model.logits_generator, regex(len(left_values)) + model.model_obj, regex(len(left_values)) ) else: if isinstance(model, OllamaLLM): # Handle call to ollama return return_ollama_response( - logits_generator=model.logits_generator, # type: ignore + model_obj=model.model_obj, # type: ignore prompt=prompt, max_tokens=max_tokens, temperature=0.0, ) - generator = outlines.generate.text(model.logits_generator) + generator = outlines.generate.text(model.model_obj) response: str = generator( prompt, diff --git a/blendsql/ingredients/builtin/map/main.py b/blendsql/ingredients/builtin/map/main.py index 3ce0a09d..58ed437a 100644 --- a/blendsql/ingredients/builtin/map/main.py +++ b/blendsql/ingredients/builtin/map/main.py @@ -116,19 +116,17 @@ def __call__( prompt += f"\nHere are some example outputs: {example_outputs}\n" prompt += "\nA:" if isinstance(model, LocalModel) and regex is not None: - generator = outlines.generate.regex( - model.logits_generator, regex(len(values)) - ) + generator = outlines.generate.regex(model.model_obj, regex(len(values))) else: if isinstance(model, OllamaLLM): # Handle call to ollama return return_ollama_response( - logits_generator=model.logits_generator, # type: ignore + model_obj=model.model_obj, # type: ignore prompt=prompt, max_tokens=max_tokens, temperature=0.0, ) - generator = outlines.generate.text(model.logits_generator) + generator = outlines.generate.text(model.model_obj) return (generator(prompt, max_tokens=max_tokens, stop_at="\n"), prompt) diff --git a/blendsql/ingredients/builtin/qa/main.py b/blendsql/ingredients/builtin/qa/main.py index 0542b1de..e295d4cf 100644 --- a/blendsql/ingredients/builtin/qa/main.py +++ b/blendsql/ingredients/builtin/qa/main.py @@ -61,7 +61,7 @@ def __call__( "Can't use `options` argument in LLMQA with an Ollama model!" ) generator = outlines.generate.choice( - model.logits_generator, [re.escape(str(i)) for i in options] + model.model_obj, [re.escape(str(i)) for i in options] ) _response: str = generator(prompt, max_tokens=max_tokens) # Map from modified options to original, as they appear in DB @@ -70,12 +70,12 @@ def __call__( if isinstance(model, OllamaLLM): # Handle call to ollama return return_ollama_response( - logits_generator=model.logits_generator, # type: ignore + model_obj=model.model_obj, # type: ignore prompt=prompt, max_tokens=max_tokens, temperature=0.0, ) - generator = outlines.generate.text(model.logits_generator) + generator = outlines.generate.text(model.model_obj) response: str = generator(prompt, max_tokens=max_tokens) return (response, prompt) diff --git a/blendsql/ingredients/builtin/validate/main.py b/blendsql/ingredients/builtin/validate/main.py index 1b6db115..4b9b82c8 100644 --- a/blendsql/ingredients/builtin/validate/main.py +++ b/blendsql/ingredients/builtin/validate/main.py @@ -23,8 +23,8 @@ def __call__( if table_title: prompt += f"\nTable Description: {table_title}" prompt += f"\n{serialized_db}\n\nAnswer:" - generator = outlines.generate.choice(model.logits_generator, ["true", "false"]) - response: str = generator(prompt) + generator = outlines.generate.choice(model.model_obj, ["true", "false"]) + response: str = generator(prompt) # type: ignore return (response, prompt) diff --git a/blendsql/models/__init__.py b/blendsql/models/__init__.py index 894b28ad..b5a7f182 100644 --- a/blendsql/models/__init__.py +++ b/blendsql/models/__init__.py @@ -2,4 +2,4 @@ from .local._llama_cpp import LlamaCppLLM from .remote._ollama import OllamaLLM from .remote._openai import OpenaiLLM, AzureOpenaiLLM -from ._model import Model, RemoteModel, LocalModel +from ._model import Model, RemoteModel, LocalModel, ModelObj diff --git a/blendsql/models/_model.py b/blendsql/models/_model.py index 50582c30..f77ee1a9 100644 --- a/blendsql/models/_model.py +++ b/blendsql/models/_model.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Type, Dict +from typing import Any, List, Optional, Type, Dict, TypeVar import pandas as pd from attr import attrib, attrs from pathlib import Path @@ -11,7 +11,6 @@ import hashlib from abc import abstractmethod from functools import cached_property -from outlines.models import LogitsGenerator from .._logger import logger from .._program import Program, program_to_str @@ -19,6 +18,7 @@ from ..db.utils import truncate_df_content CONTEXT_TRUNCATION_LIMIT = 100 +ModelObj = TypeVar("ModelObj") class TokenTimer(threading.Thread): @@ -49,7 +49,7 @@ class Model: env: str = attrib(default=".") caching: bool = attrib(default=True) - logits_generator: LogitsGenerator = attrib(init=False) + model_obj: ModelObj = attrib(init=False) prompts: List[dict] = attrib(init=False) prompt_tokens: int = attrib(init=False) completion_tokens: int = attrib(init=False) @@ -116,8 +116,6 @@ def predict(self, program: Type[Program], **kwargs) -> str: self.prompts.insert(-1, self.format_prompt(response, **kwargs)) return response # Modify fields used for tracking Model usage - response: str - prompt: str response, prompt = program(model=self, **kwargs) self.prompts.insert(-1, self.format_prompt(response, **kwargs)) self.num_calls += 1 @@ -156,8 +154,8 @@ def _create_key(self, program: Type[Program], **kwargs) -> str: return hasher.hexdigest() @cached_property - def logits_generator(self) -> LogitsGenerator: - """Allows for lazy loading of underlying model.""" + def model_obj(self) -> ModelObj: + """Allows for lazy loading of underlying model weights.""" return self._load_model() @staticmethod @@ -183,9 +181,9 @@ def _setup(self, *args, **kwargs) -> None: ... @abstractmethod - def _load_model(self, *args, **kwargs) -> Any: + def _load_model(self, *args, **kwargs) -> ModelObj: """Logic for instantiating the model class goes here. - Will most likely be an outlines.LogitsGenerator object, + Will most likely be an outlines model object, but in some cases (like OllamaLLM) we make an exception. """ ... diff --git a/blendsql/models/local/_transformers.py b/blendsql/models/local/_transformers.py index ddb6a1da..8dd08ab7 100644 --- a/blendsql/models/local/_transformers.py +++ b/blendsql/models/local/_transformers.py @@ -1,8 +1,7 @@ import importlib.util -from outlines.models import LogitsGenerator from outlines.models.transformers import transformers -from .._model import LocalModel +from .._model import LocalModel, ModelObj DEFAULT_KWARGS = {"do_sample": True, "temperature": 0.0, "top_p": 1.0} @@ -50,7 +49,7 @@ def __init__(self, model_name_or_path: str, caching: bool = True, **kwargs): **kwargs ) - def _load_model(self) -> LogitsGenerator: + def _load_model(self) -> ModelObj: # https://huggingface.co/blog/how-to-generate return transformers( self.model_name_or_path, diff --git a/blendsql/models/remote/_openai.py b/blendsql/models/remote/_openai.py index 20dc53eb..a9794f47 100644 --- a/blendsql/models/remote/_openai.py +++ b/blendsql/models/remote/_openai.py @@ -1,9 +1,8 @@ import os import importlib.util -from outlines.models import LogitsGenerator from outlines.models.openai import openai, azure_openai, OpenAIConfig -from .._model import RemoteModel +from .._model import RemoteModel, ModelObj from typing import Optional DEFAULT_CONFIG = OpenAIConfig(temperature=0.0) @@ -104,7 +103,7 @@ def __init__( **kwargs ) - def _load_model(self, config: OpenAIConfig) -> LogitsGenerator: + def _load_model(self, config: OpenAIConfig) -> ModelObj: return azure_openai( self.model_name_or_path, config=config, @@ -171,7 +170,7 @@ def __init__( **kwargs ) - def _load_model(self, config: OpenAIConfig) -> LogitsGenerator: + def _load_model(self, config: OpenAIConfig) -> ModelObj: return openai( self.model_name_or_path, config=config, api_key=os.getenv("OPENAI_API_KEY") ) # type: ignore diff --git a/blendsql/nl_to_blendsql/nl_to_blendsql.py b/blendsql/nl_to_blendsql/nl_to_blendsql.py index 1bf4c75b..0806293a 100644 --- a/blendsql/nl_to_blendsql/nl_to_blendsql.py +++ b/blendsql/nl_to_blendsql/nl_to_blendsql.py @@ -60,12 +60,12 @@ def __call__( if isinstance(model, OllamaLLM): # Handle call to ollama return return_ollama_response( - logits_generator=model.logits_generator, + model_obj=model.model_obj, prompt=prompt, stop=PARSER_STOP_TOKENS, temperature=0.0, ) - generator = outlines.generate.text(model.logits_generator) + generator = outlines.generate.text(model.model_obj) response: str = generator(prompt, stop_at=PARSER_STOP_TOKENS) return (response, prompt) @@ -93,9 +93,9 @@ def __call__( prompt += f"BlendSQL:\n" prompt += partial_completion generator = outlines.generate.choice( - model.logits_generator, [re.escape(str(i)) for i in candidates] + model.model_obj, [re.escape(str(i)) for i in candidates] ) - response: str = generator(prompt) + response: str = generator(prompt) # type: ignore return (response, prompt) diff --git a/docs/reference/examples/teaching-blendsql-via-in-context-learning.ipynb b/docs/reference/examples/teaching-blendsql-via-in-context-learning.ipynb index 82df3c3a..834ed588 100644 --- a/docs/reference/examples/teaching-blendsql-via-in-context-learning.ipynb +++ b/docs/reference/examples/teaching-blendsql-via-in-context-learning.ipynb @@ -172,7 +172,7 @@ " prompt += f\"{serialized_db}\\n\\n\"\n", " prompt += f\"Question: {question}\\n\"\n", " prompt += f\"BlendSQL: \"\n", - " generator = outlines.generate.text(model.logits_generator)\n", + " generator = outlines.generate.text(model.model_obj)\n", " result = generator(prompt)\n", " return (result, prompt)" ], diff --git a/docs/reference/examples/vqa-ingredient.ipynb b/docs/reference/examples/vqa-ingredient.ipynb index 1aa72cd6..6300a349 100644 --- a/docs/reference/examples/vqa-ingredient.ipynb +++ b/docs/reference/examples/vqa-ingredient.ipynb @@ -26,7 +26,7 @@ "source": [ "from typing import List\n", "from blendsql import blend\n", - "from blendsql.models import TransformersLLM\n", + "from blendsql.models import TransformersLLM, ModelObj\n", "from blendsql.ingredients import MapIngredient, IngredientException\n", "from blendsql.utils import fetch_from_hub\n", "from blendsql.db import SQLite" @@ -96,12 +96,12 @@ "\n", "class VQAModel(TransformersLLM):\n", " \n", - " def _load_model(self):\n", + " def _load_model(self) -> ModelObj:\n", " return pipeline(\"image-to-text\", model=self.model_name_or_path)\n", "\n", " def predict(self, question: str, img_bytes: List[bytes]) -> str:\n", " prompt = f\"USER: \\n{question}\"\n", - " model_output = self.logits_generator(\n", + " model_output = self.model_obj(\n", " images=[\n", " Image.open(BytesIO(value)) for value in img_bytes\n", " ],\n", diff --git a/examples/teaching-blendsql-via-in-context-learning.ipynb b/examples/teaching-blendsql-via-in-context-learning.ipynb index 82df3c3a..834ed588 100644 --- a/examples/teaching-blendsql-via-in-context-learning.ipynb +++ b/examples/teaching-blendsql-via-in-context-learning.ipynb @@ -172,7 +172,7 @@ " prompt += f\"{serialized_db}\\n\\n\"\n", " prompt += f\"Question: {question}\\n\"\n", " prompt += f\"BlendSQL: \"\n", - " generator = outlines.generate.text(model.logits_generator)\n", + " generator = outlines.generate.text(model.model_obj)\n", " result = generator(prompt)\n", " return (result, prompt)" ], diff --git a/examples/vqa-ingredient.ipynb b/examples/vqa-ingredient.ipynb index 1aa72cd6..6300a349 100644 --- a/examples/vqa-ingredient.ipynb +++ b/examples/vqa-ingredient.ipynb @@ -26,7 +26,7 @@ "source": [ "from typing import List\n", "from blendsql import blend\n", - "from blendsql.models import TransformersLLM\n", + "from blendsql.models import TransformersLLM, ModelObj\n", "from blendsql.ingredients import MapIngredient, IngredientException\n", "from blendsql.utils import fetch_from_hub\n", "from blendsql.db import SQLite" @@ -96,12 +96,12 @@ "\n", "class VQAModel(TransformersLLM):\n", " \n", - " def _load_model(self):\n", + " def _load_model(self) -> ModelObj:\n", " return pipeline(\"image-to-text\", model=self.model_name_or_path)\n", "\n", " def predict(self, question: str, img_bytes: List[bytes]) -> str:\n", " prompt = f\"USER: \\n{question}\"\n", - " model_output = self.logits_generator(\n", + " model_output = self.model_obj(\n", " images=[\n", " Image.open(BytesIO(value)) for value in img_bytes\n", " ],\n", diff --git a/tests/models/test_ollama.py b/tests/models/test_ollama.py index 7f2af59a..a2c91956 100644 --- a/tests/models/test_ollama.py +++ b/tests/models/test_ollama.py @@ -27,7 +27,7 @@ def ingredients() -> set: def test_ollama_basic_llmqa(db, ingredients): try: model = OllamaLLM(TEST_OLLAMA_LLM, caching=False) - model.logits_generator(messages=[{"role": "user", "content": "hello"}]) + model.model_obj(messages=[{"role": "user", "content": "hello"}]) except httpx.ConnectError: pytest.skip("Ollama server is not running, skipping this test") blendsql = """ @@ -67,7 +67,7 @@ def test_ollama_raise_exception(db, ingredients): def test_ollama_join(db, ingredients): try: model = OllamaLLM(TEST_OLLAMA_LLM, caching=False) - model.logits_generator(messages=[{"role": "user", "content": "hello"}]) + model.model_obj(messages=[{"role": "user", "content": "hello"}]) except httpx.ConnectError: pytest.skip("Ollama server is not running, skipping this test") res = blend( From 152fb5354d79754b61e1275b766c9f1c611edcd7 Mon Sep 17 00:00:00 2001 From: parkervg Date: Mon, 17 Jun 2024 12:02:31 -0700 Subject: [PATCH 03/10] LogitsGenerator -> ModelObj https://github.com/outlines-dev/outlines/pull/970 --- blendsql/models/local/_llama_cpp.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/blendsql/models/local/_llama_cpp.py b/blendsql/models/local/_llama_cpp.py index 372e21ea..874271ff 100644 --- a/blendsql/models/local/_llama_cpp.py +++ b/blendsql/models/local/_llama_cpp.py @@ -1,8 +1,7 @@ import importlib.util -from outlines.models import LogitsGenerator from outlines.models.llamacpp import llamacpp -from .._model import LocalModel +from .._model import LocalModel, ModelObj from typing import Optional _has_llama_cpp = importlib.util.find_spec("llama_cpp") is not None @@ -56,7 +55,7 @@ def __init__( caching=caching, ) - def _load_model(self, filename: str) -> LogitsGenerator: + def _load_model(self, filename: str) -> ModelObj: return llamacpp( self.model_name_or_path, filename=filename, From da799ab4d8702d2d9e055adc6fc207a58516a1ae Mon Sep 17 00:00:00 2001 From: parkervg Date: Mon, 17 Jun 2024 12:02:39 -0700 Subject: [PATCH 04/10] LogitsGenerator -> ModelObj https://github.com/outlines-dev/outlines/pull/970 --- docs/reference/ingredients/creating-custom-ingredients.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/reference/ingredients/creating-custom-ingredients.md b/docs/reference/ingredients/creating-custom-ingredients.md index 144cd2fd..2dfe6e92 100644 --- a/docs/reference/ingredients/creating-custom-ingredients.md +++ b/docs/reference/ingredients/creating-custom-ingredients.md @@ -31,11 +31,11 @@ class SummaryProgram(Program): """Program to call Model and return summary of the passed table. """ - def __call__(self, model: Model, serialized_db: str): + def __call__(self, model: Model, serialized_db: str): prompt = f"Summarize the table below.\n\n{serialized_db}\n" # Below we follow the outlines pattern for unconstrained text generation # https://github.com/outlines-dev/outlines - generator = outlines.generate.text(model.logits_generator) + generator = outlines.generate.text(model.model_obj) # Finally, return (response, prompt) tuple # Returning the prompt here allows the underlying BlendSQL classes to track token usage return (generator(prompt), prompt) From e01170b8907256c4956cdd97687b4ec205450018 Mon Sep 17 00:00:00 2001 From: parkervg Date: Mon, 17 Jun 2024 12:02:59 -0700 Subject: [PATCH 05/10] More mypy fixes Mostly changing some internal variable names --- blendsql/_program.py | 12 +++++------- blendsql/db/_duckdb.py | 2 +- blendsql/db/_pandas.py | 4 ++-- blendsql/db/bridge_content_encoder.py | 1 + blendsql/ingredients/ingredient.py | 14 +++++++++----- 5 files changed, 18 insertions(+), 15 deletions(-) diff --git a/blendsql/_program.py b/blendsql/_program.py index 88f06de2..95a57d4e 100644 --- a/blendsql/_program.py +++ b/blendsql/_program.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Tuple +from typing import Tuple, Type import inspect import ast import textwrap @@ -36,7 +36,7 @@ def __call__(self, model: Model, context: pd.DataFrame) -> Tuple[str, str]: prompt = f"Summarize the following table. {context.to_string()}" # Below we follow the outlines pattern for unconstrained text generation # https://github.com/outlines-dev/outlines - generator = outlines.generate.text(model.logits_generator) + generator = outlines.generate.text(model.model_obj) response: str = generator(prompt) # Finally, return (response, prompt) tuple # Returning the prompt here allows the underlying BlendSQL classes to track token usage @@ -64,9 +64,7 @@ def __call__(self, model: Model, *args, **kwargs) -> Tuple[str, str]: ... -def return_ollama_response( - logits_generator: partial, prompt, **kwargs -) -> Tuple[str, str]: +def return_ollama_response(model_obj: partial, prompt, **kwargs) -> Tuple[str, str]: """Helper function to work with Ollama models, since they're not recognized in the Outlines ecosystem. """ @@ -76,7 +74,7 @@ def return_ollama_response( if options.get("temperature") is None: options["temperature"] = 0.0 stream = logger.level <= logging.DEBUG - response = logits_generator( + response = model_obj( messages=[{"role": "user", "content": prompt}], options=options, stream=stream, @@ -95,7 +93,7 @@ def return_ollama_response( return (response["message"]["content"], prompt) -def program_to_str(program: Program) -> str: +def program_to_str(program: Type[Program]) -> str: """Create a string representation of a program. It is slightly tricky, since in addition to getting the code content, we need to 1) identify all global variables referenced within a function, and then diff --git a/blendsql/db/_duckdb.py b/blendsql/db/_duckdb.py index 2aed5118..e14d48e5 100644 --- a/blendsql/db/_duckdb.py +++ b/blendsql/db/_duckdb.py @@ -100,7 +100,7 @@ def from_sqlite(cls, db_url: str): import duckdb con = duckdb.connect(database=":memory:") - db_url = Path(db_url).resolve() + db_url = str(Path(db_url).resolve()) con.sql("INSTALL sqlite;") con.sql("LOAD sqlite;") con.sql(f"ATTACH '{db_url}' AS sqlite_db (TYPE sqlite);") diff --git a/blendsql/db/_pandas.py b/blendsql/db/_pandas.py index 402f105b..62aa4230 100644 --- a/blendsql/db/_pandas.py +++ b/blendsql/db/_pandas.py @@ -1,11 +1,11 @@ -from typing import Dict, Optional, Union +from typing import Dict, Union import pandas as pd from ._duckdb import DuckDB def Pandas( - data: Union[Dict[str, pd.DataFrame], pd.DataFrame], tablename: Optional[str] = "w" + data: Union[Dict[str, pd.DataFrame], pd.DataFrame], tablename: str = "w" ) -> DuckDB: """This is just a wrapper over the `DuckDB.from_pandas` class method. Makes it more intuitive to do a `from blendsql.db import Pandas`, for those diff --git a/blendsql/db/bridge_content_encoder.py b/blendsql/db/bridge_content_encoder.py index 0cb1b51d..5352e573 100644 --- a/blendsql/db/bridge_content_encoder.py +++ b/blendsql/db/bridge_content_encoder.py @@ -1,3 +1,4 @@ +# mypy: ignore-errors """ Copyright (c) 2020, salesforce.com, inc. All rights reserved. diff --git a/blendsql/ingredients/ingredient.py b/blendsql/ingredients/ingredient.py index f7fe9103..48c08fea 100644 --- a/blendsql/ingredients/ingredient.py +++ b/blendsql/ingredients/ingredient.py @@ -282,14 +282,18 @@ def __call__( # length(new inner) = length(inner) - #matched by fuzzy join _outer = res["out"][res["in"].isnull()].to_list() # length(new _outer) = length(_outer) - #matched by fuzzy join - _mapping = res.dropna(subset=["in"]).set_index("out")["in"].to_dict() + _skrub_mapping = ( + res.dropna(subset=["in"]).set_index("out")["in"].to_dict() + ) logger.debug( Fore.YELLOW + "Made the following alignment with `skrub.Joiner`:" + Fore.RESET ) - logger.debug(Fore.YELLOW + json.dumps(_mapping, indent=4) + Fore.RESET) - mapping = mapping | _mapping + logger.debug( + Fore.YELLOW + json.dumps(_skrub_mapping, indent=4) + Fore.RESET + ) + mapping = mapping | _skrub_mapping # order by length is still preserved regardless of using fuzzy join, so after initial matching and possible fuzzy join matching # This is because the lengths of each list will decrease at the same rate, so whichever list was larger at the beginning, # will be larger here at the end. @@ -316,8 +320,8 @@ def __call__( ) kwargs[IngredientKwarg.QUESTION] = question - _mapping: Dict[str, str] = self._run(*args, **kwargs) - mapping = mapping | _mapping + _predicted_mapping: Dict[str, str] = self._run(*args, **kwargs) + mapping = mapping | _predicted_mapping # Using mapped left/right values, create intermediary mapping table temp_join_tablename = get_temp_session_table(str(uuid.uuid4())[:4]) # Below, we check to see if 'swapped' is True From 2fd9e0ab73717635a0bcd35610984df21f4e8606 Mon Sep 17 00:00:00 2001 From: parkervg Date: Mon, 17 Jun 2024 13:58:37 -0700 Subject: [PATCH 06/10] `generate` directory, using `singledispatch` to handle routing to Ollama/other lms --- blendsql/_program.py | 34 +---------- blendsql/db/_duckdb.py | 2 +- blendsql/generate/__init__.py | 3 + blendsql/generate/choice.py | 22 +++++++ blendsql/generate/regex.py | 28 +++++++++ blendsql/generate/text.py | 56 +++++++++++++++++ blendsql/ingredients/builtin/join/main.py | 34 ++++------- blendsql/ingredients/builtin/map/main.py | 23 +++---- blendsql/ingredients/builtin/qa/main.py | 23 +++---- blendsql/ingredients/builtin/validate/main.py | 8 +-- blendsql/ingredients/ingredient.py | 2 +- blendsql/models/_model.py | 6 +- blendsql/nl_to_blendsql/nl_to_blendsql.py | 25 ++++---- run-debug.py | 60 +++++++++---------- run-nl-to-blendsql.py | 50 +++++++++------- 15 files changed, 219 insertions(+), 157 deletions(-) create mode 100644 blendsql/generate/__init__.py create mode 100644 blendsql/generate/choice.py create mode 100644 blendsql/generate/regex.py create mode 100644 blendsql/generate/text.py diff --git a/blendsql/_program.py b/blendsql/_program.py index 95a57d4e..217bd919 100644 --- a/blendsql/_program.py +++ b/blendsql/_program.py @@ -3,15 +3,12 @@ import inspect import ast import textwrap -import logging -from colorama import Fore from abc import abstractmethod -from functools import partial + from typing import TYPE_CHECKING if TYPE_CHECKING: from .models import Model -from ._logger import logger class Program: @@ -64,35 +61,6 @@ def __call__(self, model: Model, *args, **kwargs) -> Tuple[str, str]: ... -def return_ollama_response(model_obj: partial, prompt, **kwargs) -> Tuple[str, str]: - """Helper function to work with Ollama models, - since they're not recognized in the Outlines ecosystem. - """ - from ollama import Options - - options = Options(**kwargs) - if options.get("temperature") is None: - options["temperature"] = 0.0 - stream = logger.level <= logging.DEBUG - response = model_obj( - messages=[{"role": "user", "content": prompt}], - options=options, - stream=stream, - ) - if stream: - chunked_res = [] - for chunk in response: - chunked_res.append(chunk["message"]["content"]) - print( - Fore.CYAN + chunk["message"]["content"] + Fore.RESET, - end="", - flush=True, - ) - print("\n") - return ("".join(chunked_res), prompt) - return (response["message"]["content"], prompt) - - def program_to_str(program: Type[Program]) -> str: """Create a string representation of a program. It is slightly tricky, since in addition to getting the code content, we need to diff --git a/blendsql/db/_duckdb.py b/blendsql/db/_duckdb.py index e14d48e5..1f7f2df0 100644 --- a/blendsql/db/_duckdb.py +++ b/blendsql/db/_duckdb.py @@ -127,7 +127,7 @@ def sqlglot_schema(self) -> dict: > {"x": {"A": "INT", "B": "INT", "C": "INT", "D": "INT", "Z": "STRING"}} ``` """ - schema = {} + schema: Dict[str, dict] = {} for tablename in self.tables(): schema[f'"{double_quote_escape(tablename)}"'] = {} for column_name, column_type in self.con.execute( diff --git a/blendsql/generate/__init__.py b/blendsql/generate/__init__.py new file mode 100644 index 00000000..b3eb688b --- /dev/null +++ b/blendsql/generate/__init__.py @@ -0,0 +1,3 @@ +from .text import text +from .regex import regex +from .choice import choice diff --git a/blendsql/generate/choice.py b/blendsql/generate/choice.py new file mode 100644 index 00000000..0b56b422 --- /dev/null +++ b/blendsql/generate/choice.py @@ -0,0 +1,22 @@ +from functools import singledispatch +from typing import List +import outlines + +from ..models import Model, OllamaLLM + + +@singledispatch +def choice(model: Model, prompt: str, choices: List[str], **kwargs) -> str: + generator = outlines.generate.choice(model.model_obj, choices=choices) + return generator(prompt) + + +@choice.register(OllamaLLM) +def choice_ollama(*_, **__) -> str: + """Helper function to work with Ollama models, + since they're not recognized in the Outlines ecosystem. + """ + raise NotImplementedError( + "Cannot use choice generation with an Ollama model" + + "due to the limitations of the Ollama API." + ) diff --git a/blendsql/generate/regex.py b/blendsql/generate/regex.py new file mode 100644 index 00000000..20c0bc90 --- /dev/null +++ b/blendsql/generate/regex.py @@ -0,0 +1,28 @@ +from functools import singledispatch +from typing import Optional, List, Union +import outlines + +from ..models import Model, OllamaLLM + + +@singledispatch +def regex( + model: Model, + prompt: str, + pattern: str, + max_tokens: Optional[int] = None, + stop_at: Optional[Union[List[str], str]] = None, +) -> str: + generator = outlines.generate.regex(model.model_obj, regex_str=pattern) + return generator(prompt, max_tokens=max_tokens, stop_at=stop_at) + + +@regex.register(OllamaLLM) +def regex_ollama(*_, **__) -> str: + """Helper function to work with Ollama models, + since they're not recognized in the Outlines ecosystem. + """ + raise NotImplementedError( + "Cannot use regex-structured generation with an Ollama model" + + "due to the limitations of the Ollama API." + ) diff --git a/blendsql/generate/text.py b/blendsql/generate/text.py new file mode 100644 index 00000000..b0851119 --- /dev/null +++ b/blendsql/generate/text.py @@ -0,0 +1,56 @@ +from functools import singledispatch +import logging +from colorama import Fore +from typing import Optional, List, Union +import outlines + +from .._logger import logger +from ..models import Model, OllamaLLM + + +@singledispatch +def text( + model: Model, + prompt: str, + max_tokens: Optional[int] = None, + stop_at: Optional[Union[List[str], str]] = None, + **kwargs +) -> str: + generator = outlines.generate.text(model.model_obj) + return generator(prompt, max_tokens=max_tokens, stop_at=stop_at) + + +@text.register(OllamaLLM) +def text_ollama(model: OllamaLLM, prompt, **kwargs) -> str: + """Helper function to work with Ollama models, + since they're not recognized in the Outlines ecosystem. + """ + from ollama import Options + + # Turn outlines kwargs into Ollama + if "stop_at" in kwargs: + stop_at = kwargs.pop("stop_at") + if isinstance(stop_at, str): + stop_at = [stop_at] + kwargs["stop"] = stop_at + options = Options(**kwargs) + if options.get("temperature") is None: + options["temperature"] = 0.0 + stream = logger.level <= logging.DEBUG + response = model.model_obj( + messages=[{"role": "user", "content": prompt}], + options=options, + stream=stream, + ) # type: ignore + if stream: + chunked_res = [] + for chunk in response: + chunked_res.append(chunk["message"]["content"]) + print( + Fore.CYAN + chunk["message"]["content"] + Fore.RESET, + end="", + flush=True, + ) + print("\n") + return "".join(chunked_res) + return response["message"]["content"] diff --git a/blendsql/ingredients/builtin/join/main.py b/blendsql/ingredients/builtin/join/main.py index 364c3c01..bbfb0db8 100644 --- a/blendsql/ingredients/builtin/join/main.py +++ b/blendsql/ingredients/builtin/join/main.py @@ -1,14 +1,14 @@ from typing import List, Optional, Tuple -import outlines import re from colorama import Fore -from blendsql.models import Model, LocalModel, OllamaLLM -from blendsql._program import Program, return_ollama_response +from blendsql.models import Model, LocalModel +from blendsql._program import Program from blendsql._logger import logger from blendsql import _constants as CONST from blendsql.ingredients.ingredient import JoinIngredient from blendsql.utils import newline_dedent +from blendsql import generate class JoinProgram(Program): @@ -110,25 +110,17 @@ def __call__( ) if isinstance(model, LocalModel): - generator = outlines.generate.regex( - model.model_obj, regex(len(left_values)) + response = generate.regex( + model, + prompt=prompt, + pattern=regex(len(left_values)), + max_tokens=max_tokens, + stop_at=["---"], ) else: - if isinstance(model, OllamaLLM): - # Handle call to ollama - return return_ollama_response( - model_obj=model.model_obj, # type: ignore - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - ) - generator = outlines.generate.text(model.model_obj) - - response: str = generator( - prompt, - max_tokens=max_tokens, - stop_at=["---"], - ) + response = generate.text( + model, prompt=prompt, max_tokens=max_tokens, stop_at=["---"] + ) logger.debug(Fore.CYAN + prompt + Fore.RESET) logger.debug(Fore.LIGHTCYAN_EX + response + Fore.RESET) return (response, prompt) @@ -158,7 +150,7 @@ def run( join_criteria=question, **kwargs, ) - + # Post-process language model response _result = result.split("\n") mapping: dict = {} for item in _result: diff --git a/blendsql/ingredients/builtin/map/main.py b/blendsql/ingredients/builtin/map/main.py index 58ed437a..326c1adb 100644 --- a/blendsql/ingredients/builtin/map/main.py +++ b/blendsql/ingredients/builtin/map/main.py @@ -5,15 +5,15 @@ import pandas as pd from colorama import Fore from tqdm import tqdm -import outlines from blendsql.utils import newline_dedent from blendsql._logger import logger -from blendsql.models import Model, LocalModel, RemoteModel, OpenaiLLM, OllamaLLM +from blendsql.models import Model, LocalModel, RemoteModel, OpenaiLLM from ast import literal_eval from blendsql import _constants as CONST from blendsql.ingredients.ingredient import MapIngredient -from blendsql._program import Program, return_ollama_response +from blendsql._program import Program +from blendsql import generate class MapProgram(Program): @@ -116,18 +116,12 @@ def __call__( prompt += f"\nHere are some example outputs: {example_outputs}\n" prompt += "\nA:" if isinstance(model, LocalModel) and regex is not None: - generator = outlines.generate.regex(model.model_obj, regex(len(values))) + response = generate.regex(model, prompt=prompt, pattern=regex(len(values))) else: - if isinstance(model, OllamaLLM): - # Handle call to ollama - return return_ollama_response( - model_obj=model.model_obj, # type: ignore - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - ) - generator = outlines.generate.text(model.model_obj) - return (generator(prompt, max_tokens=max_tokens, stop_at="\n"), prompt) + response = generate.text( + model, prompt=prompt, max_tokens=max_tokens, stop_at="\n" + ) + return (response, prompt) class LLMMap(MapIngredient): @@ -212,6 +206,7 @@ def run( max_tokens=max_tokens, **kwargs, ) + # Post-process language model response _r = [ i.strip() for i in result.strip(CONST.DEFAULT_ANS_SEP).split( diff --git a/blendsql/ingredients/builtin/qa/main.py b/blendsql/ingredients/builtin/qa/main.py index e295d4cf..da620403 100644 --- a/blendsql/ingredients/builtin/qa/main.py +++ b/blendsql/ingredients/builtin/qa/main.py @@ -1,14 +1,14 @@ import copy from typing import Dict, Union, Optional, Set, Tuple import pandas as pd -import outlines import re from blendsql.models import Model, OllamaLLM from blendsql._exceptions import InvalidBlendSQL -from blendsql._program import Program, return_ollama_response +from blendsql._program import Program from blendsql.ingredients.ingredient import QAIngredient from blendsql.db.utils import single_quote_escape +from blendsql import generate class QAProgram(Program): @@ -60,23 +60,15 @@ def __call__( raise InvalidBlendSQL( "Can't use `options` argument in LLMQA with an Ollama model!" ) - generator = outlines.generate.choice( - model.model_obj, [re.escape(str(i)) for i in options] + _response = generate.choice( + model, prompt=prompt, choices=[re.escape(str(i)) for i in options] ) - _response: str = generator(prompt, max_tokens=max_tokens) # Map from modified options to original, as they appear in DB response: str = modified_option_to_original.get(_response, _response) else: - if isinstance(model, OllamaLLM): - # Handle call to ollama - return return_ollama_response( - model_obj=model.model_obj, # type: ignore - prompt=prompt, - max_tokens=max_tokens, - temperature=0.0, - ) - generator = outlines.generate.text(model.model_obj) - response: str = generator(prompt, max_tokens=max_tokens) + response = generate.text( + model, prompt=prompt, max_tokens=max_tokens, stop_at="\n" + ) return (response, prompt) @@ -111,4 +103,5 @@ def run( table_title=None, **kwargs, ) + # Post-process language model response return "'{}'".format(single_quote_escape(result.strip().lower())) diff --git a/blendsql/ingredients/builtin/validate/main.py b/blendsql/ingredients/builtin/validate/main.py index 4b9b82c8..44d0e3f8 100644 --- a/blendsql/ingredients/builtin/validate/main.py +++ b/blendsql/ingredients/builtin/validate/main.py @@ -1,10 +1,10 @@ from typing import Dict, Union, Optional, Tuple import pandas as pd -import outlines from blendsql.models import Model from blendsql._program import Program from blendsql.ingredients.ingredient import QAIngredient +from blendsql import generate class ValidateProgram(Program): @@ -23,8 +23,7 @@ def __call__( if table_title: prompt += f"\nTable Description: {table_title}" prompt += f"\n{serialized_db}\n\nAnswer:" - generator = outlines.generate.choice(model.model_obj, ["true", "false"]) - response: str = generator(prompt) # type: ignore + response = generate.choice(model, choices=["true", "false"]) return (response, prompt) @@ -49,4 +48,5 @@ def run( table_title=None, **kwargs, ) - return int(response == "true") + # Post-process language model response + return bool(response == "true") diff --git a/blendsql/ingredients/ingredient.py b/blendsql/ingredients/ingredient.py index 48c08fea..2cd716a5 100644 --- a/blendsql/ingredients/ingredient.py +++ b/blendsql/ingredients/ingredient.py @@ -405,7 +405,7 @@ def __call__( ) except ValueError: unpacked_options = options.split(";") - unpacked_options = set(unpacked_options) + unpacked_options: Set[str] = set(unpacked_options) self.num_values_passed += len(subtable) if subtable is not None else 0 kwargs[IngredientKwarg.OPTIONS] = unpacked_options kwargs[IngredientKwarg.CONTEXT] = subtable diff --git a/blendsql/models/_model.py b/blendsql/models/_model.py index f77ee1a9..f34f6e33 100644 --- a/blendsql/models/_model.py +++ b/blendsql/models/_model.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Type, Dict, TypeVar +from typing import Any, List, Optional, Generic, Type, Dict, TypeVar import pandas as pd from attr import attrib, attrs from pathlib import Path @@ -49,7 +49,7 @@ class Model: env: str = attrib(default=".") caching: bool = attrib(default=True) - model_obj: ModelObj = attrib(init=False) + model_obj: Generic[ModelObj] = attrib(init=False) prompts: List[dict] = attrib(init=False) prompt_tokens: int = attrib(init=False) completion_tokens: int = attrib(init=False) @@ -116,6 +116,8 @@ def predict(self, program: Type[Program], **kwargs) -> str: self.prompts.insert(-1, self.format_prompt(response, **kwargs)) return response # Modify fields used for tracking Model usage + response: str + prompt: str response, prompt = program(model=self, **kwargs) self.prompts.insert(-1, self.format_prompt(response, **kwargs)) self.num_calls += 1 diff --git a/blendsql/nl_to_blendsql/nl_to_blendsql.py b/blendsql/nl_to_blendsql/nl_to_blendsql.py index 0806293a..235b6b6d 100644 --- a/blendsql/nl_to_blendsql/nl_to_blendsql.py +++ b/blendsql/nl_to_blendsql/nl_to_blendsql.py @@ -1,7 +1,6 @@ from typing import Tuple, Set, Optional, Union, Type from collections.abc import Collection from textwrap import dedent -import outlines from colorama import Fore import re import logging @@ -10,12 +9,14 @@ from ..ingredients import Ingredient, IngredientException from ..models import Model, OllamaLLM from ..db import Database, double_quote_escape -from .._program import Program, return_ollama_response +from .._program import Program from ..grammars.minEarley.parser import EarleyParser from ..grammars.utils import load_cfg_parser from ..prompts import FewShot +from .. import generate from .args import NLtoBlendSQLArgs + PARSER_STOP_TOKENS = ["---", ";", "\n\n", "Q:"] PARSER_SYSTEM_PROMPT = dedent( """ @@ -57,16 +58,11 @@ def __call__( + prompt + Fore.RESET ) - if isinstance(model, OllamaLLM): - # Handle call to ollama - return return_ollama_response( - model_obj=model.model_obj, - prompt=prompt, - stop=PARSER_STOP_TOKENS, - temperature=0.0, - ) - generator = outlines.generate.text(model.model_obj) - response: str = generator(prompt, stop_at=PARSER_STOP_TOKENS) + response = generate.text( + model, + prompt=prompt, + stop_at=PARSER_STOP_TOKENS, + ) return (response, prompt) @@ -92,10 +88,9 @@ def __call__( prompt += f"Question: {question}\n" prompt += f"BlendSQL:\n" prompt += partial_completion - generator = outlines.generate.choice( - model.model_obj, [re.escape(str(i)) for i in candidates] + response = generate.choice( + model, prompt=prompt, choices=[re.escape(str(i)) for i in candidates] ) - response: str = generator(prompt) # type: ignore return (response, prompt) diff --git a/run-debug.py b/run-debug.py index 0f625322..2fdd9308 100644 --- a/run-debug.py +++ b/run-debug.py @@ -33,33 +33,33 @@ ) }} = TRUE """, - # """ - # SELECT * FROM w - # WHERE city = {{ - # LLMQA( - # 'Which city is located 120 miles west of Sydney?', - # (SELECT * FROM documents), - # options='w::city' - # ) - # }} - # """, - # """ - # SELECT date, rival, score, documents.content AS "Team Description" FROM w - # JOIN {{ - # LLMJoin( - # left_on='documents::title', - # right_on='w::rival' - # ) - # }} - # """, - # """ - # {{ - # LLMQA( - # 'What is this table about?', - # (SELECT * FROM w;) - # ) - # }} - # """ + """ + SELECT * FROM w + WHERE city = {{ + LLMQA( + 'Which city is located 120 miles west of Sydney?', + (SELECT * FROM documents), + options='w::city' + ) + }} + """, + """ + SELECT date, rival, score, documents.content AS "Team Description" FROM w + JOIN {{ + LLMJoin( + left_on='documents::title', + right_on='w::rival' + ) + }} + """, + """ + {{ + LLMQA( + 'What is this table about?', + (SELECT * FROM w;) + ) + }} + """, ] # TEST_QUERIES = [ @@ -103,7 +103,7 @@ ) ingredients = {LLMQA, LLMMap, LLMJoin} # db = SQLite(fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db")) - from blendsql.models import OpenaiLLM + from blendsql.models import TransformersLLM # model = OpenaiLLM("gpt-3.5-turbo", caching=False) times = [] @@ -113,8 +113,8 @@ smoothie = blend( query=q, db=db, - blender=OpenaiLLM("gpt-3.5-turbo", caching=False), - # blender=TransformersLLM("microsoft/Phi-3-mini-128k-instruct"), + # blender=OpenaiLLM("gpt-3.5-turbo", caching=False), + blender=TransformersLLM("Qwen/Qwen1.5-0.5B", caching=False), # blender=OllamaLLM("phi3", caching=False), verbose=True, ingredients={LLMJoin.from_args(use_skrub_joiner=False), LLMMap, LLMQA}, diff --git a/run-nl-to-blendsql.py b/run-nl-to-blendsql.py index 93ede724..ac0f8e0d 100644 --- a/run-nl-to-blendsql.py +++ b/run-nl-to-blendsql.py @@ -1,30 +1,38 @@ -from blendsql.models import OllamaLLM, OpenaiLLM +from blendsql.models import OllamaLLM, TransformersLLM from blendsql.db import SQLite from blendsql.utils import fetch_from_hub from blendsql.nl_to_blendsql import nl_to_blendsql, NLtoBlendSQLArgs, FewShot from blendsql import LLMMap, LLMQA, blend if __name__ == "__main__": - model = OllamaLLM("phi3") + ollama_model = OllamaLLM("phi3") + transformers_model = TransformersLLM("Qwen/Qwen1.5-0.5B", caching=False) db = SQLite( fetch_from_hub("1884_New_Zealand_rugby_union_tour_of_New_South_Wales_1.db") ) - prediction = nl_to_blendsql( - question="Show me all info about the game played 120 miles west of Sydney", - db=db, - model=model, - ingredients={LLMQA, LLMMap}, - correction_model=OpenaiLLM("gpt-3.5-turbo"), - few_shot_examples=FewShot.hybridqa, - args=NLtoBlendSQLArgs( - use_tables=["w"], include_db_content_tables=["w"], use_bridge_encoder=True - ), - verbose=True, - ) - smoothie = blend( - query=prediction, - blender=OpenaiLLM("gpt-3.5-turbo"), - ingredients={LLMMap, LLMQA}, - db=db, - ) - print(smoothie.df) + while True: + # question = "Show me all info about the game played 120 miles west of Sydney" + question = input(">>> ") + print("\n") + prediction = nl_to_blendsql( + question=question, + db=db, + model=ollama_model, + ingredients={LLMQA, LLMMap}, + correction_model=transformers_model, + few_shot_examples=FewShot.hybridqa, + args=NLtoBlendSQLArgs( + max_grammar_corrections=3, + use_tables=["w"], + include_db_content_tables=["w"], + use_bridge_encoder=True, + ), + verbose=True, + ) + smoothie = blend( + query=prediction, + blender=transformers_model, + ingredients={LLMMap, LLMQA}, + db=db, + ) + print(smoothie.df) From 2ac08e59e10a68fe18abb9ba961a1a93a84c1b7b Mon Sep 17 00:00:00 2001 From: parkervg Date: Mon, 17 Jun 2024 14:10:00 -0700 Subject: [PATCH 07/10] Some documentation on `infer_gen_constraints` --- blendsql/_sqlglot.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/blendsql/_sqlglot.py b/blendsql/_sqlglot.py index a13c94f8..5f62c95e 100644 --- a/blendsql/_sqlglot.py +++ b/blendsql/_sqlglot.py @@ -12,6 +12,7 @@ Optional, Dict, Any, + Literal, ) from ast import literal_eval from sqlglot.optimizer.scope import find_all_in_scope, find_in_scope @@ -422,8 +423,8 @@ def get_scope_nodes( @attrs class QueryContextManager: """Handles manipulation of underlying SQL query. - We need to maintain two representations here: - - The underlying sqlglot exp.Expression + We need to maintain two synced representations here: + - The underlying sqlglot exp.Expression node - The string representation of the query """ @@ -436,6 +437,7 @@ def parse(self, query, schema: Optional[Union[dict, Schema]] = None): self.node = _parse_one(query, schema=schema) def to_string(self): + # Only call `recover_blendsql` if we need to if hash(self.node) != hash(self._last_to_string_node): self._query = recover_blendsql(self.node.sql(dialect=FTS5SQLite)) self.last_to_string_node = self.node @@ -688,13 +690,24 @@ def infer_gen_constraints(self, start: int, end: int) -> dict: We can infer given the structure above that we expect `LLMMap` to return a boolean. This function identifies that. + Arguments: + start, end: The string indices pointing to the span within the overall BlendSQL query + containing our ingredient in question. + Returns: dict, with keys: - output_type: numeric | string | boolean - pattern: regular expression pattern to use in constrained decoding with Model + output_type: boolean | integer | float | string + pattern: regular expression pattern lambda to use in constrained decoding with Model + See `create_pattern` for more info on these pattern lambdas """ - def create_pattern(output_type: str) -> Callable[[int], str]: + def create_pattern( + output_type: Literal["boolean", "integer", "float"] + ) -> Callable[[int], str]: + """Helper function to create a pattern lambda. + These pattern lambdas take an integer (num_repeats) and return + a regex pattern which is restricted to repeat exclusively num_repeats times. + """ if output_type == "boolean": base_pattern = f"((t|f|{DEFAULT_NAN_ANS}){DEFAULT_ANS_SEP})" elif output_type == "integer": @@ -721,7 +734,7 @@ def create_pattern(output_type: str) -> Callable[[int], str]: # Example: CAST({{LLMMap('jump distance', 'w::notes')}} AS FLOAT) while isinstance(start_node, exp.Func) and start_node is not None: start_node = start_node.parent - output_type = None + output_type: Literal["boolean", "integer", "float"] = None predicate_literals: List[str] = [] if start_node is not None: predicate_literals = get_predicate_literals(start_node) @@ -747,7 +760,7 @@ def create_pattern(output_type: str) -> Callable[[int], str]: elif isinstance( ingredient_node_in_context.parent, (exp.Order, exp.Ordered, exp.AggFunc) ): - output_type = "float" # Use 'float' as default numeric pattern + output_type = "float" # Use 'float' as default numeric pattern, since it's more expressive than 'integer' if output_type is not None: added_kwargs["output_type"] = output_type added_kwargs["pattern"] = create_pattern(output_type) From 6389d9c5d69e7b363d6af35309d09fe729593ba8 Mon Sep 17 00:00:00 2001 From: parkervg Date: Mon, 17 Jun 2024 14:10:55 -0700 Subject: [PATCH 08/10] Some documentation on `infer_gen_constraints` --- blendsql/_sqlglot.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/blendsql/_sqlglot.py b/blendsql/_sqlglot.py index 5f62c95e..71d216bc 100644 --- a/blendsql/_sqlglot.py +++ b/blendsql/_sqlglot.py @@ -290,6 +290,8 @@ def extract_multi_table_predicates( def get_first_child(node): """ Helper function to get first child of a node. + The default argument to `walk()` is bfs=True, + meaning we do breadth-first search. """ gen = node.walk() _ = next(gen) From 5a711b2cebf104580cc1ac1694a614e09d0123cc Mon Sep 17 00:00:00 2001 From: parkervg Date: Mon, 17 Jun 2024 14:29:14 -0700 Subject: [PATCH 09/10] Some docs for query optimization logic --- blendsql/_sqlglot.py | 49 ++++++++++++++++++++-------- docs/reference/query_optimization.md | 18 ++++++++++ mkdocs.yml | 4 ++- 3 files changed, 57 insertions(+), 14 deletions(-) create mode 100644 docs/reference/query_optimization.md diff --git a/blendsql/_sqlglot.py b/blendsql/_sqlglot.py index 71d216bc..6d63c926 100644 --- a/blendsql/_sqlglot.py +++ b/blendsql/_sqlglot.py @@ -426,8 +426,10 @@ def get_scope_nodes( class QueryContextManager: """Handles manipulation of underlying SQL query. We need to maintain two synced representations here: - - The underlying sqlglot exp.Expression node - - The string representation of the query + + 1) The underlying sqlglot exp.Expression node + + 2) The string representation of the query """ node: exp.Expression = attrib(default=None) @@ -487,8 +489,18 @@ def abstracted_table_selects(self) -> Generator[Tuple[str, str], None, None]: abstracted_queries: Generator with (tablename, exp.Select, alias_to_table). The exp.Select is the abstracted query. Examples: - >>> {{Model('is this an italian restaurant?', 'transactions::merchant', endpoint_name='gpt-4')}} = 1 AND child_category = 'Restaurants & Dining' + ```python + scm = SubqueryContextManager( + node=_parse_one( + "SELECT * FROM transactions WHERE {{Model('is this an italian restaurant?', 'transactions::merchant')}} = TRUE AND child_category = 'Restaurants & Dining'" + ) + ) + scm.abstracted_table_selects() + ``` + Returns: + ```text ('transactions', 'SELECT * FROM transactions WHERE TRUE AND child_category = \'Restaurants & Dining\'') + ``` """ # TODO: don't really know how to optimize with 'CASE' queries right now if self.node.find(exp.Case): @@ -586,13 +598,18 @@ def _table_star_queries( table_star_queries: Generator with (tablename, exp.Select). The exp.Select is the table_star query Examples: - >>> SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)', Name - >>> FROM account_history - >>> LEFT JOIN constituents ON account_history.Symbol = constituents.Symbol - >>> WHERE constituents.Sector = 'Information Technology' - >>> AND lower(Action) like "%dividend%" + ```sql + SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)', Name + FROM account_history + LEFT JOIN constituents ON account_history.Symbol = constituents.Symbol + WHERE constituents.Sector = 'Information Technology' + AND lower(Action) like "%dividend%" + ``` + Returns (after getting str representation of `exp.Select`): + ```text ('account_history', 'SELECT * FROM account_history WHERE lower(Action) like "%dividend%') ('constituents', 'SELECT * FROM constituents WHERE sector = \'Information Technology\'') + ``` """ # Use `scope` to get all unique tablenodes in ast tablenodes = set( @@ -687,20 +704,26 @@ def infer_gen_constraints(self, start: int, end: int) -> dict: downstream Model generations. For example: - >>> SELECT * FROM w WHERE {{LLMMap('Is this true?', 'w::colname')}} + + ```sql + SELECT * FROM w WHERE {{LLMMap('Is this true?', 'w::colname')}} + ``` We can infer given the structure above that we expect `LLMMap` to return a boolean. This function identifies that. Arguments: - start, end: The string indices pointing to the span within the overall BlendSQL query + indices: The string indices pointing to the span within the overall BlendSQL query containing our ingredient in question. Returns: dict, with keys: - output_type: boolean | integer | float | string - pattern: regular expression pattern lambda to use in constrained decoding with Model - See `create_pattern` for more info on these pattern lambdas + + - output_type + - 'boolean' | 'integer' | 'float' | 'string' + + - pattern: regular expression pattern lambda to use in constrained decoding with Model + - See `create_pattern` for more info on these pattern lambdas """ def create_pattern( diff --git a/docs/reference/query_optimization.md b/docs/reference/query_optimization.md new file mode 100644 index 00000000..7d7f82d6 --- /dev/null +++ b/docs/reference/query_optimization.md @@ -0,0 +1,18 @@ +--- +hide: + - toc +--- +### Query Optimization + +::: blendsql._sqlglot.QueryContextManager + handler: python + show_source: true + +::: blendsql._sqlglot.SubqueryContextManager + handler: python + show_source: true + options: + members: + - abstracted_table_selects + - _table_star_queries + - infer_gen_constraints \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index c77bcb70..c94ead60 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -100,4 +100,6 @@ nav: - reference/programs.md - Natural Language to BlendSQL: - reference/nl_to_blendsql/nl-to-blendsql.md - - Technical Walkthrough: reference/technical_walkthrough.md + - Technical Walkthrough: + - reference/technical_walkthrough.md + - Query Optimization: reference/query_optimization.md From f1a6c8f5af95103020ef14c17d260524418e6bca Mon Sep 17 00:00:00 2001 From: parkervg Date: Mon, 17 Jun 2024 18:58:02 -0500 Subject: [PATCH 10/10] Remove accidental commits --- .gitignore | 2 ++ .zed/settings.json | 0 pyrightconfig.json | 4 ---- 3 files changed, 2 insertions(+), 4 deletions(-) delete mode 100644 .zed/settings.json delete mode 100644 pyrightconfig.json diff --git a/.gitignore b/.gitignore index 22e2306f..fb197798 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ .chainlit/ .files/ +.zed/ +pyrightconfig.json proxies.py site/ tests/data/ diff --git a/.zed/settings.json b/.zed/settings.json deleted file mode 100644 index e69de29b..00000000 diff --git a/pyrightconfig.json b/pyrightconfig.json deleted file mode 100644 index 3c226a6c..00000000 --- a/pyrightconfig.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "venvPath": "/Users/parkerglenn/opt/miniconda3/envs/", - "venv": "/Users/parkerglenn/opt/miniconda3/envs/blendsql" -}