Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DH-5776/fixing the bug with Azure OpenAI #481

Merged
merged 4 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dataherald/services/sql_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def update_the_initial_sql_generation(
initial_sql_generation.intermediate_steps = sql_generation.intermediate_steps
return self.sql_generation_repository.update(initial_sql_generation)

def create(
def create( # noqa: PLR0912
self, prompt_id: str, sql_generation_request: SQLGenerationRequest
) -> SQLGeneration:
) -> SQLGeneration: # noqa: PLR0912
initial_sql_generation = SQLGeneration(
prompt_id=prompt_id,
created_at=datetime.now(),
Expand Down
4 changes: 2 additions & 2 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,15 @@ def generate_response(
"""Generates a response to a user question."""
pass

def stream_agent_steps( # noqa: C901
def stream_agent_steps( # noqa: PLR0912, C901
self,
question: str,
agent_executor: AgentExecutor,
response: SQLGeneration,
sql_generation_repository: SQLGenerationRepository,
queue: Queue,
metadata: dict = None,
):
): # noqa: PLR0912
try:
with get_openai_callback() as cb:
for chunk in agent_executor.stream(
Expand Down
4 changes: 2 additions & 2 deletions dataherald/sql_generator/dataherald_finetuning_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def generate_response(
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
)
self.database = SQLDatabase.get_sql_engine(database_connection)
if self.llm.openai_api_type == "azure":
if self.system.settings["azure_api_key"] is not None:
embedding = AzureOpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
Expand Down Expand Up @@ -708,7 +708,7 @@ def stream_response(
f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries."
)
self.database = SQLDatabase.get_sql_engine(database_connection)
if self.llm.openai_api_type == "azure":
if self.system.settings["azure_api_key"] is not None:
embedding = AzureOpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
Expand Down
8 changes: 4 additions & 4 deletions dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,13 +710,13 @@ def create_sql_agent(
)

@override
def generate_response(
def generate_response( # noqa: PLR0912
self,
user_prompt: Prompt,
database_connection: DatabaseConnection,
context: List[dict] = None,
metadata: dict = None,
) -> SQLGeneration:
) -> SQLGeneration: # noqa: PLR0912
context_store = self.system.instance(ContextStore)
storage = self.system.instance(DB)
response = SQLGeneration(
Expand Down Expand Up @@ -754,7 +754,7 @@ def generate_response(
logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}")
self.database = SQLDatabase.get_sql_engine(database_connection)
# Set Embeddings class depending on azure / not azure
if self.llm.openai_api_type == "azure":
if self.system.settings["azure_api_key"] is not None:
toolkit = SQLDatabaseToolkit(
db=self.database,
context=context,
Expand Down Expand Up @@ -874,7 +874,7 @@ def stream_response(
number_of_samples = 0
self.database = SQLDatabase.get_sql_engine(database_connection)
# Set Embeddings class depending on azure / not azure
if self.llm.openai_api_type == "azure":
if self.system.settings["azure_api_key"] is not None:
embedding = AzureOpenAIEmbeddings(
openai_api_key=database_connection.decrypt_api_key(),
model=EMBEDDING_MODEL,
Expand Down
Loading