Skip to content

Commit

Permalink
DH-5766/adding the validation to raise exception for queries without …
Browse files Browse the repository at this point in the history
…schema in multiple schema setting
  • Loading branch information
MohammadrezaPourreza committed Apr 26, 2024
1 parent b24d9c8 commit b4f57c4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
13 changes: 13 additions & 0 deletions dataherald/context_store/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dataherald.repositories.golden_sqls import GoldenSQLRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.types import GoldenSQL, GoldenSQLRequest, Prompt
from dataherald.utils.sql_utils import extract_the_schemas_from_sql

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,6 +87,18 @@ def add_golden_sqls(self, golden_sqls: List[GoldenSQLRequest]) -> List[GoldenSQL
f"Database connection not found, {record.db_connection_id}"
)

if db_connection.schemas:
schema_not_found = True
used_schemas = extract_the_schemas_from_sql(record.sql)
for schema in db_connection.schemas:
if schema in used_schemas:
schema_not_found = False
break
if schema_not_found:
raise MalformedGoldenSQLError(
f"SQL {record.sql} does not contain any of the schemas {db_connection.schemas}"
)

prompt_text = record.prompt_text
golden_sql = GoldenSQL(
prompt_text=prompt_text,
Expand Down
11 changes: 11 additions & 0 deletions dataherald/utils/sql_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from sql_metadata import Parser


def extract_the_schemas_from_sql(sql):
table_names = Parser(sql).tables
schemas = []
for table_name in table_names:
if "." in table_name:
schema = table_name.split(".")[0]
schemas.append(schema.strip())
return schemas

0 comments on commit b4f57c4

Please sign in to comment.