Skip to content

Commit

Permalink
Merge pull request #894 from intuitem/CA-470-implement-hard-limits-on…
Browse files Browse the repository at this point in the history
…-seats

Ca 470 implement hard limits on seats
  • Loading branch information
Mohamed-Hacene authored Oct 9, 2024
2 parents be5ee6d + c0b3b8b commit cc3616d
Show file tree
Hide file tree
Showing 15 changed files with 261 additions and 59 deletions.
1 change: 1 addition & 0 deletions .github/workflows/functional-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ jobs:
echo EMAIL_HOST_PASSWORD=password >> .env
echo EMAIL_PORT=1025 >> .env
echo DJANGO_SETTINGS_MODULE=enterprise_core.settings >> .env
echo LICENSE_SEATS=999 >> .env
- name: Run migrations
working-directory: ${{ env.backend-directory }}
run: |
Expand Down
23 changes: 17 additions & 6 deletions backend/core/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@
from rest_framework.views import APIView
from weasyprint import HTML

from ciso_assistant.settings import (
BUILD,
VERSION,
)
from core.helpers import *
from core.models import (
AppliedControl,
Expand All @@ -56,12 +52,19 @@
from .models import *
from .serializers import *

import structlog

logger = structlog.get_logger(__name__)

User = get_user_model()

SHORT_CACHE_TTL = 2 # mn
MED_CACHE_TTL = 5 # mn
LONG_CACHE_TTL = 60 # mn

SETTINGS_MODULE = __import__(os.environ.get("DJANGO_SETTINGS_MODULE"))
MODULE_PATHS = SETTINGS_MODULE.settings.MODULE_PATHS


class BaseModelViewSet(viewsets.ModelViewSet):
filter_backends = [
Expand Down Expand Up @@ -95,13 +98,19 @@ def get_queryset(self):
return queryset

def get_serializer_class(self, **kwargs):
MODULE_PATHS = settings.MODULE_PATHS
serializer_factory = SerializerFactory(
self.serializers_module, *MODULE_PATHS.get("serializers", [])
self.serializers_module, MODULE_PATHS.get("serializers", [])
)
serializer_class = serializer_factory.get_serializer(
self.model.__name__, kwargs.get("action", self.action)
)
logger.debug(
"Serializer class",
serializer_class=serializer_class,
action=kwargs.get("action", self.action),
viewset=self,
module_paths=MODULE_PATHS,
)

return serializer_class

Expand Down Expand Up @@ -1941,6 +1950,8 @@ def get_build(request):
"""
API endpoint that returns the build version of the application.
"""
BUILD = settings.BUILD
VERSION = settings.VERSION
return Response({"version": VERSION, "build": BUILD})


Expand Down
31 changes: 25 additions & 6 deletions backend/iam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ def get_localization_dict(self) -> dict:
"role": BUILTIN_USERGROUP_CODENAMES.get(self.name),
}

@property
def permissions(self):
return RoleAssignment.get_permissions(self)


class UserManager(BaseUserManager):
use_in_migrations = True
Expand Down Expand Up @@ -502,6 +506,19 @@ def get_admin_users() -> List[Self]:
def is_admin(self) -> bool:
return self.user_groups.filter(name="BI-UG-ADM").exists()

@property
def is_editor(self) -> bool:
permissions = RoleAssignment.get_permissions(self)
editor_prefixes = {"add_", "change_", "delete_"}
return any(
any(perm.startswith(prefix) for prefix in editor_prefixes)
for perm in permissions
)

@classmethod
def get_editors(cls) -> List[Self]:
return [user for user in cls.objects.all() if user.is_editor]


class Role(NameDescriptionMixin, FolderMixin):
"""A role is a list of permissions"""
Expand Down Expand Up @@ -718,18 +735,20 @@ def is_user_assigned(self, user) -> bool:
)

@staticmethod
def get_role_assignments(user):
def get_role_assignments(principal: AbstractBaseUser | AnonymousUser | UserGroup):
"""get all role assignments attached to a user directly or indirectly"""
assignments = list(user.roleassignment_set.all())
for user_group in user.user_groups.all():
assignments += list(user_group.roleassignment_set.all())
assignments = list(principal.roleassignment_set.all())
if hasattr(principal, "user_groups"):
for user_group in principal.user_groups.all():
assignments += list(user_group.roleassignment_set.all())
assignments += list(principal.roleassignment_set.all())
return assignments

