Skip to content

Commit

Permalink
DH-5033/new endpoints for finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed Nov 29, 2023
1 parent 1c607ba commit 9922110
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 8 deletions.
19 changes: 19 additions & 0 deletions dataherald/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from dataherald.db_scanner.models.types import TableDescription
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
from dataherald.types import (
CancelFineTuningRequest,
CreateResponseRequest,
DatabaseConnectionRequest,
Finetuning,
FineTuningRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
Expand Down Expand Up @@ -167,3 +170,19 @@ def update_instruction(
instruction_request: UpdateInstruction,
) -> Instruction:
pass

@abstractmethod
def create_finetuning_job(
self, fine_tuning_request: FineTuningRequest, background_tasks: BackgroundTasks
) -> Finetuning:
pass

@abstractmethod
def cancel_finetuning_job(
self, cancel_fine_tuning_request: CancelFineTuningRequest
) -> Finetuning:
pass

@abstractmethod
def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning:
pass
90 changes: 90 additions & 0 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dataherald.eval import Evaluator
from dataherald.repositories.base import ResponseRepository
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.finetunings import FinetuningsRepository
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.instructions import InstructionRepository
from dataherald.repositories.question import QuestionRepository
Expand All @@ -39,8 +40,11 @@
from dataherald.sql_generator import SQLGenerator
from dataherald.sql_generator.generates_nl_answer import GeneratesNlAnswer
from dataherald.types import (
CancelFineTuningRequest,
CreateResponseRequest,
DatabaseConnectionRequest,
Finetuning,
FineTuningRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
Expand Down Expand Up @@ -68,6 +72,10 @@ def async_scanning(scanner, database, scanner_request, storage):
)


def async_fine_tuning():
pass


def delete_file(file_location: str):
os.remove(file_location)

Expand Down Expand Up @@ -638,3 +646,85 @@ def update_instruction(
)
instruction_repository.update(updated_instruction)
return json.loads(json_util.dumps(updated_instruction))

@override
def create_finetuning_job(
self, fine_tuning_request: FineTuningRequest, background_tasks: BackgroundTasks
) -> Finetuning:
db_connection_repository = DatabaseConnectionRepository(self.storage)

db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
if not db_connection:
raise HTTPException(status_code=404, detail="Database connection not found")

golden_records_repository = GoldenRecordRepository(self.storage)
golden_records = []
if fine_tuning_request.golden_records:
for golden_record_id in fine_tuning_request.golden_records:
golden_record = golden_records_repository.find_by_id(golden_record_id)
if not golden_record:
raise HTTPException(
status_code=404, detail="Golden record not found"
)
golden_records.append(golden_record)
else:
golden_records = golden_records_repository.find_by(
{"db_connection_id": ObjectId(fine_tuning_request.db_connection_id)},
page=0,
limit=0,
)
if not golden_records:
raise HTTPException(status_code=404, detail="No golden records found")

model_repository = FinetuningsRepository(self.storage)
model = model_repository.insert(
Finetuning(
db_connection_id=fine_tuning_request.db_connection_id,
base_llm=fine_tuning_request.base_llm,
golden_records=[
str(golden_record.id) for golden_record in golden_records
],
)
)

background_tasks.add_task(
async_fine_tuning, fine_tuning_request, self.storage, golden_records
)

return model

@override
def cancel_finetuning_job(
self, cancel_fine_tuning_request: CancelFineTuningRequest
) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(cancel_fine_tuning_request.finetuning_id)
if not model:
raise HTTPException(status_code=404, detail="Model not found")

if model.status == "succeeded":
raise HTTPException(
status_code=400, detail="Model has already succeeded. Cannot cancel."
)
if model.status == "failed":
raise HTTPException(
status_code=400, detail="Model has already failed. Cannot cancel."
)
if model.status == "cancelled":
raise HTTPException(
status_code=400, detail="Model has already been cancelled."
)

# Todo: Add code to cancel the fine tuning job

return model

@override
def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(finetuning_job_id)
if not model:
raise HTTPException(status_code=404, detail="Model not found")
return model
2 changes: 2 additions & 0 deletions dataherald/finetuning/__init__py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class Finetuning:
pass
17 changes: 9 additions & 8 deletions dataherald/finetuning/openai_finetuning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import logging
import os
import time
import uuid
Expand All @@ -17,6 +18,7 @@

FILE_PROCESSING_ATTEMPTS = 20

logger = logging.getLogger(__name__)

class OpenAIFineTuning:
finetuning_dataset_path: str
Expand Down Expand Up @@ -118,12 +120,11 @@ def create_fintuning_dataset(cls, fine_tuning_request: Finetuning, storage: Any)
openai.api_key = db_connection.decrypt_api_key()
model_repository = FinetuningsRepository(storage)
model = model_repository.find_by_id(fine_tuning_request.id)
model.finetuning_file_id = openai.File.create(
file=open(cls.finetuning_dataset_path, purpose="fine-tune")
)["id"]
model.finetuning_file_id = openai.File.create(file=open(cls.finetuning_dataset_path,purpose='fine-tune'))['id']
model_repository.update(model)
os.remove(cls.finetuning_dataset_path)


@classmethod
def create_fine_tuning_job(cls, fine_tuning_request: Finetuning, storage: Any):
db_connection_repository = DatabaseConnectionRepository(storage)
Expand All @@ -135,10 +136,7 @@ def create_fine_tuning_job(cls, fine_tuning_request: Finetuning, storage: Any):
model = model_repository.find_by_id(fine_tuning_request.id)
retrieve_file_attempt = 0
while True:
if (
openai.File.retrieve(id=model.finetuning_file_id)["status"]
== "processed"
):
if openai.File.retrieve(id=model.finetuning_file_id)["status"] == "processed":
break
time.sleep(5)
retrieve_file_attempt += 1
Expand All @@ -150,7 +148,7 @@ def create_fine_tuning_job(cls, fine_tuning_request: Finetuning, storage: Any):
finetuning_request = openai.FineTune.create(
training_file=model.finetuning_file_id,
model=model.base_llm.model_name,
hyperparameters=model.base_llm.model_parameters,
hyperparameters= model.base_llm.model_parameters
)
model.finetuning_job_id = finetuning_request["id"]
if finetuning_request["status"] == "failed":
Expand Down Expand Up @@ -186,3 +184,6 @@ def cancel_finetuning_job(cls, fine_tuning_request: Finetuning, storage: Any):
model.status = finetuning_request["status"]
model.error = "Fine tuning cancelled by the user"
model_repository.update(model)



61 changes: 61 additions & 0 deletions dataherald/repositories/finetunings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from bson.objectid import ObjectId

from dataherald.types import Finetuning

DB_COLLECTION = "finetunings"


class FinetuningsRepository:
def __init__(self, storage):
self.storage = storage

def insert(self, model: Finetuning) -> Finetuning:
model.id = str(
self.storage.insert_one(DB_COLLECTION, model.dict(exclude={"id"}))
)
return model

def find_one(self, query: dict) -> Finetuning | None:
row = self.storage.find_one(DB_COLLECTION, query)
if not row:
return None
obj = Finetuning(**row)
obj.id = str(row["_id"])
return obj

def update(self, model: Finetuning) -> Finetuning:
self.storage.update_or_create(
DB_COLLECTION,
{"_id": ObjectId(model.id)},
model.dict(exclude={"id"}),
)
return model

def find_by_id(self, id: str) -> Finetuning | None:
row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)})
if not row:
return None
obj = Finetuning(**row)
obj.id = str(row["_id"])
return obj

