diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index df0f40bd..34a1d8e8 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -13,6 +13,7 @@ from dataherald.repositories.golden_sqls import GoldenSQLRepository from dataherald.repositories.instructions import InstructionRepository from dataherald.types import GoldenSQL, GoldenSQLRequest, Prompt +from dataherald.utils.sql_utils import extract_the_schemas_from_sql logger = logging.getLogger(__name__) @@ -86,6 +87,18 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL f"Database connection not found, {record.db_connection_id}" ) + if db_connection.schemas: + schema_not_found = True + used_schemas = extract_the_schemas_from_sql(record.sql) + for schema in db_connection.schemas: + if schema in used_schemas: + schema_not_found = False + break + if schema_not_found: + raise MalformedGoldenSQLError( + f"SQL {record.sql} does not contain any of the schemas {db_connection.schemas}" + ) + prompt_text = record.prompt_text golden_sql = GoldenSQL( prompt_text=prompt_text, diff --git a/dataherald/utils/sql_utils.py b/dataherald/utils/sql_utils.py new file mode 100644 index 00000000..8f3e3a80 --- /dev/null +++ b/dataherald/utils/sql_utils.py @@ -0,0 +1,11 @@ +from sql_metadata import Parser + + +def extract_the_schemas_from_sql(sql): + table_names = Parser(sql).tables + schemas = [] + for table_name in table_names: + if "." in table_name: + schema = table_name.split(".")[0] + schemas.append(schema.strip()) + return schemas