Skip to content

Commit

Permalink
enterprise: UI improvements, better handling of expiry (#10828)
Browse files Browse the repository at this point in the history
* web/admin: show enterprise banner on the very top

Signed-off-by: Jens Langhammer <[email protected]>

* rework license

Signed-off-by: Jens Langhammer <[email protected]>

* fix a bunch of things

Signed-off-by: Jens Langhammer <[email protected]>

* add some more tests

Signed-off-by: Jens Langhammer <[email protected]>

* add more tests

Signed-off-by: Jens Langhammer <[email protected]>

* fix middleware

Signed-off-by: Jens Langhammer <[email protected]>

* better api

Signed-off-by: Jens Langhammer <[email protected]>

* format

Signed-off-by: Jens Langhammer <[email protected]>

* add tests for and fix read only mode

Signed-off-by: Jens Langhammer <[email protected]>

* field name consistency

Signed-off-by: Jens Langhammer <[email protected]>

---------

Signed-off-by: Jens Langhammer <[email protected]>
  • Loading branch information
BeryJu authored Aug 9, 2024
1 parent 3265b4a commit 4b5bb77
Show file tree
Hide file tree
Showing 20 changed files with 750 additions and 195 deletions.
2 changes: 1 addition & 1 deletion authentik/admin/api/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_runtime(self, request: Request) -> RuntimeDict:
"authentik_version": get_full_version(),
"environment": get_env(),
"openssl_fips_enabled": (
backend._fips_enabled if LicenseKey.get_total().is_valid() else None
backend._fips_enabled if LicenseKey.get_total().status().is_valid else None
),
"openssl_version": OPENSSL_VERSION,
"platform": platform.platform(),
Expand Down
2 changes: 1 addition & 1 deletion authentik/blueprints/v1/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(self, blueprint: Blueprint, context: dict | None = None):
def default_context(self):
"""Default context"""
return {
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().is_valid(),
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().status().is_valid,
"goauthentik.io/rbac/models": rbac_models(),
}

Expand Down
6 changes: 3 additions & 3 deletions authentik/enterprise/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
from authentik.enterprise.models import License
from authentik.enterprise.models import License, LicenseUsageStatus
from authentik.rbac.decorators import permission_required
from authentik.tenants.utils import get_unique_identifier

Expand All @@ -30,7 +30,7 @@ class EnterpriseRequiredMixin:

def validate(self, attrs: dict) -> dict:
"""Check that a valid license exists"""
if not LicenseKey.cached_summary().has_license:
if LicenseKey.cached_summary().status != LicenseUsageStatus.UNLICENSED:
raise ValidationError(_("Enterprise is required to create/update this object."))
return super().validate(attrs)

Expand Down Expand Up @@ -128,7 +128,7 @@ def forecast(self, request: Request) -> Response:
forecast_for_months = 12
response = LicenseForecastSerializer(
data={
"internal_users": LicenseKey.get_default_user_count(),
"internal_users": LicenseKey.get_internal_user_count(),
"external_users": LicenseKey.get_external_user_count(),
"forecasted_internal_users": (internal_in_last_month * forecast_for_months),
"forecasted_external_users": (external_in_last_month * forecast_for_months),
Expand Down
2 changes: 1 addition & 1 deletion authentik/enterprise/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ def check_enabled(self):
"""Actual enterprise check, cached"""
from authentik.enterprise.license import LicenseKey

return LicenseKey.cached_summary().valid
return LicenseKey.cached_summary().status
133 changes: 78 additions & 55 deletions authentik/enterprise/license.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,36 @@
from base64 import b64decode
from binascii import Error
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime

from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict
from dacite import DaciteError, from_dict
from django.core.cache import cache
from django.db.models.query import QuerySet
from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, DateTimeField, IntegerField
from rest_framework.fields import (
ChoiceField,
DateTimeField,
IntegerField,
)

from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import License, LicenseUsage
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_EXPIRY_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.tenants.utils import get_unique_identifier

CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
Expand All @@ -42,32 +54,26 @@ def get_license_aud() -> str:
class LicenseFlags(Enum):
"""License flags"""

TRIAL = "trial"


@dataclass
class LicenseSummary:
"""Internal representation of a license summary"""

internal_users: int
external_users: int
valid: bool
show_admin_warning: bool
show_user_warning: bool
read_only: bool
status: LicenseUsageStatus
latest_valid: datetime
has_license: bool


class LicenseSummarySerializer(PassiveSerializer):
"""Serializer for license status"""

internal_users = IntegerField(required=True)
external_users = IntegerField(required=True)
valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
status = ChoiceField(choices=LicenseUsageStatus.choices)
latest_valid = DateTimeField()
has_license = BooleanField()


@dataclass
Expand All @@ -83,7 +89,7 @@ class LicenseKey:
flags: list[LicenseFlags] = field(default_factory=list)

@staticmethod
def validate(jwt: str) -> "LicenseKey":
def validate(jwt: str, check_expiry=True) -> "LicenseKey":
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
Expand All @@ -107,6 +113,7 @@ def validate(jwt: str) -> "LicenseKey":
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
options={"verify_exp": check_expiry},
),
)
except PyJWTError:
Expand All @@ -116,9 +123,8 @@ def validate(jwt: str) -> "LicenseKey":
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses:
for lic in License.objects.all():
total.internal_users += lic.internal_users
total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple()))
Expand All @@ -135,7 +141,7 @@ def base_user_qs() -> QuerySet:
return User.objects.all().exclude_anonymous().exclude(is_active=False)

