Skip to content

Commit

Permalink
Merge pull request #21 from stanford-oval/fix/20
Browse files Browse the repository at this point in the history
fix for #20
  • Loading branch information
george1459 authored May 7, 2024
2 parents d1c86a0 + db308fe commit ba7b3a7
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Package metadata
name = "suql"
version = "1.1.7a1"
version = "1.1.7a2"
description = "Structured and Unstructured Query Language (SUQL) Python API"
author = "Shicheng Liu"
author_email = "[email protected]"
Expand Down
23 changes: 10 additions & 13 deletions src/suql/sql_free_text_support/execute_free_text_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,8 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr):
# change projection to include everything
to_execute_node.targetList = (ResTarget(val=ColumnRef(fields=(A_Star(),))),)
# set limit to 1, to see if there are results
to_execute_node.limitCount = Integer(ival=1)
to_execute_node.limitCount = A_Const(val=Integer(ival=1))
to_execute_node.limitOption = pglast.enums.nodes.LimitOption.LIMIT_OPTION_COUNT
# change predicates
to_execute_node.whereClause = node
# reset any groupby clause
Expand Down Expand Up @@ -1121,6 +1122,9 @@ def visit_ColumnRef(self, ancestors: Ancestor, node: ColumnRef):
if len(list(map(lambda x: x.sval, node.fields))) > 1:
assert len(list(map(lambda x: x.sval, node.fields))) == 2
node.fields = (String(sval="^".join(map(lambda x: x.sval, node.fields))),)
elif "^" in node.fields[0].sval:
# this means that it has already been replaced
pass
else:
res = None
for table_name, columns in self.table_column_mapping.items():
Expand All @@ -1142,6 +1146,7 @@ def _execute_structural_sql(
select_userpswd: str,
llm_model_name: str
):
_ = RawStream()(original_node) # RawStream takes care of some issue, to investigate
node = deepcopy(original_node)
# change projection to include everything
# there are a couple of cases here
Expand All @@ -1154,15 +1159,11 @@ def _execute_structural_sql(
table_column_mapping = {}
for table in [node.fromClause[0].larg, node.fromClause[0].rarg]:
# find out what columns this table has
_, columns = _execute_structural_sql(
SelectStmt(fromClause=(table,)),
_, columns = execute_sql_with_column_info(
RawStream()(SelectStmt(fromClause=(table,), targetList=(ResTarget(val=ColumnRef(fields=(A_Star(),))),))),
database,
None,
cache,
fts_fields,
select_username,
select_userpswd,
llm_model_name
)
# give the projection fields new names
projection_table_name = (
Expand Down Expand Up @@ -1195,15 +1196,11 @@ def _execute_structural_sql(
all_projection_fields = []
for table in node.fromClause:
# find out what columns this table has
_, columns = _execute_structural_sql(
SelectStmt(fromClause=(table,)),
_, columns = execute_sql_with_column_info(
RawStream()(SelectStmt(fromClause=(table,), targetList=(ResTarget(val=ColumnRef(fields=(A_Star(),))),))),
database,
None,
cache,
fts_fields,
select_username,
select_userpswd,
llm_model_name
)
# give the projection fields new names
projection_table_name = (
Expand Down

0 comments on commit ba7b3a7

Please sign in to comment.