diff --git a/boefjes/.ci/docker-compose.yml b/boefjes/.ci/docker-compose.yml index bf7ee7aeadd..0d1933fd7ad 100644 --- a/boefjes/.ci/docker-compose.yml +++ b/boefjes/.ci/docker-compose.yml @@ -25,7 +25,7 @@ services: dockerfile: boefjes/Dockerfile args: - ENVIRONMENT=dev - command: bash -c "python -m cProfile -o .ci/bench_$(date +%Y_%m_%d-%H:%M:%S).pstat -m pytest -v -m slow tests/integration" + command: bash -c "python -m cProfile -o .ci/bench_$(date +%Y_%m_%d-%H:%M:%S).pstat -m pytest -v -m slow tests/integration/test_bench.py::test_migration" depends_on: - ci_bytes - ci_octopoes diff --git a/boefjes/Makefile b/boefjes/Makefile index 09e5fc21b47..6b077245d31 100644 --- a/boefjes/Makefile +++ b/boefjes/Makefile @@ -75,14 +75,21 @@ itest: ## Run the integration tests. $(ci-docker-compose) build $(ci-docker-compose) down --remove-orphans $(ci-docker-compose) run --rm katalogus_integration - $(ci-docker-compose) down + $(ci-docker-compose) stop -bench: ## Run the report benchmark. +migration_bench: ## Run the migration benchmark. $(ci-docker-compose) build $(ci-docker-compose) down --remove-orphans $(ci-docker-compose) run --rm migration_bench $(ci-docker-compose) stop +bench: ## Run the other benchmarks + $(ci-docker-compose) build + $(ci-docker-compose) down --remove-orphans + $(ci-docker-compose) run --rm katalogus_integration \ + python -m cProfile -o .ci/bench_$$(date +%Y_%m_%d-%H:%M:%S).pstat -m pytest -m slow --no-cov tests/integration + $(ci-docker-compose) stop + debian12: docker run --rm \ --env PKG_NAME=kat-boefjes \ diff --git a/boefjes/boefjes/dependencies/plugins.py b/boefjes/boefjes/dependencies/plugins.py index 6f216c6fe78..a4d4fb5dc6a 100644 --- a/boefjes/boefjes/dependencies/plugins.py +++ b/boefjes/boefjes/dependencies/plugins.py @@ -1,4 +1,3 @@ -import contextlib from collections.abc import Iterator from pathlib import Path from typing import Literal @@ -17,7 +16,6 @@ from boefjes.storage.interfaces import ( ConfigStorage, DuplicatePlugin, - NotFound, PluginNotFound, PluginStorage, SettingsNotConformingToSchema, @@ -45,8 +43,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): def get_all(self, organisation_id: str) -> list[PluginType]: all_plugins = self._get_all_without_enabled() + plugin_states = self.config_storage.get_state_by_id(organisation_id) - return [self._set_plugin_enabled(plugin, organisation_id) for plugin in all_plugins.values()] + for plugin in all_plugins.values(): + if plugin.id not in plugin_states: + continue + + plugin.enabled = plugin_states[plugin.id] + + return list(all_plugins.values()) def _get_all_without_enabled(self) -> dict[str, PluginType]: all_plugins = {plugin.id: plugin for plugin in self.plugin_storage.get_all()} @@ -217,12 +222,6 @@ def _assert_settings_match_schema(self, all_settings: dict, plugin_id: str, orga except ValidationError as e: raise SettingsNotConformingToSchema(plugin_id, e.message) from e - def _set_plugin_enabled(self, plugin: PluginType, organisation_id: str) -> PluginType: - with contextlib.suppress(KeyError, NotFound): - plugin.enabled = self.config_storage.is_enabled_by_id(plugin.id, organisation_id) - - return plugin - def get_plugin_service() -> Iterator[PluginService]: def closure(session: Session): diff --git a/boefjes/boefjes/katalogus/organisations.py b/boefjes/boefjes/katalogus/organisations.py index 259ea2e7943..29f307dc141 100644 --- a/boefjes/boefjes/katalogus/organisations.py +++ b/boefjes/boefjes/katalogus/organisations.py @@ -3,25 +3,11 @@ from boefjes.models import Organisation from boefjes.sql.db import ObjectNotFoundException from boefjes.sql.organisation_storage import get_organisations_store -from boefjes.storage.interfaces import OrganisationNotFound, OrganisationStorage +from boefjes.storage.interfaces import OrganisationStorage router = APIRouter(prefix="/organisations", tags=["organisations"]) -def check_organisation_exists( - organisation_id: str, storage: OrganisationStorage = Depends(get_organisations_store) -) -> None: - """ - Checks if an organisation exists, if not, creates it. - """ - with storage as store: - try: - store.get_by_id(organisation_id) - except OrganisationNotFound: - add_organisation(Organisation(id=organisation_id, name=organisation_id), storage) - storage.get_by_id(organisation_id) - - @router.get("", response_model=dict[str, Organisation]) def list_organisations(storage: OrganisationStorage = Depends(get_organisations_store)): return storage.get_all() diff --git a/boefjes/boefjes/katalogus/plugins.py b/boefjes/boefjes/katalogus/plugins.py index e63f11ab93c..1284401726a 100644 --- a/boefjes/boefjes/katalogus/plugins.py +++ b/boefjes/boefjes/katalogus/plugins.py @@ -1,5 +1,4 @@ import datetime -from functools import partial import structlog from croniter import croniter @@ -15,14 +14,11 @@ get_plugin_service, get_plugins_filter_parameters, ) -from boefjes.katalogus.organisations import check_organisation_exists from boefjes.models import FilterParameters, PaginationParameters, PluginType from boefjes.sql.plugin_storage import get_plugin_storage from boefjes.storage.interfaces import DuplicatePlugin, IntegrityError, NotAllowed, PluginStorage -router = APIRouter( - prefix="/organisations/{organisation_id}", tags=["plugins"], dependencies=[Depends(check_organisation_exists)] -) +router = APIRouter(prefix="/organisations/{organisation_id}", tags=["plugins"]) logger = structlog.get_logger(__name__) @@ -44,33 +40,34 @@ def list_plugins( pagination_params: PaginationParameters = Depends(get_pagination_parameters), plugin_service: PluginService = Depends(get_plugin_service), ) -> list[PluginType]: - with plugin_service as p: - if filter_params.ids: - try: - plugins = p.by_plugin_ids(filter_params.ids, organisation_id) - except KeyError: - raise HTTPException(status.HTTP_404_NOT_FOUND, "Plugin not found") - else: - plugins = p.get_all(organisation_id) + if filter_params.ids: + try: + plugins = plugin_service.by_plugin_ids(filter_params.ids, organisation_id) + except KeyError: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Plugin not found") + else: + plugins = plugin_service.get_all(organisation_id) # filter plugins by id, name or description if filter_params.q is not None: - plugins = filter(partial(_plugin_matches_query, query=filter_params.q), plugins) + plugins = [plugin for plugin in plugins if _plugin_matches_query(plugin, filter_params.q)] # filter plugins by type if filter_params.type is not None: - plugins = filter(lambda plugin: plugin.type == filter_params.type, plugins) + plugins = [plugin for plugin in plugins if plugin.type == filter_params.type] # filter plugins by state if filter_params.state is not None: - plugins = filter(lambda x: x.enabled is filter_params.state, plugins) + plugins = [plugin for plugin in plugins if plugin.enabled is filter_params.state] # filter plugins by oci_image if filter_params.oci_image is not None: - plugins = filter(lambda x: x.type == "boefje" and x.oci_image == filter_params.oci_image, plugins) + plugins = [ + plugin for plugin in plugins if plugin.type == "boefje" and plugin.oci_image == filter_params.oci_image + ] # filter plugins by scan level for boefje plugins - plugins = list(filter(lambda x: x.type != "boefje" or x.scan_level >= filter_params.scan_level, plugins)) + plugins = [plugin for plugin in plugins if plugin.type != "boefje" or plugin.scan_level >= filter_params.scan_level] if pagination_params.limit is None: return plugins[pagination_params.offset :] @@ -84,8 +81,7 @@ def get_plugin( plugin_id: str, organisation_id: str, plugin_service: PluginService = Depends(get_plugin_service) ) -> PluginType: try: - with plugin_service as p: - return p.by_plugin_id(plugin_id, organisation_id) + return plugin_service.by_plugin_id(plugin_id, organisation_id) except KeyError: raise HTTPException(status.HTTP_404_NOT_FOUND, "Plugin not found") diff --git a/boefjes/boefjes/katalogus/settings.py b/boefjes/boefjes/katalogus/settings.py index 07b9af453f2..4ec2b48cb83 100644 --- a/boefjes/boefjes/katalogus/settings.py +++ b/boefjes/boefjes/katalogus/settings.py @@ -1,19 +1,13 @@ from fastapi import APIRouter, Depends from boefjes.dependencies.plugins import PluginService, get_plugin_service -from boefjes.katalogus.organisations import check_organisation_exists -router = APIRouter( - prefix="/organisations/{organisation_id}/{plugin_id}/settings", - tags=["settings"], - dependencies=[Depends(check_organisation_exists)], -) +router = APIRouter(prefix="/organisations/{organisation_id}/{plugin_id}/settings", tags=["settings"]) @router.get("", response_model=dict) def list_settings(organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service)): - with plugin_service as p: - return p.get_all_settings(organisation_id, plugin_id) + return plugin_service.get_all_settings(organisation_id, plugin_id) @router.put("") diff --git a/boefjes/boefjes/local_repository.py b/boefjes/boefjes/local_repository.py index 719b26f06a8..48944bf6323 100644 --- a/boefjes/boefjes/local_repository.py +++ b/boefjes/boefjes/local_repository.py @@ -1,7 +1,7 @@ import json import pkgutil +from functools import cache, lru_cache from pathlib import Path -from typing import Any import structlog @@ -15,6 +15,7 @@ BoefjeResource, ModuleException, NormalizerResource, + hash_path, ) logger = structlog.get_logger(__name__) @@ -23,16 +24,11 @@ class LocalPluginRepository: def __init__(self, path: Path): self.path = path - self._cached_boefjes: dict[str, Any] | None = None - self._cached_normalizers: dict[str, Any] | None = None def get_all(self) -> list[PluginType]: - all_plugins = [boefje_resource.boefje for boefje_resource in self.resolve_boefjes().values()] - normalizers = [normalizer_resource.normalizer for normalizer_resource in self.resolve_normalizers().values()] - - all_plugins += normalizers - - return all_plugins + boefjes = [resource.boefje for resource in self.resolve_boefjes().values()] + normalizers = [resource.normalizer for resource in self.resolve_normalizers().values()] + return boefjes + normalizers def by_id(self, plugin_id: str) -> PluginType: boefjes = self.resolve_boefjes() @@ -107,66 +103,83 @@ def description_path(self, id_: str) -> Path | None: return boefjes[id_].path / "description.md" def resolve_boefjes(self) -> dict[str, BoefjeResource]: - if self._cached_boefjes: - return self._cached_boefjes + return _cached_resolve_boefjes(self.path) + + def resolve_normalizers(self) -> dict[str, NormalizerResource]: + return _cached_resolve_normalizers(self.path) - paths_and_packages = self._find_packages_in_path_containing_files([BOEFJE_DEFINITION_FILE]) - boefje_resources = [] - for path, package in paths_and_packages: - try: - boefje_resources.append(BoefjeResource(path, package)) - except ModuleException as exc: - logger.exception(exc) +@cache +def _cached_resolve_boefjes(path: Path) -> dict[str, BoefjeResource]: + paths_and_packages = _find_packages_in_path_containing_files(path, (BOEFJE_DEFINITION_FILE,)) + boefje_resources = [] - self._cached_boefjes = {resource.boefje.id: resource for resource in boefje_resources} + for path, package in paths_and_packages: + try: + boefje_resources.append(get_boefje_resource(path, package, hash_path(path))) + except ModuleException as exc: + logger.exception(exc) - return self._cached_boefjes + return {resource.boefje.id: resource for resource in boefje_resources} - def resolve_normalizers(self) -> dict[str, NormalizerResource]: - if self._cached_normalizers: - return self._cached_normalizers - paths_and_packages = self._find_packages_in_path_containing_files( - [NORMALIZER_DEFINITION_FILE, ENTRYPOINT_NORMALIZERS] - ) - normalizer_resources = [] +@cache +def _cached_resolve_normalizers(path: Path) -> dict[str, NormalizerResource]: + paths_and_packages = _find_packages_in_path_containing_files( + path, (NORMALIZER_DEFINITION_FILE, ENTRYPOINT_NORMALIZERS) + ) + normalizer_resources = [] + + for path, package in paths_and_packages: + try: + normalizer_resources.append(get_normalizer_resource(path, package, hash(path))) + except ModuleException as exc: + logger.exception(exc) + + return {resource.normalizer.id: resource for resource in normalizer_resources} + + +def _find_packages_in_path_containing_files(path: Path, required_files: tuple[str, ...]) -> list[tuple[Path, str]]: + prefix = create_relative_import_statement_from_cwd(path) + paths = [] + + for package in pkgutil.walk_packages([str(path)], prefix): + if not package.ispkg: + logger.debug("%s is not a package", package.name) + continue + + new_path = path / package.name.replace(prefix, "").replace(".", "/") + missing_files = [file for file in required_files if not (new_path / file).exists()] + + if missing_files: + logger.debug("Files %s not found for %s", missing_files, package.name) + continue - for path, package in paths_and_packages: - try: - normalizer_resources.append(NormalizerResource(path, package)) - except ModuleException as exc: - logger.exception(exc) + paths.append((new_path, package.name)) - self._cached_normalizers = {resource.normalizer.id: resource for resource in normalizer_resources} + return paths - return self._cached_normalizers - def _find_packages_in_path_containing_files(self, required_files: list[str]) -> list[tuple[Path, str]]: - prefix = self.create_relative_import_statement_from_cwd(self.path) - paths = [] +def create_relative_import_statement_from_cwd(package_dir: Path) -> str: + relative_path = str(package_dir.absolute()).replace(str(Path.cwd()), "") # e.g. "/boefjes/plugins" - for package in pkgutil.walk_packages([str(self.path)], prefix): - if not package.ispkg: - logger.debug("%s is not a package", package.name) - continue + return f"{relative_path[1:].replace('/', '.')}." # Turns into "boefjes.plugins." - path = self.path / package.name.replace(prefix, "").replace(".", "/") - missing_files = [file for file in required_files if not (path / file).exists()] - if missing_files: - logger.debug("Files %s not found for %s", missing_files, package.name) - continue +@lru_cache(maxsize=200) +def get_boefje_resource(path: Path, package: str, path_hash: str): + """The cache size in theory only has to be the amount of local boefjes available, but 200 gives us some extra + space. Adding the hash to the arguments makes sure we refresh this.""" - paths.append((path, package.name)) + return BoefjeResource(path, package, path_hash) - return paths - @staticmethod - def create_relative_import_statement_from_cwd(package_dir: Path) -> str: - relative_path = str(package_dir.absolute()).replace(str(Path.cwd()), "") # e.g. "/boefjes/plugins" +@lru_cache(maxsize=200) +def get_normalizer_resource(path: Path, package: str, path_hash: str): + """The cache size in theory only has to be the amount of local normalizers available, but 200 gives us some extra + space. Adding the hash to the arguments makes sure we refresh this.""" - return f"{relative_path[1:].replace('/', '.')}." # Turns into "boefjes.plugins." + return NormalizerResource(path, package) def get_local_repository(): diff --git a/boefjes/boefjes/migrations/versions/5be152459a7b_introduce_schema_field_to_boefje_model.py b/boefjes/boefjes/migrations/versions/5be152459a7b_introduce_schema_field_to_boefje_model.py index 470385e1879..da8aba03bf1 100644 --- a/boefjes/boefjes/migrations/versions/5be152459a7b_introduce_schema_field_to_boefje_model.py +++ b/boefjes/boefjes/migrations/versions/5be152459a7b_introduce_schema_field_to_boefje_model.py @@ -34,7 +34,7 @@ def upgrade() -> None: plugins = local_repo.get_all() logger.info("Found %s plugins", len(plugins)) - for plugin in local_repo.get_all(): + for plugin in plugins: schema = local_repo.schema(plugin.id) if schema: query = text("UPDATE boefje SET schema = :schema WHERE plugin_id = :plugin_id") # noqa: S608 diff --git a/boefjes/boefjes/plugins/models.py b/boefjes/boefjes/plugins/models.py index b468124a6ff..e7753604486 100644 --- a/boefjes/boefjes/plugins/models.py +++ b/boefjes/boefjes/plugins/models.py @@ -57,10 +57,10 @@ def get_runnable_module_from_package(package: str, module_file: str, *, paramete class BoefjeResource: """Represents a Boefje package that we can run. Throws a ModuleException if any validation fails.""" - def __init__(self, path: Path, package: str): + def __init__(self, path: Path, package: str, path_hash: str): self.path = path self.boefje: Boefje = Boefje.model_validate_json(path.joinpath(BOEFJE_DEFINITION_FILE).read_text()) - self.boefje.runnable_hash = get_runnable_hash(self.path) + self.boefje.runnable_hash = path_hash self.boefje.produces = self.boefje.produces.union(set(_default_mime_types(self.boefje))) self.module: Runnable | None = None @@ -86,7 +86,7 @@ def __init__(self, path: Path, package: str): self.module = get_runnable_module_from_package(package, ENTRYPOINT_NORMALIZERS, parameter_count=2) -def get_runnable_hash(path: Path) -> str: +def hash_path(path: Path) -> str: """Returns sha256(file1 + file2 + ...) of all files in the given path.""" folder_hash = hashlib.sha256() diff --git a/boefjes/boefjes/sql/config_storage.py b/boefjes/boefjes/sql/config_storage.py index cd2e2b31db8..b944088bc32 100644 --- a/boefjes/boefjes/sql/config_storage.py +++ b/boefjes/boefjes/sql/config_storage.py @@ -88,7 +88,7 @@ def is_enabled_by_id(self, plugin_id: str, organisation_id: str) -> bool: def get_enabled_boefjes(self, organisation_id: str) -> list[str]: enabled_boefjes = ( - self.session.query(BoefjeInDB) + self.session.query(BoefjeInDB.plugin_id) .join(BoefjeConfigInDB) .filter(BoefjeConfigInDB.boefje_id == BoefjeInDB.id) .join(OrganisationInDB) @@ -96,8 +96,40 @@ def get_enabled_boefjes(self, organisation_id: str) -> list[str]: .filter(OrganisationInDB.id == organisation_id) .filter(BoefjeConfigInDB.enabled) ) + return [x[0] for x in enabled_boefjes.all()] - return [x.plugin_id for x in enabled_boefjes.all()] + def get_enabled_normalizers(self, organisation_id: str) -> list[str]: + enabled_normalizers = ( + self.session.query(NormalizerInDB.plugin_id) + .join(NormalizerConfigInDB) + .filter(NormalizerConfigInDB.normalizer_id == NormalizerInDB.id) + .join(OrganisationInDB) + .filter(NormalizerConfigInDB.organisation_pk == OrganisationInDB.pk) + .filter(OrganisationInDB.id == organisation_id) + .filter(NormalizerConfigInDB.enabled) + ) + + return [plugin[0] for plugin in enabled_normalizers.all()] + + def get_state_by_id(self, organisation_id: str) -> dict[str, bool]: + enabled_boefjes = ( + self.session.query(BoefjeInDB.plugin_id, BoefjeConfigInDB.enabled) + .join(BoefjeConfigInDB) + .filter(BoefjeConfigInDB.boefje_id == BoefjeInDB.id) + .join(OrganisationInDB) + .filter(BoefjeConfigInDB.organisation_pk == OrganisationInDB.pk) + .filter(OrganisationInDB.id == organisation_id) + ) + enabled_normalizers = ( + self.session.query(NormalizerInDB.plugin_id, NormalizerConfigInDB.enabled) + .join(NormalizerConfigInDB) + .filter(NormalizerConfigInDB.normalizer_id == NormalizerInDB.id) + .join(OrganisationInDB) + .filter(NormalizerConfigInDB.organisation_pk == OrganisationInDB.pk) + .filter(OrganisationInDB.id == organisation_id) + ) + + return {plugin[0]: plugin[1] for plugin in enabled_boefjes.union_all(enabled_normalizers).all()} def _db_instance_by_id(self, organisation_id: str, plugin_id: str) -> BoefjeConfigInDB | NormalizerConfigInDB: instance = ( diff --git a/boefjes/boefjes/storage/interfaces.py b/boefjes/boefjes/storage/interfaces.py index ccb9831b319..be6132b2cdd 100644 --- a/boefjes/boefjes/storage/interfaces.py +++ b/boefjes/boefjes/storage/interfaces.py @@ -153,3 +153,9 @@ def is_enabled_by_id(self, plugin_id: str, organisation_id: str) -> bool: def get_enabled_boefjes(self, organisation_id: str) -> list[str]: raise NotImplementedError + + def get_enabled_normalizers(self, organisation_id: str) -> list[str]: + raise NotImplementedError + + def get_state_by_id(self, organisation_id: str) -> dict[str, bool]: + raise NotImplementedError diff --git a/boefjes/boefjes/storage/memory.py b/boefjes/boefjes/storage/memory.py index df14785c240..1db992ba7f1 100644 --- a/boefjes/boefjes/storage/memory.py +++ b/boefjes/boefjes/storage/memory.py @@ -123,3 +123,13 @@ def get_enabled_boefjes(self, organisation_id: str) -> list[str]: for plugin_id, enabled in self._enabled.get(organisation_id, {}).items() if enabled and "norm" not in plugin_id ] + + def get_enabled_normalizers(self, organisation_id: str) -> list[str]: + return [ + plugin_id + for plugin_id, enabled in self._enabled.get(organisation_id, {}).items() + if enabled and "norm" in plugin_id + ] + + def get_state_by_id(self, organisation_id: str) -> dict[str, bool]: + return self._enabled.get(organisation_id, {}) diff --git a/boefjes/tests/conftest.py b/boefjes/tests/conftest.py index 7c74bb383d6..917af0d4a71 100644 --- a/boefjes/tests/conftest.py +++ b/boefjes/tests/conftest.py @@ -20,10 +20,16 @@ from boefjes.dependencies.plugins import PluginService, get_plugin_service from boefjes.job_handler import bytes_api_client from boefjes.job_models import BoefjeMeta, NormalizerMeta -from boefjes.katalogus.organisations import check_organisation_exists from boefjes.katalogus.root import app from boefjes.local import LocalBoefjeJobRunner, LocalNormalizerJobRunner -from boefjes.local_repository import LocalPluginRepository, get_local_repository +from boefjes.local_repository import ( + LocalPluginRepository, + _cached_resolve_boefjes, + _cached_resolve_normalizers, + get_boefje_resource, + get_local_repository, + get_normalizer_resource, +) from boefjes.models import Organisation from boefjes.runtime_interfaces import Handler, WorkerManager from boefjes.sql.config_storage import SQLConfigStorage, create_encrypter @@ -129,6 +135,14 @@ def get_all(self) -> list[BoefjeMeta | NormalizerMeta]: return [self.queue.get() for _ in range(self.queue.qsize())] +@pytest.fixture(autouse=True) +def clear_caches(): + get_boefje_resource.cache_clear() + get_normalizer_resource.cache_clear() + _cached_resolve_boefjes.cache_clear() + _cached_resolve_normalizers.cache_clear() + + @pytest.fixture def item_handler(tmp_path: Path): return MockHandler() @@ -255,7 +269,6 @@ def get_service(organisation_id: str): app.dependency_overrides[get_organisations_store] = lambda: _store app.dependency_overrides[get_plugin_service] = get_service - app.dependency_overrides[check_organisation_exists] = lambda: None yield client diff --git a/boefjes/tests/integration/test_api.py b/boefjes/tests/integration/test_api.py index 327274ce937..40dfb1c5b7b 100644 --- a/boefjes/tests/integration/test_api.py +++ b/boefjes/tests/integration/test_api.py @@ -20,7 +20,8 @@ def test_filter_plugins(test_client, organisation): assert len(response.json()) == 101 response = test_client.get(f"/v1/organisations/{organisation.id}/plugins?plugin_type=boefje") assert len(response.json()) == 45 - + response = test_client.get(f"/v1/organisations/{organisation.id}/plugins?state=true") + assert len(response.json()) == 62 response = test_client.get(f"/v1/organisations/{organisation.id}/plugins?limit=10") assert len(response.json()) == 10 diff --git a/boefjes/tests/integration/test_bench.py b/boefjes/tests/integration/test_bench.py index d017a33a5d2..dd9e055b90d 100644 --- a/boefjes/tests/integration/test_bench.py +++ b/boefjes/tests/integration/test_bench.py @@ -16,6 +16,7 @@ from tests.loading import get_boefje_meta, get_normalizer_meta +@pytest.mark.skipif("not os.getenv('DATABASE_MIGRATION')") @pytest.mark.slow def test_migration( octopoes_api_connector: OctopoesAPIConnector, @@ -115,3 +116,8 @@ def test_migration( assert observation.source_method == "boefje_udp" assert octopoes_api_connector.list_objects(set(), valid_time).count == total_oois + + +@pytest.mark.slow +def test_plugins_bench(plugin_service, organisation): + plugin_service.get_all(organisation.id) diff --git a/octopoes/octopoes/api/router.py b/octopoes/octopoes/api/router.py index b03d673b06e..62db165bd40 100644 --- a/octopoes/octopoes/api/router.py +++ b/octopoes/octopoes/api/router.py @@ -339,6 +339,24 @@ def save_declaration(declaration: ValidatedDeclaration, octopoes: OctopoesServic octopoes.commit() +@router.post("/declarations/save_many", tags=["Origins"]) +def save_many_declarations( + declarations: list[ValidatedDeclaration], octopoes: OctopoesService = Depends(octopoes_service) +) -> None: + for declaration in declarations: + origin = Origin( + origin_type=OriginType.DECLARATION, + method=declaration.method if declaration.method else "manual", + source=declaration.ooi.reference, + source_method=declaration.source_method, + result=[declaration.ooi.reference], + task_id=declaration.task_id if declaration.task_id else uuid.uuid4(), + ) + octopoes.save_origin(origin, [declaration.ooi], declaration.valid_time, declaration.end_valid_time) + + octopoes.commit() + + @router.post("/affirmations", tags=["Origins"]) def save_affirmation(affirmation: ValidatedAffirmation, octopoes: OctopoesService = Depends(octopoes_service)) -> None: origin = Origin( diff --git a/octopoes/octopoes/connector/octopoes.py b/octopoes/octopoes/connector/octopoes.py index 816e970b0b1..d76a03480e0 100644 --- a/octopoes/octopoes/connector/octopoes.py +++ b/octopoes/octopoes/connector/octopoes.py @@ -188,6 +188,15 @@ def save_declaration(self, declaration: Declaration) -> None: self.logger.info("Saved declaration", declaration=declaration, event_code=DECLARATION_CREATED) + def save_many_declarations(self, declarations: list[Declaration]) -> None: + self.session.post( + f"/{self.client}/declarations/save_many", + headers={"Content-Type": "application/json"}, + content=TypeAdapter(list[Declaration]).dump_json(declarations), + ) + + self.logger.info("Saved %s declarations", len(declarations), event_code=DECLARATION_CREATED) + def save_affirmation(self, affirmation: Affirmation) -> None: self.session.post( f"/{self.client}/affirmations", diff --git a/octopoes/tests/integration/test_api_connector.py b/octopoes/tests/integration/test_api_connector.py index 6795ecafaf8..17dac4724c6 100644 --- a/octopoes/tests/integration/test_api_connector.py +++ b/octopoes/tests/integration/test_api_connector.py @@ -67,7 +67,16 @@ def test_bulk_operations(octopoes_api_connector: OctopoesAPIConnector, valid_tim assert octopoes_api_connector.list_objects(types={Network, Hostname}, valid_time=valid_time).count == 6 with pytest.raises(ObjectNotFoundException): - octopoes_api_connector.delete_many(["bla"], valid_time=valid_time) + octopoes_api_connector.delete_many(["test"], valid_time=valid_time) + + assert len(octopoes_api_connector.list_origins(origin_type=OriginType.DECLARATION, valid_time=valid_time)) == 1 + + octopoes_api_connector.save_many_declarations([Declaration(ooi=h, valid_time=valid_time) for h in hostnames]) + + assert ( + len(octopoes_api_connector.list_origins(origin_type=OriginType.DECLARATION, valid_time=valid_time)) + == len(hostnames) + 1 + ) def test_history(octopoes_api_connector: OctopoesAPIConnector): diff --git a/rocky/rocky/views/upload_csv.py b/rocky/rocky/views/upload_csv.py index f2056b6aa3d..6690e9bd438 100644 --- a/rocky/rocky/views/upload_csv.py +++ b/rocky/rocky/views/upload_csv.py @@ -11,6 +11,7 @@ from django.urls.base import reverse_lazy from django.utils.translation import gettext as _ from django.views.generic.edit import FormView +from httpx import HTTPError from pydantic import ValidationError from tools.forms.upload_csv import CSV_ERRORS from tools.forms.upload_oois import UploadOOICSVForm @@ -158,15 +159,14 @@ def process_csv(self, form): csv_data = io.StringIO(csv_raw_data.decode("UTF-8")) rows_with_error = [] + oois = [] try: for row_number, row in enumerate(csv.DictReader(csv_data, delimiter=",", quotechar='"'), start=1): if not row: continue # skip empty lines try: ooi, level = self.get_ooi_from_csv(object_type, row) - self.octopoes_api_connector.save_declaration( - Declaration(ooi=ooi, valid_time=datetime.now(timezone.utc), task_id=task_id) - ) + oois.append(Declaration(ooi=ooi, valid_time=datetime.now(timezone.utc), task_id=task_id)) if isinstance(level, int): self.raise_clearance_level(ooi.reference, level) except ValidationError: @@ -179,3 +179,8 @@ def process_csv(self, form): self.add_success_notification(_("Object(s) successfully added.")) except (csv.Error, IndexError): return self.add_error_notification(CSV_ERRORS["csv_error"]) + + try: + self.octopoes_api_connector.save_many_declarations(oois) + except HTTPError: + return self.add_error_notification("Failed to save data from the CSV") diff --git a/rocky/tests/test_upload_csv.py b/rocky/tests/test_upload_csv.py index 8ef08718a8d..3a4d2b22312 100644 --- a/rocky/tests/test_upload_csv.py +++ b/rocky/tests/test_upload_csv.py @@ -122,7 +122,8 @@ def test_upload_csv( response = UploadCSV.as_view()(request, organization_code=redteam_member.organization.code) assert response.status_code == 302 - assert mock_organization_view_octopoes().save_declaration.call_count == expected_ooi_counts + assert mock_organization_view_octopoes().save_declaration.call_count == expected_ooi_counts / 2 + assert mock_organization_view_octopoes().save_many_declarations.call_count == 1 task_id = mock_bytes_client().add_manual_proof.call_args[0][0] mock_bytes_client().add_manual_proof.assert_called_once_with(