diff --git a/services/enterprise/.env.example b/services/enterprise/.env.example index 2564c286..9d9e5ff7 100644 --- a/services/enterprise/.env.example +++ b/services/enterprise/.env.example @@ -22,9 +22,8 @@ DEFAULT_ENGINE_TIMEOUT=120 S3_AWS_ACCESS_KEY_ID= S3_AWS_SECRET_ACCESS_KEY= - -# Optional posthog analytics if disabled POSTHOG_DISABLED=True +# Optional posthog analytics if disabled POSTHOG_API_KEY= POSTHOG_HOST= @@ -34,7 +33,8 @@ POSTHOG_HOST= SSH_PRIVATE_KEY_PASSWORD= SSH_PATH_TO_CREDENTIAL_FILE= -# Optional stripe env vars if organizations are only on ENTERPRISE plan (organization.invoice_details.plan = "ENTERPRISE") +STRIPE_DISABLED=True +# Optional stripe env vars if STRIPE_DISABLED is set to False # Otherwise you would need to create a new stripe account and fillout the env vars below STRIPE_API_KEY= STRIPE_WEBHOOK_SECRET= diff --git a/services/enterprise/config.py b/services/enterprise/config.py index 85133ae7..4493d8d2 100644 --- a/services/enterprise/config.py +++ b/services/enterprise/config.py @@ -102,6 +102,7 @@ def __getitem__(self, key: str) -> Any: class InvoiceSettings(BaseSettings): load_dotenv() + stripe_disabled: bool = os.environ.get("STRIPE_DISABLED", False) stripe_api_key: str = os.environ.get("STRIPE_API_KEY", None) stripe_webhook_secret: str = os.environ.get("STRIPE_WEBHOOK_SECRET", None) diff --git a/services/enterprise/modules/organization/invoice/controller.py b/services/enterprise/modules/organization/invoice/controller.py index 8d8e60cf..3cf89285 100644 --- a/services/enterprise/modules/organization/invoice/controller.py +++ b/services/enterprise/modules/organization/invoice/controller.py @@ -1,6 +1,8 @@ -from fastapi import APIRouter, Security, status +from fastapi import APIRouter, Depends, Security, status from starlette.requests import Request +from config import invoice_settings +from modules.organization.invoice.models.exceptions import StripeDisabledError from modules.organization.invoice.models.requests import ( CreditRequest, PaymentMethodRequest, @@ -16,9 +18,17 @@ from modules.organization.invoice.webhook import InvoiceWebhook from utils.auth import Authorize, User, authenticate_user + +def check_stripe_disabled(request: Request): + if invoice_settings.stripe_disabled: + raise StripeDisabledError() + return request + + router = APIRouter( prefix="/organizations", responses={404: {"description": "Not found"}}, + dependencies=[Depends(check_stripe_disabled)], ) authorize = Authorize() diff --git a/services/enterprise/modules/organization/invoice/models/exceptions.py b/services/enterprise/modules/organization/invoice/models/exceptions.py index c3559908..c84cd144 100644 --- a/services/enterprise/modules/organization/invoice/models/exceptions.py +++ b/services/enterprise/modules/organization/invoice/models/exceptions.py @@ -9,6 +9,9 @@ class InvoiceErrorCode(BaseErrorCode): + stripe_disabled = ErrorCodeData( + status_code=HTTP_400_BAD_REQUEST, message="Stripe is disabled" + ) no_payment_method = ErrorCodeData( status_code=HTTP_402_PAYMENT_REQUIRED, message="No payment method on file", @@ -59,6 +62,11 @@ class InvoiceError(BaseError): ERROR_CODES: BaseErrorCode = InvoiceErrorCode +class StripeDisabledError(InvoiceError): + def __init__(self) -> None: + super().__init__(error_code=InvoiceErrorCode.stripe_disabled.name) + + class NoPaymentMethodError(InvoiceError): def __init__(self, organization_id: str) -> None: super().__init__( diff --git a/services/enterprise/modules/organization/invoice/service.py b/services/enterprise/modules/organization/invoice/service.py index 8b8bb6b1..9c358a5c 100644 --- a/services/enterprise/modules/organization/invoice/service.py +++ b/services/enterprise/modules/organization/invoice/service.py @@ -244,6 +244,8 @@ def record_usage( quantity: int = 0, description: str = None, ): + if invoice_settings.stripe_disabled: + return organization = self.org_repo.get_organization(org_id) if organization.invoice_details.plan == PaymentPlan.ENTERPRISE: return @@ -283,7 +285,8 @@ def check_usage( type: UsageType, quantity: int = 0, ): - # check if organization has payment method + if invoice_settings.stripe_disabled: + return organization = self.org_repo.get_organization(org_id) if not organization.invoice_details: raise MissingInvoiceDetailsError(org_id) diff --git a/services/enterprise/pyproject.toml b/services/enterprise/pyproject.toml index efc53cfb..89213eba 100644 --- a/services/enterprise/pyproject.toml +++ b/services/enterprise/pyproject.toml @@ -21,7 +21,7 @@ select = [ "UP", "W", ] -ignore = ["A001", "A002", "A003", "B008", "UP006", "UP035", "PLR0913", "N805"] +ignore = ["A001", "A002", "A003", "B008", "UP006", "UP035", "PLR0913", "N805", "C901"] target-version = "py310" [tool.ruff.per-file-ignores]