From 4ba0a2ff240c99bcdac281c43ff229b4b7a40735 Mon Sep 17 00:00:00 2001 From: mohammadrezapourreza Date: Fri, 17 May 2024 12:48:10 -0400 Subject: [PATCH] reformat with black --- services/engine/dataherald/config.py | 2 +- .../dataherald/services/sql_generations.py | 4 ++-- .../exceptions/exception_handlers.py | 1 - services/enterprise/exceptions/exceptions.py | 1 - .../modules/db_connection/controller.py | 1 - .../modules/db_connection/service.py | 10 ++++++---- .../modules/organization/invoice/service.py | 19 ++++++++++--------- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/services/engine/dataherald/config.py b/services/engine/dataherald/config.py index 370947c1..9a43e334 100644 --- a/services/engine/dataherald/config.py +++ b/services/engine/dataherald/config.py @@ -45,7 +45,7 @@ class Settings(BaseSettings): encrypt_key: str = os.environ.get("ENCRYPT_KEY") s3_aws_access_key_id: str | None = os.environ.get("S3_AWS_ACCESS_KEY_ID") s3_aws_secret_access_key: str | None = os.environ.get("S3_AWS_SECRET_ACCESS_KEY") - #Needed for Azure OpenAI integration: + # Needed for Azure OpenAI integration: azure_api_key: str | None = os.environ.get("AZURE_API_KEY") embedding_model: str | None = os.environ.get("EMBEDDING_MODEL") azure_api_version: str | None = os.environ.get("AZURE_API_VERSION") diff --git a/services/engine/dataherald/services/sql_generations.py b/services/engine/dataherald/services/sql_generations.py index 2d9a8121..413101ca 100644 --- a/services/engine/dataherald/services/sql_generations.py +++ b/services/engine/dataherald/services/sql_generations.py @@ -63,9 +63,9 @@ def update_the_initial_sql_generation( initial_sql_generation.intermediate_steps = sql_generation.intermediate_steps return self.sql_generation_repository.update(initial_sql_generation) - def create( # noqa: PLR0912 + def create( # noqa: PLR0912 self, prompt_id: str, sql_generation_request: SQLGenerationRequest - ) -> SQLGeneration: # noqa: PLR0912 + ) -> SQLGeneration: # noqa: PLR0912 initial_sql_generation = SQLGeneration( prompt_id=prompt_id, created_at=datetime.now(), diff --git a/services/enterprise/exceptions/exception_handlers.py b/services/enterprise/exceptions/exception_handlers.py index 658e5b1f..cf599533 100644 --- a/services/enterprise/exceptions/exception_handlers.py +++ b/services/enterprise/exceptions/exception_handlers.py @@ -13,7 +13,6 @@ async def exception_handler(request: Request, exc: BaseError): # noqa: ARG001 - trace_id = exc.trace_id error_code = exc.error_code status_code = exc.status_code diff --git a/services/enterprise/exceptions/exceptions.py b/services/enterprise/exceptions/exceptions.py index 32480d67..a1ad997f 100644 --- a/services/enterprise/exceptions/exceptions.py +++ b/services/enterprise/exceptions/exceptions.py @@ -39,7 +39,6 @@ def __init__( description: str = None, detail: dict = None, ) -> None: - if type(self) is BaseError: raise TypeError("BaseError class may not be instantiated directly") diff --git a/services/enterprise/modules/db_connection/controller.py b/services/enterprise/modules/db_connection/controller.py index d104182f..750e4e40 100644 --- a/services/enterprise/modules/db_connection/controller.py +++ b/services/enterprise/modules/db_connection/controller.py @@ -95,7 +95,6 @@ async def ac_get_db_connection( id: ObjectIdString, user: User = Security(authenticate_user), ) -> DBConnectionResponse: - return db_connection_service.get_db_connection(id, user.organization_id) diff --git a/services/enterprise/modules/db_connection/service.py b/services/enterprise/modules/db_connection/service.py index 3c482675..4c924dc9 100644 --- a/services/enterprise/modules/db_connection/service.py +++ b/services/enterprise/modules/db_connection/service.py @@ -84,7 +84,9 @@ async def add_db_connection( ) if organization.llm_api_key: - db_connection_internal_request.llm_api_key = FernetEncrypt().decrypt(organization.llm_api_key) + db_connection_internal_request.llm_api_key = FernetEncrypt().decrypt( + organization.llm_api_key + ) if db_connection_request.use_ssh: db_connection_internal_request.ssh_settings.private_key_password = ( @@ -140,8 +142,9 @@ async def update_db_connection( ) if organization.llm_api_key: - db_connection_internal_request.llm_api_key = FernetEncrypt().decrypt(organization.llm_api_key) - + db_connection_internal_request.llm_api_key = FernetEncrypt().decrypt( + organization.llm_api_key + ) if db_connection_request.use_ssh: db_connection_internal_request.ssh_settings.private_key_password = ( @@ -163,7 +166,6 @@ async def update_db_connection( async def add_sample_db_connection( self, sample_request: SampleDBRequest, org_id: str ) -> DBConnectionResponse: - sample_db_dict = await self.sample_db.add_sample_db( sample_request.sample_db_id, org_id ) diff --git a/services/enterprise/modules/organization/invoice/service.py b/services/enterprise/modules/organization/invoice/service.py index 6cdf2b27..8b8bb6b1 100644 --- a/services/enterprise/modules/organization/invoice/service.py +++ b/services/enterprise/modules/organization/invoice/service.py @@ -84,12 +84,12 @@ def update_spending_limit( raise CannotUpdateSpendingLimitError(org_id) def get_pending_invoice(self, org_id: str) -> InvoiceResponse: - organization = self.org_repo.get_organization(org_id) - current_period_start, current_period_end = ( - self.billing.get_current_subscription_period_with_anchor( - organization.invoice_details.billing_cycle_anchor - ) + ( + current_period_start, + current_period_end, + ) = self.billing.get_current_subscription_period_with_anchor( + organization.invoice_details.billing_cycle_anchor ) upcoming_invoice = self.billing.get_upcoming_invoice( organization.invoice_details.stripe_customer_id @@ -304,10 +304,11 @@ def check_usage( ): raise SubscriptionCanceledError(org_id) raise UnknownSubscriptionStatusError(org_id) - start_date, end_date = ( - self.billing.get_current_subscription_period_with_anchor( - organization.invoice_details.billing_cycle_anchor - ) + ( + start_date, + end_date, + ) = self.billing.get_current_subscription_period_with_anchor( + organization.invoice_details.billing_cycle_anchor ) usages = self.repo.get_usages(org_id, start_date, end_date) usage = Usage(