@staticmethod
def get_permissions(user: AbstractBaseUser | AnonymousUser):
def get_permissions(principal: AbstractBaseUser | AnonymousUser | UserGroup):
"""get all permissions attached to a user directly or indirectly"""
permissions = {}
for ra in RoleAssignment.get_role_assignments(user):
for ra in RoleAssignment.get_role_assignments(principal):
for p in ra.role.permissions.all():
permission_dict = {p.codename: {"str": str(p)}}
permissions.update(permission_dict)
Expand Down
116 changes: 116 additions & 0 deletions backend/iam/tests/test_user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import pytest
from django.contrib.auth.models import Permission

from core.tests.fixtures import *
from iam.models import Folder, Role, RoleAssignment, User


@pytest.mark.django_db
class TestUser:
pytestmark = pytest.mark.django_db

@pytest.mark.usefixtures("domain_project_fixture")
def test_reader_user_is_not_editor(self):
user = User.objects.create_user(email="[email protected]", password="password")
assert user is not None

folder = Folder.objects.filter(content_type=Folder.ContentType.DOMAIN).last()
reader_role = Role.objects.create(name="test reader")
reader_permissions = Permission.objects.filter(
codename__in=[
"view_project",
"view_riskassessment",
"view_appliedcontrol",
"view_riskscenario",
"view_riskacceptance",
"view_asset",
"view_threat",
"view_referencecontrol",
"view_folder",
"view_usergroup",
]
)
reader_role.permissions.set(reader_permissions)
reader_role.save()
reader_role_assignment = RoleAssignment.objects.create(
user=user,
role=reader_role,
folder=folder,
is_recursive=True,
)
reader_role_assignment.perimeter_folders.add(folder)
reader_role_assignment.save()

assert not user.is_editor

editors = User.get_editors()
assert len(editors) == 0
assert user not in editors

@pytest.mark.usefixtures("domain_project_fixture")
def test_editor_user_is_editor(self):
user = User.objects.create_user(email="[email protected]", password="password")
assert user is not None

folder = Folder.objects.filter(content_type=Folder.ContentType.DOMAIN).last()
editor_role = Role.objects.create(name="test editor")
editor_permissions = Permission.objects.filter(
codename__in=[
"view_project",
"view_riskassessment",
"view_appliedcontrol",
"view_riskscenario",
"view_riskacceptance",
"view_asset",
"view_threat",
"view_referencecontrol",
"view_folder",
"view_usergroup",
"add_project",
"change_project",
"delete_project",
"add_riskassessment",
"change_riskassessment",
"delete_riskassessment",
"add_appliedcontrol",
"change_appliedcontrol",
"delete_appliedcontrol",
"add_riskscenario",
"change_riskscenario",
"delete_riskscenario",
"add_riskacceptance",
"change_riskacceptance",
"delete_riskacceptance",
"add_asset",
"change_asset",
"delete_asset",
"add_threat",
"change_threat",
"delete_threat",
"add_referencecontrol",
"change_referencecontrol",
"delete_referencecontrol",
"add_folder",
"change_folder",
"delete_folder",
"add_usergroup",
"change_usergroup",
"delete_usergroup",
]
)
editor_role.permissions.set(editor_permissions)
editor_role.save()
editor_role_assignment = RoleAssignment.objects.create(
user=user,
role=editor_role,
folder=folder,
is_recursive=True,
)
editor_role_assignment.perimeter_folders.add(folder)
editor_role_assignment.save()

assert user.is_editor

editors = User.get_editors()
assert len(editors) == 1
assert user in editors
52 changes: 50 additions & 2 deletions enterprise/backend/enterprise_core/serializers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from django.conf import settings
from rest_framework import serializers
from core.serializers import BaseModelSerializer
from iam.models import Folder
from core.serializers import (
BaseModelSerializer,
UserWriteSerializer as CommunityUserWriteSerializer,
)
from iam.models import Folder, User

from .models import ClientSettings
import structlog

logger = structlog.get_logger(__name__)


class FolderWriteSerializer(BaseModelSerializer):
Expand All @@ -14,6 +21,47 @@ class Meta:
]


class EditorPermissionMixin:
@staticmethod
def check_editor_permissions(instance, group):
editor_prefixes = {"add_", "change_", "delete_"}
editors = User.get_editors()
seats = settings.LICENSE_SEATS

