diff --git a/services/engine/dataherald/api/fastapi.py b/services/engine/dataherald/api/fastapi.py index 23e3976c..e9edbd57 100644 --- a/services/engine/dataherald/api/fastapi.py +++ b/services/engine/dataherald/api/fastapi.py @@ -110,8 +110,8 @@ def async_scanning(scanner, database, table_descriptions, storage): ) -def async_fine_tuning(storage, model): - openai_fine_tuning = OpenAIFineTuning(storage, model) +def async_fine_tuning(system, storage, model): + openai_fine_tuning = OpenAIFineTuning(system, storage, model) openai_fine_tuning.create_fintuning_dataset() openai_fine_tuning.create_fine_tuning_job() @@ -626,7 +626,7 @@ def create_finetuning_job( e, fine_tuning_request.dict(), "finetuning_not_created" ) - background_tasks.add_task(async_fine_tuning, self.storage, model) + background_tasks.add_task(async_fine_tuning, self.system, self.storage, model) return model @@ -652,7 +652,7 @@ def cancel_finetuning_job( status_code=400, detail="Model has already been cancelled." ) - openai_fine_tuning = OpenAIFineTuning(self.storage, model) + openai_fine_tuning = OpenAIFineTuning(self.system, self.storage, model) return openai_fine_tuning.cancel_finetuning_job() @@ -665,7 +665,7 @@ def get_finetunings(self, db_connection_id: str | None = None) -> list[Finetunin models = model_repository.find_by(query) result = [] for model in models: - openai_fine_tuning = OpenAIFineTuning(self.storage, model) + openai_fine_tuning = OpenAIFineTuning(self.system, self.storage, model) result.append( Finetuning(**openai_fine_tuning.retrieve_finetuning_job().dict()) ) @@ -685,7 +685,7 @@ def get_finetuning_job(self, finetuning_job_id: str) -> Finetuning: model = model_repository.find_by_id(finetuning_job_id) if not model: raise HTTPException(status_code=404, detail="Model not found") - openai_fine_tuning = OpenAIFineTuning(self.storage, model) + openai_fine_tuning = OpenAIFineTuning(self.system, self.storage, model) return openai_fine_tuning.retrieve_finetuning_job() @override diff --git a/services/engine/dataherald/finetuning/openai_finetuning.py b/services/engine/dataherald/finetuning/openai_finetuning.py index 58fe89c1..2876d2c8 100644 --- a/services/engine/dataherald/finetuning/openai_finetuning.py +++ b/services/engine/dataherald/finetuning/openai_finetuning.py @@ -7,12 +7,13 @@ import numpy as np import tiktoken -from langchain_openai import OpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from openai import OpenAI from overrides import override from sql_metadata import Parser from tiktoken import Encoding +from dataherald.config import System from dataherald.db_scanner.models.types import TableDescription, TableDescriptionStatus from dataherald.db_scanner.repository.base import TableDescriptionRepository from dataherald.finetuning import FinetuningModel @@ -36,17 +37,24 @@ class OpenAIFineTuning(FinetuningModel): storage: Any client: OpenAI - def __init__(self, storage: Any, fine_tuning_model: Finetuning): + def __init__(self, system: System, storage: Any, fine_tuning_model: Finetuning): self.storage = storage + self.system = system self.fine_tuning_model = fine_tuning_model db_connection_repository = DatabaseConnectionRepository(storage) db_connection = db_connection_repository.find_by_id( fine_tuning_model.db_connection_id ) - self.embedding = OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure - openai_api_key=db_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ) + if self.system.settings["azure_api_key"] is not None: + self.embedding = AzureOpenAIEmbeddings( + azure_api_key=db_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) + else: + self.embedding = OpenAIEmbeddings( + openai_api_key=db_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) self.encoding = tiktoken.encoding_for_model( fine_tuning_model.base_llm.model_name ) diff --git a/services/engine/dataherald/model/base_model.py b/services/engine/dataherald/model/base_model.py index 6b7fb23a..398655a6 100644 --- a/services/engine/dataherald/model/base_model.py +++ b/services/engine/dataherald/model/base_model.py @@ -1,7 +1,7 @@ import os from typing import Any -from langchain.llms import AlephAlpha, Anthropic, Cohere, OpenAI +from langchain.llms import AlephAlpha, Anthropic, AzureOpenAI, Cohere, OpenAI from overrides import override from dataherald.model import LLMModel @@ -19,7 +19,7 @@ def __init__(self, system): self.azure_api_key = os.environ.get("AZURE_API_KEY") @override - def get_model( + def get_model( # noqa: C901 self, database_connection: DatabaseConnection, model_family="openai", @@ -27,8 +27,8 @@ def get_model( api_base: str | None = None, # noqa: ARG002 **kwargs: Any ) -> Any: - if self.system.settings['azure_api_key'] != None: - model_family = 'azure' + if self.system.settings["azure_api_key"] is not None: + model_family = "azure" if database_connection.llm_api_key is not None: fernet_encrypt = FernetEncrypt() api_key = fernet_encrypt.decrypt(database_connection.llm_api_key) @@ -39,7 +39,7 @@ def get_model( elif model_family == "google": self.google_api_key = api_key elif model_family == "azure": - self.azure_api_key == api_key + self.azure_api_key = api_key if self.openai_api_key: self.model = OpenAI(model_name=model_name, **kwargs) elif self.aleph_alpha_api_key: diff --git a/services/engine/dataherald/model/chat_model.py b/services/engine/dataherald/model/chat_model.py index 3ab1fcce..4c7d57f9 100644 --- a/services/engine/dataherald/model/chat_model.py +++ b/services/engine/dataherald/model/chat_model.py @@ -1,7 +1,7 @@ from typing import Any from langchain_community.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm -from langchain_openai import ChatOpenAI, AzureChatOpenAI +from langchain_openai import AzureChatOpenAI, ChatOpenAI from overrides import override from dataherald.model import LLMModel @@ -22,16 +22,16 @@ def get_model( **kwargs: Any ) -> Any: api_key = database_connection.decrypt_api_key() - if self.system.settings['azure_api_key'] != None: - model_family = 'azure' + if self.system.settings["azure_api_key"] is not None: + model_family = "azure" if model_family == "azure": - if api_base.endswith("/"): #TODO check where final "/" is added to api_base + if api_base.endswith("/"): # check where final "/" is added to api_base api_base = api_base[:-1] return AzureChatOpenAI( deployment_name=model_name, openai_api_key=api_key, - azure_endpoint= api_base, - api_version=self.system.settings['azure_api_version'], + azure_endpoint=api_base, + api_version=self.system.settings["azure_api_version"], **kwargs ) if model_family == "openai": diff --git a/services/engine/dataherald/sql_generator/__init__.py b/services/engine/dataherald/sql_generator/__init__.py index e997920e..6612332b 100644 --- a/services/engine/dataherald/sql_generator/__init__.py +++ b/services/engine/dataherald/sql_generator/__init__.py @@ -179,7 +179,7 @@ def generate_response( """Generates a response to a user question.""" pass - def stream_agent_steps( # noqa: C901 + def stream_agent_steps( # noqa: PLR0912, C901 self, question: str, agent_executor: AgentExecutor, @@ -187,7 +187,7 @@ def stream_agent_steps( # noqa: C901 sql_generation_repository: SQLGenerationRepository, queue: Queue, metadata: dict = None, - ): + ): # noqa: PLR0912 try: with get_openai_callback() as cb: for chunk in agent_executor.stream( diff --git a/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py b/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py index 6fc64b95..fe54dcf4 100644 --- a/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/services/engine/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -21,7 +21,7 @@ from langchain.chains.llm import LLMChain from langchain.tools.base import BaseTool from langchain_community.callbacks import get_openai_callback -from langchain_openai import OpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from openai import OpenAI from overrides import override from pydantic import BaseModel, Field @@ -587,7 +587,7 @@ def generate_response( ) finetunings_repository = FinetuningsRepository(storage) finetuning = finetunings_repository.find_by_id(self.finetuning_id) - openai_fine_tuning = OpenAIFineTuning(storage, finetuning) + openai_fine_tuning = OpenAIFineTuning(self.system, storage, finetuning) finetuning = openai_fine_tuning.retrieve_finetuning_job() if finetuning.status != FineTuningStatus.SUCCEEDED.value: raise FinetuningNotAvailableError( @@ -595,6 +595,16 @@ def generate_response( f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries." ) self.database = SQLDatabase.get_sql_engine(database_connection) + if self.system.settings["azure_api_key"] is not None: + embedding = AzureOpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) + else: + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) toolkit = SQLDatabaseToolkit( db=self.database, instructions=instructions, @@ -605,10 +615,7 @@ def generate_response( use_finetuned_model_only=self.use_fintuned_model_only, model_name=finetuning.base_llm.model_name, openai_fine_tuning=openai_fine_tuning, - embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure - openai_api_key=database_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ), + embedding=embedding, ) agent_executor = self.create_sql_agent( toolkit=toolkit, @@ -693,7 +700,7 @@ def stream_response( ) finetunings_repository = FinetuningsRepository(storage) finetuning = finetunings_repository.find_by_id(self.finetuning_id) - openai_fine_tuning = OpenAIFineTuning(storage, finetuning) + openai_fine_tuning = OpenAIFineTuning(self.system, storage, finetuning) finetuning = openai_fine_tuning.retrieve_finetuning_job() if finetuning.status != FineTuningStatus.SUCCEEDED.value: raise FinetuningNotAvailableError( @@ -701,6 +708,16 @@ def stream_response( f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries." ) self.database = SQLDatabase.get_sql_engine(database_connection) + if self.system.settings["azure_api_key"] is not None: + embedding = AzureOpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) + else: + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) toolkit = SQLDatabaseToolkit( db=self.database, instructions=instructions, @@ -710,10 +727,7 @@ def stream_response( use_finetuned_model_only=self.use_fintuned_model_only, model_name=finetuning.base_llm.model_name, openai_fine_tuning=openai_fine_tuning, - embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure - openai_api_key=database_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ), + embedding=embedding, ) agent_executor = self.create_sql_agent( toolkit=toolkit, diff --git a/services/engine/dataherald/sql_generator/dataherald_sqlagent.py b/services/engine/dataherald/sql_generator/dataherald_sqlagent.py index a9635ff5..414ab089 100644 --- a/services/engine/dataherald/sql_generator/dataherald_sqlagent.py +++ b/services/engine/dataherald/sql_generator/dataherald_sqlagent.py @@ -22,7 +22,7 @@ from langchain.chains.llm import LLMChain from langchain.tools.base import BaseTool from langchain_community.callbacks import get_openai_callback -from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings +from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from overrides import override from pydantic import BaseModel, Field from sql_metadata import Parser @@ -710,13 +710,13 @@ def create_sql_agent( ) @override - def generate_response( + def generate_response( # noqa: PLR0912 self, user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, metadata: dict = None, - ) -> SQLGeneration: + ) -> SQLGeneration: # noqa: PLR0912 context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) response = SQLGeneration( @@ -753,8 +753,8 @@ def generate_response( number_of_samples = 0 logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}") self.database = SQLDatabase.get_sql_engine(database_connection) - #Set Embeddings class depending on azure / not azure - if self.llm.openai_api_type == "azure": + # Set Embeddings class depending on azure / not azure + if self.system.settings["azure_api_key"] is not None: toolkit = SQLDatabaseToolkit( db=self.database, context=context, @@ -873,21 +873,17 @@ def stream_response( new_fewshot_examples = None number_of_samples = 0 self.database = SQLDatabase.get_sql_engine(database_connection) - #Set Embeddings class depending on azure / not azure - if self.llm.openai_api_type == "azure": - toolkit = SQLDatabaseToolkit( - db=self.database, - context=context, - few_shot_examples=new_fewshot_examples, - instructions=instructions, - is_multiple_schema=True if user_prompt.schemas else False, - db_scan=db_scan, - embedding=AzureOpenAIEmbeddings( - openai_api_key=database_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ), + # Set Embeddings class depending on azure / not azure + if self.system.settings["azure_api_key"] is not None: + embedding = AzureOpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ) + else: + embedding = OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, ) - else: toolkit = SQLDatabaseToolkit( queuer=queue, db=self.database, @@ -896,10 +892,7 @@ def stream_response( instructions=instructions, is_multiple_schema=True if user_prompt.schemas else False, db_scan=db_scan, - embedding=OpenAIEmbeddings( - openai_api_key=database_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ), + embedding=embedding, ) agent_executor = self.create_sql_agent( toolkit=toolkit,