Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DH-5033/ the new endpoints for finetuning #267

Merged
merged 6 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from dataherald.db_scanner.models.types import TableDescription
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
from dataherald.types import (
CancelFineTuningRequest,
CreateResponseRequest,
DatabaseConnectionRequest,
Finetuning,
FineTuningRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
Expand Down Expand Up @@ -167,3 +170,19 @@ def update_instruction(
instruction_request: UpdateInstruction,
) -> Instruction:
pass

@abstractmethod
def create_finetuning_job(
self, fine_tuning_request: FineTuningRequest, background_tasks: BackgroundTasks
) -> Finetuning:
pass

@abstractmethod
def cancel_finetuning_job(
self, cancel_fine_tuning_request: CancelFineTuningRequest
) -> Finetuning:
pass

@abstractmethod
def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning:
pass
100 changes: 100 additions & 0 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
TableDescriptionRepository,
)
from dataherald.eval import Evaluator
from dataherald.finetuning.openai_finetuning import OpenAIFineTuning
from dataherald.repositories.base import ResponseRepository
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.finetunings import FinetuningsRepository
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.repositories.question import QuestionRepository
Expand All @@ -39,8 +41,11 @@
from dataherald.sql_generator import SQLGenerator
from dataherald.sql_generator.generates_nl_answer import GeneratesNlAnswer
from dataherald.types import (
CancelFineTuningRequest,
CreateResponseRequest,
DatabaseConnectionRequest,
Finetuning,
FineTuningRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
Expand All @@ -52,6 +57,7 @@
TableDescriptionRequest,
UpdateInstruction,
)
from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES
from dataherald.utils.s3 import S3

logger = logging.getLogger(__name__)
Expand All @@ -68,6 +74,12 @@ def async_scanning(scanner, database, scanner_request, storage):
)


def async_fine_tuning(storage, model):
openai_fine_tuning = OpenAIFineTuning(storage, model)
openai_fine_tuning.create_fintuning_dataset()
openai_fine_tuning.create_fine_tuning_job()


def delete_file(file_location: str):
os.remove(file_location)

Expand Down Expand Up @@ -638,3 +650,91 @@ def update_instruction(
)
instruction_repository.update(updated_instruction)
return json.loads(json_util.dumps(updated_instruction))

@override
def create_finetuning_job(
self, fine_tuning_request: FineTuningRequest, background_tasks: BackgroundTasks
) -> Finetuning:
db_connection_repository = DatabaseConnectionRepository(self.storage)

db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
if not db_connection:
raise HTTPException(status_code=404, detail="Database connection not found")

golden_records_repository = GoldenRecordRepository(self.storage)
golden_records = []
if fine_tuning_request.golden_records:
for golden_record_id in fine_tuning_request.golden_records:
golden_record = golden_records_repository.find_by_id(golden_record_id)
if not golden_record:
raise HTTPException(
status_code=404, detail="Golden record not found"
)
golden_records.append(golden_record)
else:
golden_records = golden_records_repository.find_by(
{"db_connection_id": ObjectId(fine_tuning_request.db_connection_id)},
page=0,
limit=0,
)
if not golden_records:
raise HTTPException(status_code=404, detail="No golden records found")

if fine_tuning_request.base_llm.model_name not in OPENAI_CONTEXT_WIDNOW_SIZES:
raise HTTPException(
status_code=400,
detail=f"Model {fine_tuning_request.base_llm.model_name} not supported",
)

model_repository = FinetuningsRepository(self.storage)
model = model_repository.insert(
Finetuning(
db_connection_id=fine_tuning_request.db_connection_id,
alias=fine_tuning_request.alias,
base_llm=fine_tuning_request.base_llm,
golden_records=[
str(golden_record.id) for golden_record in golden_records
],
)
)

background_tasks.add_task(async_fine_tuning, self.storage, model)

return model

@override
def cancel_finetuning_job(
self, cancel_fine_tuning_request: CancelFineTuningRequest
) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(cancel_fine_tuning_request.finetuning_id)
if not model:
raise HTTPException(status_code=404, detail="Model not found")

if model.status == "succeeded":
raise HTTPException(
status_code=400, detail="Model has already succeeded. Cannot cancel."
)
if model.status == "failed":
raise HTTPException(
status_code=400, detail="Model has already failed. Cannot cancel."
)
if model.status == "cancelled":
raise HTTPException(
status_code=400, detail="Model has already been cancelled."
)

openai_fine_tuning = OpenAIFineTuning(self.storage, model)

return openai_fine_tuning.cancel_finetuning_job()

@override
def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(finetuning_job_id)
if not model:
raise HTTPException(status_code=404, detail="Model not found")
openai_fine_tuning = OpenAIFineTuning(self.storage, model)
return openai_fine_tuning.retrieve_finetuning_job()
29 changes: 29 additions & 0 deletions dataherald/finetuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod

from dataherald.config import Component
from dataherald.types import Finetuning


class FinetuningModel(Component, ABC):
def __init__(self, storage):
self.storage = storage

@abstractmethod
def count_tokens(self, messages: dict) -> int:
pass

@abstractmethod
def create_fintuning_dataset(self):
pass

@abstractmethod
def create_fine_tuning_job(self):
pass

@abstractmethod
def retrieve_finetuning_job(self) -> Finetuning:
pass

@abstractmethod
def cancel_finetuning_job(self) -> Finetuning:
pass
Loading