Skip to content

Commit

Permalink
expose _check_required_params as experimental feature
Browse files Browse the repository at this point in the history
  • Loading branch information
george1459 committed May 2, 2024
1 parent 4c4d4b3 commit fc514e8
Showing 1 changed file with 102 additions and 2 deletions.
104 changes: 102 additions & 2 deletions src/suql/sql_free_text_support/execute_free_text_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,106 @@ def _execute_standalone_answer(suql, source_file_mapping):
source_content = fd.read()

return _answer(source_content, query)

def _check_predicate_exist(a_expr: A_Expr, field_name: str):
if isinstance(a_expr.lexpr, ColumnRef):
for i in a_expr.lexpr.fields:
if isinstance(i, String) and i.sval == field_name:
return True

if isinstance(a_expr.rexpr, ColumnRef):
for i in a_expr.rexpr.fields:
if isinstance(i, String) and i.sval == field_name:
return True

return False


class _RequiredParamMappingVisitor(Visitor):
def __init__(
self,
required_params_mapping
) -> None:
super().__init__()
self.required_params_mapping = required_params_mapping
self.missing_params = defaultdict(set)

def visit_SelectStmt(self, ancestors, node: SelectStmt):

def check_a_expr_or_and_expr(_dnf_predicate, _field):
if isinstance(_dnf_predicate, A_Expr):
return _check_predicate_exist(_dnf_predicate, _field)
elif (
isinstance(_dnf_predicate, BoolExpr)
and _dnf_predicate.boolop == BoolExprType.AND_EXPR
):
found = False
for i in _dnf_predicate.args:
# there could also be NOT clauses
if isinstance(i, A_Expr):
if _check_predicate_exist(i, _field):
found = True
break

return found

return False


for table in node.fromClause:
if isinstance(table, RangeVar) and table.relname in self.required_params_mapping:
assert type(self.required_params_mapping[table.relname]) == list

if not node.whereClause:
self.missing_params[table.relname].update(self.required_params_mapping[table.relname])
continue

dnf_predicate = _convert2dnf(node.whereClause)

if (
isinstance(dnf_predicate, BoolExpr)
and dnf_predicate.boolop == BoolExprType.OR_EXPR
):
for field in self.required_params_mapping[table.relname]:
if not all(check_a_expr_or_and_expr(i, field) for i in dnf_predicate.args):
self.missing_params[table.relname].add(field)
else:
# target condition:
# if isinstance(dnf_predicate, A_Expr) or (
# isinstance(dnf_predicate, BoolExpr)
# and dnf_predicate.boolop == BoolExprType.AND_EXPR
# ):
# and if it is a NOT, in which case we just return False
for field in self.required_params_mapping[table.relname]:
if not check_a_expr_or_and_expr(dnf_predicate, field):
self.missing_params[table.relname].add(field)


def _check_required_params(suql, required_params_mapping):
"""
Check whether all required parameters exist in the `suql`.
# Parameters:
`suql` (str): The to-be-executed suql query.
`required_params_mapping` (Dict(str -> List[str]), optional): *Experimental feature*: a dictionary mapping
from table names to a list of "required" parameters for the tables. The SUQL compiler will check whether the
SUQL query contains all required parameters (i.e., whether for each such table there exists a `WHERE` clause
with the required parameter).
# Returns:
`if_all_exist` (bool): whether all required parameters exist.
`missing_params` (Dict(str -> List[str]): a mapping from table names to a list of required missing parameters.
"""
root = parse_sql(suql)
visitor = _RequiredParamMappingVisitor(required_params_mapping)
visitor(root)

if visitor.missing_params:
return False, {key: list(value) for key, value in visitor.missing_params.items()}
else:
return True, {}


def suql_execute(
Expand Down Expand Up @@ -1696,11 +1796,11 @@ def suql_execute(
`create_userpswd` (str, optional): above user's password with create privilege in db. Defaults to "creator_role".
`source_file_mapping` (Dict(str -> str), optional): Experimental feature - a dictionary mapping from variable
`source_file_mapping` (Dict(str -> str), optional): *Experimental feature*: a dictionary mapping from variable
names to the file locations. This would support queries that only need a free text source, e.g.,
`suql = answer(yelp_general_info, 'what is your cancellation policy?')`. In this case, you can specify
`source_file_mapping = {"yelp_general_info": "PATH TO FILE"}` to inform the SUQL compiler where to find
`yelp_general_info`.
`yelp_general_info`. Defaults to `{}`.
# Returns:
`results` (List[[*]]): A list of returned database results. Each inner list stores a row of returned result.
Expand Down

0 comments on commit fc514e8

Please sign in to comment.