Skip to content

Commit

Permalink
DH-5776/fixing the bug with Azure OpenAI
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadrezaPourreza committed May 10, 2024
1 parent 2537ca5 commit 0a3f3c5
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def generate_response(
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
)
self.database = SQLDatabase.get_sql_engine(database_connection)
if self.llm.openai_api_type == "azure":
if self.system.settings["azure_api_key"] is not None:
embedding = AzureOpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
Expand Down Expand Up @@ -708,7 +708,7 @@ def stream_response(
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
)
self.database = SQLDatabase.get_sql_engine(database_connection)
if self.llm.openai_api_type == "azure":
if self.system.settings["azure_api_key"] is not None:
embedding = AzureOpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
Expand Down
4 changes: 2 additions & 2 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def generate_response(
logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}")
self.database = SQLDatabase.get_sql_engine(database_connection)
# Set Embeddings class depending on azure / not azure
if self.llm.openai_api_type == "azure":
if self.system.settings["azure_api_key"] is not None:
toolkit = SQLDatabaseToolkit(
db=self.database,
context=context,
Expand Down Expand Up @@ -874,7 +874,7 @@ def stream_response(
number_of_samples = 0
self.database = SQLDatabase.get_sql_engine(database_connection)
# Set Embeddings class depending on azure / not azure
if self.llm.openai_api_type == "azure":
if self.system.settings["azure_api_key"] is not None:
embedding = AzureOpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
Expand Down

0 comments on commit 0a3f3c5

Please sign in to comment.