Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ca 470 implement hard limits on seats #894

Merged
merged 12 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/functional-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ jobs:
echo EMAIL_HOST_PASSWORD=password >> .env
echo EMAIL_PORT=1025 >> .env
echo DJANGO_SETTINGS_MODULE=enterprise_core.settings >> .env
echo LICENSE_SEATS=999 >> .env
- name: Run migrations
working-directory: ${{ env.backend-directory }}
run: |
Expand Down
23 changes: 17 additions & 6 deletions backend/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@
from rest_framework.views import APIView
from weasyprint import HTML

from ciso_assistant.settings import (
BUILD,
VERSION,
)
from core.helpers import *
from core.models import (
AppliedControl,
Expand All @@ -56,12 +52,19 @@
from .models import *
from .serializers import *

import structlog

logger = structlog.get_logger(__name__)

User = get_user_model()

SHORT_CACHE_TTL = 2 # mn
MED_CACHE_TTL = 5 # mn
LONG_CACHE_TTL = 60 # mn

SETTINGS_MODULE = __import__(os.environ.get("DJANGO_SETTINGS_MODULE"))
MODULE_PATHS = SETTINGS_MODULE.settings.MODULE_PATHS


class BaseModelViewSet(viewsets.ModelViewSet):
filter_backends = [
Expand Down Expand Up @@ -95,13 +98,19 @@ def get_queryset(self):
return queryset

def get_serializer_class(self, **kwargs):
MODULE_PATHS = settings.MODULE_PATHS
serializer_factory = SerializerFactory(
self.serializers_module, *MODULE_PATHS.get("serializers", [])
self.serializers_module, MODULE_PATHS.get("serializers", [])
)
serializer_class = serializer_factory.get_serializer(
self.model.__name__, kwargs.get("action", self.action)
)
logger.debug(
"Serializer class",
serializer_class=serializer_class,
action=kwargs.get("action", self.action),
viewset=self,
module_paths=MODULE_PATHS,
)

return serializer_class

Expand Down Expand Up @@ -1941,6 +1950,8 @@ def get_build(request):
"""
API endpoint that returns the build version of the application.
"""
BUILD = settings.BUILD
VERSION = settings.VERSION
return Response({"version": VERSION, "build": BUILD})


Expand Down
31 changes: 25 additions & 6 deletions backend/iam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ def get_localization_dict(self) -> dict:
"role": BUILTIN_USERGROUP_CODENAMES.get(self.name),
}

@property
def permissions(self):
return RoleAssignment.get_permissions(self)


class UserManager(BaseUserManager):
use_in_migrations = True
Expand Down Expand Up @@ -502,6 +506,19 @@ def get_admin_users() -> List[Self]:
def is_admin(self) -> bool:
return self.user_groups.filter(name="BI-UG-ADM").exists()

@property
def is_editor(self) -> bool:
permissions = RoleAssignment.get_permissions(self)
editor_prefixes = {"add_", "change_", "delete_"}
return any(
any(perm.startswith(prefix) for prefix in editor_prefixes)
for perm in permissions
)

@classmethod
def get_editors(cls) -> List[Self]:
return [user for user in cls.objects.all() if user.is_editor]


class Role(NameDescriptionMixin, FolderMixin):
"""A role is a list of permissions"""
Expand Down Expand Up @@ -718,18 +735,20 @@ def is_user_assigned(self, user) -> bool:
)

@staticmethod
def get_role_assignments(user):
def get_role_assignments(principal: AbstractBaseUser | AnonymousUser | UserGroup):
"""get all role assignments attached to a user directly or indirectly"""
assignments = list(user.roleassignment_set.all())
for user_group in user.user_groups.all():
assignments += list(user_group.roleassignment_set.all())
assignments = list(principal.roleassignment_set.all())
if hasattr(principal, "user_groups"):
for user_group in principal.user_groups.all():
assignments += list(user_group.roleassignment_set.all())
assignments += list(principal.roleassignment_set.all())
return assignments

