From 268d9c8b34b858fd42f750b3f1bf1ca1eacbef2b Mon Sep 17 00:00:00 2001 From: Nassim Tabchiche Date: Tue, 13 Feb 2024 16:52:59 +0100 Subject: [PATCH] Refactor library import function and use atomic transaction for import --- backend/library/utils.py | 110 +++++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 52 deletions(-) diff --git a/backend/library/utils.py b/backend/library/utils.py index 332f54de6..be7603a03 100644 --- a/backend/library/utils.py +++ b/backend/library/utils.py @@ -1,24 +1,19 @@ +import json +import os +from typing import List, Union + +import yaml +from ciso_assistant import settings from core.models import ( Framework, - RequirementNode, - RequirementNode, - RequirementLevel, Library, + RequirementNode, + RiskMatrix, + SecurityFunction, + Threat, ) -from core.models import Threat, SecurityFunction, RiskMatrix -from django.contrib import messages -from django.contrib.auth.models import Permission -from iam.models import Folder, RoleAssignment -from ciso_assistant import settings -from django.utils.translation import gettext_lazy as _ - -from .validators import * - -import os -import yaml -import json - -from typing import Union, List +from django.db import transaction +from iam.models import Folder def get_available_library_files(): @@ -69,8 +64,8 @@ def get_library_names(libraries): names: list of available library names """ names = [] - for l in libraries: - names.append(l.get("name")) + for lib in libraries: + names.append(lib.get("name")) return names @@ -85,9 +80,9 @@ def get_library(urn: str) -> dict | None: library: library with the given urn """ libraries = get_available_libraries() - for l in libraries: - if l["urn"] == urn: - return l + for lib in libraries: + if lib["urn"] == urn: + return lib return None @@ -465,33 +460,26 @@ def init(self) -> Union[str, None]: ) is not None: return security_function_import_error - def import_library(self) -> Union[str, None]: - if (error_message := self.init()) is not None: - return error_message - - if self._library_data.get("dependencies"): - for dependency in self._library_data["dependencies"]: - if not Library.objects.filter(urn=dependency).exists(): - import_library_view(get_library(dependency)) + def check_and_import_dependencies(self): + """Check and import library dependencies.""" + dependencies = self._library_data.get("dependencies", []) + for dependency in dependencies: + if not Library.objects.filter(urn=dependency).exists(): + import_library_view(get_library(dependency)) + def create_or_update_library(self): + """Create or update the library object.""" _urn = self._library_data["urn"] + _locale = self._library_data["locale"] _default_locale = not Library.objects.filter(urn=_urn).exists() - # todo: import only if new or newer version - # if Library.objects.filter( - # urn=self._library_data['urn'], - # locale=self._library_data["locale"] - # ).exists(): - # return "A library with the same URN and the same locale value has already been loaded !" - _urn = self._library_data["urn"] - _locale = self._library_data["locale"] library_object, _created = Library.objects.update_or_create( defaults={ "ref_id": self._library_data["ref_id"], "name": self._library_data.get("name"), "description": self._library_data.get("description", None), "urn": _urn, - "locale": self._library_data["locale"], + "locale": _locale, "default_locale": _default_locale, "version": self._library_data.get("version", None), "provider": self._library_data.get("provider", None), @@ -501,25 +489,42 @@ def import_library(self) -> Union[str, None]: urn=_urn, locale=_locale, ) + return library_object - import_error_msg = None - try: - if self._framework_importer is not None: - self._framework_importer.import_framework(library_object) + def import_objects(self, library_object): + """Import library objects.""" + if self._framework_importer is not None: + self._framework_importer.import_framework(library_object) + + for threat in self._threats: + threat.import_threat(library_object) - for threat in self._threats: - threat.import_threat(library_object) + for security_function in self._security_functions: + security_function.import_security_function(library_object) - for security_function in self._security_functions: - security_function.import_security_function(library_object) + for risk_matrix in self._risk_matrices: + risk_matrix.import_risk_matrix(library_object) - for risk_matrix in self._risk_matrices: - risk_matrix.import_risk_matrix(library_object) + def import_library(self): + """Main method to import a library.""" + if (error_message := self.init()) is not None: + return error_message + self.check_and_import_dependencies() + + try: + with transaction.atomic(): + library_object = self.create_or_update_library() + self.import_objects(library_object) + library_object.dependencies.set( + Library.objects.filter( + urn__in=self._library_data.get("dependencies", []) + ) + ) except Exception as e: - print("lib exception", e) - library_object.delete() - raise e + # TODO: Switch to proper logging + print(f"Library import exception: {e}") + raise def import_library_view(library: dict) -> Union[str, None]: @@ -536,5 +541,6 @@ def import_library_view(library: dict) -> Union[str, None]: optional_error : Union[str,None] A string describing the error if the function fails and returns None on success. """ + # NOTE: We should just use LibraryImporter.import_library at this point library_importer = LibraryImporter(library) return library_importer.import_library()