@staticmethod
def get_default_user_count():
def get_internal_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()

Expand All @@ -144,59 +150,72 @@ def get_external_user_count():
"""Get current external user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count()

def is_valid(self) -> bool:
"""Check if the given license body covers all users
Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count()
if default_users > self.internal_users:
return False
active_users = self.get_external_user_count()
if active_users > self.external_users:
return False
return True
def _last_valid_date(self):
last_valid_date = (
LicenseUsage.objects.order_by("-record_date")
.filter(status=LicenseUsageStatus.VALID)
.first()
)
if not last_valid_date:
return datetime.fromtimestamp(0, UTC)
return last_valid_date.record_date

def status(self) -> LicenseUsageStatus:
"""Check if the given license body covers all users, and is valid."""
last_valid = self._last_valid_date()
if self.exp == 0 and not License.objects.exists():
return LicenseUsageStatus.UNLICENSED
_now = now()
# Check limit-exceeded based status
internal_users = self.get_internal_user_count()
external_users = self.get_external_user_count()
if internal_users > self.internal_users or external_users > self.external_users:
if last_valid < _now - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS):
return LicenseUsageStatus.READ_ONLY
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_USER
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
# Check expiry based status
if datetime.fromtimestamp(self.exp, UTC) < _now:
if datetime.fromtimestamp(self.exp, UTC) < _now - timedelta(
weeks=THRESHOLD_READ_ONLY_WEEKS
):
return LicenseUsageStatus.READ_ONLY
return LicenseUsageStatus.EXPIRED
# Expiry warning
if datetime.fromtimestamp(self.exp, UTC) <= _now + timedelta(
weeks=THRESHOLD_WARNING_EXPIRY_WEEKS
):
return LicenseUsageStatus.EXPIRY_SOON
return LicenseUsageStatus.VALID

def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if not LicenseUsage.objects.filter(record_date__gte=threshold).exists():
LicenseUsage.objects.create(
user_count=self.get_default_user_count(),
usage = (
LicenseUsage.objects.order_by("-record_date").filter(record_date__gte=threshold).first()
)
if not usage:
usage = LicenseUsage.objects.create(
internal_user_count=self.get_internal_user_count(),
external_user_count=self.get_external_user_count(),
within_limits=self.is_valid(),
status=self.status(),
)
summary = asdict(self.summary())
# Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return summary

@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
return usage

def summary(self) -> LicenseSummary:
"""Summary of license status"""
has_license = License.objects.all().count() > 0
last_valid = LicenseKey.last_valid_date()
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
status = self.status()
latest_valid = datetime.fromtimestamp(self.exp)
return LicenseSummary(
show_admin_warning=show_admin_warning and has_license,
show_user_warning=show_user_warning and has_license,
read_only=read_only and has_license,
latest_valid=latest_valid,
internal_users=self.internal_users,
external_users=self.external_users,
valid=self.is_valid(),
has_license=has_license,
status=status,
)

@staticmethod
Expand All @@ -205,4 +224,8 @@ def cached_summary() -> LicenseSummary:
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
if not summary:
return LicenseKey.get_total().summary()
return from_dict(LicenseSummary, summary)
try:
return from_dict(LicenseSummary, summary)
except DaciteError:
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
return LicenseKey.get_total().summary()
7 changes: 4 additions & 3 deletions authentik/enterprise/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from authentik.enterprise.api import LicenseViewSet
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseUsageStatus
from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.utils.reflection import class_to_path

Expand Down Expand Up @@ -43,7 +44,7 @@ def is_request_allowed(self, request: HttpRequest) -> bool:
cached_status = LicenseKey.cached_summary()
if not cached_status:
return True
if cached_status.read_only:
if cached_status.status == LicenseUsageStatus.READ_ONLY:
return False
return True

Expand All @@ -53,10 +54,10 @@ def is_request_always_allowed(self, request: HttpRequest):
if request.method.lower() in ["get", "head", "options", "trace"]:
return True
# Always allow requests to manage licenses
if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet):
if request.resolver_match._func_path == class_to_path(LicenseViewSet):
return True
# Flow executor is mounted as an API path but explicitly allowed
if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView):
if request.resolver_match._func_path == class_to_path(FlowExecutorView):
return True
# Only apply these restrictions to the API
if "authentik_api" not in request.resolver_match.app_names:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Generated by Django 5.0.8 on 2024-08-08 14:15

from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor


def migrate_license_usage(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
LicenseUsage = apps.get_model("authentik_enterprise", "licenseusage")
db_alias = schema_editor.connection.alias

for usage in LicenseUsage.objects.using(db_alias).all():
usage.status = "valid" if usage.within_limits else "limit_exceeded_admin"
usage.save()


class Migration(migrations.Migration):

dependencies = [
("authentik_enterprise", "0002_rename_users_license_internal_users_and_more"),
]

operations = [
migrations.AddField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
default=None,
null=True,
),
preserve_default=False,
),
migrations.RunPython(migrate_license_usage),
migrations.RemoveField(
model_name="licenseusage",
name="within_limits",
),
migrations.AlterField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
),
preserve_default=False,
),
migrations.RenameField(
model_name="licenseusage",
old_name="user_count",
new_name="internal_user_count",
),
]
Loading

0 comments on commit 4b5bb77

Please sign in to comment.