From fa4245f2c08be3b27b5e08bfdc8d41942ef5795d Mon Sep 17 00:00:00 2001 From: Shicheng Liu Date: Thu, 25 Apr 2024 17:51:15 +0000 Subject: [PATCH] fix for #15 --- docs/install_pip.md | 7 +- setup.py | 2 +- src/suql/agent.py | 5 +- src/suql/postgresql_connection.py | 4 +- .../execute_free_text_sql.py | 91 +++++++++++++++++-- 5 files changed, 94 insertions(+), 15 deletions(-) diff --git a/docs/install_pip.md b/docs/install_pip.md index 5192f82..4eb79c2 100644 --- a/docs/install_pip.md +++ b/docs/install_pip.md @@ -88,13 +88,16 @@ You should be good to go! In a separate terminal, set up your LLM API key enviro ```python >>> from suql import suql_execute -# e.g. suql = "SELECT * FROM restaurants WHERE answer(reviews, 'is this a family-friendly restaurant?') = 'Yes' AND rating = 4;" +# e.g. suql = "SELECT * FROM restaurants WHERE answer(reviews, 'is this a family-friendly restaurant?') = 'Yes' AND rating = 4 LIMIT 3;" >>> suql = "Your favorite SUQL" # e.g. table_w_ids = {"restaurants": "_id"} >>> table_w_ids = "mapping between table name -> unique ID column name" ->>> suql_execute(suql, table_w_ids) +# e.g. database = "restaurants" +>>> database = "your postgres database name" + +>>> suql_execute(suql, table_w_ids, database) ``` Check out [API documentation](https://stanford-oval.github.io/suql/suql/sql_free_text_support/execute_free_text_sql.html) for details. diff --git a/setup.py b/setup.py index c7fe05c..d8733a6 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ # Package metadata name = "suql" -version = "1.1.4a3" +version = "1.1.5" description = "Structured and Unstructured Query Language (SUQL) Python API" author = "Shicheng Liu" author_email = "shicheng@cs.stanford.edu" diff --git a/src/suql/agent.py b/src/suql/agent.py index 79e46dd..b541c4a 100644 --- a/src/suql/agent.py +++ b/src/suql/agent.py @@ -242,7 +242,10 @@ def parse_execute_sql(dlgHistory, user_query, prompt_file="prompts/parser_suql.p suql_execute_start_time = time.time() final_res, column_names, cache = suql_execute( - postprocessed_suql, {"restaurants": "_id"}, fts_fields=[("restaurants", "name")] + postprocessed_suql, + {"restaurants": "_id"}, + "restaurants", + fts_fields=[("restaurants", "name")] ) suql_execute_end_time = time.time() diff --git a/src/suql/postgresql_connection.py b/src/suql/postgresql_connection.py index ed78cdc..c024f36 100644 --- a/src/suql/postgresql_connection.py +++ b/src/suql/postgresql_connection.py @@ -5,7 +5,7 @@ def execute_sql( sql_query, - database="restaurants", + database, user="select_user", password="select_user", data=None, @@ -72,7 +72,7 @@ def execute_sql( def execute_sql_with_column_info( sql_query, - database="restaurants", + database, user="select_user", password="select_user", unprotected=False, diff --git a/src/suql/sql_free_text_support/execute_free_text_sql.py b/src/suql/sql_free_text_support/execute_free_text_sql.py index afaacf9..ae92a30 100644 --- a/src/suql/sql_free_text_support/execute_free_text_sql.py +++ b/src/suql/sql_free_text_support/execute_free_text_sql.py @@ -196,6 +196,7 @@ class _SelectVisitor(Visitor): def __init__( self, fts_fields, + database, embedding_server_address, select_username, select_userpswd, @@ -226,6 +227,9 @@ def __init__( # store max verify param self.max_verify = max_verify + + # store database + self.database = database def __call__(self, node): super().__call__(node) @@ -255,6 +259,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt): # main entry point for SUQL compiler optimization results, column_info = _analyze_SelectStmt( node, + self.database, self.cache, self.fts_fields, self.embedding_server_address, @@ -273,6 +278,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt): logging.info("created table {}".format(tmp_table_name)) execute_sql( create_stmt, + self.database, user=self.create_username, password=self.create_userpswd, commit_in_lieu_fetch=True, @@ -296,6 +302,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt): ) execute_sql( f"INSERT INTO {tmp_table_name} VALUES ({placeholder_str})", + self.database, data=updated_results, user=self.create_username, password=self.create_userpswd, @@ -311,6 +318,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt): else: _classify_db_fields( node, + self.database, self.cache, self.fts_fields, self.select_username, @@ -347,6 +355,7 @@ def drop_tmp_tables(self): drop_stmt = f"DROP TABLE {tmp_table_name}" execute_sql( drop_stmt, + self.database, user=self.create_username, password=self.create_userpswd, commit_in_lieu_fetch=True, @@ -840,10 +849,18 @@ def _get_comma_separated_numbers(input_string): class _StructuralClassification(Visitor): def __init__( - self, node: SelectStmt, cache, fts_fields, select_username, select_userpswd, llm_model_name + self, + node: SelectStmt, + database, + cache, + fts_fields, + select_username, + select_userpswd, + llm_model_name ) -> None: super().__init__() self.node = node + self.database = database self.cache = cache self.fts_fields = fts_fields self.select_username = select_username @@ -933,6 +950,7 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): try: res, column_infos = execute_sql_with_column_info( RawStream()(to_execute_node), + self.database, unprotected=True, user=self.select_username, password=self.select_userpswd, @@ -945,6 +963,7 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): to_execute_node.whereClause = None _, column_infos = execute_sql_with_column_info( RawStream()(to_execute_node), + self.database, user=self.select_username, password=self.select_userpswd, ) @@ -1026,6 +1045,7 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): to_execute_node.limitCount = None # find all entries field_value_choices, _ = execute_sql_with_column_info( RawStream()(to_execute_node), + self.database, user=self.select_username, password=self.select_userpswd, ) @@ -1066,6 +1086,7 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): def _classify_db_fields( node: SelectStmt, + database: str, cache: dict, fts_fields: List, select_username: str, @@ -1077,7 +1098,13 @@ def _classify_db_fields( # the goal of this function is to determine which predicate leads to no results # for a field without results, try to classify into one of the existing fields visitor = _StructuralClassification( - node, cache, fts_fields, select_username, select_userpswd, llm_model_name + node, + database, + cache, + fts_fields, + select_username, + select_userpswd, + llm_model_name ) visitor(node) @@ -1107,6 +1134,7 @@ def visit_ColumnRef(self, ancestors: Ancestor, node: ColumnRef): def _execute_structural_sql( original_node: SelectStmt, + database: str, predicate: BoolExpr, cache: dict, fts_fields: List, @@ -1128,6 +1156,7 @@ def _execute_structural_sql( # find out what columns this table has _, columns = _execute_structural_sql( SelectStmt(fromClause=(table,)), + database, None, cache, fts_fields, @@ -1168,6 +1197,7 @@ def _execute_structural_sql( # find out what columns this table has _, columns = _execute_structural_sql( SelectStmt(fromClause=(table,)), + database, None, cache, fts_fields, @@ -1214,11 +1244,22 @@ def _execute_structural_sql( assert _if_all_structural(node) # deal with sturctural field classification - _classify_db_fields(node, cache, fts_fields, select_username, select_userpswd, llm_model_name) + _classify_db_fields( + node, + database, + cache, + fts_fields, + select_username, + select_userpswd, + llm_model_name + ) sql = RawStream()(node) return execute_sql_with_column_info( - sql, user=select_username, password=select_userpswd + sql, + database, + user=select_username, + password=select_userpswd ) @@ -1377,6 +1418,7 @@ def breakdown_unstructural_query(predicate: A_Expr): def _execute_and( sql_dnf_predicates, + database: str, node: SelectStmt, limit, cache: dict, @@ -1410,6 +1452,7 @@ def _execute_and( # execute structural part structural_res, column_info = _execute_structural_sql( node, + database, structural_predicates, cache, fts_fields, @@ -1447,6 +1490,7 @@ def _execute_and( if _if_all_structural(sql_dnf_predicates): return _execute_structural_sql( node, + database, sql_dnf_predicates, cache, fts_fields, @@ -1456,7 +1500,14 @@ def _execute_and( ) else: all_results, column_info = _execute_structural_sql( - node, None, cache, fts_fields, select_username, select_userpswd, llm_model_name + node, + database, + None, + cache, + fts_fields, + select_username, + select_userpswd, + llm_model_name ) return _execute_free_text_queries( node, @@ -1473,6 +1524,7 @@ def _execute_and( def _analyze_SelectStmt( node: SelectStmt, + database: str, cache: dict, fts_fields: List, embedding_server_address: str, @@ -1498,6 +1550,7 @@ def _analyze_SelectStmt( for choice in choices: choice_res, column_info = _execute_and( choice, + database, node, limit - len(res), cache, @@ -1523,6 +1576,7 @@ def _analyze_SelectStmt( ): return _execute_and( sql_dnf_predicates, + database, node, limit, cache, @@ -1541,6 +1595,7 @@ def _analyze_SelectStmt( ): return _execute_and( sql_dnf_predicates, + database, node, limit, cache, @@ -1588,6 +1643,7 @@ def _execute_standalone_answer(suql, source_file_mapping): def suql_execute( suql, table_w_ids, + database, fts_fields=[], llm_model_name="gpt-3.5-turbo-0125", max_verify=20, @@ -1610,7 +1666,9 @@ def suql_execute( `table_w_ids` (dict): A dictionary where each key is a table name, and each value is the corresponding unique ID column name in this table, e.g., `table_w_ids = {"restaurants": "_id"}`, meaning that the relevant tables to the SUQL compiler include only the `restaurants` table, which has unique ID column `_id`. - + + `database` (str): The name of the PostgreSQL database to execute the query. + `fts_fields` (List[str], optional): Fields that should use PostgreSQL's Full Text Search (FTS) operators; The SUQL compiler would change certain string operators like "=" to use PostgreSQL's FTS operators. It uses `websearch_to_tsquery` and the `@@` operator to match against these fields. @@ -1685,6 +1743,7 @@ def suql_execute( results, column_names, cache = _suql_execute_single( suql, table_w_ids, + database, fts_fields, llm_model_name, max_verify, @@ -1717,6 +1776,7 @@ def suql_execute( def _suql_execute_single( suql, table_w_ids, + database, fts_fields, llm_model_name, max_verify, @@ -1735,6 +1795,7 @@ def _suql_execute_single( if disable_try_catch: visitor = _SelectVisitor( fts_fields, + database, embedding_server_address, select_username, select_userpswd, @@ -1749,11 +1810,18 @@ def _suql_execute_single( second_sql = RawStream()(root) cache = visitor.serialize_cache() - return execute_sql(second_sql, user=select_username, password=select_userpswd, no_print=True) + return execute_sql( + second_sql, + database, + user=select_username, + password=select_userpswd, + no_print=True + ) else: try: visitor = _SelectVisitor( fts_fields, + database, embedding_server_address, select_username, select_userpswd, @@ -1769,7 +1837,11 @@ def _suql_execute_single( cache = visitor.serialize_cache() results, column_names, cache = execute_sql( - second_sql, user=select_username, password=select_userpswd, no_print=True + second_sql, + database, + user=select_username, + password=select_userpswd, + no_print=True ) except Exception as err: with open("_suql_error_log.txt", "a") as file: @@ -1785,12 +1857,13 @@ def _suql_execute_single( if __name__ == "__main__": # print(suql_execute(sql, disable_try_catch=True, fts_fields=[("restaurants", "name")] )[0]) + database = "restaurants" with open("sql_free_text_support/test_cases.txt", "r") as fd: test_cases = fd.readlines() res = [] for sql in test_cases: sql = sql.strip() - i_res = suql_execute(sql, disable_try_catch=True)[0] + i_res = suql_execute(sql, database, disable_try_catch=True)[0] res.append(i_res) with open("sql_free_text_support/test_cases_res.txt", "w") as fd: for i_res in res: