Skip to content

Commit

Permalink
core: add primitives for source property mappings (#10651)
Browse files Browse the repository at this point in the history
  • Loading branch information
rissson authored Jul 26, 2024
1 parent ecd6c0a commit 45e4643
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 167 deletions.
27 changes: 22 additions & 5 deletions authentik/core/api/property_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@

from json import dumps

from django_filters.filters import AllValuesMultipleFilter, BooleanFilter
from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from drf_spectacular.utils import (
OpenApiParameter,
OpenApiResponse,
extend_schema,
extend_schema_field,
)
from guardian.shortcuts import get_objects_for_user
from rest_framework import mixins
from rest_framework.decorators import action
Expand Down Expand Up @@ -67,6 +74,18 @@ class Meta:
]


class PropertyMappingFilterSet(FilterSet):
"""Filter for PropertyMapping"""

managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))

managed__isnull = BooleanFilter(field_name="managed", lookup_expr="isnull")

class Meta:
model = PropertyMapping
fields = ["name", "managed"]


class PropertyMappingViewSet(
TypesMixin,
mixins.RetrieveModelMixin,
Expand All @@ -87,11 +106,9 @@ class PropertyMappingTestSerializer(PolicyTestSerializer):

queryset = PropertyMapping.objects.select_subclasses()
serializer_class = PropertyMappingSerializer
search_fields = [
"name",
]
filterset_fields = {"managed": ["isnull"]}
filterset_class = PropertyMappingFilterSet
ordering = ["name"]
search_fields = ["name"]

@permission_required("authentik_core.view_propertymapping")
@extend_schema(
Expand Down
115 changes: 75 additions & 40 deletions authentik/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from authentik.lib.avatars import get_avatar
from authentik.lib.expression.exceptions import ControlFlowException
from authentik.lib.generators import generate_id
from authentik.lib.merge import MERGE_LIST_UNIQUE
from authentik.lib.models import (
CreatedUpdatedModel,
DomainlessFormattedURLValidator,
Expand Down Expand Up @@ -100,6 +101,38 @@ class UserTypes(models.TextChoices):
INTERNAL_SERVICE_ACCOUNT = "internal_service_account"


class AttributesMixin(models.Model):
"""Adds an attributes property to a model"""

attributes = models.JSONField(default=dict, blank=True)

class Meta:
abstract = True

def update_attributes(self, properties: dict[str, Any]):
"""Update fields and attributes, but correctly by merging dicts"""
for key, value in properties.items():
if key == "attributes":
continue
setattr(self, key, value)
final_attributes = {}
MERGE_LIST_UNIQUE.merge(final_attributes, self.attributes)
MERGE_LIST_UNIQUE.merge(final_attributes, properties.get("attributes", {}))
self.attributes = final_attributes
self.save()

@classmethod
def update_or_create_attributes(
cls, query: dict[str, Any], properties: dict[str, Any]
) -> tuple[models.Model, bool]:
"""Same as django's update_or_create but correctly updates attributes by merging dicts"""
instance = cls.objects.filter(**query).first()
if not instance:
return cls.objects.create(**properties), True
instance.update_attributes(properties)
return instance, False


class GroupQuerySet(CTEQuerySet):
def with_children_recursive(self):
"""Recursively get all groups that have the current queryset as parents
Expand Down Expand Up @@ -134,7 +167,7 @@ def make_cte(cte):
return cte.join(Group, group_uuid=cte.col.group_uuid).with_cte(cte)


class Group(SerializerModel):
class Group(SerializerModel, AttributesMixin):
"""Group model which supports a basic hierarchy and has attributes"""

group_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
Expand All @@ -154,10 +187,27 @@ class Group(SerializerModel):
on_delete=models.SET_NULL,
related_name="children",
)
attributes = models.JSONField(default=dict, blank=True)

objects = GroupQuerySet.as_manager()

class Meta:
unique_together = (
(
"name",
"parent",
),
)
indexes = [models.Index(fields=["name"])]
verbose_name = _("Group")
verbose_name_plural = _("Groups")
permissions = [
("add_user_to_group", _("Add user to group")),
("remove_user_from_group", _("Remove user from group")),
]

def __str__(self):
return f"Group {self.name}"

@property
def serializer(self) -> Serializer:
from authentik.core.api.groups import GroupSerializer
Expand All @@ -182,24 +232,6 @@ def children_recursive(self: Self | QuerySet["Group"]) -> QuerySet["Group"]:
qs = Group.objects.filter(group_uuid=self.group_uuid)
return qs.with_children_recursive()

def __str__(self):
return f"Group {self.name}"

class Meta:
unique_together = (
(
"name",
"parent",
),
)
indexes = [models.Index(fields=["name"])]
verbose_name = _("Group")
verbose_name_plural = _("Groups")
permissions = [
("add_user_to_group", _("Add user to group")),
("remove_user_from_group", _("Remove user from group")),
]


class UserQuerySet(models.QuerySet):
"""User queryset"""
Expand All @@ -225,7 +257,7 @@ def exclude_anonymous(self) -> QuerySet:
return self.get_queryset().exclude_anonymous()


class User(SerializerModel, GuardianUserMixin, AbstractUser):
class User(SerializerModel, GuardianUserMixin, AttributesMixin, AbstractUser):
"""authentik User model, based on django's contrib auth user model."""

uuid = models.UUIDField(default=uuid4, editable=False, unique=True)
Expand All @@ -241,6 +273,28 @@ class User(SerializerModel, GuardianUserMixin, AbstractUser):

objects = UserManager()

class Meta:
verbose_name = _("User")
verbose_name_plural = _("Users")
permissions = [
("reset_user_password", _("Reset Password")),
("impersonate", _("Can impersonate other users")),
("assign_user_permissions", _("Can assign permissions to users")),
("unassign_user_permissions", _("Can unassign permissions from users")),
("preview_user", _("Can preview user data sent to providers")),
("view_user_applications", _("View applications the user has access to")),
]
indexes = [
models.Index(fields=["last_login"]),
models.Index(fields=["password_change_date"]),
models.Index(fields=["uuid"]),
models.Index(fields=["path"]),
models.Index(fields=["type"]),
]

def __str__(self):
return self.username

@staticmethod
def default_path() -> str:
"""Get the default user path"""
Expand Down Expand Up @@ -322,25 +376,6 @@ def avatar(self) -> str:
"""Get avatar, depending on authentik.avatar setting"""
return get_avatar(self)

class Meta:
verbose_name = _("User")
verbose_name_plural = _("Users")
permissions = [
("reset_user_password", _("Reset Password")),
("impersonate", _("Can impersonate other users")),
("assign_user_permissions", _("Can assign permissions to users")),
("unassign_user_permissions", _("Can unassign permissions from users")),
("preview_user", _("Can preview user data sent to providers")),
("view_user_applications", _("View applications the user has access to")),
]
indexes = [
models.Index(fields=["last_login"]),
models.Index(fields=["password_change_date"]),
models.Index(fields=["uuid"]),
models.Index(fields=["path"]),
models.Index(fields=["type"]),
]


class Provider(SerializerModel):
"""Application-independent Provider instance. For example SAML2 Remote, OAuth2 Application"""
Expand Down
14 changes: 4 additions & 10 deletions authentik/providers/oauth2/api/scopes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
"""OAuth2Provider API Views"""

from django_filters.filters import AllValuesMultipleFilter
from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
from rest_framework.fields import CharField
from rest_framework.serializers import ValidationError
from rest_framework.viewsets import ModelViewSet

from authentik.core.api.property_mappings import PropertyMappingSerializer
from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.providers.oauth2.models import ScopeMapping

Expand All @@ -33,14 +29,12 @@ class Meta:
]


class ScopeMappingFilter(FilterSet):
class ScopeMappingFilter(PropertyMappingFilterSet):
"""Filter for ScopeMapping"""

managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))

class Meta:
class Meta(PropertyMappingFilterSet.Meta):
model = ScopeMapping
fields = ["scope_name", "name", "managed"]
fields = PropertyMappingFilterSet.Meta.fields + ["scope_name"]


class ScopeMappingViewSet(UsedByMixin, ModelViewSet):
Expand Down
13 changes: 3 additions & 10 deletions authentik/providers/radius/api/property_mappings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""Radius Property mappings API Views"""

from django_filters.filters import AllValuesMultipleFilter
from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
from rest_framework.viewsets import ModelViewSet

from authentik.core.api.property_mappings import PropertyMappingSerializer
from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.providers.radius.models import RadiusProviderPropertyMapping

Expand All @@ -19,14 +15,11 @@ class Meta:
fields = PropertyMappingSerializer.Meta.fields


class RadiusProviderPropertyMappingFilter(FilterSet):
class RadiusProviderPropertyMappingFilter(PropertyMappingFilterSet):
"""Filter for RadiusProviderPropertyMapping"""

managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))

class Meta:
class Meta(PropertyMappingFilterSet.Meta):
model = RadiusProviderPropertyMapping
fields = "__all__"


class RadiusProviderPropertyMappingViewSet(UsedByMixin, ModelViewSet):
Expand Down
13 changes: 3 additions & 10 deletions authentik/providers/saml/api/property_mappings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""SAML Property mappings API Views"""

