Skip to content

Commit

Permalink
Azure OpenAI deployments compatibility (#457)
Browse files Browse the repository at this point in the history
(non finetunning)

Co-authored-by: Julio Navarro <[email protected]>
  • Loading branch information
jmanuelnavarro and navarrojulio authored May 7, 2024
1 parent 514e498 commit 581b6ff
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 32 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,21 @@ UPPER_LIMIT_QUERY_RETURN_ROWS = 50
DH_ENGINE_TIMEOUT = 150
```

In case you want to use models deployed in Azure OpenAI, you must set the following variables:
```
AZURE_API_KEY = "xxxxx"
AZURE_OPENAI_API_KEY = "xxxxxx"
API_BASE = "azure_openai_endpoint"
AZURE_OPENAI_ENDPOINT = "azure_openai_endpoint"
AZURE_API_VERSION = "version of the API to use"
LLM_MODEL = "name_of_the_deployment"
```
In addition, an embedding model will be also used. There must be a deployment created with name "text-embedding-3-large".

The existence of AZURE_API_KEY as environment variable indicates Azure models must be used.

Remember to remove comments beside the environment variables.

While not strictly required, we also strongly suggest you change the MONGO username and password fields as well.

Follow the next commands to generate an ENCRYPT_KEY and paste it in the .env file like
Expand Down
4 changes: 4 additions & 0 deletions dataherald/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class Settings(BaseSettings):
encrypt_key: str = os.environ.get("ENCRYPT_KEY")
s3_aws_access_key_id: str | None = os.environ.get("S3_AWS_ACCESS_KEY_ID")
s3_aws_secret_access_key: str | None = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
#Needed for Azure OpenAI integration:
azure_api_key: str | None = os.environ.get("AZURE_API_KEY")
embedding_model: str | None = os.environ.get("EMBEDDING_MODEL")
azure_api_version: str | None = os.environ.get("AZURE_API_VERSION")
only_store_csv_files_locally: bool | None = os.environ.get(
"ONLY_STORE_CSV_FILES_LOCALLY", False
)
Expand Down
2 changes: 1 addition & 1 deletion dataherald/finetuning/openai_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, storage: Any, fine_tuning_model: Finetuning):
db_connection = db_connection_repository.find_by_id(
fine_tuning_model.db_connection_id
)
self.embedding = OpenAIEmbeddings(
self.embedding = OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
openai_api_key=db_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
)
Expand Down
7 changes: 7 additions & 0 deletions dataherald/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, system):
self.aleph_alpha_api_key = os.environ.get("ALEPH_ALPHA_API_KEY")
self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY")
self.cohere_api_key = os.environ.get("COHERE_API_KEY")
self.azure_api_key = os.environ.get("AZURE_API_KEY")

@override
def get_model(
Expand All @@ -26,6 +27,8 @@ def get_model(
api_base: str | None = None, # noqa: ARG002
**kwargs: Any
) -> Any:
if self.system.settings['azure_api_key'] != None:
model_family = 'azure'
if database_connection.llm_api_key is not None:
fernet_encrypt = FernetEncrypt()
api_key = fernet_encrypt.decrypt(database_connection.llm_api_key)
Expand All @@ -35,6 +38,8 @@ def get_model(
self.anthropic_api_key = api_key
elif model_family == "google":
self.google_api_key = api_key
elif model_family == "azure":
self.azure_api_key == api_key
if self.openai_api_key:
self.model = OpenAI(model_name=model_name, **kwargs)
elif self.aleph_alpha_api_key:
Expand All @@ -43,6 +48,8 @@ def get_model(
self.model = Anthropic(model=model_name, **kwargs)
elif self.cohere_api_key:
self.model = Cohere(model=model_name, **kwargs)
elif self.azure_api_key:
self.model = AzureOpenAI(model=model_name, **kwargs)
else:
raise ValueError("No valid API key environment variable found")
return self.model
14 changes: 13 additions & 1 deletion dataherald/model/chat_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any

from langchain_community.chat_models import ChatAnthropic, ChatCohere, ChatGooglePalm
from langchain_openai import ChatOpenAI
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from overrides import override

from dataherald.model import LLMModel
Expand All @@ -22,6 +22,18 @@ def get_model(
**kwargs: Any
) -> Any:
api_key = database_connection.decrypt_api_key()
if self.system.settings['azure_api_key'] != None:
model_family = 'azure'
if model_family == "azure":
if api_base.endswith("/"): #TODO check where final "/" is added to api_base
api_base = api_base[:-1]
return AzureChatOpenAI(
deployment_name=model_name,
openai_api_key=api_key,
azure_endpoint= api_base,
api_version=self.system.settings['azure_api_version'],
**kwargs
)
if model_family == "openai":
return ChatOpenAI(
model_name=model_name,
Expand Down
4 changes: 2 additions & 2 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def generate_response(
use_finetuned_model_only=self.use_fintuned_model_only,
model_name=finetuning.base_llm.model_name,
openai_fine_tuning=openai_fine_tuning,
embedding=OpenAIEmbeddings(
embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
Expand Down Expand Up @@ -710,7 +710,7 @@ def stream_response(
use_finetuned_model_only=self.use_fintuned_model_only,
model_name=finetuning.base_llm.model_name,
openai_fine_tuning=openai_fine_tuning,
embedding=OpenAIEmbeddings(
embedding=OpenAIEmbeddings( #TODO AzureOpenAIEmbeddings when Azure
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
Expand Down
82 changes: 56 additions & 26 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from langchain.chains.llm import LLMChain
from langchain.tools.base import BaseTool
from langchain_community.callbacks import get_openai_callback
from langchain_openai import OpenAIEmbeddings
from langchain_openai import OpenAIEmbeddings, AzureOpenAIEmbeddings
from overrides import override
from pydantic import BaseModel, Field
from sql_metadata import Parser
Expand Down Expand Up @@ -753,18 +753,33 @@ def generate_response(
number_of_samples = 0
logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}")
self.database = SQLDatabase.get_sql_engine(database_connection)
toolkit = SQLDatabaseToolkit(
db=self.database,
context=context,
few_shot_examples=new_fewshot_examples,
instructions=instructions,
is_multiple_schema=True if user_prompt.schemas else False,
db_scan=db_scan,
embedding=OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
)
#Set Embeddings class depending on azure / not azure
if self.llm.openai_api_type == "azure":
toolkit = SQLDatabaseToolkit(
db=self.database,
context=context,
few_shot_examples=new_fewshot_examples,
instructions=instructions,
is_multiple_schema=True if user_prompt.schemas else False,
db_scan=db_scan,
embedding=AzureOpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
)
else:
toolkit = SQLDatabaseToolkit(
db=self.database,
context=context,
few_shot_examples=new_fewshot_examples,
instructions=instructions,
is_multiple_schema=True if user_prompt.schemas else False,
db_scan=db_scan,
embedding=OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
)
agent_executor = self.create_sql_agent(
toolkit=toolkit,
verbose=True,
Expand Down Expand Up @@ -858,19 +873,34 @@ def stream_response(
new_fewshot_examples = None
number_of_samples = 0
self.database = SQLDatabase.get_sql_engine(database_connection)
toolkit = SQLDatabaseToolkit(
queuer=queue,
db=self.database,
context=[{}],
few_shot_examples=new_fewshot_examples,
instructions=instructions,
is_multiple_schema=True if user_prompt.schemas else False,
db_scan=db_scan,
embedding=OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
)
#Set Embeddings class depending on azure / not azure
if self.llm.openai_api_type == "azure":
toolkit = SQLDatabaseToolkit(
db=self.database,
context=context,
few_shot_examples=new_fewshot_examples,
instructions=instructions,
is_multiple_schema=True if user_prompt.schemas else False,
db_scan=db_scan,
embedding=AzureOpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
)
else:
toolkit = SQLDatabaseToolkit(
queuer=queue,
db=self.database,
context=[{}],
few_shot_examples=new_fewshot_examples,
instructions=instructions,
is_multiple_schema=True if user_prompt.schemas else False,
db_scan=db_scan,
embedding=OpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
),
)
agent_executor = self.create_sql_agent(
toolkit=toolkit,
verbose=True,
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ services:
ports:
- 27017:27017
volumes:
- ./initdb.d/:/docker-entrypoint-initdb.d/
- ./initdb.d/init-mongo.sh:/docker-entrypoint-initdb.d/init-mongo.sh:ro
- ./dbdata/mongo_data/data:/data/db/
- ./dbdata/mongo_data/db_config:/data/configdb/
environment:
Expand Down
2 changes: 1 addition & 1 deletion initdb.d/init-mongo.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
set -e
# set -e

mongosh <<EOF
use $MONGO_INITDB_DATABASE
Expand Down

0 comments on commit 581b6ff

Please sign in to comment.