Skip to content

Commit

Permalink
feat: stripe self checkout (#325)
Browse files Browse the repository at this point in the history
* feat: extend org model with stripe fields

* feat: add stripe dependency

* feat: add queries and mutation for stripe self-checkout

* feat: create customer on stripe if running in cloud mode

* feat: add misc utils for orgs and billing

* feat: add webhook handlers for stripe

* feat: load stripe env vars in settings

* feat: add stripe js

* feat: add frontend queries, mutations, schema

* feat: add self-checkout flow to frontend

* feat: add stripe public key to variable replacement script

* feat: update csp for stripe

* fix: csp

* fix: add placeholder for frontend stripe public key

* fix: store free tier stripe subscription id when creating customer

* chore: conditionally import ee utils

* feat: add size variants to generic dialog

* chore: remove unused imports and code, resize upgrade dialog

* feat: restyle checkout preview screen

* fix: tweaks to post-checkout screen

* feat: enable proration

* fix: settings page height

* feat: restyle plan label

* feat: add plan label to org selection menu + misc restyle

* feat: add rework upsell ui

* fix: tweak copy for self-hosted instances

* fix: layout height

* fix: pat page vertical overflow

* fix: tweak org menu button style

---------

Co-authored-by: Nimish <[email protected]>
  • Loading branch information
rohan-chaturvedi and nimish-ks authored Aug 5, 2024
1 parent cb7d381 commit abab513
Show file tree
Hide file tree
Showing 36 changed files with 984 additions and 171 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Generated by Django 4.2.7 on 2024-07-30 10:08

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('api', '0074_correct_set_index_values'),
]

operations = [
migrations.AddField(
model_name='organisation',
name='stripe_customer_id',
field=models.CharField(blank=True, max_length=255, null=True),
),
migrations.AddField(
model_name='organisation',
name='stripe_subscription_id',
field=models.CharField(blank=True, max_length=255, null=True),
),
migrations.AlterField(
model_name='activatedphaselicense',
name='seats',
field=models.IntegerField(null=True),
),
migrations.AlterField(
model_name='activatedphaselicense',
name='tokens',
field=models.IntegerField(null=True),
),
]
6 changes: 4 additions & 2 deletions backend/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class Organisation(models.Model):
choices=PLAN_TIERS,
default=FREE_PLAN,
)
stripe_customer_id = models.CharField(max_length=255, blank=True, null=True)
stripe_subscription_id = models.CharField(max_length=255, blank=True, null=True)
list_display = ("name", "identity_key", "id")

def __str__(self):
Expand All @@ -107,8 +109,8 @@ class ActivatedPhaseLicense(models.Model):
choices=Organisation.PLAN_TIERS,
default=Organisation.ENTERPRISE_PLAN,
)
seats = models.IntegerField()
tokens = models.IntegerField()
seats = models.IntegerField(null=True)
tokens = models.IntegerField(null=True)
metadata = models.JSONField()
environment = models.CharField(max_length=255)
license_type = models.CharField(max_length=255)
Expand Down
15 changes: 15 additions & 0 deletions backend/api/utils/organisations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from api.models import OrganisationMember, OrganisationMemberInvite
from django.utils import timezone


def get_organisation_seats(organisation):
seats = (
OrganisationMember.objects.filter(
organisation=organisation, deleted_at=None
).count()
+ OrganisationMemberInvite.objects.filter(
organisation=organisation, valid=True, expires_at__gte=timezone.now()
).count()
)

return seats
15 changes: 15 additions & 0 deletions backend/backend/graphene/mutations/organisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def mutate(
wrapped_recovery=wrapped_recovery,
)

if settings.APP_HOST == "cloud":
from ee.billing.stripe import create_stripe_customer

create_stripe_customer(org, owner.email)

if settings.PHASE_LICENSE:
from ee.license.utils import activate_license

