Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DH-5033/ the new endpoints for finetuning #267

Merged
merged 6 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
DH-5033/ finalized llm finetuning with openai
  • Loading branch information
MohammadrezaPourreza committed Nov 30, 2023
commit d683a32de7adf2d968b04f305cf434a42fad65b7
25 changes: 17 additions & 8 deletions dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TableDescriptionRepository,
)
from dataherald.eval import Evaluator
from dataherald.finetuning.openai_finetuning import OpenAIFineTuning
from dataherald.repositories.base import ResponseRepository
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.finetunings import FinetuningsRepository
Expand Down Expand Up @@ -56,6 +57,7 @@
TableDescriptionRequest,
UpdateInstruction,
)
from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES
from dataherald.utils.s3 import S3

logger = logging.getLogger(__name__)
Expand All @@ -72,8 +74,10 @@ def async_scanning(scanner, database, scanner_request, storage):
)


def async_fine_tuning():
pass
def async_fine_tuning(storage, model):
openai_fine_tuning = OpenAIFineTuning(storage, model)
openai_fine_tuning.create_fintuning_dataset()
openai_fine_tuning.create_fine_tuning_job()


def delete_file(file_location: str):
Expand Down Expand Up @@ -678,6 +682,12 @@ def create_finetuning_job(
if not golden_records:
raise HTTPException(status_code=404, detail="No golden records found")

if fine_tuning_request.base_llm.model_name not in OPENAI_CONTEXT_WIDNOW_SIZES:
raise HTTPException(
status_code=400,
detail=f"Model {fine_tuning_request.base_llm.model_name} not supported",
)

model_repository = FinetuningsRepository(self.storage)
model = model_repository.insert(
Finetuning(
Expand All @@ -689,9 +699,7 @@ def create_finetuning_job(
)
)

background_tasks.add_task(
async_fine_tuning, fine_tuning_request, self.storage, golden_records
)
background_tasks.add_task(async_fine_tuning, self.storage, model)

return model

Expand All @@ -717,14 +725,15 @@ def cancel_finetuning_job(
status_code=400, detail="Model has already been cancelled."
)

# Todo: Add code to cancel the fine tuning job
openai_fine_tuning = OpenAIFineTuning(self.storage, model)

return model
return openai_fine_tuning.cancel_finetuning_job()

@override
def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(finetuning_job_id)
if not model:
raise HTTPException(status_code=404, detail="Model not found")
return model
openai_fine_tuning = OpenAIFineTuning(self.storage, model)
return openai_fine_tuning.retrieve_finetuning_job()
29 changes: 29 additions & 0 deletions dataherald/finetuning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod

from dataherald.config import Component
from dataherald.types import Finetuning


class FinetuningModel(Component, ABC):
def __init__(self, storage):
self.storage = storage

@abstractmethod
def count_tokens(self, messages: dict) -> int:
pass

@abstractmethod
def create_fintuning_dataset(self):
pass

@abstractmethod
def create_fine_tuning_job(self):
pass

@abstractmethod
def retrieve_finetuning_job(self) -> Finetuning:
pass

@abstractmethod
def cancel_finetuning_job(self) -> Finetuning:
pass
2 changes: 0 additions & 2 deletions dataherald/finetuning/__init__py

This file was deleted.

194 changes: 116 additions & 78 deletions dataherald/finetuning/openai_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,47 @@
import uuid
from typing import Any, List

import openai
import tiktoken
from bson.objectid import ObjectId
from openai import OpenAI
from overrides import override
from tiktoken import Encoding

from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus
from dataherald.db_scanner.repository.base import TableDescriptionRepository
from dataherald.finetuning import FinetuningModel
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.finetunings import FinetuningsRepository
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.types import Finetuning
from dataherald.utils.agent_prompts import FINETUNING_SYSTEM_INFORMATION
from dataherald.utils.models_context_window import OPENAI_CONTEXT_WIDNOW_SIZES

FILE_PROCESSING_ATTEMPTS = 20

logger = logging.getLogger(__name__)

class OpenAIFineTuning:
finetuning_dataset_path: str

def format_columns(self, table: TableDescription, top_k: int = 100) -> str:
class OpenAIFineTuning(FinetuningModel):
encoding: Encoding
fine_tuning_model: Finetuning
storage: Any
client: OpenAI

def __init__(self, storage: Any, fine_tuning_model: Finetuning):
self.storage = storage
self.fine_tuning_model = fine_tuning_model
db_connection_repository = DatabaseConnectionRepository(storage)
db_connection = db_connection_repository.find_by_id(
fine_tuning_model.db_connection_id
)
self.encoding = tiktoken.encoding_for_model(
fine_tuning_model.base_llm.model_name
)
self.client = OpenAI(api_key=db_connection.decrypt_api_key())

@classmethod
def format_columns(cls, table: TableDescription, top_k: int = 100) -> str:
"""
format_columns formats the columns.

Expand Down Expand Up @@ -65,14 +87,14 @@ def format_columns(self, table: TableDescription, top_k: int = 100) -> str:
)
return columns_information

@staticmethod
def format_dataset(self, db_scan: List[TableDescription]) -> str:
@classmethod
def format_dataset(cls, db_scan: List[TableDescription]) -> str:
schema_of_database = ""
for table in db_scan:
tables_schema = table.table_schema
schema_of_database += f"{tables_schema}\n"
schema_of_database += "# Categorical Columns:\n"
columns_information = self.format_columns(table)
columns_information = cls.format_columns(table)
schema_of_database += columns_information
sample_rows = table.examples
schema_of_database += "# Sample rows:\n"
Expand All @@ -83,107 +105,123 @@ def format_dataset(self, db_scan: List[TableDescription]) -> str:
schema_of_database += "\n\n"
return schema_of_database

@classmethod
def create_fintuning_dataset(cls, fine_tuning_request: Finetuning, storage: Any):
db_connection_id = fine_tuning_request.db_connection_id
repository = TableDescriptionRepository(storage)
@override
def count_tokens(self, messages: dict) -> int:
prompt = ""
for message in messages["messages"]:
prompt += message["content"]
return len(self.encoding.encode(prompt))

@override
def create_fintuning_dataset(self):
db_connection_id = self.fine_tuning_model.db_connection_id
repository = TableDescriptionRepository(self.storage)
db_scan = repository.get_all_tables_by_db(
{
"db_connection_id": ObjectId(db_connection_id),
"status": TableDescriptionStatus.SYNCHRONIZED.value,
}
)
golden_records_repository = GoldenRecordRepository(storage)
database_schema = cls.format_dataset(db_scan)
cls.finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl"
for golden_record_id in fine_tuning_request.golden_records:
golden_records_repository = GoldenRecordRepository(self.storage)
database_schema = self.format_dataset(db_scan)
finetuning_dataset_path = f"tmp/{str(uuid.uuid4())}.jsonl"
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(self.fine_tuning_model.id)
for golden_record_id in self.fine_tuning_model.golden_records:
golden_record = golden_records_repository.find_by_id(golden_record_id)
question = golden_record.question
query = golden_record.sql_query
system_prompt = FINETUNING_SYSTEM_INFORMATION + database_schema
user_prompt = "User Question: " + question + "\n SQL: "
assistant_prompt = query + "\n"
with open(cls.finetuning_dataset_path, "a") as outfile:
with open(finetuning_dataset_path, "a") as outfile:
messages = {
"messages": [
{"role": "system", "content": f"{system_prompt}"},
{"role": "user", "content": f"Question : {user_prompt}"},
{"role": "assistant", "content": f"{assistant_prompt}"},
]
}
number_of_tokens = self.count_tokens(messages)
if (
number_of_tokens
> OPENAI_CONTEXT_WIDNOW_SIZES[
self.fine_tuning_model.base_llm.model_name
]
):
model.status = "failed"
model.error = "The number of tokens in the prompt is too large"
model_repository.update(model)
os.remove(finetuning_dataset_path)
return
json.dump(messages, outfile)
outfile.write("\n")
db_connection_repository = DatabaseConnectionRepository(storage)
db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
openai.api_key = db_connection.decrypt_api_key()
model_repository = FinetuningsRepository(storage)
model = model_repository.find_by_id(fine_tuning_request.id)
model.finetuning_file_id = openai.File.create(file=open(cls.finetuning_dataset_path,purpose='fine-tune'))['id']
model.finetuning_file_id = self.client.files.create(
file=open(finetuning_dataset_path, "rb"), purpose="fine-tune"
).id
model_repository.update(model)
os.remove(cls.finetuning_dataset_path)

os.remove(finetuning_dataset_path)

@classmethod
def create_fine_tuning_job(cls, fine_tuning_request: Finetuning, storage: Any):
db_connection_repository = DatabaseConnectionRepository(storage)
db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
openai.api_key = db_connection.decrypt_api_key()
model_repository = FinetuningsRepository(storage)
model = model_repository.find_by_id(fine_tuning_request.id)
def check_file_status(self, file_id: str) -> bool:
retrieve_file_attempt = 0
while True:
if openai.File.retrieve(id=model.finetuning_file_id)["status"] == "processed":
break
file_info = self.client.files.retrieve(file_id=file_id)
if file_info.status == "processed":
return True
time.sleep(5)
retrieve_file_attempt += 1
if retrieve_file_attempt == FILE_PROCESSING_ATTEMPTS:
model.status = "failed"
model.error = "File processing failed"
model_repository.update(model)
return
finetuning_request = openai.FineTune.create(
training_file=model.finetuning_file_id,
model=model.base_llm.model_name,
hyperparameters= model.base_llm.model_parameters
return False

@override
def create_fine_tuning_job(self):
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(self.fine_tuning_model.id)
if self.check_file_status(model.finetuning_file_id):
finetuning_request = self.client.fine_tuning.jobs.create(
training_file=model.finetuning_file_id,
model=model.base_llm.model_name,
hyperparameters=model.base_llm.model_parameters
if model.base_llm.model_parameters
else {
"batch_size": 1,
"learning_rate_multiplier": "auto",
"n_epochs": 3,
},
)
model.finetuning_job_id = finetuning_request.id
if finetuning_request.status == "failed":
model.error = "Fine tuning failed before starting"
model.status = finetuning_request.status
model_repository.update(model)
else:
model.status = "failed"
model.error = "File processing failed"
model_repository.update(model)

@override
def retrieve_finetuning_job(self) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(self.fine_tuning_model.id)
finetuning_request = self.client.fine_tuning.jobs.retrieve(
fine_tuning_job_id=model.finetuning_job_id
)
model.finetuning_job_id = finetuning_request["id"]
if finetuning_request["status"] == "failed":
model.error = "Fine tuning failed before starting"
model.status = finetuning_request["status"]
if finetuning_request.status == "failed":
model.error = finetuning_request.error.message
model.status = finetuning_request.status
if finetuning_request.fine_tuned_model:
model.model_id = finetuning_request.fine_tuned_model
model_repository.update(model)

@classmethod
def retrieve_finetuning_job(cls, fine_tuning_request: Finetuning, storage: Any):
db_connection_repository = DatabaseConnectionRepository(storage)
db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
return model

@override
def cancel_finetuning_job(self) -> Finetuning:
model_repository = FinetuningsRepository(self.storage)
model = model_repository.find_by_id(self.fine_tuning_model.id)
finetuning_request = self.client.fine_tuning.jobs.cancel(
fine_tuning_job_id=model.finetuning_job_id
)
openai.api_key = db_connection.decrypt_api_key()
model_repository = FinetuningsRepository(storage)
model = model_repository.find_by_id(fine_tuning_request.id)
finetuning_request = openai.FineTune.retrieve(id=model.finetuning_job_id)
if finetuning_request["status"] == "failed":
model.error = "Fine tuning failed during processing by OpenAI"
model.status = finetuning_request["status"]
model_repository.update(model)

@classmethod
def cancel_finetuning_job(cls, fine_tuning_request: Finetuning, storage: Any):
db_connection_repository = DatabaseConnectionRepository(storage)
db_connection = db_connection_repository.find_by_id(
fine_tuning_request.db_connection_id
)
openai.api_key = db_connection.decrypt_api_key()
model_repository = FinetuningsRepository(storage)
model = model_repository.find_by_id(fine_tuning_request.id)
finetuning_request = openai.FineTune.cancel(id=model.finetuning_job_id)
model.status = finetuning_request["status"]
model.status = finetuning_request.status
model.error = "Fine tuning cancelled by the user"
model_repository.update(model)



return model
1 change: 1 addition & 0 deletions dataherald/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class FineTuningStatus(Enum):
SUCCEEDED = "succeeded"
FAILED = "failed"
CANCELLED = "cancelled"
VALIDATING_FILES = "validating_files"


class BaseLLM(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
dnspython==2.3.0
fastapi==0.98.0
httpx==0.24.1
langchain==0.0.312
langchain==0.0.335
load-dotenv==0.1.0
mypy-extensions==1.0.0
openai==0.27.8
openai==1.3.6
openapi-schema-pydantic==1.2.4
overrides==7.3.1
packaging==23.1
Expand Down