@staticmethod
def get_permissions(user: AbstractBaseUser | AnonymousUser):
def get_permissions(principal: AbstractBaseUser | AnonymousUser | UserGroup):
"""get all permissions attached to a user directly or indirectly"""
permissions = {}
for ra in RoleAssignment.get_role_assignments(user):
for ra in RoleAssignment.get_role_assignments(principal):
for p in ra.role.permissions.all():
permission_dict = {p.codename: {"str": str(p)}}
permissions.update(permission_dict)
Expand Down
116 changes: 116 additions & 0 deletions backend/iam/tests/test_user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import pytest
from django.contrib.auth.models import Permission

from core.tests.fixtures import *
from iam.models import Folder, Role, RoleAssignment, User


@pytest.mark.django_db
class TestUser:
pytestmark = pytest.mark.django_db

@pytest.mark.usefixtures("domain_project_fixture")
def test_reader_user_is_not_editor(self):
user = User.objects.create_user(email="[email protected]", password="password")
assert user is not None

folder = Folder.objects.filter(content_type=Folder.ContentType.DOMAIN).last()
reader_role = Role.objects.create(name="test reader")
reader_permissions = Permission.objects.filter(
codename__in=[
"view_project",
"view_riskassessment",
"view_appliedcontrol",
"view_riskscenario",
"view_riskacceptance",
"view_asset",
"view_threat",
"view_referencecontrol",
"view_folder",
"view_usergroup",
]
)
reader_role.permissions.set(reader_permissions)
reader_role.save()
reader_role_assignment = RoleAssignment.objects.create(
user=user,
role=reader_role,
folder=folder,
is_recursive=True,
)
reader_role_assignment.perimeter_folders.add(folder)
reader_role_assignment.save()

assert not user.is_editor

editors = User.get_editors()
assert len(editors) == 0
assert user not in editors

@pytest.mark.usefixtures("domain_project_fixture")
def test_editor_user_is_editor(self):
user = User.objects.create_user(email="[email protected]", password="password")
assert user is not None

folder = Folder.objects.filter(content_type=Folder.ContentType.DOMAIN).last()
editor_role = Role.objects.create(name="test editor")
editor_permissions = Permission.objects.filter(
codename__in=[
"view_project",
"view_riskassessment",
"view_appliedcontrol",
"view_riskscenario",
"view_riskacceptance",
"view_asset",
"view_threat",
"view_referencecontrol",
"view_folder",
"view_usergroup",
"add_project",
"change_project",
"delete_project",
"add_riskassessment",
"change_riskassessment",
"delete_riskassessment",
"add_appliedcontrol",
"change_appliedcontrol",
"delete_appliedcontrol",
"add_riskscenario",
"change_riskscenario",
"delete_riskscenario",
"add_riskacceptance",
"change_riskacceptance",
"delete_riskacceptance",
"add_asset",
"change_asset",
"delete_asset",
"add_threat",
"change_threat",
"delete_threat",
"add_referencecontrol",
"change_referencecontrol",
"delete_referencecontrol",
"add_folder",
"change_folder",
"delete_folder",
"add_usergroup",
"change_usergroup",
"delete_usergroup",
]
)
editor_role.permissions.set(editor_permissions)
editor_role.save()
editor_role_assignment = RoleAssignment.objects.create(
user=user,
role=editor_role,
folder=folder,
is_recursive=True,
)
editor_role_assignment.perimeter_folders.add(folder)
editor_role_assignment.save()

assert user.is_editor

editors = User.get_editors()
assert len(editors) == 1
assert user in editors
52 changes: 50 additions & 2 deletions enterprise/backend/enterprise_core/serializers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from django.conf import settings
from rest_framework import serializers
from core.serializers import BaseModelSerializer
from iam.models import Folder
from core.serializers import (
BaseModelSerializer,
UserWriteSerializer as CommunityUserWriteSerializer,
)
from iam.models import Folder, User

from .models import ClientSettings
import structlog

logger = structlog.get_logger(__name__)


class FolderWriteSerializer(BaseModelSerializer):
Expand All @@ -14,6 +21,47 @@ class Meta:
]


class EditorPermissionMixin:
@staticmethod
def check_editor_permissions(instance, group):
editor_prefixes = {"add_", "change_", "delete_"}
editors = User.get_editors()
seats = settings.LICENSE_SEATS

perms = group.permissions
if any(perm.startswith(prefix) for prefix in editor_prefixes for perm in perms):
logger.info("Adding editor permissions to user", user=instance, group=group)
if instance not in editors and len(editors) >= seats:
logger.error(
"License seats exceeded, cannot add editor user groups to user",
user=instance,
seats=seats,
)
raise serializers.ValidationError(
{"user_groups": "errorLicenseSeatsExceeded"}
)