from django_filters.filters import AllValuesMultipleFilter
from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
from rest_framework.viewsets import ModelViewSet

from authentik.core.api.property_mappings import PropertyMappingSerializer
from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.providers.saml.models import SAMLPropertyMapping

Expand All @@ -22,14 +18,11 @@ class Meta:
]


class SAMLPropertyMappingFilter(FilterSet):
class SAMLPropertyMappingFilter(PropertyMappingFilterSet):
"""Filter for SAMLPropertyMapping"""

managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))

class Meta:
class Meta(PropertyMappingFilterSet.Meta):
model = SAMLPropertyMapping
fields = "__all__"


class SAMLPropertyMappingViewSet(UsedByMixin, ModelViewSet):
Expand Down
13 changes: 3 additions & 10 deletions authentik/providers/scim/api/property_mappings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""scim Property mappings API Views"""

from django_filters.filters import AllValuesMultipleFilter
from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema_field
from rest_framework.viewsets import ModelViewSet

from authentik.core.api.property_mappings import PropertyMappingSerializer
from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.providers.scim.models import SCIMMapping

Expand All @@ -19,14 +15,11 @@ class Meta:
fields = PropertyMappingSerializer.Meta.fields


class SCIMMappingFilter(FilterSet):
class SCIMMappingFilter(PropertyMappingFilterSet):
"""Filter for SCIMMapping"""

managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))

class Meta:
class Meta(PropertyMappingFilterSet.Meta):
model = SCIMMapping
fields = "__all__"


class SCIMMappingViewSet(UsedByMixin, ModelViewSet):
Expand Down
14 changes: 4 additions & 10 deletions authentik/sources/ldap/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
from typing import Any

from django.core.cache import cache
from django_filters.filters import AllValuesMultipleFilter
from django_filters.filterset import FilterSet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_field, inline_serializer
from drf_spectacular.utils import extend_schema, inline_serializer
from guardian.shortcuts import get_objects_for_user
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
Expand All @@ -16,7 +13,7 @@
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet

from authentik.core.api.property_mappings import PropertyMappingSerializer
from authentik.core.api.property_mappings import PropertyMappingFilterSet, PropertyMappingSerializer
from authentik.core.api.sources import SourceSerializer
from authentik.core.api.used_by import UsedByMixin
from authentik.crypto.models import CertificateKeyPair
Expand Down Expand Up @@ -185,14 +182,11 @@ class Meta:
fields = PropertyMappingSerializer.Meta.fields


class LDAPSourcePropertyMappingFilter(FilterSet):
class LDAPSourcePropertyMappingFilter(PropertyMappingFilterSet):
"""Filter for LDAPSourcePropertyMapping"""

managed = extend_schema_field(OpenApiTypes.STR)(AllValuesMultipleFilter(field_name="managed"))

class Meta:
class Meta(PropertyMappingFilterSet.Meta):
model = LDAPSourcePropertyMapping
fields = "__all__"


class LDAPSourcePropertyMappingViewSet(UsedByMixin, ModelViewSet):
Expand Down
Loading

0 comments on commit 45e4643

Please sign in to comment.