From 581b6ffed46005ab34d9992b8357acd867203f36 Mon Sep 17 00:00:00 2001 From: jmanuelnavarro Date: Tue, 7 May 2024 17:21:50 +0200 Subject: [PATCH] Azure OpenAI deployments compatibility (#457) (non finetunning) Co-authored-by: Julio Navarro --- README.md | 15 ++++ dataherald/config.py | 4 + dataherald/finetuning/openai_finetuning.py | 2 +- dataherald/model/base_model.py | 7 ++ dataherald/model/chat_model.py | 14 +++- .../dataherald_finetuning_agent.py | 4 +- .../sql_generator/dataherald_sqlagent.py | 82 +++++++++++++------ docker-compose.yml | 2 +- initdb.d/init-mongo.sh | 2 +- 9 files changed, 100 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 4287d802..49eedeed 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,21 @@ UPPER_LIMIT_QUERY_RETURN_ROWS = 50 DH_ENGINE_TIMEOUT = 150 ``` +In case you want to use models deployed in Azure OpenAI, you must set the following variables: +``` +AZURE_API_KEY = "xxxxx" +AZURE_OPENAI_API_KEY = "xxxxxx" +API_BASE = "azure_openai_endpoint" +AZURE_OPENAI_ENDPOINT = "azure_openai_endpoint" +AZURE_API_VERSION = "version of the API to use" +LLM_MODEL = "name_of_the_deployment" +``` +In addition, an embedding model will be also used. There must be a deployment created with name "text-embedding-3-large". + +The existence of AZURE_API_KEY as environment variable indicates Azure models must be used. + +Remember to remove comments beside the environment variables. + While not strictly required, we also strongly suggest you change the MONGO username and password fields as well. Follow the next commands to generate an ENCRYPT_KEY and paste it in the .env file like diff --git a/dataherald/config.py b/dataherald/config.py index ec9ff313..370947c1 100644 --- a/dataherald/config.py +++ b/dataherald/config.py @@ -45,6 +45,10 @@ class Settings(BaseSettings): encrypt_key: str = os.environ.get("ENCRYPT_KEY") s3_aws_access_key_id: str | None = os.environ.get("S3_AWS_ACCESS_KEY_ID") s3_aws_secret_access_key: str | None = os.environ.get("S3_AWS_SECRET_ACCESS_KEY") + #Needed for Azure OpenAI integration: + azure_api_key: str | None = os.environ.get("AZURE_API_KEY") + embedding_model: str | None = os.environ.get("EMBEDDING_MODEL") + azure_api_version: str | None = os.environ.get("AZURE_API_VERSION") only_store_csv_files_locally: bool | None = os.environ.get( "ONLY_STORE_CSV_FILES_LOCALLY", False ) diff --git a/dataherald/finetuning/openai_finetuning.py b/dataherald/finetuning/openai_finetuning.py index 95b0f10f..58fe89c1 100644 --- a/dataherald/finetuning/openai_finetuning.py +++ b/dataherald/finetuning/openai_finetuning.py @@ -43,7 +43,7 @@ def __init__(self, storage: Any, fine_tuning_model: Finetuning): db_connection = db_connection_repository.find_by_id( fine_tuning_model.db_connection_id ) - self.embedding = OpenAIEmbeddings( + self.embedding = OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure openai_api_key=db_connection.decrypt_api_key(), model=EMBEDDING_MODEL, ) diff --git a/dataherald/model/base_model.py b/dataherald/model/base_model.py index c30bf9e6..6b7fb23a 100644 --- a/dataherald/model/base_model.py +++ b/dataherald/model/base_model.py @@ -16,6 +16,7 @@ def __init__(self, system): self.aleph_alpha_api_key = os.environ.get("ALEPH_ALPHA_API_KEY") self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") self.cohere_api_key = os.environ.get("COHERE_API_KEY") + self.azure_api_key = os.environ.get("AZURE_API_KEY") @override def get_model( @@ -26,6 +27,8 @@ def get_model( api_base: str | None = None, # noqa: ARG002 **kwargs: Any ) -> Any: + if self.system.settings['azure_api_key'] != None: + model_family = 'azure' if database_connection.llm_api_key is not None: fernet_encrypt = FernetEncrypt() api_key = fernet_encrypt.decrypt(database_connection.llm_api_key) @@ -35,6 +38,8 @@ def get_model( self.anthropic_api_key = api_key elif model_family == "google": self.google_api_key = api_key + elif model_family == "azure": + self.azure_api_key == api_key if self.openai_api_key: self.model = OpenAI(model_name=model_name, **kwargs) elif self.aleph_alpha_api_key: @@ -43,6 +48,8 @@ def get_model( self.model = Anthropic(model=model_name, **kwargs) elif self.cohere_api_key: self.model = Cohere(model=model_name, **kwargs) + elif self.azure_api_key: + self.model = AzureOpenAI(model=model_name, **kwargs) else: raise ValueError("No valid API key environment variable found") return self.model diff --git a/dataherald/model/chat_model.py b/dataherald/model/chat_model.py index 7ed2a393..3ab1fcce 100644 --- a/dataherald/model/chat_model.py +++ b/dataherald/model/chat_model.py @@ -1,7 +1,7 @@ from typing import Any from langchain_community.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm -from langchain_openai import ChatOpenAI +from langchain_openai import ChatOpenAI, AzureChatOpenAI from overrides import override from dataherald.model import LLMModel @@ -22,6 +22,18 @@ def get_model( **kwargs: Any ) -> Any: api_key = database_connection.decrypt_api_key() + if self.system.settings['azure_api_key'] != None: + model_family = 'azure' + if model_family == "azure": + if api_base.endswith("/"): #TODO check where final "/" is added to api_base + api_base = api_base[:-1] + return AzureChatOpenAI( + deployment_name=model_name, + openai_api_key=api_key, + azure_endpoint= api_base, + api_version=self.system.settings['azure_api_version'], + **kwargs + ) if model_family == "openai": return ChatOpenAI( model_name=model_name, diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index 52c929fe..6fc64b95 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -605,7 +605,7 @@ def generate_response( use_finetuned_model_only=self.use_fintuned_model_only, model_name=finetuning.base_llm.model_name, openai_fine_tuning=openai_fine_tuning, - embedding=OpenAIEmbeddings( + embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL, ), @@ -710,7 +710,7 @@ def stream_response( use_finetuned_model_only=self.use_fintuned_model_only, model_name=finetuning.base_llm.model_name, openai_fine_tuning=openai_fine_tuning, - embedding=OpenAIEmbeddings( + embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL, ), diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 8c62330f..a9635ff5 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -22,7 +22,7 @@ from langchain.chains.llm import LLMChain from langchain.tools.base import BaseTool from langchain_community.callbacks import get_openai_callback -from langchain_openai import OpenAIEmbeddings +from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings from overrides import override from pydantic import BaseModel, Field from sql_metadata import Parser @@ -753,18 +753,33 @@ def generate_response( number_of_samples = 0 logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}") self.database = SQLDatabase.get_sql_engine(database_connection) - toolkit = SQLDatabaseToolkit( - db=self.database, - context=context, - few_shot_examples=new_fewshot_examples, - instructions=instructions, - is_multiple_schema=True if user_prompt.schemas else False, - db_scan=db_scan, - embedding=OpenAIEmbeddings( - openai_api_key=database_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ), - ) + #Set Embeddings class depending on azure / not azure + if self.llm.openai_api_type == "azure": + toolkit = SQLDatabaseToolkit( + db=self.database, + context=context, + few_shot_examples=new_fewshot_examples, + instructions=instructions, + is_multiple_schema=True if user_prompt.schemas else False, + db_scan=db_scan, + embedding=AzureOpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ), + ) + else: + toolkit = SQLDatabaseToolkit( + db=self.database, + context=context, + few_shot_examples=new_fewshot_examples, + instructions=instructions, + is_multiple_schema=True if user_prompt.schemas else False, + db_scan=db_scan, + embedding=OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ), + ) agent_executor = self.create_sql_agent( toolkit=toolkit, verbose=True, @@ -858,19 +873,34 @@ def stream_response( new_fewshot_examples = None number_of_samples = 0 self.database = SQLDatabase.get_sql_engine(database_connection) - toolkit = SQLDatabaseToolkit( - queuer=queue, - db=self.database, - context=[{}], - few_shot_examples=new_fewshot_examples, - instructions=instructions, - is_multiple_schema=True if user_prompt.schemas else False, - db_scan=db_scan, - embedding=OpenAIEmbeddings( - openai_api_key=database_connection.decrypt_api_key(), - model=EMBEDDING_MODEL, - ), - ) + #Set Embeddings class depending on azure / not azure + if self.llm.openai_api_type == "azure": + toolkit = SQLDatabaseToolkit( + db=self.database, + context=context, + few_shot_examples=new_fewshot_examples, + instructions=instructions, + is_multiple_schema=True if user_prompt.schemas else False, + db_scan=db_scan, + embedding=AzureOpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ), + ) + else: + toolkit = SQLDatabaseToolkit( + queuer=queue, + db=self.database, + context=[{}], + few_shot_examples=new_fewshot_examples, + instructions=instructions, + is_multiple_schema=True if user_prompt.schemas else False, + db_scan=db_scan, + embedding=OpenAIEmbeddings( + openai_api_key=database_connection.decrypt_api_key(), + model=EMBEDDING_MODEL, + ), + ) agent_executor = self.create_sql_agent( toolkit=toolkit, verbose=True, diff --git a/docker-compose.yml b/docker-compose.yml index af64a2e6..a2c19c20 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,7 +24,7 @@ services: ports: - 27017:27017 volumes: - - ./initdb.d/:/docker-entrypoint-initdb.d/ + - ./initdb.d/init-mongo.sh:/docker-entrypoint-initdb.d/init-mongo.sh:ro - ./dbdata/mongo_data/data:/data/db/ - ./dbdata/mongo_data/db_config:/data/configdb/ environment: diff --git a/initdb.d/init-mongo.sh b/initdb.d/init-mongo.sh index d6124d21..3ae8092b 100644 --- a/initdb.d/init-mongo.sh +++ b/initdb.d/init-mongo.sh @@ -1,4 +1,4 @@ -set -e +# set -e mongosh <