diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 82931f8d..23e3976c 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -91,6 +91,10 @@ ) from dataherald.utils.encrypt import FernetEncrypt from dataherald.utils.error_codes import error_response, stream_error_response +from dataherald.utils.sql_utils import ( + filter_golden_records_based_on_schema, + validate_finetuning_schema, +) logger = logging.getLogger(__name__) @@ -564,7 +568,6 @@ def create_finetuning_job( ) -> Finetuning: try: db_connection_repository = DatabaseConnectionRepository(self.storage) - db_connection = db_connection_repository.find_by_id( fine_tuning_request.db_connection_id ) @@ -572,7 +575,7 @@ def create_finetuning_job( raise DatabaseConnectionNotFoundError( f"Database connection not found, {fine_tuning_request.db_connection_id}" ) - + validate_finetuning_schema(fine_tuning_request, db_connection) golden_sqls_repository = GoldenSQLRepository(self.storage) golden_sqls = [] if fine_tuning_request.golden_sqls: @@ -593,6 +596,9 @@ def create_finetuning_job( raise GoldenSQLNotFoundError( f"No golden sqls found for db_connection: {fine_tuning_request.db_connection_id}" ) + golden_sqls = filter_golden_records_based_on_schema( + golden_sqls, fine_tuning_request.schemas + ) default_base_llm = BaseLLM( model_provider="openai", model_name="gpt-3.5-turbo-1106", @@ -606,6 +612,7 @@ def create_finetuning_job( model = model_repository.insert( Finetuning( db_connection_id=fine_tuning_request.db_connection_id, + schemas=fine_tuning_request.schemas, alias=fine_tuning_request.alias if fine_tuning_request.alias else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}", diff --git a/dataherald/finetuning/openai_finetuning.py b/dataherald/finetuning/openai_finetuning.py index 5c553579..95b0f10f 100644 --- a/dataherald/finetuning/openai_finetuning.py +++ b/dataherald/finetuning/openai_finetuning.py @@ -69,6 +69,12 @@ def map_finetuning_status(status: str) -> str: return FineTuningStatus.QUEUED.value return mapped_statuses[status] + @staticmethod + def _filter_tables_by_schema(db_scan: List[TableDescription], schemas: List[str]): + if schemas: + return [table for table in db_scan if table.schema_name in schemas] + return db_scan + def format_columns( self, table: TableDescription, top_k: int = CATEGORICAL_COLUMNS_THRESHOLD ) -> str: @@ -197,6 +203,7 @@ def create_fintuning_dataset(self): "status": TableDescriptionStatus.SCANNED.value, } ) + db_scan = self._filter_tables_by_schema(db_scan, self.fine_tuning_model.schemas) golden_sqls_repository = GoldenSQLRepository(self.storage) finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl" model_repository = FinetuningsRepository(self.storage) diff --git a/dataherald/types.py b/dataherald/types.py index 8977fbca..de113e72 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -150,6 +150,7 @@ class Finetuning(BaseModel): id: str | None = None alias: str | None = None db_connection_id: str | None = None + schemas: list[str] | None status: str = "QUEUED" error: str | None = None base_llm: BaseLLM | None = None @@ -163,6 +164,7 @@ class Finetuning(BaseModel): class FineTuningRequest(BaseModel): db_connection_id: str + schemas: list[str] | None alias: str | None = None base_llm: BaseLLM | None = None golden_sqls: list[str] | None = None diff --git a/dataherald/utils/sql_utils.py b/dataherald/utils/sql_utils.py index 8f3e3a80..a1ab3680 100644 --- a/dataherald/utils/sql_utils.py +++ b/dataherald/utils/sql_utils.py @@ -1,7 +1,11 @@ from sql_metadata import Parser +from dataherald.sql_database.models.types import DatabaseConnection +from dataherald.sql_database.services.database_connection import SchemaNotSupportedError +from dataherald.types import FineTuningRequest, GoldenSQL -def extract_the_schemas_from_sql(sql): + +def extract_the_schemas_from_sql(sql: str) -> list[str]: table_names = Parser(sql).tables schemas = [] for table_name in table_names: @@ -9,3 +13,35 @@ def extract_the_schemas_from_sql(sql): schema = table_name.split(".")[0] schemas.append(schema.strip()) return schemas + + +def filter_golden_records_based_on_schema( + golden_sqls: list[GoldenSQL], schemas: list[str] +) -> list[GoldenSQL]: + filtered_records = [] + if not schemas: + return golden_sqls + for record in golden_sqls: + used_schemas = extract_the_schemas_from_sql(record.sql) + for schema in schemas: + if schema in used_schemas: + filtered_records.append(record) + break + return filtered_records + + +def validate_finetuning_schema( + finetuning_request: FineTuningRequest, db_connection: DatabaseConnection +): + if finetuning_request.schemas: + if not db_connection.schemas: + raise SchemaNotSupportedError( + "Schema not supported for this db", + description=f"The {db_connection.id} db doesn't have schemas", + ) + for schema in finetuning_request.schemas: + if schema not in db_connection.schemas: + raise SchemaNotSupportedError( + f"Schema {schema} not supported for this db", + description=f"The {db_connection.dialect} dialect doesn't support schema {schema}", + )