Skip to content

Commit

Permalink
DH-5765/add support multiple schema for finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Apr 26, 2024
1 parent b4f57c4 commit 2b63aa8
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 3 deletions.
11 changes: 9 additions & 2 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@
)
from dataherald.utils.encrypt import FernetEncrypt
from dataherald.utils.error_codes import error_response, stream_error_response
from dataherald.utils.sql_utils import (
filter_golden_records_based_on_schema,
validate_finetuning_schema,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -564,15 +568,14 @@ def create_finetuning_job(
) -> Finetuning:
try:
db_connection_repository = DatabaseConnectionRepository(self.storage)

db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
if not db_connection:
raise DatabaseConnectionNotFoundError(
f"Database connection not found, {fine_tuning_request.db_connection_id}"
)

validate_finetuning_schema(fine_tuning_request, db_connection)
golden_sqls_repository = GoldenSQLRepository(self.storage)
golden_sqls = []
if fine_tuning_request.golden_sqls:
Expand All @@ -593,6 +596,9 @@ def create_finetuning_job(
raise GoldenSQLNotFoundError(
f"No golden sqls found for db_connection: {fine_tuning_request.db_connection_id}"
)
golden_sqls = filter_golden_records_based_on_schema(
golden_sqls, fine_tuning_request.schemas
)
default_base_llm = BaseLLM(
model_provider="openai",
model_name="gpt-3.5-turbo-1106",
Expand All @@ -606,6 +612,7 @@ def create_finetuning_job(
model = model_repository.insert(
Finetuning(
db_connection_id=fine_tuning_request.db_connection_id,
schemas=fine_tuning_request.schemas,
alias=fine_tuning_request.alias
if fine_tuning_request.alias
else f"{db_connection.alias}_{datetime.datetime.now().strftime('%Y%m%d%H')}",
Expand Down
7 changes: 7 additions & 0 deletions dataherald/finetuning/openai_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def map_finetuning_status(status: str) -> str:
return FineTuningStatus.QUEUED.value
return mapped_statuses[status]

@staticmethod
def _filter_tables_by_schema(db_scan: List[TableDescription], schemas: List[str]):
if schemas:
return [table for table in db_scan if table.schema_name in schemas]
return db_scan

def format_columns(
self, table: TableDescription, top_k: int = CATEGORICAL_COLUMNS_THRESHOLD
) -> str:
Expand Down Expand Up @@ -197,6 +203,7 @@ def create_fintuning_dataset(self):
"status": TableDescriptionStatus.SCANNED.value,
}
)
db_scan = self._filter_tables_by_schema(db_scan, self.fine_tuning_model.schemas)
golden_sqls_repository = GoldenSQLRepository(self.storage)
finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl"
model_repository = FinetuningsRepository(self.storage)
Expand Down
2 changes: 2 additions & 0 deletions dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Finetuning(BaseModel):
id: str | None = None
alias: str | None = None
db_connection_id: str | None = None
schemas: list[str] | None
status: str = "QUEUED"
error: str | None = None
base_llm: BaseLLM | None = None
Expand All @@ -163,6 +164,7 @@ class Finetuning(BaseModel):

class FineTuningRequest(BaseModel):
db_connection_id: str
schemas: list[str] | None
alias: str | None = None
base_llm: BaseLLM | None = None
golden_sqls: list[str] | None = None
Expand Down
38 changes: 37 additions & 1 deletion dataherald/utils/sql_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,47 @@
from sql_metadata import Parser

from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_database.services.database_connection import SchemaNotSupportedError
from dataherald.types import FineTuningRequest, GoldenSQL

def extract_the_schemas_from_sql(sql):

def extract_the_schemas_from_sql(sql: str) -> list[str]:
table_names = Parser(sql).tables
schemas = []
for table_name in table_names:
if "." in table_name:
schema = table_name.split(".")[0]
schemas.append(schema.strip())
return schemas


def filter_golden_records_based_on_schema(
golden_sqls: list[GoldenSQL], schemas: list[str]
) -> list[GoldenSQL]:
filtered_records = []
if not schemas:
return golden_sqls
for record in golden_sqls:
used_schemas = extract_the_schemas_from_sql(record.sql)
for schema in schemas:
if schema in used_schemas:
filtered_records.append(record)
break
return filtered_records


def validate_finetuning_schema(
finetuning_request: FineTuningRequest, db_connection: DatabaseConnection
):
if finetuning_request.schemas:
if not db_connection.schemas:
raise SchemaNotSupportedError(
"Schema not supported for this db",
description=f"The {db_connection.id} db doesn't have schemas",
)
for schema in finetuning_request.schemas:
if schema not in db_connection.schemas:
raise SchemaNotSupportedError(
f"Schema {schema} not supported for this db",
description=f"The {db_connection.dialect} dialect doesn't support schema {schema}",
)

0 comments on commit 2b63aa8

Please sign in to comment.