perms = group.permissions
if any(perm.startswith(prefix) for prefix in editor_prefixes for perm in perms):
logger.info("Adding editor permissions to user", user=instance, group=group)
if instance not in editors and len(editors) >= seats:
logger.error(
"License seats exceeded, cannot add editor user groups to user",
user=instance,
seats=seats,
)
raise serializers.ValidationError(
{"user_groups": "errorLicenseSeatsExceeded"}
)


class UserWriteSerializer(CommunityUserWriteSerializer, EditorPermissionMixin):
def _update_user_groups(self, instance, validated_data):
if validated_data.get("user_groups"):
logger.info(
"Updating user groups",
user=instance,
groups=validated_data["user_groups"],
)
for group in validated_data["user_groups"]:
self.check_editor_permissions(instance, group)

def update(self, instance: User, validated_data):
self._update_user_groups(instance, validated_data)
return super().update(instance, validated_data)

def partial_update(self, instance, validated_data):
self._update_user_groups(instance, validated_data)
return super().partial_update(instance, validated_data)


class ClientSettingsWriteSerializer(BaseModelSerializer):
class Meta:
model = ClientSettings
Expand Down
6 changes: 2 additions & 4 deletions enterprise/backend/enterprise_core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def set_ciso_assistant_url(_, __, event_dict):
logger = structlog.getLogger(__name__)

FEATURE_FLAGS = {}
MODULE_PATHS = {}
MODULE_PATHS = {"serializers": "enterprise_core.serializers"}
ROUTES = {}
MODULES = {}

Expand Down Expand Up @@ -385,8 +385,6 @@ def set_ciso_assistant_url(_, __, event_dict):
},
}

MODULE_PATHS["serializers"] = ["enterprise_core.serializers"]

ROUTES["client-settings"] = {
"viewset": "enterprise_core.views.ClientSettingsViewSet",
"basename": "client-settings",
Expand All @@ -401,7 +399,7 @@ def set_ciso_assistant_url(_, __, event_dict):
"Enterprise startup info", feature_flags=FEATURE_FLAGS, module_paths=MODULE_PATHS
)

LICENSE_SEATS = int(os.environ.get("LICENSE_SEATS", 0))
LICENSE_SEATS = int(os.environ.get("LICENSE_SEATS", 1))
LICENSE_EXPIRATION = os.environ.get("LICENSE_EXPIRATION", "unset")

INSTALLED_APPS.append("enterprise_core")
7 changes: 4 additions & 3 deletions enterprise/backend/enterprise_core/views.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import mimetypes
import magic

import structlog
from core.views import BaseModelViewSet
from django.http import HttpResponse
from rest_framework import status
from rest_framework.permissions import AllowAny
from rest_framework.decorators import (
Expand All @@ -16,6 +13,9 @@

from django.conf import settings

from core.views import BaseModelViewSet
from iam.models import User

from .models import ClientSettings
from .serializers import ClientSettingsReadSerializer

Expand Down Expand Up @@ -145,6 +145,7 @@ def get_build(request):
"version": VERSION,
"build": BUILD,
"license_seats": LICENSE_SEATS,
"available_seats": LICENSE_SEATS - len(User.get_editors()),
"license_expiration": LICENSE_EXPIRATION,
}
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import { ClientSettingsSchema } from '$lib/utils/client-settings';
import { BASE_API_URL } from '$lib/utils/constants';
import { safeTranslate } from '$lib/utils/i18n';
import { ClientSettingsSchema } from '$lib/utils/client-settings';
import * as m from '$paraglide/messages';
import { fail, type Actions } from '@sveltejs/kit';
import { setFlash } from 'sveltekit-flash-message/server';
import { setError, superValidate } from 'sveltekit-superforms';
import { zod } from 'sveltekit-superforms/adapters';
import type { PageServerLoad } from './$types';
import * as m from '$paraglide/messages';

export const load: PageServerLoad = async ({ fetch }) => {
const settings = await fetch(`${BASE_API_URL}/client-settings/`)
Expand Down Expand Up @@ -62,7 +62,7 @@ export const actions: Actions = {
return { form };
}
if (response.error) {
setFlash({ type: 'error', message: response.error }, event);
setFlash({ type: 'error', message: safeTranslate(response.error) }, event);
return { form };
}
Object.entries(response).forEach(([key, value]) => {
Expand Down Expand Up @@ -98,7 +98,6 @@ export const actions: Actions = {
}
}

const modelVerboseName: string = 'clientSettings';

return setFlash(
{
Expand Down
Loading

0 comments on commit cc3616d

Please sign in to comment.