Expand Down Expand Up @@ -202,6 +207,11 @@ def mutate(
invite.valid = False
invite.save()

if settings.APP_HOST == "cloud":
from ee.billing.stripe import update_stripe_subscription_seats

update_stripe_subscription_seats(org)

try:
send_user_joined_email(invite, org_member)
except Exception as e:
Expand All @@ -228,6 +238,11 @@ def mutate(cls, root, info, member_id):
if user_is_admin(info.context.user.userId, org_member.organisation.id):
org_member.delete()

if settings.APP_HOST == "cloud":
from ee.billing.stripe import update_stripe_subscription_seats

update_stripe_subscription_seats(org_member.organisation)

return DeleteOrganisationMemberMutation(ok=True)
else:
raise GraphQLError("You don't have permission to perform that action")
Expand Down
14 changes: 14 additions & 0 deletions backend/backend/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from api.utils.syncing.vault.main import VaultMountType
from api.utils.syncing.gitlab.main import GitLabGroupType, GitLabProjectType
from api.utils.syncing.railway.main import RailwayEnvironmentType, RailwayProjectType
from ee.billing.graphene.queries.stripe import (
StripeCheckoutDetails,
resolve_stripe_checkout_details,
)
from ee.billing.graphene.mutations.stripe import CreateProUpgradeCheckoutSession
from .graphene.mutations.lockbox import CreateLockboxMutation
from .graphene.queries.syncing import (
resolve_aws_secret_manager_secrets,
Expand Down Expand Up @@ -259,6 +264,10 @@ class Query(graphene.ObjectType):

test_nomad_creds = graphene.Field(graphene.Boolean, credential_id=graphene.ID())

stripe_checkout_details = graphene.Field(
StripeCheckoutDetails, stripe_session_id=graphene.String(required=True)
)

# --------------------------------------------------------------------

resolve_server_public_key = resolve_server_public_key
Expand Down Expand Up @@ -673,6 +682,8 @@ def resolve_app_activity_chart(root, info, app_id, period=TimeRange.DAY):

return time_series_logs

resolve_stripe_checkout_details = resolve_stripe_checkout_details


class Mutation(graphene.ObjectType):
create_organisation = CreateOrganisationMutation.Field()
Expand Down Expand Up @@ -751,5 +762,8 @@ class Mutation(graphene.ObjectType):
# Lockbox
create_lockbox = CreateLockboxMutation.Field()

# Billing
create_pro_upgrade_checkout_session = CreateProUpgradeCheckoutSession.Field()


schema = graphene.Schema(query=Query, mutation=Mutation)
14 changes: 14 additions & 0 deletions backend/backend/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,17 @@
}

PHASE_LICENSE = check_license(os.getenv("PHASE_LICENSE_OFFLINE"))


STRIPE = {}
try:
STRIPE["secret_key"] = os.getenv("STRIPE_SECRET_KEY")
STRIPE["public_key"] = os.getenv("STRIPE_PUBLIC_KEY")
STRIPE["webhook_secret"] = os.getenv("STRIPE_WEBHOOK_SECRET")
STRIPE["prices"] = {
"free": os.getenv("STRIPE_FREE"),
"pro_monthly": os.getenv("STRIPE_PRO_MONTHLY"),
"pro_yearly": os.getenv("STRIPE_PRO_YEARLY"),
}
except:
pass
7 changes: 6 additions & 1 deletion backend/backend/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from api.views.auth import logout_view, health_check, github_callback, secrets_tokens
from api.views.kms import kms


CLOUD_HOSTED = settings.APP_HOST == "cloud"

urlpatterns = [
Expand All @@ -24,8 +25,12 @@
path("lockbox/<box_id>", LockboxView.as_view()),
]

if not CLOUD_HOSTED:
if CLOUD_HOSTED:
from ee.billing.webhooks.stripe import stripe_webhook

urlpatterns.append(path("kms/<app_id>", kms))
urlpatterns.append(path("stripe/webhook/", stripe_webhook, name="stripe-webhook"))


try:
if settings.ADMIN_ENABLED:
Expand Down
55 changes: 55 additions & 0 deletions backend/ee/billing/graphene/mutations/stripe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from api.models import Organisation
from api.utils.organisations import get_organisation_seats
import stripe
from django.conf import settings
from graphene import Mutation, ID, String
from graphql import GraphQLError


class CreateProUpgradeCheckoutSession(Mutation):
class Arguments:
organisation_id = ID(required=True)
billing_period = String()

client_secret = String()

def mutate(self, info, organisation_id, billing_period):

try:
stripe.api_key = settings.STRIPE["secret_key"]

organisation = Organisation.objects.get(id=organisation_id)
seats = get_organisation_seats(organisation)

# Ensure the organisation has a Stripe customer ID
if not organisation.stripe_customer_id:
raise GraphQLError("Organisation must have a Stripe customer ID.")

price = (
settings.STRIPE["prices"]["pro_monthly"]
if billing_period == "monthly"
else settings.STRIPE["prices"]["pro_yearly"]
)

# Create the checkout session
session = stripe.checkout.Session.create(
mode="subscription",
ui_mode="embedded",
line_items=[
{
"price": price,
"quantity": seats,
},
],
customer=organisation.stripe_customer_id,
payment_method_types=["card"],
return_url=f"{settings.OAUTH_REDIRECT_URI}/{organisation.name}/settings?stripe_session_id={{CHECKOUT_SESSION_ID}}",
)
return CreateProUpgradeCheckoutSession(client_secret=session.client_secret)

except Organisation.DoesNotExist:
raise GraphQLError("Organisation not found.")
except Exception as e:
raise GraphQLError(
f"Something went wrong during checkout. Please try again."
)
42 changes: 42 additions & 0 deletions backend/ee/billing/graphene/queries/stripe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import graphene
from graphene import ObjectType, String, Field
import stripe
from django.conf import settings


class StripeCheckoutDetails(graphene.ObjectType):
payment_status = graphene.String()
customer_email = graphene.String()
billing_start_date = graphene.String()
billing_end_date = graphene.String()
subscription_id = graphene.String()
plan_name = graphene.String()


def resolve_stripe_checkout_details(self, info, stripe_session_id):
stripe.api_key = settings.STRIPE["secret_key"]

try:
session = stripe.checkout.Session.retrieve(stripe_session_id)

subscription_id = session.get("subscription")
if subscription_id:
subscription = stripe.Subscription.retrieve(subscription_id)
plan_name = subscription["items"]["data"][0]["plan"]["nickname"]
billing_start_date = subscription["current_period_start"]
billing_end_date = subscription["current_period_end"]
else:
plan_name = None
billing_start_date = None
billing_end_date = None

return StripeCheckoutDetails(
payment_status=session.payment_status,
customer_email=session.customer_details.email,
billing_start_date=str(billing_start_date),
billing_end_date=str(billing_end_date),
subscription_id=subscription_id,
plan_name=plan_name,
)
except stripe.error.StripeError as e:
return None
77 changes: 77 additions & 0 deletions backend/ee/billing/stripe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from api.models import Organisation
from backend.api.notifier import notify_slack
from api.utils.organisations import get_organisation_seats
import stripe
from django.conf import settings


def create_stripe_customer(organisation, email):
stripe.api_key = settings.STRIPE["secret_key"]

stripe_customer = stripe.Customer.create(
name=organisation.name,
email=email,
)
organisation.stripe_customer_id = stripe_customer.id
subscription = stripe.Subscription.create(
customer=stripe_customer.id,
items=[
{
"price": settings.STRIPE["prices"]["free"],
}
],
)
organisation.stripe_subscription_id = subscription.id
organisation.save()


def update_stripe_subscription_seats(organisation):
stripe.api_key = settings.STRIPE["secret_key"]

if not organisation.stripe_subscription_id:
raise ValueError("Organisation must have a Stripe subscription ID.")

try:
new_seat_count = get_organisation_seats(organisation)

# Retrieve the subscription
subscription = stripe.Subscription.retrieve(organisation.stripe_subscription_id)

if not subscription["items"]["data"]:
raise ValueError("No items found in the subscription.")

# Assume we're updating the first item in the subscription
item_id = subscription["items"]["data"][0]["id"]

# Modify the subscription with the new seat count
updated_subscription = stripe.Subscription.modify(
organisation.stripe_subscription_id,
items=[
{
"id": item_id,
"quantity": new_seat_count,
}
],
proration_behavior='always_invoice'
)
return updated_subscription

except Exception as ex:
print("Failed to update Stripe seat count:", ex)
try:
notify_slack(
f"Failed to update Stripe seat count for organisation {organisation.id}: {ex}"
)
except:
pass
pass


def map_stripe_plan_to_tier(stripe_plan_id):
if (
stripe_plan_id == settings.STRIPE["prices"]["pro_monthly"]
or stripe_plan_id == settings.STRIPE["prices"]["pro_yearly"]
):
return Organisation.PRO_PLAN
elif stripe_plan_id == settings.STRIPE["prices"]["free"]:
return Organisation.FREE_PLAN
Loading

0 comments on commit abab513

Please sign in to comment.