def find_by(self, query: dict, page: int = 1, limit: int = 10) -> list[Finetuning]:
rows = self.storage.find(DB_COLLECTION, query, page=page, limit=limit)
result = []
for row in rows:
obj = Finetuning(**row)
obj.id = str(row["_id"])
result.append(obj)
return result

def find_all(self, page: int = 0, limit: int = 0) -> list[Finetuning]:
rows = self.storage.find_all(DB_COLLECTION, page=page, limit=limit)
result = []
for row in rows:
obj = Finetuning(**row)
obj.id = str(row["_id"])
result.append(obj)
return result

def delete_by_id(self, id: str) -> int:
return self.storage.delete_by_id(DB_COLLECTION, id)
41 changes: 41 additions & 0 deletions dataherald/server/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
from dataherald.db_scanner.models.types import TableDescription
from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings
from dataherald.types import (
CancelFineTuningRequest,
CreateResponseRequest,
DatabaseConnectionRequest,
Finetuning,
FineTuningRequest,
GoldenRecord,
GoldenRecordRequest,
Instruction,
Expand Down Expand Up @@ -215,6 +218,28 @@ def __init__(self, settings: Settings):
tags=["Instructions"],
)

self.router.add_api_route(
"/api/v1/finetunings",
self.create_finetuning_job,
methods=["POST"],
status_code=201,
tags=["Finetunings"],
)

self.router.add_api_route(
"/api/v1/finetunings/{finetuning_id}",
self.get_finetuning_job,
methods=["GET"],
tags=["Finetunings"],
)

self.router.add_api_route(
"/api/v1/finetunings/{finetuning_id}/cancel",
self.cancel_finetuning_job,
methods=["POST"],
tags=["Finetunings"],
)

self.router.add_api_route(
"/api/v1/heartbeat", self.heartbeat, methods=["GET"], tags=["System"]
)
Expand Down Expand Up @@ -379,3 +404,19 @@ def update_instruction(
) -> Instruction:
"""Updates an instruction"""
return self._api.update_instruction(instruction_id, instruction_request)

def create_finetuning_job(
self, fine_tuning_request: FineTuningRequest, background_tasks: BackgroundTasks
) -> Finetuning:
"""Creates a fine tuning job"""
return self._api.create_finetuning_job(fine_tuning_request, background_tasks)

def cancel_finetuning_job(
self, cancel_fine_tuning_request: CancelFineTuningRequest
) -> Finetuning:
"""Cancels a fine tuning job"""
return self._api.cancel_finetuning_job(cancel_fine_tuning_request)

def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning:
"""Gets fine tuning jobs"""
return self._api.get_finetuning_job(finetuning_job_id)
40 changes: 40 additions & 0 deletions dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,43 @@ class ColumnDescriptionRequest(BaseModel):
class TableDescriptionRequest(BaseModel):
description: str | None
columns: list[ColumnDescriptionRequest] | None


class FineTuningStatus(Enum):
QUEUED = "queued"
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
CANCELLED = "cancelled"


class BaseLLM(BaseModel):
model_provider: str | None = None
model_name: str | None = None
model_parameters: dict[str, str] | None = None


class Finetuning(BaseModel):
id: str | None = None
db_connection_id: str | None = None
status: str = "queued"
error: str | None = None
base_llm: BaseLLM | None = None
finetuning_file_id: str | None = None
finetuning_job_id: str | None = None
model_id: str | None = None
created_at: datetime = Field(default_factory=datetime.now)
golden_records: list[str] | None = None
metadata: dict[str, str] | None = None


class FineTuningRequest(BaseModel):
db_connection_id: str
base_llm: BaseLLM
golden_records: list[str] | None = None
metadata: dict[str, str] | None = None


class CancelFineTuningRequest(BaseModel):
finetuning_id: str
metadata: dict[str, str] | None = None

0 comments on commit 9922110

Please sign in to comment.