Skip to content

Commit

Permalink
DH-5033/ the new endpoints for finetuning (#267)
Browse files Browse the repository at this point in the history
* DH-5033/ the new endpoints for finetuning

* DH-5033/new endpoints for finetuning

* DH-5033/ finalized llm finetuning with openai

* DH-5033/adding the documents for finetuning

* DH-5033/update the docs

* DH-5033/add alias
  • Loading branch information
MohammadrezaPourreza authored and DishenWang2023 committed May 7, 2024
1 parent fabbc14 commit 5cd4429
Show file tree
Hide file tree
Showing 11 changed files with 877 additions and 0 deletions.
19 changes: 19 additions & 0 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from dataherald.db_scanner.models.types import TableDescription
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
from dataherald.types import (
CancelFineTuningRequest,
CreateResponseRequest,
DatabaseConnectionRequest,
Finetuning,
FineTuningRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
Expand Down Expand Up @@ -167,3 +170,19 @@ def update_instruction(
instruction_request: UpdateInstruction,
) -> Instruction:
pass

@abstractmethod
def create_finetuning_job(
self, fine_tuning_request: FineTuningRequest, background_tasks: BackgroundTasks
) -> Finetuning:
pass

@abstractmethod
def cancel_finetuning_job(
self, cancel_fine_tuning_request: CancelFineTuningRequest
) -> Finetuning:
pass

@abstractmethod
def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning:
pass
100 changes: 100 additions & 0 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
TableDescriptionRepository,
)
from dataherald.eval import Evaluator
from dataherald.finetuning.openai_finetuning import OpenAIFineTuning
from dataherald.repositories.base import ResponseRepository
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.finetunings import FinetuningsRepository
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.repositories.question import QuestionRepository
Expand All @@ -39,8 +41,11 @@
from dataherald.sql_generator import SQLGenerator
from dataherald.sql_generator.generates_nl_answer import GeneratesNlAnswer
from dataherald.types import (
CancelFineTuningRequest,
CreateResponseRequest,
DatabaseConnectionRequest,
Finetuning,
FineTuningRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
Expand All @@ -52,6 +57,7 @@
TableDescriptionRequest,
UpdateInstruction,
)
from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES
from dataherald.utils.s3 import S3

logger = logging.getLogger(__name__)
Expand All @@ -68,6 +74,12 @@ def async_scanning(scanner, database, scanner_request, storage):
)


def async_fine_tuning(storage, model):
openai_fine_tuning = OpenAIFineTuning(storage, model)
openai_fine_tuning.create_fintuning_dataset()
openai_fine_tuning.create_fine_tuning_job()


def delete_file(file_location: str):
os.remove(file_location)

Expand Down Expand Up @@ -638,3 +650,91 @@ def update_instruction(
)
instruction_repository.update(updated_instruction)
return json.loads(json_util.dumps(updated_instruction))

@override
def create_finetuning_job(
self, fine_tuning_request: FineTuningRequest, background_tasks: BackgroundTasks
) -> Finetuning:
db_connection_repository = DatabaseConnectionRepository(self.storage)

db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
if not db_connection:
raise HTTPException(status_code=404, detail="Database connection not found")

golden_records_repository = GoldenRecordRepository(self.storage)
golden_records = []
if fine_tuning_request.golden_records:
for golden_record_id in fine_tuning_request.golden_records:
golden_record = golden_records_repository.find_by_id(golden_record_id)
if not golden_record:
raise HTTPException(
status_code=404, detail="Golden record not found"
)
golden_records.append(golden_record)
else:
golden_records = golden_records_repository.find_by(
{"db_connection_id": ObjectId(fine_tuning_request.db_connection_id)},
page=0,
limit=0,
)
if not golden_records:
raise HTTPException(status_code=404, detail="No golden records found")

if fine_tuning_request.base_llm.model_name not in OPENAI_CONTEXT_WIDNOW_SIZES:
raise HTTPException(
status_code=400,
detail=f"Model {fine_tuning_request.base_llm.model_name} not supported",
)

model_repository = FinetuningsRepository(self.storage)
model = model_repository.insert(
Finetuning(
db_connection_id=fine_tuning_request.db_connection_id,
alias=fine_tuning_request.alias,
base_llm=fine_tuning_request.base_llm,
golden_records=[
str(golden_record.id) for golden_record in golden_records
],
)
)

background_tasks.add_task(async_fine_tuning, self.storage, model)

return model

@override
def cancel_finetuning_job(
self, cancel_fine_tuning_request: CancelFineTuningRequest
) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(cancel_fine_tuning_request.finetuning_id)
if not model:
raise HTTPException(status_code=404, detail="Model not found")

if model.status == "succeeded":
raise HTTPException(
status_code=400, detail="Model has already succeeded. Cannot cancel."
)
if model.status == "failed":
raise HTTPException(
status_code=400, detail="Model has already failed. Cannot cancel."
)
if model.status == "cancelled":
raise HTTPException(
status_code=400, detail="Model has already been cancelled."
)

openai_fine_tuning = OpenAIFineTuning(self.storage, model)

return openai_fine_tuning.cancel_finetuning_job()

@override
def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(finetuning_job_id)
if not model:
raise HTTPException(status_code=404, detail="Model not found")
openai_fine_tuning = OpenAIFineTuning(self.storage, model)
return openai_fine_tuning.retrieve_finetuning_job()
29 changes: 29 additions & 0 deletions dataherald/finetuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod

from dataherald.config import Component
from dataherald.types import Finetuning


class FinetuningModel(Component, ABC):
def __init__(self, storage):
self.storage = storage

@abstractmethod
def count_tokens(self, messages: dict) -> int:
pass

@abstractmethod
def create_fintuning_dataset(self):
pass

@abstractmethod
def create_fine_tuning_job(self):
pass

@abstractmethod
def retrieve_finetuning_job(self) -> Finetuning:
pass

@abstractmethod
def cancel_finetuning_job(self) -> Finetuning:
pass
Loading

0 comments on commit 5cd4429

Please sign in to comment.