From ec9d80e87034ccc553433152d0bd04956f62129e Mon Sep 17 00:00:00 2001 From: Jeroen Dekkers Date: Tue, 10 Dec 2024 15:35:07 +0100 Subject: [PATCH] Fix typing in more places and configure mypy to follow imports (#3932) Co-authored-by: stephanie0x00 <9821756+stephanie0x00@users.noreply.github.com> Co-authored-by: ammar92 Co-authored-by: Jan Klopper --- .pre-commit-config.yaml | 1 + boefjes/boefjes/__main__.py | 2 +- boefjes/boefjes/api.py | 8 +++--- boefjes/boefjes/app.py | 6 ++-- boefjes/boefjes/dependencies/encryption.py | 4 +-- boefjes/boefjes/katalogus/root.py | 8 +++--- boefjes/boefjes/katalogus/settings.py | 10 +++++-- boefjes/boefjes/migrations/env.py | 2 +- ...fafdaf_json_settings_for_settings_table.py | 4 +-- boefjes/boefjes/plugins/kat_crt_sh/main.py | 8 +++++- .../plugins/kat_cve_2023_35078/normalize.py | 18 ++++++------ boefjes/boefjes/plugins/kat_dnssec/main.py | 6 ++-- .../plugins/kat_manual/csv/normalize.py | 10 +++---- boefjes/boefjes/plugins/kat_masscan/main.py | 2 +- boefjes/boefjes/plugins/kat_nmap_tcp/main.py | 6 ++-- .../kat_security_txt_downloader/main.py | 8 ++++-- .../boefjes/plugins/kat_snyk/check_version.py | 8 ++++-- .../plugins/kat_webpage_analysis/main.py | 9 +++--- .../plugins/kat_webpage_capture/main.py | 4 +-- boefjes/boefjes/plugins/models.py | 2 +- boefjes/boefjes/runtime_interfaces.py | 2 +- boefjes/boefjes/sql/db.py | 2 +- boefjes/boefjes/sql/organisation_storage.py | 2 +- bytes/bytes/api/metrics.py | 2 +- bytes/bytes/api/root.py | 2 +- bytes/bytes/api/router.py | 2 +- bytes/bytes/database/migrations/env.py | 2 +- bytes/bytes/database/sql_meta_repository.py | 2 +- bytes/bytes/models.py | 4 +-- bytes/bytes/rabbitmq.py | 2 +- bytes/bytes/raw/file_raw_repository.py | 9 +++++- cveapi/cveapi.py | 4 +-- mula/scheduler/__init__.py | 4 +-- mula/scheduler/clients/amqp/__init__.py | 2 ++ mula/scheduler/clients/amqp/listeners.py | 2 +- mula/scheduler/clients/connector.py | 3 +- mula/scheduler/context/__init__.py | 2 ++ mula/scheduler/context/context.py | 3 ++ mula/scheduler/schedulers/rankers/__init__.py | 2 ++ .../scheduler/schedulers/schedulers/boefje.py | 2 +- mula/scheduler/server/__init__.py | 2 ++ mula/scheduler/storage/filters/__init__.py | 2 ++ mula/scheduler/storage/filters/functions.py | 3 +- mula/scheduler/utils/__init__.py | 2 ++ mula/scheduler/utils/cron.py | 2 +- octopoes/bits/check_cve_2021_41773/bit.py | 2 +- .../check_cve_2021_41773.py | 2 +- octopoes/bits/check_hsts_header/bit.py | 2 +- .../check_hsts_header/check_hsts_header.py | 2 +- .../cipher_classification.py | 5 ++-- .../missing_certificate.py | 2 +- octopoes/bits/nxdomain_flag/bit.py | 2 +- octopoes/bits/nxdomain_flag/nxdomain_flag.py | 2 +- octopoes/bits/nxdomain_header_flag/bit.py | 2 +- .../nxdomain_header_flag.py | 2 +- octopoes/bits/oois_in_headers/bit.py | 2 +- .../bits/oois_in_headers/oois_in_headers.py | 4 +-- octopoes/bits/retire_js/retire_js.py | 8 +++--- octopoes/bits/runner.py | 10 ++----- octopoes/bits/spf_discovery/spf_discovery.py | 2 +- octopoes/octopoes/api/router.py | 8 +++--- octopoes/octopoes/connector/__init__.py | 6 +--- octopoes/octopoes/connector/octopoes.py | 10 +++---- octopoes/octopoes/core/app.py | 2 +- octopoes/octopoes/core/service.py | 6 ++-- octopoes/octopoes/models/__init__.py | 14 +++++----- octopoes/octopoes/models/ooi/dns/records.py | 2 +- octopoes/octopoes/models/ooi/findings.py | 6 ++-- octopoes/octopoes/models/ooi/network.py | 2 +- octopoes/octopoes/models/origin.py | 4 ++- octopoes/octopoes/models/path.py | 10 +++---- octopoes/octopoes/models/persistence.py | 4 +-- octopoes/octopoes/models/tree.py | 2 +- octopoes/octopoes/models/types.py | 6 ++-- .../octopoes/repositories/ooi_repository.py | 19 +++++++++---- .../origin_parameter_repository.py | 2 +- .../repositories/scan_profile_repository.py | 2 +- octopoes/octopoes/tasks/tasks.py | 5 ++-- octopoes/octopoes/xtdb/client.py | 10 +++---- octopoes/octopoes/xtdb/query.py | 28 ++++++++++--------- octopoes/octopoes/xtdb/query_builder.py | 3 +- .../octopoes/xtdb/related_field_generator.py | 8 +++--- pyproject.toml | 15 +++++++--- rocky/account/admin.py | 1 - rocky/account/forms/__init__.py | 16 +++++++++++ rocky/account/forms/organization.py | 14 +++++----- rocky/account/mixins.py | 7 +++-- rocky/account/models.py | 2 +- rocky/account/views/account.py | 2 +- rocky/crisis_room/views.py | 3 +- rocky/katalogus/client.py | 8 +++--- rocky/katalogus/forms/plugin_settings.py | 4 ++- rocky/katalogus/views/boefje_setup.py | 6 ++-- rocky/katalogus/views/mixins.py | 6 ++-- rocky/onboarding/view_helpers.py | 4 +-- rocky/onboarding/views.py | 7 ++--- rocky/reports/forms.py | 5 ++-- .../aggregate_organisation_report/report.py | 7 +++-- rocky/reports/report_types/definitions.py | 6 ++-- .../multi_organization_report/report.py | 4 ++- .../report_types/name_server_report/report.py | 8 ++++-- .../report_types/web_system_report/report.py | 8 ++++-- .../report_overview/report_history_table.html | 10 +++---- rocky/reports/views/aggregate_report.py | 4 +-- rocky/reports/views/base.py | 12 ++++---- rocky/reports/views/generate_report.py | 4 +-- rocky/reports/views/multi_report.py | 3 +- rocky/rocky/bytes_client.py | 4 +-- rocky/rocky/exceptions.py | 7 +++-- rocky/rocky/keiko.py | 6 ++-- rocky/rocky/locale/django.pot | 23 +++++++++------ rocky/rocky/middleware/onboarding.py | 8 +++--- rocky/rocky/paginator.py | 12 ++++---- rocky/rocky/scheduler.py | 6 ++-- rocky/rocky/views/finding_add.py | 3 +- rocky/rocky/views/finding_list.py | 8 +++--- rocky/rocky/views/health.py | 6 ++-- rocky/rocky/views/mixins.py | 10 +++---- .../rocky/views/ooi_detail_related_object.py | 2 +- rocky/rocky/views/ooi_list.py | 20 +++++++------ rocky/rocky/views/ooi_tree.py | 8 +++--- rocky/rocky/views/ooi_view.py | 6 ++-- rocky/rocky/views/organization_list.py | 4 +-- rocky/rocky/views/organization_member_add.py | 13 +++++---- rocky/rocky/views/organization_member_list.py | 6 ++-- rocky/rocky/views/organization_settings.py | 3 +- rocky/rocky/views/privacy_statement.py | 2 +- rocky/rocky/views/scan_profile.py | 2 +- rocky/rocky/views/scans.py | 3 +- rocky/rocky/views/upload_csv.py | 8 +++--- rocky/tools/add_ooi_information.py | 2 +- rocky/tools/admin.py | 6 ++-- rocky/tools/forms/base.py | 6 ++-- rocky/tools/forms/finding_type.py | 5 ++-- rocky/tools/forms/ooi.py | 6 ++-- rocky/tools/forms/ooi_form.py | 13 ++++++--- rocky/tools/forms/settings.py | 5 ++-- .../management/commands/export_migrations.py | 3 +- .../management/commands/generate_report.py | 6 ++-- .../management/commands/setup_test_users.py | 6 ++-- rocky/tools/models.py | 7 +++-- rocky/tools/ooi_helpers.py | 11 ++++---- rocky/tools/templatetags/ooi_extra.py | 14 +++++----- rocky/tools/view_helpers.py | 25 ++++++++--------- 144 files changed, 481 insertions(+), 370 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 89228a63f0f..e11535302b1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -92,6 +92,7 @@ repos: - types-python-dateutil - types-requests - types-croniter + - boto3-stubs[s3] exclude: | (?x)( ^boefjes/tools | diff --git a/boefjes/boefjes/__main__.py b/boefjes/boefjes/__main__.py index aef329cf194..a6ab713b15c 100644 --- a/boefjes/boefjes/__main__.py +++ b/boefjes/boefjes/__main__.py @@ -37,7 +37,7 @@ @click.command() @click.argument("worker_type", type=click.Choice([q.value for q in WorkerManager.Queue])) @click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"]), help="Log level", default="INFO") -def cli(worker_type: str, log_level: str): +def cli(worker_type: str, log_level: str) -> None: logger.setLevel(log_level) logger.info("Starting runtime for %s", worker_type) diff --git a/boefjes/boefjes/api.py b/boefjes/boefjes/api.py index 28bd5d24c92..409508cba3a 100644 --- a/boefjes/boefjes/api.py +++ b/boefjes/boefjes/api.py @@ -33,7 +33,7 @@ def __init__(self, config: Config): self.server = Server(config=config) self.config = config - def stop(self): + def stop(self) -> None: self.terminate() def run(self, *args, **kwargs): @@ -88,7 +88,7 @@ def boefje_input( task_id: UUID, scheduler_client: SchedulerAPIClient = Depends(get_scheduler_client), plugin_service: PluginService = Depends(get_plugin_service), -): +) -> BoefjeInput: task = get_task(task_id, scheduler_client) if task.status is not TaskStatus.RUNNING: @@ -108,7 +108,7 @@ def boefje_output( scheduler_client: SchedulerAPIClient = Depends(get_scheduler_client), bytes_client: BytesAPIClient = Depends(get_bytes_client), plugin_service: PluginService = Depends(get_plugin_service), -): +) -> Response: task = get_task(task_id, scheduler_client) if task.status is not TaskStatus.RUNNING: @@ -127,7 +127,7 @@ def boefje_output( for file in boefje_output.files: raw = base64.b64decode(file.content) # when supported, also save file.name to Bytes - bytes_client.save_raw(task_id, raw, mime_types.union(file.tags)) + bytes_client.save_raw(task_id, raw, mime_types.union(file.tags) if file.tags else mime_types) if boefje_output.status == StatusEnum.COMPLETED: scheduler_client.patch_task(task_id, TaskStatus.COMPLETED) diff --git a/boefjes/boefjes/app.py b/boefjes/boefjes/app.py index 4077edffc45..af30d6319da 100644 --- a/boefjes/boefjes/app.py +++ b/boefjes/boefjes/app.py @@ -80,7 +80,7 @@ def run(self, queue_type: WorkerManager.Queue) -> None: raise - def _fill_queue(self, task_queue: Queue, queue_type: WorkerManager.Queue): + def _fill_queue(self, task_queue: Queue, queue_type: WorkerManager.Queue) -> None: if task_queue.qsize() > self.settings.pool_size: time.sleep(self.settings.worker_heartbeat) return @@ -189,7 +189,7 @@ def _cleanup_pending_worker_task(self, worker: BaseProcess) -> None: def _worker_args(self) -> tuple: return self.task_queue, self.item_handler, self.scheduler_client, self.handling_tasks - def exit(self, signum: int | None = None): + def exit(self, signum: int | None = None) -> None: try: if signum: logger.info("Received %s, exiting", signal.Signals(signum).name) @@ -238,7 +238,7 @@ def _start_working( handler: Handler, scheduler_client: SchedulerClientInterface, handling_tasks: dict[int, str], -): +) -> None: logger.info("Started listening for tasks from worker[pid=%s]", os.getpid()) while True: diff --git a/boefjes/boefjes/dependencies/encryption.py b/boefjes/boefjes/dependencies/encryption.py index 45e001a9fb1..43c56376b16 100644 --- a/boefjes/boefjes/dependencies/encryption.py +++ b/boefjes/boefjes/dependencies/encryption.py @@ -34,8 +34,8 @@ def __init__(self, private_key: str, public_key: str): def encode(self, contents: str) -> str: encrypted_contents = self.box.encrypt(contents.encode()) - encrypted_contents = base64.b64encode(encrypted_contents) - return encrypted_contents.decode() + base64_encrypted_contents = base64.b64encode(encrypted_contents) + return base64_encrypted_contents.decode() def decode(self, contents: str) -> str: encrypted_binary = base64.b64decode(contents) diff --git a/boefjes/boefjes/katalogus/root.py b/boefjes/boefjes/katalogus/root.py index 4320c6042b4..542e35a854b 100644 --- a/boefjes/boefjes/katalogus/root.py +++ b/boefjes/boefjes/katalogus/root.py @@ -73,22 +73,22 @@ @app.exception_handler(NotFound) -def entity_not_found_handler(request: Request, exc: NotFound): +def entity_not_found_handler(request: Request, exc: NotFound) -> JSONResponse: return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content={"message": exc.message}) @app.exception_handler(NotAllowed) -def not_allowed_handler(request: Request, exc: NotAllowed): +def not_allowed_handler(request: Request, exc: NotAllowed) -> JSONResponse: return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={"message": exc.message}) @app.exception_handler(IntegrityError) -def integrity_error_handler(request: Request, exc: IntegrityError): +def integrity_error_handler(request: Request, exc: IntegrityError) -> JSONResponse: return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={"message": exc.message}) @app.exception_handler(StorageError) -def storage_error_handler(request: Request, exc: StorageError): +def storage_error_handler(request: Request, exc: StorageError) -> JSONResponse: return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": exc.message}) diff --git a/boefjes/boefjes/katalogus/settings.py b/boefjes/boefjes/katalogus/settings.py index 4ec2b48cb83..24ef4198a92 100644 --- a/boefjes/boefjes/katalogus/settings.py +++ b/boefjes/boefjes/katalogus/settings.py @@ -6,19 +6,23 @@ @router.get("", response_model=dict) -def list_settings(organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service)): +def list_settings( + organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service) +) -> dict[str, str]: return plugin_service.get_all_settings(organisation_id, plugin_id) @router.put("") def upsert_settings( organisation_id: str, plugin_id: str, values: dict, plugin_service: PluginService = Depends(get_plugin_service) -): +) -> None: with plugin_service as p: p.upsert_settings(values, organisation_id, plugin_id) @router.delete("") -def remove_settings(organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service)): +def remove_settings( + organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service) +) -> None: with plugin_service as p: p.delete_settings(organisation_id, plugin_id) diff --git a/boefjes/boefjes/migrations/env.py b/boefjes/boefjes/migrations/env.py index 883101aee81..0ec4438da97 100644 --- a/boefjes/boefjes/migrations/env.py +++ b/boefjes/boefjes/migrations/env.py @@ -4,7 +4,7 @@ from sqlalchemy import engine_from_config, pool from boefjes.config import settings -from boefjes.sql.db_models import SQL_BASE +from boefjes.sql.db import SQL_BASE # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py b/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py index f76286dee3c..03489b76f85 100644 --- a/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py +++ b/boefjes/boefjes/migrations/versions/cd34fdfafdaf_json_settings_for_settings_table.py @@ -42,7 +42,7 @@ def upgrade() -> None: # ### end Alembic commands ### -def upgrade_encrypted_settings(conn: Connection): +def upgrade_encrypted_settings(conn: Connection) -> None: encrypter = create_encrypter() with conn.begin(): @@ -90,7 +90,7 @@ def downgrade() -> None: # ### end Alembic commands ### -def downgrade_encrypted_settings(conn: Connection): +def downgrade_encrypted_settings(conn: Connection) -> None: encrypter = create_encrypter() with conn.begin(): diff --git a/boefjes/boefjes/plugins/kat_crt_sh/main.py b/boefjes/boefjes/plugins/kat_crt_sh/main.py index cb2d891bae0..318a45efaf1 100644 --- a/boefjes/boefjes/plugins/kat_crt_sh/main.py +++ b/boefjes/boefjes/plugins/kat_crt_sh/main.py @@ -31,7 +31,13 @@ ) -def request_certs(search_string, search_type="Identity", match="=", deduplicate=True, json_output=True) -> str: +def request_certs( + search_string: str, + search_type: str = "Identity", + match: str = "=", + deduplicate: bool = True, + json_output: bool = True, +) -> str: """Queries the public service CRT.sh for certificate information the searchtype can be specified and defaults to Identity. the type of sql matching can be specified and defaults to "=" diff --git a/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py b/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py index 92670e81f45..d93a87dc9c9 100644 --- a/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py +++ b/boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py @@ -4,12 +4,12 @@ from octopoes.models import Reference from octopoes.models.ooi.findings import CVEFindingType, Finding from octopoes.models.ooi.software import Software, SoftwareInstance -from packaging import version +from packaging.version import Version, parse VULNERABLE_RANGES: list[tuple[str, str]] = [("0", "11.8.1.1"), ("11.9.0.0", "11.9.1.1"), ("11.10.0.0", "11.10.0.2")] -def extract_js_version(html_content: str) -> version.Version | bool: +def extract_js_version(html_content: str) -> Version | bool: telltale = "/mifs/scripts/auth.js?" telltale_position = html_content.find(telltale) if telltale_position == -1: @@ -20,10 +20,10 @@ def extract_js_version(html_content: str) -> version.Version | bool: version_string = html_content[telltale_position + len(telltale) : version_end] if not version_string: return False - return version.parse(" ".join(strip_vsp_and_build(version_string))) + return parse(" ".join(strip_vsp_and_build(version_string))) -def extract_css_version(html_content: str) -> version.Version | bool: +def extract_css_version(html_content: str) -> Version | bool: telltale = "/mifs/css/windowsAllAuth.css?" telltale_position = html_content.find(telltale) if telltale_position == -1: @@ -34,7 +34,7 @@ def extract_css_version(html_content: str) -> version.Version | bool: version_string = html_content[telltale_position + len(telltale) : version_end] if not version_string: return False - return version.parse(" ".join(strip_vsp_and_build(version_string))) + return parse(" ".join(strip_vsp_and_build(version_string))) def strip_vsp_and_build(url: str) -> Iterable[str]: @@ -47,9 +47,7 @@ def strip_vsp_and_build(url: str) -> Iterable[str]: yield part -def is_vulnerable_version( - vulnerable_ranges: list[tuple[version.Version, version.Version]], detected_version: version.Version -) -> bool: +def is_vulnerable_version(vulnerable_ranges: list[tuple[Version, Version]], detected_version: Version) -> bool: return any(start <= detected_version < end for start, end in vulnerable_ranges) @@ -70,11 +68,11 @@ def run(input_ooi: dict, raw: bytes) -> Iterable[NormalizerOutput]: yield software_instance if js_detected_version: vulnerable = is_vulnerable_version( - [(version.parse(start), version.parse(end)) for start, end in VULNERABLE_RANGES], js_detected_version + [(parse(start), parse(end)) for start, end in VULNERABLE_RANGES], js_detected_version ) else: # The CSS version only included the first two parts of the version number so we don't know the patch level - vulnerable = css_detected_version < version.parse("11.8") + vulnerable = css_detected_version < parse("11.8") if vulnerable: finding_type = CVEFindingType(id="CVE-2023-35078") finding = Finding( diff --git a/boefjes/boefjes/plugins/kat_dnssec/main.py b/boefjes/boefjes/plugins/kat_dnssec/main.py index 9fc385134f2..4af36d11220 100644 --- a/boefjes/boefjes/plugins/kat_dnssec/main.py +++ b/boefjes/boefjes/plugins/kat_dnssec/main.py @@ -2,7 +2,7 @@ import subprocess -def run(boefje_meta: dict): +def run(boefje_meta: dict) -> list[tuple[set, bytes | str]]: input_ = boefje_meta["arguments"]["input"] domain = input_["name"] @@ -19,6 +19,4 @@ def run(boefje_meta: dict): output.check_returncode() - results = [({"openkat/dnssec-output"}, output.stdout)] - - return results + return [({"openkat/dnssec-output"}, output.stdout)] diff --git a/boefjes/boefjes/plugins/kat_manual/csv/normalize.py b/boefjes/boefjes/plugins/kat_manual/csv/normalize.py index a25bf73f0d1..262ba102cbe 100644 --- a/boefjes/boefjes/plugins/kat_manual/csv/normalize.py +++ b/boefjes/boefjes/plugins/kat_manual/csv/normalize.py @@ -7,7 +7,7 @@ from pydantic import ValidationError from boefjes.job_models import NormalizerDeclaration, NormalizerOutput -from octopoes.models import Reference +from octopoes.models import OOI, Reference from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.network import IPAddressV4, IPAddressV6, Network from octopoes.models.ooi.web import URL @@ -30,7 +30,7 @@ def run(input_ooi: dict, raw: bytes) -> Iterable[NormalizerOutput]: yield from process_csv(raw, reference_cache) -def process_csv(csv_raw_data, reference_cache) -> Iterable[NormalizerOutput]: +def process_csv(csv_raw_data: bytes, reference_cache: dict) -> Iterable[NormalizerOutput]: csv_data = io.StringIO(csv_raw_data.decode("UTF-8")) object_type = get_object_type(csv_data) @@ -74,7 +74,7 @@ def get_object_type(csv_data: io.StringIO) -> str: def get_ooi_from_csv( - ooi_type_name: str, values: dict[str, str], reference_cache + ooi_type_name: str, values: dict[str, str], reference_cache: dict ) -> tuple[OOIType, list[NormalizerDeclaration]]: skip_properties = ("object_type", "scan_profile", "primary_key") @@ -85,7 +85,7 @@ def get_ooi_from_csv( if field not in skip_properties ] - kwargs = {} + kwargs: dict[str, Reference | str | None] = {} extra_declarations: list[NormalizerDeclaration] = [] for field, is_reference, required in ooi_fields: @@ -109,7 +109,7 @@ def get_ooi_from_csv( return ooi_type(**kwargs), extra_declarations -def get_or_create_reference(ooi_type_name: str, value: str | None, reference_cache): +def get_or_create_reference(ooi_type_name: str, value: str | None, reference_cache: dict) -> OOI: ooi_type_name = next(filter(lambda x: x.casefold() == ooi_type_name.casefold(), OOI_TYPES.keys())) # get from cache diff --git a/boefjes/boefjes/plugins/kat_masscan/main.py b/boefjes/boefjes/plugins/kat_masscan/main.py index 6ef83ece35c..f2604d00dc3 100644 --- a/boefjes/boefjes/plugins/kat_masscan/main.py +++ b/boefjes/boefjes/plugins/kat_masscan/main.py @@ -10,7 +10,7 @@ FILE_PATH = "/tmp/output.json" # noqa: S108 -def run_masscan(target_ip) -> bytes: +def run_masscan(target_ip: str) -> bytes: """Run Masscan in Docker.""" client = docker.from_env() diff --git a/boefjes/boefjes/plugins/kat_nmap_tcp/main.py b/boefjes/boefjes/plugins/kat_nmap_tcp/main.py index 3b14b10292c..9e9d052c2a0 100644 --- a/boefjes/boefjes/plugins/kat_nmap_tcp/main.py +++ b/boefjes/boefjes/plugins/kat_nmap_tcp/main.py @@ -5,7 +5,7 @@ TOP_PORTS_DEFAULT = 250 -def run(boefje_meta: dict): +def run(boefje_meta: dict) -> list[tuple[set, bytes | str]]: top_ports_key = "TOP_PORTS" if boefje_meta["boefje"]["id"] == "nmap-udp": top_ports_key = "TOP_PORTS_UDP" @@ -22,6 +22,4 @@ def run(boefje_meta: dict): output.check_returncode() - results = [({"openkat/nmap-output"}, output.stdout.decode())] - - return results + return [({"openkat/nmap-output"}, output.stdout.decode())] diff --git a/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py b/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py index 48538de257c..bad69556c3b 100644 --- a/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py +++ b/boefjes/boefjes/plugins/kat_security_txt_downloader/main.py @@ -5,6 +5,7 @@ import requests from forcediphttpsadapter.adapters import ForcedIPHTTPSAdapter from requests import Session +from requests.models import Response from boefjes.job_models import BoefjeMeta @@ -41,7 +42,10 @@ def run(boefje_meta: BoefjeMeta) -> list[tuple[set, bytes | str]]: elif response.status_code in [301, 302, 307, 308]: uri = response.headers["Location"] response = requests.get(uri, stream=True, timeout=30, verify=False) # noqa: S501 - ip = response.raw._connection.sock.getpeername()[0] + if response.raw._connection: + ip = response.raw._connection.sock.getpeername()[0] + else: + ip = "" results[path] = { "content": response.content.decode(), "url": response.url, @@ -53,7 +57,7 @@ def run(boefje_meta: BoefjeMeta) -> list[tuple[set, bytes | str]]: return [(set(), json.dumps(results))] -def do_request(hostname: str, session: Session, uri: str, useragent: str): +def do_request(hostname: str, session: Session, uri: str, useragent: str) -> Response: response = session.get( uri, headers={"Host": hostname, "User-Agent": useragent}, verify=False, allow_redirects=False ) diff --git a/boefjes/boefjes/plugins/kat_snyk/check_version.py b/boefjes/boefjes/plugins/kat_snyk/check_version.py index 442978e0f4c..a758fc12246 100644 --- a/boefjes/boefjes/plugins/kat_snyk/check_version.py +++ b/boefjes/boefjes/plugins/kat_snyk/check_version.py @@ -70,7 +70,7 @@ def check_version(version1: str, version2: str) -> VersionCheck: return check_version(version1_split[1], version2_split[1]) -def check_version_agains_versionlist(my_version: str, all_versions: list[str]): +def check_version_agains_versionlist(my_version: str, all_versions: list[str]) -> tuple[bool, list[str] | None]: lowerbound = all_versions.pop(0).strip() upperbound = None @@ -164,10 +164,12 @@ def check_version_agains_versionlist(my_version: str, all_versions: list[str]): return True, all_versions -def check_version_in(version: str, versions: str): +def check_version_in(version: str, versions: str) -> bool: if not version: return False - all_versions = versions.split(",") # Example: https://snyk.io/vuln/composer%3Awoocommerce%2Fwoocommerce-blocks + all_versions: list[str] | None = versions.split( + "," + ) # Example: https://snyk.io/vuln/composer%3Awoocommerce%2Fwoocommerce-blocks in_range = False while not in_range and all_versions: in_range, all_versions = check_version_agains_versionlist(version, all_versions) diff --git a/boefjes/boefjes/plugins/kat_webpage_analysis/main.py b/boefjes/boefjes/plugins/kat_webpage_analysis/main.py index f9fef3921fa..65d124c9d60 100644 --- a/boefjes/boefjes/plugins/kat_webpage_analysis/main.py +++ b/boefjes/boefjes/plugins/kat_webpage_analysis/main.py @@ -7,6 +7,7 @@ import requests from forcediphttpsadapter.adapters import ForcedIPHTTPSAdapter from requests import Session +from requests.models import Response from boefjes.job_models import BoefjeMeta @@ -54,9 +55,9 @@ def run(boefje_meta: BoefjeMeta) -> list[tuple[set, bytes | str]]: body_mimetypes.add(content_type) # Pick up the content type for the body from the server and split away encodings to make normalization easier - content_type = content_type.split(";") - if content_type[0] in ALLOWED_CONTENT_TYPES: - body_mimetypes.add(content_type[0]) + content_type_splitted = content_type.split(";") + if content_type_splitted[0] in ALLOWED_CONTENT_TYPES: + body_mimetypes.add(content_type_splitted[0]) # in case of a full response object, we hexdump to avoid issues with binary data or different encoding response_dump = json.dumps(create_response_object(response)) @@ -87,7 +88,7 @@ def create_response_object(response: requests.Response) -> dict: } -def do_request(hostname: str, session: Session, uri: str, useragent: str): +def do_request(hostname: str, session: Session, uri: str, useragent: str) -> Response: response = session.get( uri, headers={"Host": hostname, "User-Agent": useragent}, verify=False, allow_redirects=False ) diff --git a/boefjes/boefjes/plugins/kat_webpage_capture/main.py b/boefjes/boefjes/plugins/kat_webpage_capture/main.py index 6ee7e8dd44e..afecb4bbc89 100644 --- a/boefjes/boefjes/plugins/kat_webpage_capture/main.py +++ b/boefjes/boefjes/plugins/kat_webpage_capture/main.py @@ -10,11 +10,11 @@ class WebpageCaptureException(Exception): """Exception raised when webpage capture fails.""" - def __init__(self, message, container_log=None): + def __init__(self, message: str, container_log: str): self.message = message self.container_log = container_log - def __str__(self): + def __str__(self) -> str: return str(self.message) + "\n\nContainer log:\n" + self.container_log diff --git a/boefjes/boefjes/plugins/models.py b/boefjes/boefjes/plugins/models.py index e7753604486..364f14fafb9 100644 --- a/boefjes/boefjes/plugins/models.py +++ b/boefjes/boefjes/plugins/models.py @@ -102,7 +102,7 @@ def hash_path(path: Path) -> str: return folder_hash.hexdigest() -def _default_mime_types(boefje: Boefje): +def _default_mime_types(boefje: Boefje) -> set: mime_types = {f"boefje/{boefje.id}"} if boefje.version is not None: diff --git a/boefjes/boefjes/runtime_interfaces.py b/boefjes/boefjes/runtime_interfaces.py index 0a8375bdb86..70bacfb5554 100644 --- a/boefjes/boefjes/runtime_interfaces.py +++ b/boefjes/boefjes/runtime_interfaces.py @@ -4,7 +4,7 @@ class Handler: - def handle(self, item: BoefjeMeta | NormalizerMeta): + def handle(self, item: BoefjeMeta | NormalizerMeta) -> None: raise NotImplementedError() diff --git a/boefjes/boefjes/sql/db.py b/boefjes/boefjes/sql/db.py index 97fa858b75b..ea3fad670c2 100644 --- a/boefjes/boefjes/sql/db.py +++ b/boefjes/boefjes/sql/db.py @@ -52,5 +52,5 @@ def session_managed_iterator(service_factory: Callable[[Session], Any]) -> Itera class ObjectNotFoundException(Exception): - def __init__(self, cls: type | UnionType, **kwargs): # type: ignore + def __init__(self, cls: type | UnionType, **kwargs): super().__init__(f"The object of type {cls} was not found for query parameters {kwargs}") diff --git a/boefjes/boefjes/sql/organisation_storage.py b/boefjes/boefjes/sql/organisation_storage.py index 2124f9a2901..6c69f28abed 100644 --- a/boefjes/boefjes/sql/organisation_storage.py +++ b/boefjes/boefjes/sql/organisation_storage.py @@ -59,7 +59,7 @@ def to_organisation(organisation_in_db: OrganisationInDB) -> Organisation: return Organisation(id=organisation_in_db.id, name=organisation_in_db.name) -def create_organisation_storage(session) -> SQLOrganisationStorage: +def create_organisation_storage(session: Session) -> SQLOrganisationStorage: return SQLOrganisationStorage(session, settings) diff --git a/bytes/bytes/api/metrics.py b/bytes/bytes/api/metrics.py index d6a44010b4a..0d1f760cedf 100644 --- a/bytes/bytes/api/metrics.py +++ b/bytes/bytes/api/metrics.py @@ -23,7 +23,7 @@ logger = structlog.get_logger(__name__) -def ignore_arguments_key(meta_repository: MetaDataRepository): +def ignore_arguments_key(meta_repository: MetaDataRepository) -> str: return "" diff --git a/bytes/bytes/api/root.py b/bytes/bytes/api/root.py index 53661005c33..399bfa588f7 100644 --- a/bytes/bytes/api/root.py +++ b/bytes/bytes/api/root.py @@ -47,7 +47,7 @@ def health() -> ServiceHealth: @router.get("/metrics", dependencies=[Depends(authenticate_token)]) -def metrics(meta_repository: MetaDataRepository = Depends(create_meta_data_repository)): +def metrics(meta_repository: MetaDataRepository = Depends(create_meta_data_repository)) -> Response: collector_registry = get_registry(meta_repository) data = prometheus_client.generate_latest(collector_registry) diff --git a/bytes/bytes/api/router.py b/bytes/bytes/api/router.py index 0b6f0cb901a..716177cc8d6 100644 --- a/bytes/bytes/api/router.py +++ b/bytes/bytes/api/router.py @@ -271,7 +271,7 @@ def get_raw_count_per_mime_type( return cached_counts_per_mime_type(meta_repository, query_filter) -def ignore_arguments_key(meta_repository: MetaDataRepository, query_filter: RawDataFilter): +def ignore_arguments_key(meta_repository: MetaDataRepository, query_filter: RawDataFilter) -> str: """Helper to not cache based on the stateful meta_repository, but only use the query parameters as a key.""" return query_filter.model_dump_json() diff --git a/bytes/bytes/database/migrations/env.py b/bytes/bytes/database/migrations/env.py index f91d8947d18..15d1e25a528 100644 --- a/bytes/bytes/database/migrations/env.py +++ b/bytes/bytes/database/migrations/env.py @@ -6,7 +6,7 @@ # this is the Alembic Config object, which provides # access to the values within the .ini file in use. from bytes.config import get_settings -from bytes.database.db_models import SQL_BASE +from bytes.database.db import SQL_BASE config = context.config diff --git a/bytes/bytes/database/sql_meta_repository.py b/bytes/bytes/database/sql_meta_repository.py index 93ee36cdc6f..e912d0fb409 100644 --- a/bytes/bytes/database/sql_meta_repository.py +++ b/bytes/bytes/database/sql_meta_repository.py @@ -229,7 +229,7 @@ def create_meta_data_repository() -> Iterator[MetaDataRepository]: class ObjectNotFoundException(Exception): - def __init__(self, cls: type[SQL_BASE], **kwargs): + def __init__(self, cls: type[SQL_BASE], **kwargs: str): super().__init__(f"The object of type {cls} was not found for query parameters {kwargs}") diff --git a/bytes/bytes/models.py b/bytes/bytes/models.py index 03ae39506aa..44ea08c8c85 100644 --- a/bytes/bytes/models.py +++ b/bytes/bytes/models.py @@ -38,10 +38,10 @@ def _validate_timezone_aware_datetime(value: datetime) -> datetime: class MimeType(BaseModel): value: str - def __hash__(self): + def __hash__(self) -> int: return hash(self.value) - def __lt__(self, other: MimeType): + def __lt__(self, other: MimeType) -> bool: return self.value < other.value diff --git a/bytes/bytes/rabbitmq.py b/bytes/bytes/rabbitmq.py index 4ffa498ddd8..8ae5b83f446 100644 --- a/bytes/bytes/rabbitmq.py +++ b/bytes/bytes/rabbitmq.py @@ -41,7 +41,7 @@ def publish(self, event: Event) -> None: logger.info("Published event [event_id=%s] to queue %s", event.event_id, queue_name) - def _check_connection(self): + def _check_connection(self) -> None: if self.connection.is_closed: self.connection = pika.BlockingConnection(pika.URLParameters(self.queue_uri)) self.channel = self.connection.channel() diff --git a/bytes/bytes/raw/file_raw_repository.py b/bytes/bytes/raw/file_raw_repository.py index 6d0fd6c843b..605a6505dc5 100644 --- a/bytes/bytes/raw/file_raw_repository.py +++ b/bytes/bytes/raw/file_raw_repository.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import logging from pathlib import Path +from typing import TYPE_CHECKING from uuid import UUID import structlog @@ -14,6 +17,10 @@ logger = structlog.get_logger(__name__) +if TYPE_CHECKING: + from mypy_boto3_s3.service_resource import Bucket + + def create_raw_repository(settings: Settings) -> RawRepository: if settings.s3_bucket_name or settings.s3_bucket_prefix: return S3RawRepository( @@ -87,7 +94,7 @@ def __init__( set_boto3_stream_logger("", logging.WARNING) self._s3resource = BotoSession().resource("s3") - def get_or_create_bucket(self, organization: str): + def get_or_create_bucket(self, organization: str) -> Bucket: # Create a bucket, and if it exists already return that instead bucket_name = f"{self.s3_bucket_prefix}{organization}" if self.bucket_per_org else self.s3_bucket_name diff --git a/cveapi/cveapi.py b/cveapi/cveapi.py index 09fcab2aa27..0fdc4d1765f 100644 --- a/cveapi/cveapi.py +++ b/cveapi/cveapi.py @@ -13,7 +13,7 @@ logger = logging.getLogger("cveapi") -def download_files(directory, last_update, update_timestamp): +def download_files(directory: pathlib.Path, last_update: datetime | None, update_timestamp: datetime) -> None: index = 0 client = httpx.Client() error_count = 0 @@ -66,7 +66,7 @@ def download_files(directory, last_update, update_timestamp): logger.info("Downloaded new information of %s CVEs", response_json["totalResults"]) -def run(): +def run() -> None: loglevel = os.getenv("CVEAPI_LOGLEVEL", "INFO") numeric_level = getattr(logging, loglevel.upper(), None) if not isinstance(numeric_level, int): diff --git a/mula/scheduler/__init__.py b/mula/scheduler/__init__.py index 211a13feb2f..3bd881035da 100644 --- a/mula/scheduler/__init__.py +++ b/mula/scheduler/__init__.py @@ -1,4 +1,4 @@ from .app import App -from .version import version +from .version import __version__ -__version__ = version +__all__ = ["App", "__version__"] diff --git a/mula/scheduler/clients/amqp/__init__.py b/mula/scheduler/clients/amqp/__init__.py index ba7a3fab3ef..85194b47cb5 100644 --- a/mula/scheduler/clients/amqp/__init__.py +++ b/mula/scheduler/clients/amqp/__init__.py @@ -1,3 +1,5 @@ from .listeners import Listener, RabbitMQ from .raw_data import RawData from .scan_profile import ScanProfileMutation + +__all__ = ["Listener", "RabbitMQ", "RawData", "ScanProfileMutation"] diff --git a/mula/scheduler/clients/amqp/listeners.py b/mula/scheduler/clients/amqp/listeners.py index 10ab1665476..6a8955a9107 100644 --- a/mula/scheduler/clients/amqp/listeners.py +++ b/mula/scheduler/clients/amqp/listeners.py @@ -181,7 +181,7 @@ def callback( # Submit the message to the thread pool executor self.executor.submit(self.dispatch, channel, method.delivery_tag, body) - def dispatch(self, channel, delivery_tag, body: bytes) -> None: + def dispatch(self, channel: pika.channel.Channel, delivery_tag: int, body: bytes) -> None: # Check if we still have a connection if self.connection is None or self.connection.is_closed: self.logger.debug("No connection available, cannot dispatch message!") diff --git a/mula/scheduler/clients/connector.py b/mula/scheduler/clients/connector.py index 857fd0ad1c4..7fdb34949a2 100644 --- a/mula/scheduler/clients/connector.py +++ b/mula/scheduler/clients/connector.py @@ -1,6 +1,7 @@ import socket import time from collections.abc import Callable +from typing import Any import httpx import structlog @@ -47,7 +48,7 @@ def is_host_healthy(self, host: str, health_endpoint: str) -> bool: self.logger.warning("Exception: %s", exc) return False - def retry(self, func: Callable, *args, **kwargs) -> bool: + def retry(self, func: Callable, *args: Any, **kwargs: Any) -> bool: """Retry a function until it returns True. Args: diff --git a/mula/scheduler/context/__init__.py b/mula/scheduler/context/__init__.py index 61627b2c729..ae686bb6e0c 100644 --- a/mula/scheduler/context/__init__.py +++ b/mula/scheduler/context/__init__.py @@ -1 +1,3 @@ from .context import AppContext + +__all__ = ["AppContext"] diff --git a/mula/scheduler/context/context.py b/mula/scheduler/context/context.py index 00b4d8f16f4..540a86ba545 100644 --- a/mula/scheduler/context/context.py +++ b/mula/scheduler/context/context.py @@ -34,6 +34,9 @@ class AppContext: the schedulers. """ + metrics_qsize: Gauge + metrics_task_status_counts: Gauge + def __init__(self) -> None: """Initializer of the AppContext class.""" self.config: settings.Settings = settings.Settings() diff --git a/mula/scheduler/schedulers/rankers/__init__.py b/mula/scheduler/schedulers/rankers/__init__.py index 9fd14d2b21e..f2467f78e87 100644 --- a/mula/scheduler/schedulers/rankers/__init__.py +++ b/mula/scheduler/schedulers/rankers/__init__.py @@ -1,3 +1,5 @@ from .boefje import BoefjeRanker, BoefjeRankerTimeBased from .normalizer import NormalizerRanker from .ranker import Ranker + +__all__ = ["BoefjeRanker", "NormalizerRanker", "Ranker"] diff --git a/mula/scheduler/schedulers/schedulers/boefje.py b/mula/scheduler/schedulers/schedulers/boefje.py index 5b9fc5653fb..260b5cb40db 100644 --- a/mula/scheduler/schedulers/schedulers/boefje.py +++ b/mula/scheduler/schedulers/schedulers/boefje.py @@ -926,7 +926,7 @@ def has_boefje_task_grace_period_passed(self, task: BoefjeTask) -> bool: return True - def get_boefjes_for_ooi(self, ooi) -> list[Plugin]: + def get_boefjes_for_ooi(self, ooi: OOI) -> list[Plugin]: """Get available all boefjes (enabled and disabled) for an ooi. Args: diff --git a/mula/scheduler/server/__init__.py b/mula/scheduler/server/__init__.py index b7f2cf59516..09ed39ca17f 100644 --- a/mula/scheduler/server/__init__.py +++ b/mula/scheduler/server/__init__.py @@ -1 +1,3 @@ from .server import Server + +__all__ = ["Server"] diff --git a/mula/scheduler/storage/filters/__init__.py b/mula/scheduler/storage/filters/__init__.py index ddf32f56ef3..eb44f2d36e1 100644 --- a/mula/scheduler/storage/filters/__init__.py +++ b/mula/scheduler/storage/filters/__init__.py @@ -1,3 +1,5 @@ from .casting import cast_expression from .filters import Filter, FilterRequest from .functions import apply_filter + +__all__ = ["cast_expression", "Filter", "FilterRequest", "apply_filter"] diff --git a/mula/scheduler/storage/filters/functions.py b/mula/scheduler/storage/filters/functions.py index 805d7d0591d..f50aeb88fb3 100644 --- a/mula/scheduler/storage/filters/functions.py +++ b/mula/scheduler/storage/filters/functions.py @@ -1,4 +1,5 @@ import sqlalchemy +from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm.query import Query from sqlalchemy.sql.elements import BinaryExpression @@ -9,7 +10,7 @@ from .operators import FILTER_OPERATORS -def apply_filter(entity, query: Query, filter_request: FilterRequest) -> Query: +def apply_filter(entity: DeclarativeBase, query: Query, filter_request: FilterRequest) -> Query: """Apply the filter criteria to a SQLAlchemy query. Args: diff --git a/mula/scheduler/utils/__init__.py b/mula/scheduler/utils/__init__.py index 47047ddb189..cb8b2af51d1 100644 --- a/mula/scheduler/utils/__init__.py +++ b/mula/scheduler/utils/__init__.py @@ -2,3 +2,5 @@ from .dict_utils import ExpiredError, ExpiringDict, deep_get from .functions import remove_trailing_slash from .thread import ThreadRunner + +__all__ = ["GUID", "ExpiredError", "ExpiringDict", "deep_get", "remove_trailing_slash", "ThreadRunner"] diff --git a/mula/scheduler/utils/cron.py b/mula/scheduler/utils/cron.py index ac632c9ab71..45be36accb4 100644 --- a/mula/scheduler/utils/cron.py +++ b/mula/scheduler/utils/cron.py @@ -1,6 +1,6 @@ from datetime import datetime, timezone -from croniter import croniter # type: ignore +from croniter import croniter def next_run(expression: str, start_time: datetime | None = None) -> datetime: diff --git a/octopoes/bits/check_cve_2021_41773/bit.py b/octopoes/bits/check_cve_2021_41773/bit.py index 367183e5f8d..3e32a458c9c 100644 --- a/octopoes/bits/check_cve_2021_41773/bit.py +++ b/octopoes/bits/check_cve_2021_41773/bit.py @@ -1,5 +1,5 @@ from bits.definitions import BitDefinition -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader BIT = BitDefinition( id="check_cve_2021_41773", diff --git a/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py b/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py index d07b81268c0..83076aed418 100644 --- a/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py +++ b/octopoes/bits/check_cve_2021_41773/check_cve_2021_41773.py @@ -3,7 +3,7 @@ from octopoes.models import OOI from octopoes.models.ooi.findings import CVEFindingType, Finding -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: diff --git a/octopoes/bits/check_hsts_header/bit.py b/octopoes/bits/check_hsts_header/bit.py index 5e8436084f6..6b98f2d86a9 100644 --- a/octopoes/bits/check_hsts_header/bit.py +++ b/octopoes/bits/check_hsts_header/bit.py @@ -1,5 +1,5 @@ from bits.definitions import BitDefinition -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader BIT = BitDefinition( id="check-hsts-header", diff --git a/octopoes/bits/check_hsts_header/check_hsts_header.py b/octopoes/bits/check_hsts_header/check_hsts_header.py index 824e6948ace..de92e5e38dc 100644 --- a/octopoes/bits/check_hsts_header/check_hsts_header.py +++ b/octopoes/bits/check_hsts_header/check_hsts_header.py @@ -4,7 +4,7 @@ from octopoes.models import OOI, Reference from octopoes.models.ooi.findings import Finding, KATFindingType -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader def run(input_ooi: HTTPHeader, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: diff --git a/octopoes/bits/cipher_classification/cipher_classification.py b/octopoes/bits/cipher_classification/cipher_classification.py index 5f701df86cb..88cf99765b3 100644 --- a/octopoes/bits/cipher_classification/cipher_classification.py +++ b/octopoes/bits/cipher_classification/cipher_classification.py @@ -1,6 +1,7 @@ import csv from collections.abc import Iterator from pathlib import Path +from typing import Any from octopoes.models import OOI from octopoes.models.ooi.findings import Finding, KATFindingType @@ -13,7 +14,7 @@ } -def get_severity_and_reasons(cipher_suite) -> list[tuple[str, str]]: +def get_severity_and_reasons(cipher_suite: str) -> list[tuple[str, str]]: with Path.open(Path(__file__).parent / "list-ciphers-openssl-with-finding-type.csv", newline="") as csvfile: reader = csv.DictReader(csvfile) data = [{k.strip(): v.strip() for k, v in row.items() if k} for row in reader] @@ -76,7 +77,7 @@ def get_highest_severity_and_all_reasons(cipher_suites: dict) -> tuple[str, str] return highest_severity, all_reasons_str -def run(input_ooi: TLSCipher, additional_oois, config) -> Iterator[OOI]: +def run(input_ooi: TLSCipher, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: # Get the highest severity and all reasons for the cipher suite highest_severity, all_reasons = get_highest_severity_and_all_reasons(input_ooi.suites) diff --git a/octopoes/bits/missing_certificate/missing_certificate.py b/octopoes/bits/missing_certificate/missing_certificate.py index 04721d51dc6..ae08a2a214e 100644 --- a/octopoes/bits/missing_certificate/missing_certificate.py +++ b/octopoes/bits/missing_certificate/missing_certificate.py @@ -6,7 +6,7 @@ from octopoes.models.ooi.web import Website -def run(input_ooi: Website, additional_oois, config: dict[str, Any]) -> Iterator[OOI]: +def run(input_ooi: Website, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: if input_ooi.ip_service.tokenized.service.name.lower() != "https": return diff --git a/octopoes/bits/nxdomain_flag/bit.py b/octopoes/bits/nxdomain_flag/bit.py index 7c593bb563c..d5f44c3fe45 100644 --- a/octopoes/bits/nxdomain_flag/bit.py +++ b/octopoes/bits/nxdomain_flag/bit.py @@ -1,6 +1,6 @@ from bits.definitions import BitDefinition, BitParameterDefinition +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname -from octopoes.models.types import NXDOMAIN BIT = BitDefinition( id="nxdomain-flag", diff --git a/octopoes/bits/nxdomain_flag/nxdomain_flag.py b/octopoes/bits/nxdomain_flag/nxdomain_flag.py index 34401e6ea96..b26aeb368c9 100644 --- a/octopoes/bits/nxdomain_flag/nxdomain_flag.py +++ b/octopoes/bits/nxdomain_flag/nxdomain_flag.py @@ -2,9 +2,9 @@ from typing import Any from octopoes.models import OOI +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.findings import Finding, KATFindingType -from octopoes.models.types import NXDOMAIN def run(input_ooi: Hostname, additional_oois: list[NXDOMAIN], config: dict[str, Any]) -> Iterator[OOI]: diff --git a/octopoes/bits/nxdomain_header_flag/bit.py b/octopoes/bits/nxdomain_header_flag/bit.py index 3d4883adec1..296ac1a3580 100644 --- a/octopoes/bits/nxdomain_header_flag/bit.py +++ b/octopoes/bits/nxdomain_header_flag/bit.py @@ -1,7 +1,7 @@ from bits.definitions import BitDefinition, BitParameterDefinition +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.web import HTTPHeaderHostname -from octopoes.models.types import NXDOMAIN BIT = BitDefinition( id="nxdomain-header-flag", diff --git a/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py b/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py index f55c6b54561..7a4f78a004e 100644 --- a/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py +++ b/octopoes/bits/nxdomain_header_flag/nxdomain_header_flag.py @@ -2,10 +2,10 @@ from typing import Any from octopoes.models import OOI +from octopoes.models.ooi.dns.records import NXDOMAIN from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.findings import Finding, KATFindingType from octopoes.models.ooi.web import HTTPHeaderHostname -from octopoes.models.types import NXDOMAIN def run( diff --git a/octopoes/bits/oois_in_headers/bit.py b/octopoes/bits/oois_in_headers/bit.py index 70beea24333..ef2dd5c40d8 100644 --- a/octopoes/bits/oois_in_headers/bit.py +++ b/octopoes/bits/oois_in_headers/bit.py @@ -1,5 +1,5 @@ from bits.definitions import BitDefinition -from octopoes.models.types import HTTPHeader +from octopoes.models.ooi.web import HTTPHeader BIT = BitDefinition( id="oois-in-headers", diff --git a/octopoes/bits/oois_in_headers/oois_in_headers.py b/octopoes/bits/oois_in_headers/oois_in_headers.py index cd49dd8dffb..23b5eff3008 100644 --- a/octopoes/bits/oois_in_headers/oois_in_headers.py +++ b/octopoes/bits/oois_in_headers/oois_in_headers.py @@ -7,8 +7,8 @@ from octopoes.models import OOI from octopoes.models.ooi.dns.zone import Hostname -from octopoes.models.ooi.web import HTTPHeaderHostname -from octopoes.models.types import URL, HTTPHeader, HTTPHeaderURL, Network +from octopoes.models.ooi.network import Network +from octopoes.models.ooi.web import URL, HTTPHeader, HTTPHeaderHostname, HTTPHeaderURL def is_url(input_str): diff --git a/octopoes/bits/retire_js/retire_js.py b/octopoes/bits/retire_js/retire_js.py index 5b57382ffc4..b79a06ec6fc 100644 --- a/octopoes/bits/retire_js/retire_js.py +++ b/octopoes/bits/retire_js/retire_js.py @@ -7,7 +7,7 @@ from octopoes.models import OOI from octopoes.models.ooi.findings import CVEFindingType, Finding, RetireJSFindingType from octopoes.models.ooi.software import Software, SoftwareInstance -from packaging import version +from packaging.version import parse def run(input_ooi: Software, additional_oois: list[SoftwareInstance], config: dict[str, Any]) -> Iterator[OOI]: @@ -40,7 +40,7 @@ def run(input_ooi: Software, additional_oois: list[SoftwareInstance], config: di ) -def _check_vulnerabilities(name, package_version: str, known_vulnerabilities: dict) -> dict[str, list[str]]: +def _check_vulnerabilities(name: str, package_version: str, known_vulnerabilities: dict) -> dict[str, list[str]]: vulnerabilities: dict[str, list[str]] = {"CVE": [], "RetireJS": []} processed_name = _process_name(name) found_brands = [brand for brand in known_vulnerabilities if processed_name == _process_name(brand)] @@ -70,10 +70,10 @@ def _hash_identifiers(identifiers: dict[str, str | list[str]]) -> str: def _check_versions(package_version: str, known_vulnerability: dict) -> bool: - below = version.parse(package_version) < version.parse(known_vulnerability["below"]) + below = parse(package_version) < parse(known_vulnerability["below"]) # Some packages are only vulnerable below a version and not above above = ( - version.parse(package_version) >= version.parse(known_vulnerability["atOrAbove"]) + parse(package_version) >= parse(known_vulnerability["atOrAbove"]) if "atOrAbove" in known_vulnerability else True ) diff --git a/octopoes/bits/runner.py b/octopoes/bits/runner.py index da851076c34..931297ae5df 100644 --- a/octopoes/bits/runner.py +++ b/octopoes/bits/runner.py @@ -1,7 +1,7 @@ from collections.abc import Iterator from importlib import import_module from inspect import isfunction, signature -from typing import Any, Protocol +from typing import Any from bits.definitions import BitDefinition from octopoes.models import OOI @@ -11,15 +11,11 @@ class ModuleException(Exception): """General error for modules""" -class Runnable(Protocol): - def run(self, *args, **kwargs) -> Any: ... - - class BitRunner: def __init__(self, bit_definition: BitDefinition): self.module = bit_definition.module - def run(self, *args, **kwargs) -> list[OOI]: + def run(self, *args: Any, **kwargs: Any) -> list[OOI]: module = import_module(self.module) if not hasattr(module, "run") or not isfunction(module.run): @@ -31,7 +27,7 @@ def run(self, *args, **kwargs) -> list[OOI]: ) return list(module.run(*args, **kwargs)) - def __str__(self): + def __str__(self) -> str: return f"BitRunner {self.module}" diff --git a/octopoes/bits/spf_discovery/spf_discovery.py b/octopoes/bits/spf_discovery/spf_discovery.py index 36a10f16003..a094cc28cfb 100644 --- a/octopoes/bits/spf_discovery/spf_discovery.py +++ b/octopoes/bits/spf_discovery/spf_discovery.py @@ -10,7 +10,7 @@ from octopoes.models.ooi.network import IPAddressV4, IPAddressV6, Network -def run(input_ooi: DNSTXTRecord, additional_oois, config: dict[str, Any]) -> Iterator[OOI]: +def run(input_ooi: DNSTXTRecord, additional_oois: list, config: dict[str, Any]) -> Iterator[OOI]: if input_ooi.value.startswith("v=spf1"): spf_value = input_ooi.value.replace("%(d)", input_ooi.hostname.tokenized.name) parsed = parse(spf_value) diff --git a/octopoes/octopoes/api/router.py b/octopoes/octopoes/api/router.py index 62db165bd40..5816d2a2d88 100644 --- a/octopoes/octopoes/api/router.py +++ b/octopoes/octopoes/api/router.py @@ -443,8 +443,8 @@ def get_scan_profile_inheritance( def list_findings( exclude_muted: bool = True, only_muted: bool = False, - offset=DEFAULT_OFFSET, - limit=DEFAULT_LIMIT, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), severities: set[RiskLevelSeverity] = Query(DEFAULT_SEVERITY_FILTER), @@ -459,8 +459,8 @@ def list_findings( @router.get("/reports", tags=["Reports"]) def list_reports( - offset=DEFAULT_OFFSET, - limit=DEFAULT_LIMIT, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, octopoes: OctopoesService = Depends(octopoes_service), valid_time: datetime = Depends(extract_valid_time), ) -> Paginated[tuple[Report, list[Report | None]]]: diff --git a/octopoes/octopoes/connector/__init__.py b/octopoes/octopoes/connector/__init__.py index a5ca0237d37..5f56b2ee59e 100644 --- a/octopoes/octopoes/connector/__init__.py +++ b/octopoes/octopoes/connector/__init__.py @@ -1,12 +1,8 @@ -# Keep for backwards compatibility -from octopoes.models.exception import ObjectNotFoundException - - class ConnectorException(Exception): def __init__(self, value: str): self.value = value - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/octopoes/octopoes/connector/octopoes.py b/octopoes/octopoes/connector/octopoes.py index d76a03480e0..de2004fe0ce 100644 --- a/octopoes/octopoes/connector/octopoes.py +++ b/octopoes/octopoes/connector/octopoes.py @@ -1,5 +1,5 @@ import json -from collections.abc import Sequence, Set +from collections.abc import Iterable, Sequence, Set from datetime import datetime from typing import Literal from uuid import UUID @@ -206,7 +206,7 @@ def save_affirmation(self, affirmation: Affirmation) -> None: self.logger.info("Saved affirmation", affirmation=affirmation, event_code=DECLARATION_CREATED) - def save_scan_profile(self, scan_profile: ScanProfile, valid_time: datetime): + def save_scan_profile(self, scan_profile: ScanProfile, valid_time: datetime) -> None: params = {"valid_time": str(valid_time)} self.session.put( f"/{self.client}/scan_profiles", @@ -262,7 +262,7 @@ def count_findings_by_severity(self, valid_time: datetime) -> dict[str, int]: def list_findings( self, - severities: set[RiskLevelSeverity], + severities: Iterable[RiskLevelSeverity], valid_time: datetime, exclude_muted: bool = True, only_muted: bool = False, @@ -301,8 +301,8 @@ def get_report(self, report_id: str) -> Report: return TypeAdapter(Report).validate_json(res.content) - def load_objects_bulk(self, references: set[Reference], valid_time): - params = {"valid_time": valid_time} + def load_objects_bulk(self, references: set[Reference], valid_time: datetime) -> dict[Reference, OOIType]: + params = {"valid_time": str(valid_time)} res = self.session.post( f"/{self.client}/objects/load_bulk", params=params, json=[str(ref) for ref in references] ) diff --git a/octopoes/octopoes/core/app.py b/octopoes/octopoes/core/app.py index 9840f0beaf1..af61d604f11 100644 --- a/octopoes/octopoes/core/app.py +++ b/octopoes/octopoes/core/app.py @@ -20,7 +20,7 @@ def get_xtdb_client(base_uri: str, client: str) -> XTDBHTTPClient: return XTDBHTTPClient(f"{base_uri}/_xtdb", client) -def close_rabbit_channel(queue_uri: str): +def close_rabbit_channel(queue_uri: str) -> None: rabbit_channel = get_rabbit_channel(queue_uri) try: diff --git a/octopoes/octopoes/core/service.py b/octopoes/octopoes/core/service.py index 135ad39f8ba..ce019172c25 100644 --- a/octopoes/octopoes/core/service.py +++ b/octopoes/octopoes/core/service.py @@ -143,7 +143,7 @@ def list_ooi( def get_ooi_tree( self, reference: Reference, valid_time: datetime, search_types: set[type[OOI]] | None = None, depth: int = 1 - ): + ) -> ReferenceTree: tree = self.ooi_repository.get_tree(reference, valid_time, search_types, depth) self._populate_scan_profiles(tree.store.values(), valid_time) return tree @@ -257,7 +257,7 @@ def _run_inference(self, origin: Origin, valid_time: datetime) -> None: logger.exception("Error running inference", exc_info=e) @staticmethod - def check_path_level(path_level: int | None, current_level: int): + def check_path_level(path_level: int | None, current_level: int) -> bool: return path_level is not None and path_level >= current_level def recalculate_scan_profiles(self, valid_time: datetime) -> None: @@ -379,7 +379,7 @@ def recalculate_scan_profiles(self, valid_time: datetime) -> None: ) logger.info("Recalculated scan profiles") - def process_event(self, event: DBEvent): + def process_event(self, event: DBEvent) -> None: # handle event event_handler_name = f"_on_{event.operation_type.value}_{event.entity_type}" handler: Callable[[DBEvent], None] | None = getattr(self, event_handler_name) diff --git a/octopoes/octopoes/models/__init__.py b/octopoes/octopoes/models/__init__.py index 44153335a2c..362b4c0d254 100644 --- a/octopoes/octopoes/models/__init__.py +++ b/octopoes/octopoes/models/__init__.py @@ -43,12 +43,12 @@ def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHa return core_schema.with_info_after_validator_function(cls.validate, core_schema.str_schema()) @classmethod - def validate(cls, v, info: ValidationInfo): + def validate(cls, v: str, info: ValidationInfo) -> Any: if not isinstance(v, str): raise TypeError("string required") return cls(str(v)) - def __repr__(self): + def __repr__(self) -> str: return f"Reference({super().__repr__()})" @classmethod @@ -124,7 +124,7 @@ class OOI(BaseModel): def model_post_init(self, __context: Any) -> None: # noqa: F841 self.primary_key = self.primary_key or f"{self.get_object_type()}|{self.natural_key}" - def __str__(self): + def __str__(self) -> str: return self.primary_key @classmethod @@ -191,11 +191,11 @@ def get_reverse_relation_name(cls, attr: str) -> str: return cls._reverse_relation_names.get(attr, f"{cls.get_object_type()}_{attr}") @classmethod - def get_tokenized_primary_key(cls, natural_key: str): + def get_tokenized_primary_key(cls, natural_key: str) -> PrimaryKeyToken: token_tree = build_token_tree(cls) natural_key_parts = natural_key.split("|") - def hydrate(node) -> dict | str: + def hydrate(node: dict[str, dict | str]) -> dict | str: for key, value in node.items(): if isinstance(value, dict): node[key] = hydrate(value) @@ -256,10 +256,10 @@ def format_id_short(id_: str) -> str: class PrimaryKeyToken(RootModel): root: dict[str, str | PrimaryKeyToken] - def __getattr__(self, item) -> Any: + def __getattr__(self, item: str) -> Any: return self.root[item] - def __getitem__(self, item) -> Any: + def __getitem__(self, item: str) -> Any: return self.root[item] diff --git a/octopoes/octopoes/models/ooi/dns/records.py b/octopoes/octopoes/models/ooi/dns/records.py index 972b20e4a07..fda89849a0f 100644 --- a/octopoes/octopoes/models/ooi/dns/records.py +++ b/octopoes/octopoes/models/ooi/dns/records.py @@ -144,7 +144,7 @@ class CAATAGS(Enum): ISSUEVMC = "issuevmc" ISSUEMAIL = "issuemail" - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/octopoes/octopoes/models/ooi/findings.py b/octopoes/octopoes/models/ooi/findings.py index f7bce1365e9..af82e2de756 100644 --- a/octopoes/octopoes/models/ooi/findings.py +++ b/octopoes/octopoes/models/ooi/findings.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum from functools import total_ordering from typing import Annotated, Literal @@ -24,10 +26,10 @@ class RiskLevelSeverity(Enum): # unknown = the third party has been contacted, but third party has not determined the risk level (yet) UNKNOWN = "unknown" - def __gt__(self, other: "RiskLevelSeverity") -> bool: + def __gt__(self, other: RiskLevelSeverity) -> bool: return severity_order.index(self.value) > severity_order.index(other.value) - def __str__(self): + def __str__(self) -> str: return self.value diff --git a/octopoes/octopoes/models/ooi/network.py b/octopoes/octopoes/models/ooi/network.py index 83094b15dc0..5157796a518 100644 --- a/octopoes/octopoes/models/ooi/network.py +++ b/octopoes/octopoes/models/ooi/network.py @@ -84,7 +84,7 @@ class IPPort(OOI): _information_value = ["protocol", "port"] @classmethod - def format_reference_human_readable(cls, reference: Reference): + def format_reference_human_readable(cls, reference: Reference) -> str: tokenized = reference.tokenized return f"{tokenized.address.address}:{tokenized.port}/{tokenized.protocol}" diff --git a/octopoes/octopoes/models/origin.py b/octopoes/octopoes/models/origin.py index b39ec31e7a2..755fdf98b4c 100644 --- a/octopoes/octopoes/models/origin.py +++ b/octopoes/octopoes/models/origin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum from uuid import UUID @@ -21,7 +23,7 @@ class Origin(BaseModel): result: list[Reference] = Field(default_factory=list) task_id: UUID | None = None - def __sub__(self, other) -> set[Reference]: + def __sub__(self, other: Origin) -> set[Reference]: if isinstance(other, Origin): return set(self.result) - set(other.result) else: diff --git a/octopoes/octopoes/models/path.py b/octopoes/octopoes/models/path.py index 399a9945804..650c9e06f1d 100644 --- a/octopoes/octopoes/models/path.py +++ b/octopoes/octopoes/models/path.py @@ -42,7 +42,7 @@ def parse_step(cls, step: str) -> tuple[Direction, str, type[OOI] | None]: raise ValueError(f"Could not parse step: {step}") @classmethod - def calculate_step(cls, source_type: type[OOI], step: str): + def calculate_step(cls, source_type: type[OOI], step: str) -> Segment: direction, property_name, explicit_target_type = cls.parse_step(step) if explicit_target_type: @@ -90,7 +90,7 @@ def __eq__(self, other: object) -> bool: and self.property_name == other.property_name ) - def __str__(self): + def __str__(self) -> str: if self.direction == Direction.INCOMING: if self.target_type is None: raise ValueError("Direction cannot be incoming if target type is None") @@ -99,7 +99,7 @@ def __str__(self): else: return f"{self.property_name}" - def __repr__(self): + def __repr__(self) -> str: return str(self) @@ -108,7 +108,7 @@ def __init__(self, segments: list[Segment]): self.segments = segments @classmethod - def parse(cls, path: str): + def parse(cls, path: str) -> Path: start_type, step, *rest = path.split(".") segments = [Segment.calculate_step(type_by_name(start_type), step)] @@ -140,7 +140,7 @@ def __lt__(self, other): def __hash__(self): return hash(str(self)) - def __repr__(self): + def __repr__(self) -> str: return str(self) diff --git a/octopoes/octopoes/models/persistence.py b/octopoes/octopoes/models/persistence.py index 7c28805e807..82ea8ececd0 100644 --- a/octopoes/octopoes/models/persistence.py +++ b/octopoes/octopoes/models/persistence.py @@ -1,4 +1,4 @@ -from __future__ import annotations +from typing import Any from pydantic import Field from pydantic.fields import FieldInfo @@ -11,7 +11,7 @@ def ReferenceField( *, max_issue_scan_level: int | None = None, max_inherit_scan_level: int | None = None, - **kwargs, + **kwargs: Any, ) -> FieldInfo: if not isinstance(object_type, str): object_type = object_type.get_object_type() diff --git a/octopoes/octopoes/models/tree.py b/octopoes/octopoes/models/tree.py index e6efdc3a3e2..06d7d92a70a 100644 --- a/octopoes/octopoes/models/tree.py +++ b/octopoes/octopoes/models/tree.py @@ -12,7 +12,7 @@ class ReferenceNode(BaseModel): reference: Reference children: dict[str, list[ReferenceNode]] - def filter_children(self, filter_fn: Callable[[ReferenceNode], bool]): + def filter_children(self, filter_fn: Callable[[ReferenceNode], bool]) -> bool: """ Mutable filter function to evict any children from the tree that do not adhere to the provided callback """ diff --git a/octopoes/octopoes/models/types.py b/octopoes/octopoes/models/types.py index 871620eb47d..7d073ce8f2f 100644 --- a/octopoes/octopoes/models/types.py +++ b/octopoes/octopoes/models/types.py @@ -2,6 +2,8 @@ from collections.abc import Iterator +from pydantic.fields import FieldInfo + from octopoes.models import OOI, Reference from octopoes.models.exception import TypeNotFound from octopoes.models.ooi.certificate import ( @@ -213,14 +215,14 @@ def to_concrete(object_types: set[type[OOI]]) -> set[type[OOI]]: return concrete_types -def type_by_name(type_name: str): +def type_by_name(type_name: str) -> type[OOI]: try: return next(t for t in ALL_TYPES if t.__name__ == type_name) except StopIteration: raise TypeNotFound -def related_object_type(field) -> type[OOI]: +def related_object_type(field: FieldInfo) -> type[OOI]: object_type: str | type[OOI] = field.json_schema_extra["object_type"] if isinstance(object_type, str): return type_by_name(object_type) diff --git a/octopoes/octopoes/repositories/ooi_repository.py b/octopoes/octopoes/repositories/ooi_repository.py index b582ffb6abd..7c432b6250a 100644 --- a/octopoes/octopoes/repositories/ooi_repository.py +++ b/octopoes/octopoes/repositories/ooi_repository.py @@ -134,7 +134,16 @@ def count_findings_by_severity(self, valid_time: datetime) -> Counter: raise NotImplementedError def list_findings( - self, severities, valid_time, exclude_muted, only_muted, offset, limit, search_string, order_by, asc_desc + self, + severities: set[RiskLevelSeverity], + valid_time: datetime, + exclude_muted: bool = False, + only_muted: bool = False, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, + search_string: str | None = None, + order_by: Literal["score", "finding_type"] = "score", + asc_desc: Literal["asc", "desc"] = "desc", ) -> Paginated[Finding]: raise NotImplementedError @@ -694,10 +703,10 @@ def list_findings( self, severities: set[RiskLevelSeverity], valid_time: datetime, - exclude_muted=False, - only_muted=False, - offset=DEFAULT_OFFSET, - limit=DEFAULT_LIMIT, + exclude_muted: bool = False, + only_muted: bool = False, + offset: int = DEFAULT_OFFSET, + limit: int = DEFAULT_LIMIT, search_string: str | None = None, order_by: Literal["score", "finding_type"] = "score", asc_desc: Literal["asc", "desc"] = "desc", diff --git a/octopoes/octopoes/repositories/origin_parameter_repository.py b/octopoes/octopoes/repositories/origin_parameter_repository.py index e6054f5a2e6..d4f335585c7 100644 --- a/octopoes/octopoes/repositories/origin_parameter_repository.py +++ b/octopoes/octopoes/repositories/origin_parameter_repository.py @@ -71,7 +71,7 @@ def list_by_origin(self, origin_id: set[str], valid_time: datetime) -> list[Orig results = self.session.client.query(query, valid_time=valid_time) return [self.deserialize(r[0]) for r in results] - def list_by_reference(self, reference: Reference, valid_time: datetime): + def list_by_reference(self, reference: Reference, valid_time: datetime) -> list[OriginParameter]: query = generate_pull_query( FieldSet.ALL_FIELDS, {"reference": str(reference), "type": OriginParameter.__name__} ) diff --git a/octopoes/octopoes/repositories/scan_profile_repository.py b/octopoes/octopoes/repositories/scan_profile_repository.py index e1b8fcf2bd3..c27954e0009 100644 --- a/octopoes/octopoes/repositories/scan_profile_repository.py +++ b/octopoes/octopoes/repositories/scan_profile_repository.py @@ -50,7 +50,7 @@ def commit(self): self.session.commit() @classmethod - def format_id(cls, ooi_reference: Reference): + def format_id(cls, ooi_reference: Reference) -> str: return f"{cls.object_type}|{ooi_reference}" @classmethod diff --git a/octopoes/octopoes/tasks/tasks.py b/octopoes/octopoes/tasks/tasks.py index 1434993e3c6..b9f734be4ec 100644 --- a/octopoes/octopoes/tasks/tasks.py +++ b/octopoes/octopoes/tasks/tasks.py @@ -3,6 +3,7 @@ from datetime import datetime, timezone from logging import config from pathlib import Path +from typing import Any import structlog import yaml @@ -65,7 +66,7 @@ def init_worker(**kwargs): @app.task(queue=QUEUE_NAME_OCTOPOES) -def handle_event(event: dict): +def handle_event(event: dict) -> None: try: parsed_event: DBEvent = TypeAdapter(DBEventType).validate_python(event) @@ -96,7 +97,7 @@ def schedule_scan_profile_recalculations(): @app.task(queue=QUEUE_NAME_OCTOPOES) -def recalculate_scan_profiles(org: str, *args, **kwargs): +def recalculate_scan_profiles(org: str, *args: Any, **kwargs: Any) -> None: session = XTDBSession(get_xtdb_client(str(settings.xtdb_uri), org)) octopoes = bootstrap_octopoes(settings, org, session) diff --git a/octopoes/octopoes/xtdb/client.py b/octopoes/octopoes/xtdb/client.py index 57561c43068..388dae2dfbc 100644 --- a/octopoes/octopoes/xtdb/client.py +++ b/octopoes/octopoes/xtdb/client.py @@ -56,7 +56,7 @@ def _get_xtdb_http_session(base_url: str) -> httpx.Client: class XTDBHTTPClient: - def __init__(self, base_url, client: str): + def __init__(self, base_url: str, client: str): self._client = client self._session = _get_xtdb_http_session(base_url) @@ -173,7 +173,7 @@ def export_transactions(self): self._verify_response(res) return res.json() - def sync(self, timeout: int | None = None): + def sync(self, timeout: int | None = None) -> Any: params = {} if timeout is not None: @@ -198,10 +198,10 @@ def __enter__(self): def __exit__(self, _exc_type: type[Exception], _exc_value: str, _exc_traceback: str) -> None: self.commit() - def add(self, operation: Operation): + def add(self, operation: Operation) -> None: self._operations.append(operation) - def put(self, document: str | dict[str, Any], valid_time: datetime): + def put(self, document: str | dict[str, Any], valid_time: datetime) -> None: self.add((OperationType.PUT, document, valid_time)) def commit(self) -> None: @@ -219,5 +219,5 @@ def commit(self) -> None: logger.info("Called %s callbacks after committing XTDBSession", len(self.post_commit_callbacks)) self.post_commit_callbacks = [] - def listen_post_commit(self, callback: Callable[[], None]): + def listen_post_commit(self, callback: Callable[[], None]) -> None: self.post_commit_callbacks.append(callback) diff --git a/octopoes/octopoes/xtdb/query.py b/octopoes/octopoes/xtdb/query.py index 3be2464028f..147c58174a0 100644 --- a/octopoes/octopoes/xtdb/query.py +++ b/octopoes/octopoes/xtdb/query.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass, field from uuid import UUID, uuid4 @@ -76,13 +78,13 @@ class Query: _offset: int | None = None _order_by: tuple[Aliased, bool] | None = None - def where(self, ooi_type: Ref, **kwargs) -> "Query": + def where(self, ooi_type: Ref, **kwargs: Ref | str | set[str] | bool) -> Query: for field_name, value in kwargs.items(): self._where_field_is(ooi_type, field_name, value) return self - def where_in(self, ooi_type: Ref, **kwargs: list[str]) -> "Query": + def where_in(self, ooi_type: Ref, **kwargs: list[str]) -> Query: """Allows for filtering on multiple values for a specific field.""" for field_name, values in kwargs.items(): @@ -94,7 +96,7 @@ def format(self) -> str: return self._compile(separator="\n ") @classmethod - def from_path(cls, path: Path) -> "Query": + def from_path(cls, path: Path) -> Query: """ Create a query from a Path. @@ -147,14 +149,14 @@ def from_path(cls, path: Path) -> "Query": return query - def pull(self, ooi_type: Ref, *, fields: str = "[*]") -> "Query": + def pull(self, ooi_type: Ref, *, fields: str = "[*]") -> Query: """By default, we pull the target type. But when using find, count, etc., you have to pull explicitly.""" self._find_clauses.append(f"(pull {self._get_object_alias(ooi_type)} {fields})") return self - def find(self, item: Ref, *, index: int | None = None) -> "Query": + def find(self, item: Ref, *, index: int | None = None) -> Query: """Add a find clause, so we can select specific fields in a query to be returned as well.""" if index is None: @@ -164,27 +166,27 @@ def find(self, item: Ref, *, index: int | None = None) -> "Query": return self - def count(self, ooi_type: Ref) -> "Query": + def count(self, ooi_type: Ref) -> Query: self._find_clauses.append(f"(count {self._get_object_alias(ooi_type)})") return self - def limit(self, limit: int) -> "Query": + def limit(self, limit: int) -> Query: self._limit = limit return self - def offset(self, offset: int) -> "Query": + def offset(self, offset: int) -> Query: self._offset = offset return self - def order_by(self, ref: Aliased, ascending: bool = True) -> "Query": + def order_by(self, ref: Aliased, ascending: bool = True) -> Query: self._order_by = (ref, ascending) return self - def _where_field_is(self, ref: Ref, field_name: str, value: Ref | str | set[str]) -> None: + def _where_field_is(self, ref: Ref, field_name: str, value: Ref | str | set[str] | bool) -> None: """ We need isinstance(value, type) checks to verify value is an OOIType, as issubclass() fails on non-classes: @@ -321,7 +323,7 @@ def _assert_type(self, ref: Ref, ooi_type: type[OOI]) -> str: def _to_object_type_statement(self, ref: Ref, other_type: type[OOI]) -> str: return f'[ {self._get_object_alias(ref)} :object_type "{other_type.get_object_type()}" ]' - def _compile_where_clauses(self, *, separator=" ") -> str: + def _compile_where_clauses(self, *, separator: str = " ") -> str: """Sorted and deduplicated where clauses, since they are both idempotent and commutative""" return separator + separator.join(sorted(set(self._where_clauses))) @@ -329,7 +331,7 @@ def _compile_where_clauses(self, *, separator=" ") -> str: def _compile_find_clauses(self) -> str: return " ".join(self._find_clauses) - def _compile(self, *, separator=" ") -> str: + def _compile(self, *, separator: str = " ") -> str: result_ooi_type = self.result_type.type if isinstance(self.result_type, Aliased) else self.result_type self._where_clauses.append(self._assert_type(self.result_type, result_ooi_type)) @@ -365,7 +367,7 @@ def _get_object_alias(self, object_type: Ref) -> str: def __str__(self) -> str: return self._compile() - def __eq__(self, other: object): + def __eq__(self, other: object) -> bool: if not isinstance(other, Query): return NotImplemented diff --git a/octopoes/octopoes/xtdb/query_builder.py b/octopoes/octopoes/xtdb/query_builder.py index 172cfc745c2..54335f5c8d8 100644 --- a/octopoes/octopoes/xtdb/query_builder.py +++ b/octopoes/octopoes/xtdb/query_builder.py @@ -2,7 +2,8 @@ from collections.abc import Iterable, Mapping from typing import Any -from octopoes.xtdb.related_field_generator import FieldSet, RelatedFieldNode +from octopoes.xtdb import FieldSet +from octopoes.xtdb.related_field_generator import RelatedFieldNode def join_csv(values: Iterable[Any]) -> str: diff --git a/octopoes/octopoes/xtdb/related_field_generator.py b/octopoes/octopoes/xtdb/related_field_generator.py index 1b44112338c..4a26f9a2114 100644 --- a/octopoes/octopoes/xtdb/related_field_generator.py +++ b/octopoes/octopoes/xtdb/related_field_generator.py @@ -51,7 +51,7 @@ def construct_incoming_relations(self): RelatedFieldNode(self.data_model, {foreign_object_type}, self.path + (foreign_key,)) ) - def build_tree(self, depth: int): + def build_tree(self, depth: int) -> None: if depth > 0: self.construct_outgoing_relations() for child_node in self.relations_out.values(): @@ -61,7 +61,7 @@ def build_tree(self, depth: int): for child_node in self.relations_in.values(): child_node.build_tree(depth - 1) - def generate_field(self, field_set: FieldSet, pk_prefix: str): + def generate_field(self, field_set: FieldSet, pk_prefix: str) -> str: queried_fields = pk_prefix if field_set is FieldSet.ONLY_ID else "*" """ Output dicts in XTDB Query Language @@ -105,10 +105,10 @@ def search_nodes(self, search_object_types=set[str]): # Match self return not self.object_types.isdisjoint(search_object_types) - def __repr__(self): + def __repr__(self) -> str: return f"QueryNode[{self}]" - def __str__(self): + def __str__(self) -> str: return ",".join(self.object_types) def __eq__(self, other): diff --git a/pyproject.toml b/pyproject.toml index 021561d838d..7af9b569c1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,8 +2,6 @@ python_version = "3.10" plugins = ["pydantic.mypy"] strict = true -follow_imports = "skip" -warn_unused_ignores = false # This gives false positives in pre-commit as long as we don't enable follow imports disallow_subclassing_any = false disallow_untyped_decorators = false # Needed for FastAPI decorators disallow_any_generics = false @@ -14,8 +12,17 @@ no_implicit_reexport = false warn_return_any = false [[tool.mypy.overrides]] -module = ["httpx.*"] -follow_imports = "normal" +# Following pydantic imports currently gives 2000 errors +module = ["pydantic.*"] +follow_imports = "skip" + +[[tool.mypy.overrides]] +module = ["bytes.*", "cveapi.*"] +disallow_any_generics = true +disallow_untyped_defs = true +disallow_untyped_calls = true +disallow_incomplete_defs = true +no_implicit_reexport = true [tool.setuptools_scm] write_to = "_version.py" diff --git a/rocky/account/admin.py b/rocky/account/admin.py index 952d6ff76fc..e5a7e07a8d7 100644 --- a/rocky/account/admin.py +++ b/rocky/account/admin.py @@ -10,7 +10,6 @@ @admin.register(User) class KATUserAdmin(UserAdmin): - model = User list_display = ("email", "is_staff", "is_active") fieldsets = ( (None, {"fields": ("email", "password", "full_name")}), diff --git a/rocky/account/forms/__init__.py b/rocky/account/forms/__init__.py index 448a075559d..7da11d44645 100644 --- a/rocky/account/forms/__init__.py +++ b/rocky/account/forms/__init__.py @@ -11,3 +11,19 @@ from account.forms.login import LoginForm from account.forms.password_reset import PasswordResetForm from account.forms.token import TwoFactorBackupTokenForm, TwoFactorSetupTokenForm, TwoFactorVerifyTokenForm + +__all__ = [ + "AccountTypeSelectForm", + "IndemnificationAddForm", + "MemberRegistrationForm", + "OnboardingOrganizationUpdateForm", + "OrganizationForm", + "OrganizationMemberEditForm", + "OrganizationUpdateForm", + "SetPasswordForm", + "LoginForm", + "PasswordResetForm", + "TwoFactorBackupTokenForm", + "TwoFactorSetupTokenForm", + "TwoFactorVerifyTokenForm", +] diff --git a/rocky/account/forms/organization.py b/rocky/account/forms/organization.py index 64a6b426f60..7bc8b9e72ec 100644 --- a/rocky/account/forms/organization.py +++ b/rocky/account/forms/organization.py @@ -28,11 +28,11 @@ def populate_dropdown_list(self, user): organizations.append([organization.code, organization.name]) if organizations: - props = { - "required": True, - "label": _("Organizations"), - "help_text": _("The organization from which to clone settings."), - "error_messages": self.error_messages, - } - self.fields["organization"] = forms.ChoiceField(**props) + self.fields["organization"] = forms.ChoiceField( + required=True, + label=_("Organizations"), + help_text=_("The organization from which to clone settings."), + error_messages=self.error_messages, + ) + self.fields["organization"].choices = [BLANK_CHOICE] + organizations diff --git a/rocky/account/mixins.py b/rocky/account/mixins.py index 00dde712924..1ec7c87d61c 100644 --- a/rocky/account/mixins.py +++ b/rocky/account/mixins.py @@ -9,6 +9,7 @@ from django.http import Http404 from django.utils.translation import gettext_lazy as _ from django.views import View +from django.views.generic.base import ContextMixin from katalogus.client import KATalogus, get_katalogus from rest_framework.exceptions import ValidationError from rest_framework.request import Request @@ -31,7 +32,7 @@ class OrganizationPermLookupDict: def __init__(self, organization_member, app_label): self.organization_member, self.app_label = organization_member, app_label - def __repr__(self): + def __repr__(self) -> str: return str(self.organization_member.get_all_permissions) def __getitem__(self, perm_name): @@ -50,7 +51,7 @@ class OrganizationPermWrapper: def __init__(self, organization_member): self.organization_member = organization_member - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__qualname__}({self.organization_member!r})" def __getitem__(self, app_label): @@ -71,7 +72,7 @@ def __contains__(self, perm_name): return self[app_label][perm_name] -class OrganizationView(View): +class OrganizationView(ContextMixin, View): def setup(self, request, *args, **kwargs): super().setup(request, *args, **kwargs) diff --git a/rocky/account/models.py b/rocky/account/models.py index 2552a560146..55d7f75322e 100644 --- a/rocky/account/models.py +++ b/rocky/account/models.py @@ -138,7 +138,7 @@ class Meta: EVENT_CODES = {"created": 900111, "updated": 900122, "deleted": 900123} - def __str__(self): + def __str__(self) -> str: return f"{self.name} ({self.user})" def generate_new_token(self) -> str: diff --git a/rocky/account/views/account.py b/rocky/account/views/account.py index c16b78188ed..b287562b073 100644 --- a/rocky/account/views/account.py +++ b/rocky/account/views/account.py @@ -23,7 +23,7 @@ def post(self, request, *args, **kwargs): # Mypy doesn't have the information to understand this return self.get(request, *args, **kwargs) # type: ignore[attr-defined] - def handle_page_action(self, action: str): + def handle_page_action(self, action: str) -> None: if action == PageActions.ACCEPT_CLEARANCE.value: self.organization_member.acknowledged_clearance_level = self.organization_member.trusted_clearance_level elif action == PageActions.WITHDRAW_ACCEPTANCE.value: diff --git a/rocky/crisis_room/views.py b/rocky/crisis_room/views.py index 7a39cd7abcc..4a2f8163882 100644 --- a/rocky/crisis_room/views.py +++ b/rocky/crisis_room/views.py @@ -14,8 +14,7 @@ from octopoes.connector import ConnectorException from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models.ooi.findings import RiskLevelSeverity -from rocky.views.mixins import ObservedAtMixin -from rocky.views.ooi_view import ConnectorFormMixin +from rocky.views.mixins import ConnectorFormMixin, ObservedAtMixin logger = structlog.get_logger(__name__) diff --git a/rocky/katalogus/client.py b/rocky/katalogus/client.py index b10dba57d12..d9ace3a80ce 100644 --- a/rocky/katalogus/client.py +++ b/rocky/katalogus/client.py @@ -56,7 +56,7 @@ class Plugin(BaseModel): # make sense out of it: for which organization is this plugin in fact enabled? enabled: bool - def can_scan(self, member) -> bool: + def can_scan(self, member: OrganizationMember) -> bool: return member.has_perm("tools.can_scan_organization") @@ -73,7 +73,7 @@ class Boefje(Plugin): # use a custom field_serializer for `consumes` @field_serializer("consumes") - def serialize_consumes(self, consumes: set[type[OOI]]): + def serialize_consumes(self, consumes: set[type[OOI]]) -> set[str]: return {ooi_class.get_ooi_type() for ooi_class in consumes} @field_validator("boefje_schema") @@ -89,7 +89,7 @@ def json_schema_valid(cls, boefje_schema: dict) -> dict | None: return boefje_schema - def can_scan(self, member) -> bool: + def can_scan(self, member: OrganizationMember) -> bool: return super().can_scan(member) and member.has_clearance_level(self.scan_level.value) @@ -99,7 +99,7 @@ class Normalizer(Plugin): # use a custom field_serializer for `produces` @field_serializer("produces") - def serialize_produces(self, produces: set[type[OOI]]): + def serialize_produces(self, produces: set[type[OOI]]) -> set[str]: return {ooi_class.get_ooi_type() for ooi_class in produces} diff --git a/rocky/katalogus/forms/plugin_settings.py b/rocky/katalogus/forms/plugin_settings.py index 303f0df4816..22a7b38b282 100644 --- a/rocky/katalogus/forms/plugin_settings.py +++ b/rocky/katalogus/forms/plugin_settings.py @@ -1,3 +1,5 @@ +from typing import Any + from django import forms from django.utils.translation import gettext_lazy as _ from jsonschema.validators import Draft202012Validator @@ -11,7 +13,7 @@ class PluginSchemaForm(forms.Form): error_messages = {"required": _("This field is required.")} - def __init__(self, plugin_schema: dict, values: dict, *args, **kwargs): + def __init__(self, plugin_schema: dict, values: dict, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.plugin_schema = plugin_schema self.values = values diff --git a/rocky/katalogus/views/boefje_setup.py b/rocky/katalogus/views/boefje_setup.py index 6d386a55961..af868f229de 100644 --- a/rocky/katalogus/views/boefje_setup.py +++ b/rocky/katalogus/views/boefje_setup.py @@ -23,8 +23,8 @@ class BoefjeSetupView(OrganizationPermissionRequiredMixin, OrganizationView, For def setup(self, request, *args, **kwargs): super().setup(request, *args, **kwargs) - self.plugin_id = uuid.uuid4() - self.created = str(datetime.now()) + self.plugin_id = str(uuid.uuid4()) + self.created: str | None = str(datetime.now()) self.query_params = urlencode({"new_variant": True}) def get_success_url(self) -> str: @@ -209,7 +209,7 @@ def get_context_data(self, **kwargs): return context -def create_boefje_with_form_data(form_data, plugin_id: str, created: str): +def create_boefje_with_form_data(form_data, plugin_id: str, created: str | None): arguments = [] if not form_data["oci_arguments"] else form_data["oci_arguments"].split() consumes = [] if not form_data["consumes"] else form_data["consumes"].strip("[]").replace("'", "").split(", ") produces = [] if not form_data["produces"] else form_data["produces"].split(",") diff --git a/rocky/katalogus/views/mixins.py b/rocky/katalogus/views/mixins.py index 022cb16ae80..b742bf91f58 100644 --- a/rocky/katalogus/views/mixins.py +++ b/rocky/katalogus/views/mixins.py @@ -1,7 +1,9 @@ +from typing import Any + import structlog from account.mixins import OrganizationView from django.contrib import messages -from django.http import Http404 +from django.http import Http404, HttpRequest from django.shortcuts import redirect from django.urls import reverse from django.utils.translation import gettext_lazy as _ @@ -17,7 +19,7 @@ class SinglePluginView(OrganizationView): katalogus_client: KATalogus plugin: Plugin - def setup(self, request, *args, plugin_id: str, **kwargs): + def setup(self, request: HttpRequest, *args: Any, plugin_id: str, **kwargs: Any) -> None: """ Prepare organization info and KAT-alogus API client. """ diff --git a/rocky/onboarding/view_helpers.py b/rocky/onboarding/view_helpers.py index 78146742e49..d7f57d41f54 100644 --- a/rocky/onboarding/view_helpers.py +++ b/rocky/onboarding/view_helpers.py @@ -2,7 +2,7 @@ from django.utils.translation import gettext_lazy as _ from reports.views.base import get_selection from tools.models import Organization -from tools.view_helpers import BreadcrumbsMixin, StepsMixin +from tools.view_helpers import Breadcrumb, BreadcrumbsMixin, StepsMixin ONBOARDING_PERMISSIONS = ( "tools.can_scan_organization", @@ -77,7 +77,7 @@ def build_steps(self): class OnboardingBreadcrumbsMixin(BreadcrumbsMixin): organization: Organization - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: return [ { "url": reverse_lazy("step_introduction", kwargs={"organization_code": self.organization.code}), diff --git a/rocky/onboarding/views.py b/rocky/onboarding/views.py index 4da8bf5b43e..783fd0599e2 100644 --- a/rocky/onboarding/views.py +++ b/rocky/onboarding/views.py @@ -400,11 +400,8 @@ class OnboardingOrganizationSetupView(PermissionRequiredMixin, IntroductionRegis permission_required = "tools.add_organization" def get(self, request, *args, **kwargs): - members = OrganizationMember.objects.filter(user=self.request.user) - if members: - return redirect( - reverse("step_organization_update", kwargs={"organization_code": members.first().organization.code}) - ) + if member := OrganizationMember.objects.filter(user=self.request.user).first(): + return redirect(reverse("step_organization_update", kwargs={"organization_code": member.organization.code})) return super().get(request, *args, **kwargs) def post(self, request, *args, **kwargs): diff --git a/rocky/reports/forms.py b/rocky/reports/forms.py index e9a6e58815e..7ee01c1ee1d 100644 --- a/rocky/reports/forms.py +++ b/rocky/reports/forms.py @@ -1,4 +1,5 @@ from datetime import datetime, timezone +from typing import Any from django import forms from django.utils.translation import gettext_lazy as _ @@ -12,7 +13,7 @@ class OOITypeMultiCheckboxForReportForm(BaseRockyForm): label=_("Filter by OOI types"), required=False, widget=forms.CheckboxSelectMultiple ) - def __init__(self, ooi_types: list[str], *args, **kwargs): + def __init__(self, ooi_types: list[str], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.fields["ooi_type"].choices = ((ooi_type, ooi_type) for ooi_type in ooi_types) @@ -22,7 +23,7 @@ class ReportTypeMultiselectForm(BaseRockyForm): label=_("Report types"), required=False, widget=forms.CheckboxSelectMultiple ) - def __init__(self, report_types: set[Report], *args, **kwargs): + def __init__(self, report_types: set[Report], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) report_types_choices = ((report_type.id, report_type.name) for report_type in report_types) self.fields["report_type"].choices = report_types_choices diff --git a/rocky/reports/report_types/aggregate_organisation_report/report.py b/rocky/reports/report_types/aggregate_organisation_report/report.py index 46f1ae506cf..de7f1fbc61f 100644 --- a/rocky/reports/report_types/aggregate_organisation_report/report.py +++ b/rocky/reports/report_types/aggregate_organisation_report/report.py @@ -2,6 +2,7 @@ from typing import Any import structlog +from django.utils.translation import gettext_lazy as _ from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models import OOI @@ -24,7 +25,7 @@ class AggregateOrganisationReport(AggregateReport): id = "aggregate-organisation-report" - name = "Aggregate Organisation Report" + name = _("Aggregate Organisation Report") description = "Aggregate Organisation Report" reports = { "required": [SystemReport], @@ -411,7 +412,9 @@ def is_mail_compliant(result): "config_oois": config_oois, } - def collect_system_specific_data(self, data, services, system_type: str, report_id: str) -> dict[str, Any]: + def collect_system_specific_data( + self, data: dict, services: dict, system_type: str, report_id: str + ) -> dict[str, Any]: """Given a system, return a list of report data from the right sub-reports based on the related report_id""" report_data: dict[str, Any] = {} diff --git a/rocky/reports/report_types/definitions.py b/rocky/reports/report_types/definitions.py index ec7c4be9ed6..05262d3f3c0 100644 --- a/rocky/reports/report_types/definitions.py +++ b/rocky/reports/report_types/definitions.py @@ -3,6 +3,8 @@ from pathlib import Path from typing import Any, TypedDict, TypeVar +from django.utils.functional import Promise + from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models import OOI, Reference from octopoes.models.ooi.dns.zone import Hostname @@ -37,8 +39,8 @@ def report_plugins_union(report_types: list[type["BaseReport"]]) -> ReportPlugin class BaseReport: id: str - name: str - description: str + name: Promise + description: Promise template_path: str = "report.html" plugins: ReportPlugins input_ooi_types: set[type[OOI]] diff --git a/rocky/reports/report_types/multi_organization_report/report.py b/rocky/reports/report_types/multi_organization_report/report.py index bc9d66ba5b2..1fbc91a7c4b 100644 --- a/rocky/reports/report_types/multi_organization_report/report.py +++ b/rocky/reports/report_types/multi_organization_report/report.py @@ -255,7 +255,9 @@ def post_process_data(self, data: dict[str, Any]) -> dict[str, Any]: } -def collect_report_data(connector: OctopoesAPIConnector, input_ooi_references: list[str], observed_at: datetime): +def collect_report_data( + connector: OctopoesAPIConnector, input_ooi_references: list[str], observed_at: datetime +) -> dict: report_data = {} for ooi in [x for x in input_ooi_references if Reference.from_str(x).class_type == ReportData]: report_data[ooi] = connector.get(Reference.from_str(ooi), observed_at).model_dump() diff --git a/rocky/reports/report_types/name_server_report/report.py b/rocky/reports/report_types/name_server_report/report.py index 6d0c0a84ec3..cdec60e352c 100644 --- a/rocky/reports/report_types/name_server_report/report.py +++ b/rocky/reports/report_types/name_server_report/report.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Iterable from dataclasses import dataclass, field from datetime import datetime @@ -37,13 +39,13 @@ def has_dnssec(self): def has_valid_dnssec(self): return sum([check.has_valid_dnssec for check in self.checks]) - def __bool__(self): + def __bool__(self) -> bool: return all(bool(check) for check in self.checks) - def __len__(self): + def __len__(self) -> int: return len(self.checks) - def __add__(self, other: "NameServerChecks"): + def __add__(self, other: NameServerChecks) -> NameServerChecks: return NameServerChecks(checks=self.checks + other.checks) diff --git a/rocky/reports/report_types/web_system_report/report.py b/rocky/reports/report_types/web_system_report/report.py index 20a6d22df70..47f62c760c9 100644 --- a/rocky/reports/report_types/web_system_report/report.py +++ b/rocky/reports/report_types/web_system_report/report.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections.abc import Iterable from dataclasses import dataclass, field from datetime import datetime @@ -77,13 +79,13 @@ def certificates_not_expired(self): def certificates_not_expiring_soon(self): return sum([check.certificates_not_expiring_soon for check in self.checks]) - def __bool__(self): + def __bool__(self) -> bool: return all(bool(check) for check in self.checks) - def __len__(self): + def __len__(self) -> int: return len(self.checks) - def __add__(self, other: "WebChecks"): + def __add__(self, other: WebChecks) -> WebChecks: return WebChecks(checks=self.checks + other.checks) diff --git a/rocky/reports/templates/report_overview/report_history_table.html b/rocky/reports/templates/report_overview/report_history_table.html index 9452ad6cb24..cee41f0f203 100644 --- a/rocky/reports/templates/report_overview/report_history_table.html +++ b/rocky/reports/templates/report_overview/report_history_table.html @@ -154,11 +154,11 @@ {% translate "Subreports details:" %}
{% translate "Report types" %}

- {% blocktranslate count counter=report.total_children_reports %} - This report consist of {{counter}} subreport with the following report type and object. - {% plural %} - This report consist of {{counter}} subreports with the following report types and objects. - {% endblocktranslate %} + {% blocktranslate trimmed count counter=report.total_children_reports %} + This report consist of {{ counter }} subreport with the following report type and object. + {% plural %} + This report consist of {{ counter }} subreports with the following report types and objects. + {% endblocktranslate %}

diff --git a/rocky/reports/views/aggregate_report.py b/rocky/reports/views/aggregate_report.py index cf08dec316d..ef98c4fd153 100644 --- a/rocky/reports/views/aggregate_report.py +++ b/rocky/reports/views/aggregate_report.py @@ -8,7 +8,7 @@ from django.views.generic import TemplateView from httpx import HTTPError from katalogus.client import get_katalogus -from tools.view_helpers import PostRedirect +from tools.view_helpers import Breadcrumb, PostRedirect from reports.report_types.aggregate_organisation_report.report import AggregateOrganisationReport from reports.views.base import ( @@ -26,7 +26,7 @@ class BreadcrumbsAggregateReportView(ReportBreadcrumbs): - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() kwargs = self.get_kwargs() selection = get_selection(self.request) diff --git a/rocky/reports/views/base.py b/rocky/reports/views/base.py index 6adc3d81ea5..3eb27848bf7 100644 --- a/rocky/reports/views/base.py +++ b/rocky/reports/views/base.py @@ -1,5 +1,5 @@ from collections import defaultdict -from collections.abc import Iterable, Sequence +from collections.abc import Iterable, Mapping, Sequence from datetime import datetime, timezone from operator import attrgetter from typing import Any, Literal, cast @@ -20,7 +20,7 @@ from katalogus.client import Boefje, KATalogus, KATalogusError, Plugin from pydantic import RootModel, TypeAdapter from tools.ooi_helpers import create_ooi -from tools.view_helpers import BreadcrumbsMixin, PostRedirect, url_with_querystring +from tools.view_helpers import Breadcrumb, BreadcrumbsMixin, PostRedirect, url_with_querystring from octopoes.models import OOI, Reference from octopoes.models.ooi.reports import Report as ReportOOI @@ -47,7 +47,7 @@ REPORTS_PRE_SELECTION = {"clearance_level": ["2", "3", "4"], "clearance_type": "declared"} -def get_selection(request: HttpRequest, pre_selection: dict[str, str | Sequence[str]] | None = None) -> str: +def get_selection(request: HttpRequest, pre_selection: Mapping[str, str | Sequence[str]] | None = None) -> str: if pre_selection is not None: return "?" + urlencode(pre_selection, True) return "?" + urlencode(request.GET, True) @@ -80,13 +80,11 @@ def get_kwargs(self): def is_valid_breadcrumbs(self): return self.breadcrumbs_step < len(self.breadcrumbs) - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: kwargs = self.get_kwargs() selection = get_selection(self.request) - breadcrumbs = [{"url": reverse("reports", kwargs=kwargs) + selection, "text": _("Reports")}] - - return breadcrumbs + return [{"url": reverse("reports", kwargs=kwargs) + selection, "text": _("Reports")}] def get_breadcrumbs(self): if self.is_valid_breadcrumbs(): diff --git a/rocky/reports/views/generate_report.py b/rocky/reports/views/generate_report.py index 67de9a26778..c6f34a6ccd9 100644 --- a/rocky/reports/views/generate_report.py +++ b/rocky/reports/views/generate_report.py @@ -8,7 +8,7 @@ from django.views.generic import TemplateView from httpx import HTTPError from katalogus.client import get_katalogus -from tools.view_helpers import PostRedirect +from tools.view_helpers import Breadcrumb, PostRedirect from reports.views.base import ( REPORTS_PRE_SELECTION, @@ -25,7 +25,7 @@ class BreadcrumbsGenerateReportView(ReportBreadcrumbs): - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() kwargs = self.get_kwargs() selection = get_selection(self.request) diff --git a/rocky/reports/views/multi_report.py b/rocky/reports/views/multi_report.py index 51e771cdcb4..4d26a62dc37 100644 --- a/rocky/reports/views/multi_report.py +++ b/rocky/reports/views/multi_report.py @@ -5,6 +5,7 @@ from django.urls import reverse from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView +from tools.view_helpers import Breadcrumb from reports.report_types.multi_organization_report.report import MultiOrganizationReport from reports.views.base import ( @@ -21,7 +22,7 @@ class BreadcrumbsMultiReportView(ReportBreadcrumbs): - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() kwargs = self.get_kwargs() selection = get_selection(self.request) diff --git a/rocky/rocky/bytes_client.py b/rocky/rocky/bytes_client.py index 80f90daff87..5694e82742c 100644 --- a/rocky/rocky/bytes_client.py +++ b/rocky/rocky/bytes_client.py @@ -28,14 +28,14 @@ def health(self) -> ServiceHealth: return ServiceHealth.model_validate(response.json()) @staticmethod - def raw_from_declarations(declarations: list[Declaration]): + def raw_from_declarations(declarations: list[Declaration]) -> bytes: json_string = f"[{','.join([declaration.model_dump_json() for declaration in declarations])}]" return json_string.encode("utf-8") def add_manual_proof( self, normalizer_id: uuid.UUID, raw: bytes, manual_mime_types: Set[str] = frozenset({"manual/ooi"}) - ): + ) -> None: """Per convention for a generic normalizer, we add a raw list of declarations, not a single declaration""" self.login() diff --git a/rocky/rocky/exceptions.py b/rocky/rocky/exceptions.py index ccd3e720012..8f76f5c052c 100644 --- a/rocky/rocky/exceptions.py +++ b/rocky/rocky/exceptions.py @@ -1,3 +1,6 @@ +from typing import Any + + class RockyError(Exception): pass @@ -21,13 +24,13 @@ class TrustedClearanceLevelTooLowException(ClearanceLevelTooLowException): class ServiceException(RockyError): """Base exception representing an issue with an (external) service""" - def __init__(self, service_name: str, *args): + def __init__(self, service_name: str, *args: Any): super().__init__(*args) self.service_name = service_name class OctopoesException(ServiceException): - def __init__(self, *args): + def __init__(self, *args: Any): super().__init__("Octopoes", *args) diff --git a/rocky/rocky/keiko.py b/rocky/rocky/keiko.py index 472c8139dbb..9078b7b3457 100644 --- a/rocky/rocky/keiko.py +++ b/rocky/rocky/keiko.py @@ -180,7 +180,7 @@ def get_organization_finding_report( return self.get_report(valid_time, "Organisatie", organization_name, store, filters) @classmethod - def ooi_report_file_name(cls, valid_time: datetime, organization_code: str, ooi_id: str): + def ooi_report_file_name(cls, valid_time: datetime, organization_code: str, ooi_id: str) -> str: report_file_name = "_".join( [ "bevindingenrapport", @@ -198,7 +198,7 @@ def ooi_report_file_name(cls, valid_time: datetime, organization_code: str, ooi_ return report_file_name @classmethod - def organization_report_file_name(cls, organization_code: str): + def organization_report_file_name(cls, organization_code: str) -> str: file_name = "_".join( [ "bevindingenrapport_nl", @@ -210,7 +210,7 @@ def organization_report_file_name(cls, organization_code: str): return f"{file_name}.pdf" -def _ooi_field_as_string(findings_grouped: dict, store: dict): +def _ooi_field_as_string(findings_grouped: dict, store: dict) -> dict: new_findings_grouped = {} for finding_type, finding_group in findings_grouped.items(): diff --git a/rocky/rocky/locale/django.pot b/rocky/rocky/locale/django.pot index 962ed1d10cf..744f46adbde 100644 --- a/rocky/rocky/locale/django.pot +++ b/rocky/rocky/locale/django.pot @@ -9,7 +9,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2024-11-25 09:27+0000\n" +"POT-Creation-Date: 2024-12-03 21:56+0000\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -2772,6 +2772,10 @@ msgstr "" msgid "No CVEs have been found." msgstr "" +#: reports/report_types/aggregate_organisation_report/report.py +msgid "Aggregate Organisation Report" +msgstr "" + #: reports/report_types/aggregate_organisation_report/report_design.html msgid "Observed at:" msgstr "" @@ -3809,6 +3813,7 @@ msgid "Add reference date" msgstr "" #: reports/templates/partials/report_names_form.html +#: reports/templates/report_overview/modal_partials/rename_modal.html #: rocky/views/scan_profile.py msgid "Reset" msgstr "" @@ -4261,15 +4266,11 @@ msgstr "" #: reports/templates/report_overview/report_history_table.html #, python-format msgid "" -"\n" -" This report consist of %(counter)s " -"subreport with the following report type and object.\n" -" " +"This report consist of %(counter)s subreport with the following report type " +"and object." msgid_plural "" -"\n" -" This report consist of %(counter)s " -"subreports with the following report types and objects.\n" -" " +"This report consist of %(counter)s subreports with the following report " +"types and objects." msgstr[0] "" msgstr[1] "" @@ -7350,6 +7351,10 @@ msgstr "" msgid "Can not reset scan level. Scan level of {ooi_name} not declared" msgstr "" +#: rocky/views/scans.py +msgid "Scans" +msgstr "" + #: rocky/views/scheduler.py msgid "Your report has been scheduled." msgstr "" diff --git a/rocky/rocky/middleware/onboarding.py b/rocky/rocky/middleware/onboarding.py index fb96f152fa7..3f14fcd71a0 100644 --- a/rocky/rocky/middleware/onboarding.py +++ b/rocky/rocky/middleware/onboarding.py @@ -31,12 +31,12 @@ def middleware(request): if request.user.is_superuser: return redirect(reverse("step_introduction_registration")) - member = OrganizationMember.objects.filter(user=request.user) - # Members with these permissions can run a full DNS-report onboarding. - if member.exists() and member.first().has_perms(ONBOARDING_PERMISSIONS): + if (member := OrganizationMember.objects.filter(user=request.user).first()) and member.has_perms( + ONBOARDING_PERMISSIONS + ): return redirect( - reverse("step_introduction", kwargs={"organization_code": member.first().organization.code}) + reverse("step_introduction", kwargs={"organization_code": member.organization.code}) ) return response diff --git a/rocky/rocky/paginator.py b/rocky/rocky/paginator.py index aa6b5fa8bcf..1ec070b3d18 100644 --- a/rocky/rocky/paginator.py +++ b/rocky/rocky/paginator.py @@ -1,3 +1,5 @@ +from typing import Any + from django.core.paginator import EmptyPage, Page, PageNotAnInteger, Paginator from django.utils.translation import gettext_lazy as _ @@ -9,19 +11,19 @@ def __init__(self, *args, **kwargs) -> None: if self.orphans != 0: raise ValueError("Setting orphans is not supported") - def validate_number(self, number) -> int: + def validate_number(self, number: Any) -> int: """Validate the given 1-based page number.""" try: if isinstance(number, float) and not number.is_integer(): raise ValueError - number = int(number) + parsed_number = int(number) except (TypeError, ValueError): raise PageNotAnInteger(_("That page number is not an integer")) - if number < 1: + if parsed_number < 1: raise EmptyPage(_("That page number is less than 1")) - return number + return parsed_number - def page(self, number) -> Page: + def page(self, number: Any) -> Page: """Return a Page object per page number.""" number = self.validate_number(number) bottom = (number - 1) * self.per_page diff --git a/rocky/rocky/scheduler.py b/rocky/rocky/scheduler.py index 603bdcdf9c5..5b160d5bfbf 100644 --- a/rocky/rocky/scheduler.py +++ b/rocky/rocky/scheduler.py @@ -190,7 +190,7 @@ class PaginatedSchedulesResponse(BaseModel): class LazyTaskList: HARD_LIMIT = 500 - def __init__(self, scheduler_client: SchedulerClient, **kwargs): + def __init__(self, scheduler_client: SchedulerClient, **kwargs: Any): self.scheduler_client = scheduler_client self.kwargs = kwargs self._count: int | None = None @@ -204,7 +204,7 @@ def count(self) -> int: def __len__(self): return self.count - def __getitem__(self, key) -> list[Task]: + def __getitem__(self, key: slice | int) -> list[Task]: if isinstance(key, slice): offset = key.start or 0 limit = min(LazyTaskList.HARD_LIMIT, key.stop - offset or key.stop or LazyTaskList.HARD_LIMIT) @@ -231,7 +231,7 @@ def __init__(self, *args: object, extra_message: str | None = None) -> None: if extra_message is not None: self.message = extra_message + self.message - def __str__(self): + def __str__(self) -> str: return str(self.message) diff --git a/rocky/rocky/views/finding_add.py b/rocky/rocky/views/finding_add.py index 40c313c3033..534fb72054f 100644 --- a/rocky/rocky/views/finding_add.py +++ b/rocky/rocky/views/finding_add.py @@ -1,6 +1,7 @@ from datetime import datetime, timezone from uuid import uuid4 +from django.forms import Form from django.shortcuts import redirect from django.urls.base import reverse from django.utils.translation import gettext_lazy as _ @@ -73,7 +74,7 @@ def get_form_kwargs(self): return kwargs - def get_form(self, form_class=None) -> FindingAddForm: + def get_form(self, form_class: type[Form] | None = None) -> FindingAddForm: if form_class is None: form_class = self.get_form_class() diff --git a/rocky/rocky/views/finding_list.py b/rocky/rocky/views/finding_list.py index 4de08ef74d6..4034ccff1cf 100644 --- a/rocky/rocky/views/finding_list.py +++ b/rocky/rocky/views/finding_list.py @@ -14,7 +14,7 @@ OrderByFindingTypeForm, OrderBySeverityForm, ) -from tools.view_helpers import BreadcrumbsMixin +from tools.view_helpers import Breadcrumb, BreadcrumbsMixin from octopoes.models.ooi.findings import RiskLevelSeverity from rocky.views.mixins import ConnectorFormMixin, FindingList, OctopoesView, SeveritiesMixin @@ -22,7 +22,7 @@ logger = structlog.get_logger(__name__) -def sort_by_severity_desc(findings) -> list[dict[str, Any]]: +def sort_by_severity_desc(findings: Iterable) -> list[dict[str, Any]]: # Sorting is stable (when multiple records have the same key, their original # order is preserved) so if we first sort by finding id the findings with # the same risk score will be sorted by finding id @@ -117,7 +117,7 @@ class FindingListView(BreadcrumbsMixin, FindingListFilter): template_name = "findings/finding_list.html" paginate_by = 150 - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: return [ { "url": reverse_lazy("finding_list", kwargs={"organization_code": self.organization.code}), @@ -130,7 +130,7 @@ class Top10FindingListView(FindingListView): template_name = "findings/finding_list.html" paginate_by = 10 - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: return [ { "url": reverse_lazy("organization_crisis_room", kwargs={"organization_code": self.organization.code}), diff --git a/rocky/rocky/views/health.py b/rocky/rocky/views/health.py index 857b952d13f..05d9de2b517 100644 --- a/rocky/rocky/views/health.py +++ b/rocky/rocky/views/health.py @@ -1,6 +1,8 @@ +from typing import Any + import structlog from account.mixins import OrganizationView -from django.http import JsonResponse +from django.http import HttpRequest, JsonResponse from django.urls.base import reverse from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView, View @@ -18,7 +20,7 @@ class Health(OrganizationView, View): - def get(self, request, *args, **kwargs) -> JsonResponse: + def get(self, request: HttpRequest, *args: Any, **kwargs: Any) -> JsonResponse: octopoes_connector = self.octopoes_api_connector rocky_health = get_rocky_health(self.organization.code, octopoes_connector) return JsonResponse(rocky_health.model_dump()) diff --git a/rocky/rocky/views/mixins.py b/rocky/rocky/views/mixins.py index 50d0ec2056f..dc53525df47 100644 --- a/rocky/rocky/views/mixins.py +++ b/rocky/rocky/views/mixins.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from dataclasses import dataclass from datetime import datetime, timedelta, timezone from functools import cached_property @@ -22,9 +22,9 @@ from tools.ooi_helpers import get_knowledge_base_data_for_ooi_store from tools.view_helpers import convert_date_to_datetime, get_ooi_url -from octopoes.connector import ObjectNotFoundException from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models import OOI, Reference, ScanLevel, ScanProfileType +from octopoes.models.exception import ObjectNotFoundException from octopoes.models.explanation import InheritanceSection from octopoes.models.ooi.findings import Finding, FindingType, RiskLevelSeverity from octopoes.models.ooi.reports import Report @@ -168,7 +168,7 @@ def get_origins(self, reference: Reference, organization: Organization) -> Origi return results - def handle_connector_exception(self, exception: Exception): + def handle_connector_exception(self, exception: Exception) -> None: if isinstance(exception, ObjectNotFoundException): raise Http404("OOI not found") @@ -261,7 +261,7 @@ def __init__( self, octopoes_connector: OctopoesAPIConnector, valid_time: datetime, - severities: set[RiskLevelSeverity], + severities: Iterable[RiskLevelSeverity], exclude_muted: bool = True, only_muted: bool = False, search_string: str | None = None, @@ -494,7 +494,7 @@ def get_breadcrumb_list(self): }, ] - def get_ooi_properties(self, ooi: OOI): + def get_ooi_properties(self, ooi: OOI) -> dict: class_relations = get_relations(ooi.__class__) props = {field_name: value for field_name, value in ooi if field_name not in class_relations} diff --git a/rocky/rocky/views/ooi_detail_related_object.py b/rocky/rocky/views/ooi_detail_related_object.py index badcbe6ecd2..4efd07d7ad2 100644 --- a/rocky/rocky/views/ooi_detail_related_object.py +++ b/rocky/rocky/views/ooi_detail_related_object.py @@ -10,7 +10,7 @@ from octopoes.models import OOI from octopoes.models.ooi.findings import Finding, FindingType, RiskLevelSeverity from octopoes.models.types import OOI_TYPES, get_relations, to_concrete -from rocky.views.ooi_view import SingleOOITreeMixin +from rocky.views.mixins import SingleOOITreeMixin class OOIRelatedObjectManager(SingleOOITreeMixin): diff --git a/rocky/rocky/views/ooi_list.py b/rocky/rocky/views/ooi_list.py index 513916ec6dd..f5faa1246ee 100644 --- a/rocky/rocky/views/ooi_list.py +++ b/rocky/rocky/views/ooi_list.py @@ -2,12 +2,14 @@ import json from datetime import datetime, timezone from enum import Enum +from typing import Any from django.contrib import messages from django.http import Http404, HttpRequest, HttpResponse from django.shortcuts import redirect from django.urls import reverse, reverse_lazy -from django.utils.translation import gettext_lazy as _ +from django.utils.translation import gettext as _ +from django.utils.translation import gettext_lazy from httpx import HTTPError from tools.enums import CUSTOM_SCAN_LEVEL from tools.forms.ooi_form import OOISearchForm, OOITypeMultiCheckboxForm @@ -32,7 +34,7 @@ class PageActions(Enum): class OOIListView(BaseOOIListView, OctopoesView): - breadcrumbs = [{"url": reverse_lazy("ooi_list"), "text": _("Objects")}] + breadcrumbs = [{"url": reverse_lazy("ooi_list"), "text": gettext_lazy("Objects")}] template_name = "oois/ooi_list.html" def get_context_data(self, **kwargs): @@ -50,14 +52,14 @@ def get_context_data(self, **kwargs): return context - def get(self, request: HttpRequest, *args, status=200, **kwargs) -> HttpResponse: + def get(self, request: HttpRequest, *args: Any, status: int = 200, **kwargs: Any) -> HttpResponse: """Override the response status in case submitting a form returns an error message""" response = super().get(request, *args, **kwargs) response.status_code = status return response - def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: """Perform bulk action on selected oois.""" selected_oois = request.POST.getlist("ooi") if not selected_oois: @@ -82,10 +84,10 @@ def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: return self.get(request, status=404, *args, **kwargs) def _set_scan_profiles( - self, selected_oois: list[Reference], level: CUSTOM_SCAN_LEVEL, request: HttpRequest, *args, **kwargs + self, selected_oois: list[str], level: CUSTOM_SCAN_LEVEL, request: HttpRequest, *args: Any, **kwargs: Any ) -> HttpResponse: try: - self.raise_clearance_levels(selected_oois, level.value) + self.raise_clearance_levels([Reference.from_str(ooi) for ooi in selected_oois], level.value) except IndemnificationNotPresentException: messages.add_message( self.request, @@ -139,7 +141,7 @@ def _set_scan_profiles( return self.get(request, *args, **kwargs) def _set_oois_to_inherit( - self, selected_oois: list[Reference], request: HttpRequest, *args, **kwargs + self, selected_oois: list[str], request: HttpRequest, *args: Any, **kwargs: Any ) -> HttpResponse: scan_profiles = [EmptyScanProfile(reference=Reference.from_str(ooi)) for ooi in selected_oois] @@ -163,12 +165,12 @@ def _set_oois_to_inherit( ) return self.get(request, *args, **kwargs) - def _delete_oois(self, selected_oois: list[Reference], request: HttpRequest, *args, **kwargs) -> HttpResponse: + def _delete_oois(self, selected_oois: list[str], request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: connector = self.octopoes_api_connector valid_time = datetime.now(timezone.utc) try: - connector.delete_many(selected_oois, valid_time) + connector.delete_many([Reference.from_str(ooi) for ooi in selected_oois], valid_time) except (HTTPError, RemoteException, ConnectionError): messages.add_message(request, messages.ERROR, _("An error occurred while deleting oois.")) return self.get(request, status=500, *args, **kwargs) diff --git a/rocky/rocky/views/ooi_tree.py b/rocky/rocky/views/ooi_tree.py index bc1ba41af3c..eae420724f3 100644 --- a/rocky/rocky/views/ooi_tree.py +++ b/rocky/rocky/views/ooi_tree.py @@ -16,7 +16,7 @@ class OOITreeView(BaseOOIDetailView, TemplateView): def get_tree_dict(self): return create_object_tree_item_from_ref(self.tree.root, self.tree.store) - def get_filtered_tree(self, tree_dict): + def get_filtered_tree(self, tree_dict: dict) -> dict: filtered_types = self.request.GET.getlist("ooi_type", []) return filter_ooi_tree(tree_dict, filtered_types) @@ -68,7 +68,7 @@ def get_last_breadcrumb(self): class OOIGraphView(OOITreeView): template_name = "graph-d3.html" - def get_filtered_tree(self, tree_dict): + def get_filtered_tree(self, tree_dict: dict) -> dict: filtered_tree = super().get_filtered_tree(tree_dict) return hydrate_tree(filtered_tree, self.organization.code) @@ -79,11 +79,11 @@ def get_last_breadcrumb(self): } -def hydrate_tree(tree, organization_code: str): +def hydrate_tree(tree: dict, organization_code: str) -> dict: return hydrate_branch(tree, organization_code) -def hydrate_branch(branch, organization_code: str): +def hydrate_branch(branch: dict, organization_code: str) -> dict: branch["name"] = branch["tree_meta"]["location"] + "-" + branch["ooi_type"] branch["overlay_data"] = {"Type": branch["ooi_type"]} if branch["ooi_type"] == "Finding": diff --git a/rocky/rocky/views/ooi_view.py b/rocky/rocky/views/ooi_view.py index d13c93e0170..eea44dec53b 100644 --- a/rocky/rocky/views/ooi_view.py +++ b/rocky/rocky/views/ooi_view.py @@ -2,7 +2,7 @@ from time import sleep from typing import Literal -from django import forms +from django.forms import Form from django.http import Http404 from django.shortcuts import redirect from django.urls import reverse @@ -188,12 +188,12 @@ def build_breadcrumbs(self) -> list[Breadcrumb]: class BaseOOIFormView(SingleOOIMixin, FormView): ooi_class: type[OOI] - form_class: forms.Form = OOIForm + form_class: type[BaseRockyForm] = OOIForm def get_ooi_class(self): return self.ooi.__class__ if hasattr(self, "ooi") else None - def get_form(self, form_class=None) -> BaseRockyForm: + def get_form(self, form_class: type[Form] | None = None) -> BaseRockyForm: form = super().get_form(form_class) # Disable natural key attributes diff --git a/rocky/rocky/views/organization_list.py b/rocky/rocky/views/organization_list.py index a4d6f163645..7617b79caa8 100644 --- a/rocky/rocky/views/organization_list.py +++ b/rocky/rocky/views/organization_list.py @@ -5,7 +5,7 @@ from django.conf import settings from django.contrib import messages from django.core.exceptions import PermissionDenied -from django.db.models import Count +from django.db.models import Count, QuerySet from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest from django.utils.translation import gettext_lazy as _ from django.views.generic import ListView @@ -21,7 +21,7 @@ class OrganizationListView(OrganizationBreadcrumbsMixin, ListView): template_name = "organizations/organization_list.html" - def get_queryset(self) -> list[Organization]: + def get_queryset(self) -> QuerySet[Organization]: user: KATUser = self.request.user return ( Organization.objects.annotate(member_count=Count("members")) diff --git a/rocky/rocky/views/organization_member_add.py b/rocky/rocky/views/organization_member_add.py index c89a0ba0d6c..cff62ccf010 100644 --- a/rocky/rocky/views/organization_member_add.py +++ b/rocky/rocky/views/organization_member_add.py @@ -10,6 +10,7 @@ from django.contrib.auth.models import Group from django.core.exceptions import ObjectDoesNotExist, ValidationError from django.db import transaction +from django.forms import Form from django.http import FileResponse, HttpRequest, HttpResponse from django.shortcuts import redirect from django.urls import reverse_lazy @@ -19,7 +20,7 @@ from onboarding.view_helpers import DNS_REPORT_LEAST_CLEARANCE_LEVEL from tools.forms.upload_csv import UploadCSVForm from tools.models import GROUP_ADMIN, GROUP_CLIENT, GROUP_REDTEAM, OrganizationMember -from tools.view_helpers import OrganizationMemberBreadcrumbsMixin +from tools.view_helpers import Breadcrumb, OrganizationMemberBreadcrumbsMixin from rocky.messaging import clearance_level_warning_dns_report @@ -65,7 +66,7 @@ def get(self, request: HttpRequest, *args: str, **kwargs: Any) -> HttpResponse: ) ) - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() breadcrumbs.append( { @@ -109,7 +110,7 @@ def add_success_notification(self): def get_success_url(self, **kwargs): return reverse_lazy("organization_member_list", kwargs={"organization_code": self.organization.code}) - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() breadcrumbs.extend( [ @@ -157,7 +158,7 @@ def form_valid(self, form): self.process_csv(form) return super().form_valid(form) - def process_csv(self, form) -> None: + def process_csv(self, form: Form) -> None: csv_raw_data = form.cleaned_data["csv_file"].read() csv_data = io.StringIO(csv_raw_data.decode("UTF-8")) @@ -176,7 +177,7 @@ def process_csv(self, form) -> None: ) except KeyError: messages.add_message(self.request, messages.ERROR, _("The csv file is missing required columns")) - return redirect("organization_member_upload", self.organization.code) + return try: with transaction.atomic(): @@ -209,7 +210,7 @@ def process_csv(self, form) -> None: def save_models( self, name: str, email: str, account_type: str, trusted_clearance: int, acknowledged_clearance: int - ): + ) -> None: user, user_created = User.objects.get_or_create(email=email, defaults={"full_name": name}) member_kwargs = { diff --git a/rocky/rocky/views/organization_member_list.py b/rocky/rocky/views/organization_member_list.py index 79681c45b19..4df3a92073e 100644 --- a/rocky/rocky/views/organization_member_list.py +++ b/rocky/rocky/views/organization_member_list.py @@ -35,7 +35,7 @@ def get_queryset(self): queryset = self.model.objects.filter(organization=self.organization) if "client_status" in self.request.GET: status_filter = self.request.GET.getlist("client_status", []) - queryset = [member for member in queryset if member.status in status_filter] + queryset = queryset.filter(status__in=status_filter) if "blocked_status" in self.request.GET: blocked_filter = self.request.GET.getlist("blocked_status", []) @@ -48,7 +48,7 @@ def get_queryset(self): if filter_option == "unblocked": blocked_filter_bools.append(False) - queryset = [member for member in queryset if member.blocked in blocked_filter_bools] + queryset = queryset.filter(blocked__in=blocked_filter_bools) return queryset def setup(self, request, *args, **kwargs): @@ -63,7 +63,7 @@ def post(self, request, *args, **kwargs): self.handle_page_action(request.POST.get("action")) return redirect(reverse("organization_member_list", kwargs={"organization_code": self.organization.code})) - def handle_page_action(self, action: str): + def handle_page_action(self, action: str) -> None: member_id = self.request.POST.get("member_id") organizationmember = self.model.objects.get(id=member_id) try: diff --git a/rocky/rocky/views/organization_settings.py b/rocky/rocky/views/organization_settings.py index 1d4b2676516..35f4d189d79 100644 --- a/rocky/rocky/views/organization_settings.py +++ b/rocky/rocky/views/organization_settings.py @@ -1,5 +1,6 @@ from datetime import datetime from enum import Enum +from typing import Any from account.mixins import OrganizationPermissionRequiredMixin, OrganizationView from django.contrib import messages @@ -23,7 +24,7 @@ class OrganizationSettingsView( template_name = "organizations/organization_settings.html" permission_required = "tools.view_organization" - def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + def post(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: """Perform actions based on action type""" action = request.POST.get("action") if not self.request.user.has_perm("tools.can_recalculate_bits"): diff --git a/rocky/rocky/views/privacy_statement.py b/rocky/rocky/views/privacy_statement.py index 7bba1e1e434..9b369c2e584 100644 --- a/rocky/rocky/views/privacy_statement.py +++ b/rocky/rocky/views/privacy_statement.py @@ -1,4 +1,4 @@ -from django.shortcuts import reverse +from django.urls import reverse from django.utils.translation import gettext_lazy as _ from django.views.generic import TemplateView diff --git a/rocky/rocky/views/scan_profile.py b/rocky/rocky/views/scan_profile.py index 7521e71c24b..a6fa4f6a45a 100644 --- a/rocky/rocky/views/scan_profile.py +++ b/rocky/rocky/views/scan_profile.py @@ -17,7 +17,7 @@ class ScanProfileDetailView(FormView, OOIDetailView): template_name = "scan_profiles/scan_profile_detail.html" form_class = SetClearanceLevelForm - def get_context_data(self, **kwargs) -> dict[str, Any]: + def get_context_data(self, **kwargs: Any) -> dict[str, Any]: context = super().get_context_data(**kwargs) context["mandatory_fields"] = get_mandatory_fields(self.request) if self.ooi.scan_profile and self.ooi.scan_profile.user_id: diff --git a/rocky/rocky/views/scans.py b/rocky/rocky/views/scans.py index c2f62ae72b1..0ce17c061dc 100644 --- a/rocky/rocky/views/scans.py +++ b/rocky/rocky/views/scans.py @@ -1,4 +1,5 @@ from account.mixins import OrganizationView +from django.utils.translation import gettext as _ from django.views.generic import TemplateView from tools.view_helpers import Breadcrumb, ObjectsBreadcrumbsMixin @@ -9,7 +10,7 @@ class ScanListView(ObjectsBreadcrumbsMixin, OrganizationView, TemplateView): def build_breadcrumbs(self) -> list[Breadcrumb]: breadcrumbs = super().build_breadcrumbs() - breadcrumbs.append({"url": "", "text": "Scans"}) + breadcrumbs.append({"url": "", "text": _("Scans")}) return breadcrumbs diff --git a/rocky/rocky/views/upload_csv.py b/rocky/rocky/views/upload_csv.py index 6690e9bd438..915a887bc87 100644 --- a/rocky/rocky/views/upload_csv.py +++ b/rocky/rocky/views/upload_csv.py @@ -17,7 +17,7 @@ from tools.forms.upload_oois import UploadOOICSVForm from octopoes.api.models import Declaration -from octopoes.models import Reference +from octopoes.models import OOI, Reference from octopoes.models.ooi.dns.zone import Hostname from octopoes.models.ooi.network import IPAddressV4, IPAddressV6, Network from octopoes.models.ooi.web import URL @@ -80,7 +80,7 @@ def get_context_data(self, **kwargs): context["criteria"] = CSV_CRITERIA return context - def get_or_create_reference(self, ooi_type_name: str, value: str | None): + def get_or_create_reference(self, ooi_type_name: str, value: str | None) -> OOI: ooi_type_name = next(filter(lambda x: x.casefold() == ooi_type_name.casefold(), self.ooi_types.keys())) # get from cache @@ -101,7 +101,7 @@ def get_or_create_reference(self, ooi_type_name: str, value: str | None): return ooi - def get_ooi_from_csv(self, ooi_type_name: str, values: dict[str, str]): + def get_ooi_from_csv(self, ooi_type_name: str, values: dict[str, str]) -> tuple[OOI, int | None]: key = "clearance" level = int(values[key]) if key in values and values[key] in CLEARANCE_VALUES else None ooi_type = self.ooi_types[ooi_type_name]["type"] @@ -111,7 +111,7 @@ def get_ooi_from_csv(self, ooi_type_name: str, values: dict[str, str]): if field not in self.skip_properties ] - kwargs = {} + kwargs: dict[str, Any] = {} for field, is_reference, required in ooi_fields: if is_reference and required: try: diff --git a/rocky/tools/add_ooi_information.py b/rocky/tools/add_ooi_information.py index 10d51ba24b5..e86470bc414 100644 --- a/rocky/tools/add_ooi_information.py +++ b/rocky/tools/add_ooi_information.py @@ -57,7 +57,7 @@ def iana_service_table(search_query: str) -> list[_Service]: return services -def service_info(value) -> tuple[str, str]: +def service_info(value: str) -> tuple[str, str]: """Provides information about IP Services such as common assigned ports for certain protocols and descriptions""" services = iana_service_table(value) source = "https://www.iana.org/assignments/service-names-port-numbers/service-names-port-numbers.xhtml" diff --git a/rocky/tools/admin.py b/rocky/tools/admin.py index 1d064e44a48..112d3201826 100644 --- a/rocky/tools/admin.py +++ b/rocky/tools/admin.py @@ -5,7 +5,7 @@ from django.contrib import admin, messages from django.db.models import JSONField from django.forms import widgets -from django.http import HttpResponseRedirect +from django.http import HttpRequest, HttpResponseRedirect from rocky.exceptions import RockyError from tools.models import Indemnification, OOIInformation, Organization, OrganizationMember, OrganizationTag @@ -34,7 +34,9 @@ class OOIInformationAdmin(admin.ModelAdmin): formfield_overrides = {JSONField: {"widget": JSONInfoWidget}} # if pk is not readonly, it will create a new record upon editing - def get_readonly_fields(self, request, obj=None): + def get_readonly_fields( + self, request: HttpRequest, obj: OOIInformation | None = None + ) -> list[str] | tuple[str, ...]: if obj is not None: # editing an existing object if not obj.value: return self.readonly_fields + ("id", "consult_api") diff --git a/rocky/tools/forms/base.py b/rocky/tools/forms/base.py index f6e7e3061c5..37b6acde7e6 100644 --- a/rocky/tools/forms/base.py +++ b/rocky/tools/forms/base.py @@ -95,15 +95,15 @@ class CheckboxGroup(forms.CheckboxSelectMultiple): required_options: list[str] wrap_label = True - def __init__(self, required_options: list[str] | None = None, *args, **kwargs) -> None: + def __init__(self, required_options: list[str] | None = None, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.required_options = required_options or [] - def get_context(self, name, value, attrs) -> dict[str, Any]: + def get_context(self, name: str, value: Any, attrs: dict[str, Any] | None) -> dict[str, Any]: context = super().get_context(name, value, attrs) return context - def create_option(self, *arg, **kwargs) -> dict[str, Any]: + def create_option(self, *arg: Any, **kwargs: Any) -> dict[str, Any]: option = super().create_option(*arg, **kwargs) option["wrap_label"] = self.wrap_label option["attrs"]["checked"] = self.is_required_option(option["value"]) diff --git a/rocky/tools/forms/finding_type.py b/rocky/tools/forms/finding_type.py index e6fb39e661d..c56e8f085c5 100644 --- a/rocky/tools/forms/finding_type.py +++ b/rocky/tools/forms/finding_type.py @@ -1,12 +1,13 @@ from datetime import datetime, timezone +from typing import Any from django import forms from django.core.exceptions import ValidationError from django.utils.translation import gettext_lazy as _ -from octopoes.connector import ObjectNotFoundException from octopoes.connector.octopoes import OctopoesAPIConnector from octopoes.models import Reference +from octopoes.models.exception import ObjectNotFoundException from tools.forms.base import BaseRockyForm, DataListInput, DateTimeInput from tools.forms.settings import ( FINDING_DATETIME_HELP_TEXT, @@ -114,7 +115,7 @@ class FindingAddForm(BaseRockyForm): help_text=FINDING_DATETIME_HELP_TEXT, ) - def __init__(self, connector: OctopoesAPIConnector, ooi_list: list[tuple[str, str]], *args, **kwargs): + def __init__(self, connector: OctopoesAPIConnector, ooi_list: list[tuple[str, str]], *args: Any, **kwargs: Any): self.octopoes_connector = connector super().__init__(*args, **kwargs) self.set_choices_for_widget("ooi_id", ooi_list) diff --git a/rocky/tools/forms/ooi.py b/rocky/tools/forms/ooi.py index b765f0e234c..81af675dc1a 100644 --- a/rocky/tools/forms/ooi.py +++ b/rocky/tools/forms/ooi.py @@ -17,7 +17,7 @@ class OOIReportSettingsForm(ObservedAtForm): class OoiTreeSettingsForm(OOIReportSettingsForm): ooi_type = forms.MultipleChoiceField(label=_("Filter types"), widget=forms.CheckboxSelectMultiple(), required=False) - def __init__(self, ooi_types: list[str], *args, **kwargs): + def __init__(self, ooi_types: list[str], *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.set_ooi_types(ooi_types) @@ -42,7 +42,9 @@ class SelectOOIForm(BaseRockyForm): ), ) - def __init__(self, oois: list[OOI], organization_code: str, mandatory_fields: list | None = None, *args, **kwargs): + def __init__( + self, oois: list[OOI], organization_code: str, mandatory_fields: list | None = None, *args: Any, **kwargs: Any + ): super().__init__(*args, **kwargs) self.fields["ooi"].widget.attrs["organization_code"] = organization_code if mandatory_fields: diff --git a/rocky/tools/forms/ooi_form.py b/rocky/tools/forms/ooi_form.py index 82349ad9b4a..2e1c3232f49 100644 --- a/rocky/tools/forms/ooi_form.py +++ b/rocky/tools/forms/ooi_form.py @@ -2,7 +2,7 @@ from enum import Enum from inspect import isclass from ipaddress import IPv4Address, IPv6Address -from typing import Any, Literal, Union, get_args, get_origin +from typing import Any, Literal, TypedDict, Union, get_args, get_origin from django import forms from django.utils.translation import gettext_lazy as _ @@ -19,7 +19,7 @@ class OOIForm(BaseRockyForm): - def __init__(self, ooi_class: type[OOI], connector: OctopoesAPIConnector, *args, **kwargs): + def __init__(self, ooi_class: type[OOI], connector: OctopoesAPIConnector, *args: Any, **kwargs: Any): self.user_id = kwargs.pop("user_id", None) super().__init__(*args, **kwargs) self.ooi_class = ooi_class @@ -38,7 +38,7 @@ def get_fields(self) -> dict[str, forms.fields.Field]: return self.generate_form_fields() def generate_form_fields(self, hidden_ooi_fields: dict[str, str] | None = None) -> dict[str, forms.fields.Field]: - fields = {} + fields: dict[str, forms.fields.Field] = {} for name, field in self.ooi_class.model_fields.items(): annotation = field.annotation default_attrs = default_field_options(name, field) @@ -161,7 +161,12 @@ def generate_url_field(field: FieldInfo) -> forms.fields.Field: return field -def default_field_options(name: str, field_info: FieldInfo) -> dict[str, str | bool]: +class DefaultFieldOptions(TypedDict): + label: str + required: bool + + +def default_field_options(name: str, field_info: FieldInfo) -> DefaultFieldOptions: return {"label": name, "required": field_info.is_required()} diff --git a/rocky/tools/forms/settings.py b/rocky/tools/forms/settings.py index 7798c115c5e..88af5cf2f38 100644 --- a/rocky/tools/forms/settings.py +++ b/rocky/tools/forms/settings.py @@ -1,11 +1,10 @@ -from typing import Any - +from django.utils.functional import Promise from django.utils.safestring import mark_safe from django.utils.translation import gettext_lazy as _ from tools.enums import SCAN_LEVEL -Choice = tuple[Any, str] +Choice = tuple[str, Promise] Choices = list[Choice] ChoicesGroup = tuple[str, Choices] ChoicesGroups = list[ChoicesGroup] diff --git a/rocky/tools/management/commands/export_migrations.py b/rocky/tools/management/commands/export_migrations.py index c16124135ff..a4073f4c6c3 100644 --- a/rocky/tools/management/commands/export_migrations.py +++ b/rocky/tools/management/commands/export_migrations.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Any import structlog from django.core.management import BaseCommand, CommandParser @@ -17,7 +18,7 @@ def add_arguments(self, parser: CommandParser) -> None: parser.add_argument("from_id", action="store", type=int, help="Migration id to start from") parser.add_argument("--output-folder", action="store", default="export_migrations", help="Output folder") - def handle(self, **options) -> None: + def handle(self, **options: Any) -> None: # Get the database we're operating from connection = connections[DEFAULT_DB_ALIAS] diff --git a/rocky/tools/management/commands/generate_report.py b/rocky/tools/management/commands/generate_report.py index 961a8013e18..0791b5e4d32 100644 --- a/rocky/tools/management/commands/generate_report.py +++ b/rocky/tools/management/commands/generate_report.py @@ -73,7 +73,9 @@ def handle(self, *args, **options): self.stdout.buffer.write(report.read()) @staticmethod - def get_findings_metadata(organization, valid_time, severities) -> list[dict[str, Any]]: + def get_findings_metadata( + organization: Organization, valid_time: datetime, severities: list[RiskLevelSeverity] + ) -> list[dict[str, Any]]: findings = FindingList( OctopoesAPIConnector( settings.OCTOPOES_API, organization.code, timeout=settings.ROCKY_OUTGOING_REQUEST_TIMEOUT @@ -85,7 +87,7 @@ def get_findings_metadata(organization, valid_time, severities) -> list[dict[str return generate_findings_metadata(findings, severities) @staticmethod - def get_organization(**options) -> Organization | None: + def get_organization(**options: str) -> Organization | None: if options["code"] and options["id"]: return None diff --git a/rocky/tools/management/commands/setup_test_users.py b/rocky/tools/management/commands/setup_test_users.py index 696914757f2..08ddf2aa6a3 100644 --- a/rocky/tools/management/commands/setup_test_users.py +++ b/rocky/tools/management/commands/setup_test_users.py @@ -19,7 +19,7 @@ def handle(self, **options): add_test_user("e2e-client", password, GROUP_CLIENT) -def add_superuser(email: str, password: str): +def add_superuser(email: str, password: str) -> None: user_kwargs: dict[str, str | bool] = { "email": email, "password": password, @@ -31,13 +31,13 @@ def add_superuser(email: str, password: str): add_user(user_kwargs) -def add_test_user(email: str, password: str, group_name: str | None = None): +def add_test_user(email: str, password: str, group_name: str | None = None) -> None: user_kwargs: dict[str, str | bool] = {"email": email, "password": password, "full_name": "End-to-end user"} add_user(user_kwargs, group_name) -def add_user(user_kwargs: dict[str, str | bool], group_name: str | None = None): +def add_user(user_kwargs: dict[str, str | bool], group_name: str | None = None) -> None: """ Creates a test user with the given user_kwargs. User is optionally added to group group_name. diff --git a/rocky/tools/models.py b/rocky/tools/models.py index 6e3f873cbea..16ab716bad6 100644 --- a/rocky/tools/models.py +++ b/rocky/tools/models.py @@ -68,6 +68,7 @@ def css_class(self): class Organization(models.Model): + id: int name = models.CharField(max_length=126, unique=True, help_text=_("The name of the organisation")) code = LowerCaseSlugField( max_length=ORGANIZATION_CODE_LENGTH, @@ -82,7 +83,7 @@ class Organization(models.Model): EVENT_CODES = {"created": 900201, "updated": 900202, "deleted": 900203} - def __str__(self): + def __str__(self) -> str: return str(self.name) class Meta: @@ -190,7 +191,7 @@ def has_clearance_level(self, level: int) -> bool: class Meta: unique_together = ["user", "organization"] - def __str__(self): + def __str__(self) -> str: return str(self.user) @@ -240,5 +241,5 @@ def get_internet_description(self): self.data[key] = value self.save() - def __str__(self): + def __str__(self) -> str: return self.id diff --git a/rocky/tools/ooi_helpers.py b/rocky/tools/ooi_helpers.py index 4649ca94ea0..28c05ecac1a 100644 --- a/rocky/tools/ooi_helpers.py +++ b/rocky/tools/ooi_helpers.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from datetime import datetime from enum import Enum from typing import Any @@ -47,7 +48,7 @@ def format_display(data: dict, ignore: list | None = None) -> dict[str, str]: return {format_attr_name(k): format_value(v) for k, v in data.items() if k not in ignore} -def get_knowledge_base_data_for_ooi_store(ooi_store) -> dict[str, dict]: +def get_knowledge_base_data_for_ooi_store(ooi_store: dict) -> dict[str, dict]: knowledge_base = {} for ooi in ooi_store.values(): @@ -126,9 +127,9 @@ def create_object_tree_item_from_ref( reference_node: ReferenceNode, ooi_store: dict[str, OOI], knowledge_base: dict[str, dict] | None = None, - depth=0, - position=1, - location="loc", + depth: int = 0, + position: int = 1, + location: str = "loc", ) -> dict: depth = sum([depth, 1]) location = location + "-" + str(position) @@ -177,7 +178,7 @@ def get_ooi_types_from_tree(ooi, include_self=True): return sorted(types) -def filter_ooi_tree(ooi_node: dict, show_types=[], hide_types=[]) -> dict: +def filter_ooi_tree(ooi_node: dict, show_types: Sequence = [], hide_types: Sequence = []) -> dict: if not show_types and not hide_types: return ooi_node diff --git a/rocky/tools/templatetags/ooi_extra.py b/rocky/tools/templatetags/ooi_extra.py index 5283eb47e51..7a6724540dd 100644 --- a/rocky/tools/templatetags/ooi_extra.py +++ b/rocky/tools/templatetags/ooi_extra.py @@ -12,7 +12,7 @@ @register.filter -def get_encoded_dict(data_dict: dict): +def get_encoded_dict(data_dict: dict) -> str: return parse.urlencode(data_dict) @@ -37,27 +37,27 @@ def get_scan_levels() -> list[str]: @register.filter -def ooi_types_to_strings(ooi_types: set[type[OOI]]): +def ooi_types_to_strings(ooi_types: set[type[OOI]]) -> list["str"]: return [ooi_type.get_ooi_type() for ooi_type in ooi_types] @register.filter() -def get_type(x: Any): +def get_type(x: Any) -> Any: return type(x) @register.simple_tag() -def ooi_url(routename: str, ooi_id: str, organization_code: str, **kwargs) -> str: +def ooi_url(routename: str, ooi_id: str, organization_code: str, **kwargs: str) -> str: return get_ooi_url(routename, ooi_id, organization_code, **kwargs) @register.filter() -def is_finding(ooi: OOI): +def is_finding(ooi: OOI) -> bool: return isinstance(ooi, Finding) @register.filter() -def is_finding_type(ooi: OOI): +def is_finding_type(ooi: OOI) -> bool: return isinstance(ooi, FindingType) @@ -79,7 +79,7 @@ def index(indexable, i): @register.filter -def pretty_json(obj: dict): +def pretty_json(obj: dict) -> str: return json.dumps(obj, default=str, indent=4) diff --git a/rocky/tools/view_helpers.py b/rocky/tools/view_helpers.py index 263512ca54c..8bc300818dc 100644 --- a/rocky/tools/view_helpers.py +++ b/rocky/tools/view_helpers.py @@ -1,11 +1,12 @@ import uuid from datetime import date, datetime, timezone -from typing import TypedDict +from typing import Any, TypedDict from urllib.parse import urlencode, urlparse, urlunparse from django.http import HttpRequest from django.http.response import HttpResponseRedirectBase from django.urls.base import reverse, reverse_lazy +from django.utils.functional import Promise from django.utils.translation import gettext_lazy as _ from octopoes.models.types import OOI_TYPES @@ -17,7 +18,7 @@ def convert_date_to_datetime(d: date) -> datetime: return datetime.combine(d, datetime.max.time(), tzinfo=timezone.utc) -def get_mandatory_fields(request, params: list[str] | None = None): +def get_mandatory_fields(request: HttpRequest, params: list[str] | None = None) -> list: mandatory_fields = [] if not params: @@ -37,7 +38,7 @@ def generate_job_id(): return str(uuid.uuid4()) -def url_with_querystring(path, doseq=False, **kwargs) -> str: +def url_with_querystring(path: str, doseq: bool = False, /, **kwargs: Any) -> str: parsed_route = urlparse(path) return str( @@ -54,7 +55,7 @@ def url_with_querystring(path, doseq=False, **kwargs) -> str: ) -def get_ooi_url(routename: str, ooi_id: str, organization_code: str, **kwargs) -> str: +def get_ooi_url(routename: str, ooi_id: str, organization_code: str, **kwargs: Any) -> str: if ooi_id: kwargs["ooi_id"] = ooi_id @@ -75,7 +76,7 @@ def existing_ooi_type(ooi_type: str) -> bool: class Breadcrumb(TypedDict): - text: str + text: str | Promise url: str @@ -129,35 +130,31 @@ class OrganizationBreadcrumbsMixin(BreadcrumbsMixin): class OrganizationDetailBreadcrumbsMixin(BreadcrumbsMixin): organization: Organization - def build_breadcrumbs(self): - breadcrumbs = [ + def build_breadcrumbs(self) -> list[Breadcrumb]: + return [ { "url": reverse("organization_settings", kwargs={"organization_code": self.organization.code}), "text": _("Settings"), } ] - return breadcrumbs - class OrganizationMemberBreadcrumbsMixin(BreadcrumbsMixin): organization: Organization - def build_breadcrumbs(self): - breadcrumbs = [ + def build_breadcrumbs(self) -> list[Breadcrumb]: + return [ { "url": reverse("organization_member_list", kwargs={"organization_code": self.organization.code}), "text": _("Members"), } ] - return breadcrumbs - class ObjectsBreadcrumbsMixin(BreadcrumbsMixin): organization: Organization - def build_breadcrumbs(self): + def build_breadcrumbs(self) -> list[Breadcrumb]: return [ { "url": reverse_lazy("ooi_list", kwargs={"organization_code": self.organization.code}),