Skip to content

Commit

Permalink
Merge pull request #16 from stanford-oval/fix/15
Browse files Browse the repository at this point in the history
fix for #15, delete `restaurants` as default parameter in postgreSQL file
  • Loading branch information
george1459 authored Apr 25, 2024
2 parents bf73a44 + fa4245f commit e0d47a1
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 15 deletions.
7 changes: 5 additions & 2 deletions docs/install_pip.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,16 @@ You should be good to go! In a separate terminal, set up your LLM API key enviro

```python
>>> from suql import suql_execute
# e.g. suql = "SELECT * FROM restaurants WHERE answer(reviews, 'is this a family-friendly restaurant?') = 'Yes' AND rating = 4;"
# e.g. suql = "SELECT * FROM restaurants WHERE answer(reviews, 'is this a family-friendly restaurant?') = 'Yes' AND rating = 4 LIMIT 3;"
>>> suql = "Your favorite SUQL"

# e.g. table_w_ids = {"restaurants": "_id"}
>>> table_w_ids = "mapping between table name -> unique ID column name"

>>> suql_execute(suql, table_w_ids)
# e.g. database = "restaurants"
>>> database = "your postgres database name"

>>> suql_execute(suql, table_w_ids, database)
```

Check out [API documentation](https://stanford-oval.github.io/suql/suql/sql_free_text_support/execute_free_text_sql.html) for details.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package metadata
name = "suql"
version = "1.1.4a3"
version = "1.1.5"
description = "Structured and Unstructured Query Language (SUQL) Python API"
author = "Shicheng Liu"
author_email = "[email protected]"
Expand Down
5 changes: 4 additions & 1 deletion src/suql/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,10 @@ def parse_execute_sql(dlgHistory, user_query, prompt_file="prompts/parser_suql.p

suql_execute_start_time = time.time()
final_res, column_names, cache = suql_execute(
postprocessed_suql, {"restaurants": "_id"}, fts_fields=[("restaurants", "name")]
postprocessed_suql,
{"restaurants": "_id"},
"restaurants",
fts_fields=[("restaurants", "name")]
)
suql_execute_end_time = time.time()

Expand Down
4 changes: 2 additions & 2 deletions src/suql/postgresql_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def execute_sql(
sql_query,
database="restaurants",
database,
user="select_user",
password="select_user",
data=None,
Expand Down Expand Up @@ -72,7 +72,7 @@ def execute_sql(

def execute_sql_with_column_info(
sql_query,
database="restaurants",
database,
user="select_user",
password="select_user",
unprotected=False,
Expand Down
91 changes: 82 additions & 9 deletions src/suql/sql_free_text_support/execute_free_text_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ class _SelectVisitor(Visitor):
def __init__(
self,
fts_fields,
database,
embedding_server_address,
select_username,
select_userpswd,
Expand Down Expand Up @@ -226,6 +227,9 @@ def __init__(

# store max verify param
self.max_verify = max_verify

# store database
self.database = database

def __call__(self, node):
super().__call__(node)
Expand Down Expand Up @@ -255,6 +259,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt):
# main entry point for SUQL compiler optimization
results, column_info = _analyze_SelectStmt(
node,
self.database,
self.cache,
self.fts_fields,
self.embedding_server_address,
Expand All @@ -273,6 +278,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt):
logging.info("created table {}".format(tmp_table_name))
execute_sql(
create_stmt,
self.database,
user=self.create_username,
password=self.create_userpswd,
commit_in_lieu_fetch=True,
Expand All @@ -296,6 +302,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt):
)
execute_sql(
f"INSERT INTO {tmp_table_name} VALUES ({placeholder_str})",
self.database,
data=updated_results,
user=self.create_username,
password=self.create_userpswd,
Expand All @@ -311,6 +318,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt):
else:
_classify_db_fields(
node,
self.database,
self.cache,
self.fts_fields,
self.select_username,
Expand Down Expand Up @@ -347,6 +355,7 @@ def drop_tmp_tables(self):
drop_stmt = f"DROP TABLE {tmp_table_name}"
execute_sql(
drop_stmt,
self.database,
user=self.create_username,
password=self.create_userpswd,
commit_in_lieu_fetch=True,
Expand Down Expand Up @@ -840,10 +849,18 @@ def _get_comma_separated_numbers(input_string):

class _StructuralClassification(Visitor):
def __init__(
self, node: SelectStmt, cache, fts_fields, select_username, select_userpswd, llm_model_name
self,
node: SelectStmt,
database,
cache,
fts_fields,
select_username,
select_userpswd,
llm_model_name
) -> None:
super().__init__()
self.node = node
self.database = database
self.cache = cache
self.fts_fields = fts_fields
self.select_username = select_username
Expand Down Expand Up @@ -933,6 +950,7 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr):
try:
res, column_infos = execute_sql_with_column_info(
RawStream()(to_execute_node),
self.database,
unprotected=True,
user=self.select_username,
password=self.select_userpswd,
Expand All @@ -945,6 +963,7 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr):
to_execute_node.whereClause = None
_, column_infos = execute_sql_with_column_info(
RawStream()(to_execute_node),
self.database,
user=self.select_username,
password=self.select_userpswd,
)
Expand Down Expand Up @@ -1026,6 +1045,7 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr):
to_execute_node.limitCount = None # find all entries
field_value_choices, _ = execute_sql_with_column_info(
RawStream()(to_execute_node),
self.database,
user=self.select_username,
password=self.select_userpswd,
)
Expand Down Expand Up @@ -1066,6 +1086,7 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr):

def _classify_db_fields(
node: SelectStmt,
database: str,
cache: dict,
fts_fields: List,
select_username: str,
Expand All @@ -1077,7 +1098,13 @@ def _classify_db_fields(
# the goal of this function is to determine which predicate leads to no results
# for a field without results, try to classify into one of the existing fields
visitor = _StructuralClassification(
node, cache, fts_fields, select_username, select_userpswd, llm_model_name
node,
database,
cache,
fts_fields,
select_username,
select_userpswd,
llm_model_name
)
visitor(node)

Expand Down Expand Up @@ -1107,6 +1134,7 @@ def visit_ColumnRef(self, ancestors: Ancestor, node: ColumnRef):

def _execute_structural_sql(
original_node: SelectStmt,
database: str,
predicate: BoolExpr,
cache: dict,
fts_fields: List,
Expand All @@ -1128,6 +1156,7 @@ def _execute_structural_sql(
# find out what columns this table has
_, columns = _execute_structural_sql(
SelectStmt(fromClause=(table,)),
database,
None,
cache,
fts_fields,
Expand Down Expand Up @@ -1168,6 +1197,7 @@ def _execute_structural_sql(
# find out what columns this table has
_, columns = _execute_structural_sql(
SelectStmt(fromClause=(table,)),
database,
None,
cache,
fts_fields,
Expand Down Expand Up @@ -1214,11 +1244,22 @@ def _execute_structural_sql(
assert _if_all_structural(node)

# deal with sturctural field classification
_classify_db_fields(node, cache, fts_fields, select_username, select_userpswd, llm_model_name)
_classify_db_fields(
node,
database,
cache,
fts_fields,
select_username,
select_userpswd,
llm_model_name
)

sql = RawStream()(node)
return execute_sql_with_column_info(
sql, user=select_username, password=select_userpswd
sql,
database,
user=select_username,
password=select_userpswd
)


Expand Down Expand Up @@ -1377,6 +1418,7 @@ def breakdown_unstructural_query(predicate: A_Expr):

def _execute_and(
sql_dnf_predicates,
database: str,
node: SelectStmt,
limit,
cache: dict,
Expand Down Expand Up @@ -1410,6 +1452,7 @@ def _execute_and(
# execute structural part
structural_res, column_info = _execute_structural_sql(
node,
database,
structural_predicates,
cache,
fts_fields,
Expand Down Expand Up @@ -1447,6 +1490,7 @@ def _execute_and(
if _if_all_structural(sql_dnf_predicates):
return _execute_structural_sql(
node,
database,
sql_dnf_predicates,
cache,
fts_fields,
Expand All @@ -1456,7 +1500,14 @@ def _execute_and(
)
else:
all_results, column_info = _execute_structural_sql(
node, None, cache, fts_fields, select_username, select_userpswd, llm_model_name
node,
database,
None,
cache,
fts_fields,
select_username,
select_userpswd,
llm_model_name
)
return _execute_free_text_queries(
node,
Expand All @@ -1473,6 +1524,7 @@ def _execute_and(

def _analyze_SelectStmt(
node: SelectStmt,
database: str,
cache: dict,
fts_fields: List,
embedding_server_address: str,
Expand All @@ -1498,6 +1550,7 @@ def _analyze_SelectStmt(
for choice in choices:
choice_res, column_info = _execute_and(
choice,
database,
node,
limit - len(res),
cache,
Expand All @@ -1523,6 +1576,7 @@ def _analyze_SelectStmt(
):
return _execute_and(
sql_dnf_predicates,
database,
node,
limit,
cache,
Expand All @@ -1541,6 +1595,7 @@ def _analyze_SelectStmt(
):
return _execute_and(
sql_dnf_predicates,
database,
node,
limit,
cache,
Expand Down Expand Up @@ -1588,6 +1643,7 @@ def _execute_standalone_answer(suql, source_file_mapping):
def suql_execute(
suql,
table_w_ids,
database,
fts_fields=[],
llm_model_name="gpt-3.5-turbo-0125",
max_verify=20,
Expand All @@ -1610,7 +1666,9 @@ def suql_execute(
`table_w_ids` (dict): A dictionary where each key is a table name, and each value is the corresponding
unique ID column name in this table, e.g., `table_w_ids = {"restaurants": "_id"}`, meaning that the
relevant tables to the SUQL compiler include only the `restaurants` table, which has unique ID column `_id`.
`database` (str): The name of the PostgreSQL database to execute the query.
`fts_fields` (List[str], optional): Fields that should use PostgreSQL's Full Text Search (FTS) operators;
The SUQL compiler would change certain string operators like "=" to use PostgreSQL's FTS operators.
It uses `websearch_to_tsquery` and the `@@` operator to match against these fields.
Expand Down Expand Up @@ -1685,6 +1743,7 @@ def suql_execute(
results, column_names, cache = _suql_execute_single(
suql,
table_w_ids,
database,
fts_fields,
llm_model_name,
max_verify,
Expand Down Expand Up @@ -1717,6 +1776,7 @@ def suql_execute(
def _suql_execute_single(
suql,
table_w_ids,
database,
fts_fields,
llm_model_name,
max_verify,
Expand All @@ -1735,6 +1795,7 @@ def _suql_execute_single(
if disable_try_catch:
visitor = _SelectVisitor(
fts_fields,
database,
embedding_server_address,
select_username,
select_userpswd,
Expand All @@ -1749,11 +1810,18 @@ def _suql_execute_single(
second_sql = RawStream()(root)
cache = visitor.serialize_cache()

return execute_sql(second_sql, user=select_username, password=select_userpswd, no_print=True)
return execute_sql(
second_sql,
database,
user=select_username,
password=select_userpswd,
no_print=True
)
else:
try:
visitor = _SelectVisitor(
fts_fields,
database,
embedding_server_address,
select_username,
select_userpswd,
Expand All @@ -1769,7 +1837,11 @@ def _suql_execute_single(
cache = visitor.serialize_cache()

results, column_names, cache = execute_sql(
second_sql, user=select_username, password=select_userpswd, no_print=True
second_sql,
database,
user=select_username,
password=select_userpswd,
no_print=True
)
except Exception as err:
with open("_suql_error_log.txt", "a") as file:
Expand All @@ -1785,12 +1857,13 @@ def _suql_execute_single(

if __name__ == "__main__":
# print(suql_execute(sql, disable_try_catch=True, fts_fields=[("restaurants", "name")] )[0])
database = "restaurants"
with open("sql_free_text_support/test_cases.txt", "r") as fd:
test_cases = fd.readlines()
res = []
for sql in test_cases:
sql = sql.strip()
i_res = suql_execute(sql, disable_try_catch=True)[0]
i_res = suql_execute(sql, database, disable_try_catch=True)[0]
res.append(i_res)
with open("sql_free_text_support/test_cases_res.txt", "w") as fd:
for i_res in res:
Expand Down

0 comments on commit e0d47a1

Please sign in to comment.