diff --git a/dataherald/services/sql_generations.py b/dataherald/services/sql_generations.py index b1890443..413101ca 100644 --- a/dataherald/services/sql_generations.py +++ b/dataherald/services/sql_generations.py @@ -63,9 +63,9 @@ def update_the_initial_sql_generation( initial_sql_generation.intermediate_steps = sql_generation.intermediate_steps return self.sql_generation_repository.update(initial_sql_generation) - def create( + def create( # noqa: PLR0912 self, prompt_id: str, sql_generation_request: SQLGenerationRequest - ) -> SQLGeneration: + ) -> SQLGeneration: # noqa: PLR0912 initial_sql_generation = SQLGeneration( prompt_id=prompt_id, created_at=datetime.now(), diff --git a/dataherald/sql_generator/__init__.py b/dataherald/sql_generator/__init__.py index e997920e..6612332b 100644 --- a/dataherald/sql_generator/__init__.py +++ b/dataherald/sql_generator/__init__.py @@ -179,7 +179,7 @@ def generate_response( """Generates a response to a user question.""" pass - def stream_agent_steps( # noqa: C901 + def stream_agent_steps( # noqa: PLR0912, C901 self, question: str, agent_executor: AgentExecutor, @@ -187,7 +187,7 @@ def stream_agent_steps( # noqa: C901 sql_generation_repository: SQLGenerationRepository, queue: Queue, metadata: dict = None, - ): + ): # noqa: PLR0912 try: with get_openai_callback() as cb: for chunk in agent_executor.stream( diff --git a/dataherald/sql_generator/dataherald_finetuning_agent.py b/dataherald/sql_generator/dataherald_finetuning_agent.py index 536a4e17..fe54dcf4 100644 --- a/dataherald/sql_generator/dataherald_finetuning_agent.py +++ b/dataherald/sql_generator/dataherald_finetuning_agent.py @@ -595,7 +595,7 @@ def generate_response( f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries." ) self.database = SQLDatabase.get_sql_engine(database_connection) - if self.llm.openai_api_type == "azure": + if self.system.settings["azure_api_key"] is not None: embedding = AzureOpenAIEmbeddings( openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL, @@ -708,7 +708,7 @@ def stream_response( f"Finetuning should have the status {FineTuningStatus.SUCCEEDED.value} to generate SQL queries." ) self.database = SQLDatabase.get_sql_engine(database_connection) - if self.llm.openai_api_type == "azure": + if self.system.settings["azure_api_key"] is not None: embedding = AzureOpenAIEmbeddings( openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL, diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 0c898f34..414ab089 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -710,13 +710,13 @@ def create_sql_agent( ) @override - def generate_response( + def generate_response( # noqa: PLR0912 self, user_prompt: Prompt, database_connection: DatabaseConnection, context: List[dict] = None, metadata: dict = None, - ) -> SQLGeneration: + ) -> SQLGeneration: # noqa: PLR0912 context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) response = SQLGeneration( @@ -754,7 +754,7 @@ def generate_response( logger.info(f"Generating SQL response to question: {str(user_prompt.dict())}") self.database = SQLDatabase.get_sql_engine(database_connection) # Set Embeddings class depending on azure / not azure - if self.llm.openai_api_type == "azure": + if self.system.settings["azure_api_key"] is not None: toolkit = SQLDatabaseToolkit( db=self.database, context=context, @@ -874,7 +874,7 @@ def stream_response( number_of_samples = 0 self.database = SQLDatabase.get_sql_engine(database_connection) # Set Embeddings class depending on azure / not azure - if self.llm.openai_api_type == "azure": + if self.system.settings["azure_api_key"] is not None: embedding = AzureOpenAIEmbeddings( openai_api_key=database_connection.decrypt_api_key(), model=EMBEDDING_MODEL,