class UserWriteSerializer(CommunityUserWriteSerializer, EditorPermissionMixin):
def _update_user_groups(self, instance, validated_data):
if validated_data.get("user_groups"):
logger.info(
"Updating user groups",
user=instance,
groups=validated_data["user_groups"],
)
for group in validated_data["user_groups"]:
self.check_editor_permissions(instance, group)

def update(self, instance: User, validated_data):
self._update_user_groups(instance, validated_data)
return super().update(instance, validated_data)

def partial_update(self, instance, validated_data):
self._update_user_groups(instance, validated_data)
return super().partial_update(instance, validated_data)


class ClientSettingsWriteSerializer(BaseModelSerializer):
class Meta:
model = ClientSettings
Expand Down
6 changes: 2 additions & 4 deletions enterprise/backend/enterprise_core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def set_ciso_assistant_url(_, __, event_dict):
logger = structlog.getLogger(__name__)

FEATURE_FLAGS = {}
MODULE_PATHS = {}
MODULE_PATHS = {"serializers": "enterprise_core.serializers"}
ROUTES = {}
MODULES = {}

Expand Down Expand Up @@ -385,8 +385,6 @@ def set_ciso_assistant_url(_, __, event_dict):
},
}

MODULE_PATHS["serializers"] = ["enterprise_core.serializers"]

ROUTES["client-settings"] = {
"viewset": "enterprise_core.views.ClientSettingsViewSet",
"basename": "client-settings",
Expand All @@ -401,7 +399,7 @@ def set_ciso_assistant_url(_, __, event_dict):
"Enterprise startup info", feature_flags=FEATURE_FLAGS, module_paths=MODULE_PATHS
)

LICENSE_SEATS = int(os.environ.get("LICENSE_SEATS", 0))
LICENSE_SEATS = int(os.environ.get("LICENSE_SEATS", 1))
LICENSE_EXPIRATION = os.environ.get("LICENSE_EXPIRATION", "unset")

INSTALLED_APPS.append("enterprise_core")
7 changes: 4 additions & 3 deletions enterprise/backend/enterprise_core/views.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import mimetypes
import magic

import structlog
from core.views import BaseModelViewSet
from django.http import HttpResponse
from rest_framework import status
from rest_framework.permissions import AllowAny
from rest_framework.decorators import (
Expand All @@ -16,6 +13,9 @@

from django.conf import settings

from core.views import BaseModelViewSet
from iam.models import User

from .models import ClientSettings
from .serializers import ClientSettingsReadSerializer

Expand Down Expand Up @@ -145,6 +145,7 @@ def get_build(request):
"version": VERSION,
"build": BUILD,
"license_seats": LICENSE_SEATS,
"available_seats": LICENSE_SEATS - len(User.get_editors()),
"license_expiration": LICENSE_EXPIRATION,
}
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { ClientSettingsSchema } from '$lib/utils/client-settings';
import { BASE_API_URL } from '$lib/utils/constants';
import { safeTranslate } from '$lib/utils/i18n';
import { ClientSettingsSchema } from '$lib/utils/client-settings';
import * as m from '$paraglide/messages';
import { fail, type Actions } from '@sveltejs/kit';
import { setFlash } from 'sveltekit-flash-message/server';
import { setError, superValidate } from 'sveltekit-superforms';
import { zod } from 'sveltekit-superforms/adapters';
import type { PageServerLoad } from './$types';
import * as m from '$paraglide/messages';

export const load: PageServerLoad = async ({ fetch }) => {
const settings = await fetch(`${BASE_API_URL}/client-settings/`)
Expand Down Expand Up @@ -62,7 +62,7 @@ export const actions: Actions = {
return { form };
}
if (response.error) {
setFlash({ type: 'error', message: response.error }, event);
setFlash({ type: 'error', message: safeTranslate(response.error) }, event);
return { form };
}
Object.entries(response).forEach(([key, value]) => {
Expand Down Expand Up @@ -98,7 +98,6 @@ export const actions: Actions = {
}
}

const modelVerboseName: string = 'clientSettings';

return setFlash(
{
Expand Down
Loading
Loading