diff --git a/src/suql/sql_free_text_support/execute_free_text_sql.py b/src/suql/sql_free_text_support/execute_free_text_sql.py index 0967e52..1c536bd 100644 --- a/src/suql/sql_free_text_support/execute_free_text_sql.py +++ b/src/suql/sql_free_text_support/execute_free_text_sql.py @@ -4,6 +4,7 @@ import string import time import traceback +import logging from collections import defaultdict from copy import deepcopy from typing import List, Union @@ -229,7 +230,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt): list(map(lambda x: f'"{x[0]}" {x[1]}', column_info)) ) create_stmt = f"CREATE TABLE {tmp_table_name} (\n{column_create_stmt}\n); GRANT SELECT ON {tmp_table_name} TO {self.select_username};" - print("created table {}".format(tmp_table_name)) + logging.info("created table {}".format(tmp_table_name)) execute_sql( create_stmt, user=self.create_username, @@ -484,9 +485,9 @@ def verify_single_value(single_value, single_column_name): ) ) if all_found: - print("\n".join(found_stmt)) + logging.info("\n".join(found_stmt)) elif found_stmt: - print("partially verified: " + "\n".join(found_stmt)) + logging.info("partially verified: " + "\n".join(found_stmt)) return all_found @@ -677,7 +678,7 @@ def _retrieve_and_verify( id_res.append(each_res[0]) end_time = time.time() - print("retrieve + verification time {}s".format(end_time - start_time)) + logging.info("retrieve + verification time {}s".format(end_time - start_time)) if single_table: res = list(filter(lambda x: x[id_index] in id_res, existing_results)) @@ -906,13 +907,13 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): password=self.select_userpswd, ) except psyconpg2Error: - print( + logging.info( "above error happens during ENUM classification attempts. Marking this predicate as returning answer." ) res = True if not res: - print("determined the above predicate returns no result") + logging.info("determined the above predicate returns no result") # try to classify into one of the known values # first, we need to find out what is the value here - some heuristics here to find out column_name, value_res = _get_a_expr_field_value(node) @@ -926,7 +927,7 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr): else: raise ValueError() - print( + logging.info( "determined column name: {}; value: {}".format( column_name, value_res_clear ) @@ -1524,6 +1525,7 @@ def suql_execute( llm_model_name="gpt-3.5-turbo-0125", max_verify=20, loggings="", + log_filename=None, disable_try_catch=False, embedding_server_address="http://127.0.0.1:8501", select_username="select_user", @@ -1553,6 +1555,8 @@ def suql_execute( `loggings` (str, optional): Prefix for error case loggings. Errors are written to a "_suql_error_log.txt" file by default. + + `log_filename` (str, optional): Logging file name for the SUQL compiler. If not provided, logging is disabled. `disable_try_catch` (bool, optional): whether to disable try-catch (errors would directly propagate to caller). @@ -1589,6 +1593,18 @@ def suql_execute( Ideally, this query should match against all `Mcdonald's`, as opposed to just 'mcdonalds'. FTS helps with such cases. """ + if log_filename: + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[ + logging.FileHandler(log_filename), + logging.StreamHandler() + ]) + + else: + logging.basicConfig(level=logging.CRITICAL + 1) + results, column_names, cache = _suql_execute_single( suql, table_w_ids,