Skip to content

Commit

Permalink
Refactor library import function and use atomic transaction for import
Browse files Browse the repository at this point in the history
  • Loading branch information
nas-tabchiche committed Feb 13, 2024
1 parent 0a6201f commit 268d9c8
Showing 1 changed file with 58 additions and 52 deletions.
110 changes: 58 additions & 52 deletions backend/library/utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,19 @@
import json
import os
from typing import List, Union

import yaml
from ciso_assistant import settings
from core.models import (
Framework,
RequirementNode,
RequirementNode,
RequirementLevel,
Library,
RequirementNode,
RiskMatrix,
SecurityFunction,
Threat,
)
from core.models import Threat, SecurityFunction, RiskMatrix
from django.contrib import messages
from django.contrib.auth.models import Permission
from iam.models import Folder, RoleAssignment
from ciso_assistant import settings
from django.utils.translation import gettext_lazy as _

from .validators import *

import os
import yaml
import json

from typing import Union, List
from django.db import transaction
from iam.models import Folder


def get_available_library_files():
Expand Down Expand Up @@ -69,8 +64,8 @@ def get_library_names(libraries):
names: list of available library names
"""
names = []
for l in libraries:
names.append(l.get("name"))
for lib in libraries:
names.append(lib.get("name"))
return names


Expand All @@ -85,9 +80,9 @@ def get_library(urn: str) -> dict | None:
library: library with the given urn
"""
libraries = get_available_libraries()
for l in libraries:
if l["urn"] == urn:
return l
for lib in libraries:
if lib["urn"] == urn:
return lib
return None


Expand Down Expand Up @@ -465,33 +460,26 @@ def init(self) -> Union[str, None]:
) is not None:
return security_function_import_error

def import_library(self) -> Union[str, None]:
if (error_message := self.init()) is not None:
return error_message

if self._library_data.get("dependencies"):
for dependency in self._library_data["dependencies"]:
if not Library.objects.filter(urn=dependency).exists():
import_library_view(get_library(dependency))
def check_and_import_dependencies(self):
"""Check and import library dependencies."""
dependencies = self._library_data.get("dependencies", [])
for dependency in dependencies:
if not Library.objects.filter(urn=dependency).exists():
import_library_view(get_library(dependency))

def create_or_update_library(self):
"""Create or update the library object."""
_urn = self._library_data["urn"]
_locale = self._library_data["locale"]
_default_locale = not Library.objects.filter(urn=_urn).exists()

# todo: import only if new or newer version
# if Library.objects.filter(
# urn=self._library_data['urn'],
# locale=self._library_data["locale"]
# ).exists():
# return "A library with the same URN and the same locale value has already been loaded !"
_urn = self._library_data["urn"]
_locale = self._library_data["locale"]
library_object, _created = Library.objects.update_or_create(
defaults={
"ref_id": self._library_data["ref_id"],
"name": self._library_data.get("name"),
"description": self._library_data.get("description", None),
"urn": _urn,
"locale": self._library_data["locale"],
"locale": _locale,
"default_locale": _default_locale,
"version": self._library_data.get("version", None),
"provider": self._library_data.get("provider", None),
Expand All @@ -501,25 +489,42 @@ def import_library(self) -> Union[str, None]:
urn=_urn,
locale=_locale,
)
return library_object

import_error_msg = None
try:
if self._framework_importer is not None:
self._framework_importer.import_framework(library_object)
def import_objects(self, library_object):
"""Import library objects."""
if self._framework_importer is not None:
self._framework_importer.import_framework(library_object)

for threat in self._threats:
threat.import_threat(library_object)

for threat in self._threats:
threat.import_threat(library_object)
for security_function in self._security_functions:
security_function.import_security_function(library_object)

for security_function in self._security_functions:
security_function.import_security_function(library_object)
for risk_matrix in self._risk_matrices:
risk_matrix.import_risk_matrix(library_object)

for risk_matrix in self._risk_matrices:
risk_matrix.import_risk_matrix(library_object)
def import_library(self):
"""Main method to import a library."""
if (error_message := self.init()) is not None:
return error_message

self.check_and_import_dependencies()

try:
with transaction.atomic():
library_object = self.create_or_update_library()
self.import_objects(library_object)
library_object.dependencies.set(
Library.objects.filter(
urn__in=self._library_data.get("dependencies", [])
)
)
except Exception as e:
print("lib exception", e)
library_object.delete()
raise e
# TODO: Switch to proper logging
print(f"Library import exception: {e}")
raise


def import_library_view(library: dict) -> Union[str, None]:
Expand All @@ -536,5 +541,6 @@ def import_library_view(library: dict) -> Union[str, None]:
optional_error : Union[str,None]
A string describing the error if the function fails and returns None on success.
"""
# NOTE: We should just use LibraryImporter.import_library at this point
library_importer = LibraryImporter(library)
return library_importer.import_library()

0 comments on commit 268d9c8

Please sign in to comment.