Skip to content

Commit

Permalink
Fix typing in more places and configure mypy to follow imports (#3932)
Browse files Browse the repository at this point in the history
Co-authored-by: stephanie0x00 <[email protected]>
Co-authored-by: ammar92 <[email protected]>
Co-authored-by: Jan Klopper <[email protected]>
  • Loading branch information
4 people authored Dec 10, 2024
1 parent 39d0dbc commit ec9d80e
Show file tree
Hide file tree
Showing 144 changed files with 481 additions and 370 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ repos:
- types-python-dateutil
- types-requests
- types-croniter
- boto3-stubs[s3]
exclude: |
(?x)(
^boefjes/tools |
Expand Down
2 changes: 1 addition & 1 deletion boefjes/boefjes/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
@click.command()
@click.argument("worker_type", type=click.Choice([q.value for q in WorkerManager.Queue]))
@click.option("--log-level", type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR"]), help="Log level", default="INFO")
def cli(worker_type: str, log_level: str):
def cli(worker_type: str, log_level: str) -> None:
logger.setLevel(log_level)
logger.info("Starting runtime for %s", worker_type)

Expand Down
8 changes: 4 additions & 4 deletions boefjes/boefjes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, config: Config):
self.server = Server(config=config)
self.config = config

def stop(self):
def stop(self) -> None:
self.terminate()

def run(self, *args, **kwargs):
Expand Down Expand Up @@ -88,7 +88,7 @@ def boefje_input(
task_id: UUID,
scheduler_client: SchedulerAPIClient = Depends(get_scheduler_client),
plugin_service: PluginService = Depends(get_plugin_service),
):
) -> BoefjeInput:
task = get_task(task_id, scheduler_client)

if task.status is not TaskStatus.RUNNING:
Expand All @@ -108,7 +108,7 @@ def boefje_output(
scheduler_client: SchedulerAPIClient = Depends(get_scheduler_client),
bytes_client: BytesAPIClient = Depends(get_bytes_client),
plugin_service: PluginService = Depends(get_plugin_service),
):
) -> Response:
task = get_task(task_id, scheduler_client)

if task.status is not TaskStatus.RUNNING:
Expand All @@ -127,7 +127,7 @@ def boefje_output(
for file in boefje_output.files:
raw = base64.b64decode(file.content)
# when supported, also save file.name to Bytes
bytes_client.save_raw(task_id, raw, mime_types.union(file.tags))
bytes_client.save_raw(task_id, raw, mime_types.union(file.tags) if file.tags else mime_types)

if boefje_output.status == StatusEnum.COMPLETED:
scheduler_client.patch_task(task_id, TaskStatus.COMPLETED)
Expand Down
6 changes: 3 additions & 3 deletions boefjes/boefjes/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run(self, queue_type: WorkerManager.Queue) -> None:

raise

def _fill_queue(self, task_queue: Queue, queue_type: WorkerManager.Queue):
def _fill_queue(self, task_queue: Queue, queue_type: WorkerManager.Queue) -> None:
if task_queue.qsize() > self.settings.pool_size:
time.sleep(self.settings.worker_heartbeat)
return
Expand Down Expand Up @@ -189,7 +189,7 @@ def _cleanup_pending_worker_task(self, worker: BaseProcess) -> None:
def _worker_args(self) -> tuple:
return self.task_queue, self.item_handler, self.scheduler_client, self.handling_tasks

def exit(self, signum: int | None = None):
def exit(self, signum: int | None = None) -> None:
try:
if signum:
logger.info("Received %s, exiting", signal.Signals(signum).name)
Expand Down Expand Up @@ -238,7 +238,7 @@ def _start_working(
handler: Handler,
scheduler_client: SchedulerClientInterface,
handling_tasks: dict[int, str],
):
) -> None:
logger.info("Started listening for tasks from worker[pid=%s]", os.getpid())

while True:
Expand Down
4 changes: 2 additions & 2 deletions boefjes/boefjes/dependencies/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(self, private_key: str, public_key: str):

def encode(self, contents: str) -> str:
encrypted_contents = self.box.encrypt(contents.encode())
encrypted_contents = base64.b64encode(encrypted_contents)
return encrypted_contents.decode()
base64_encrypted_contents = base64.b64encode(encrypted_contents)
return base64_encrypted_contents.decode()

def decode(self, contents: str) -> str:
encrypted_binary = base64.b64decode(contents)
Expand Down
8 changes: 4 additions & 4 deletions boefjes/boefjes/katalogus/root.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,22 +73,22 @@


@app.exception_handler(NotFound)
def entity_not_found_handler(request: Request, exc: NotFound):
def entity_not_found_handler(request: Request, exc: NotFound) -> JSONResponse:
return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content={"message": exc.message})


@app.exception_handler(NotAllowed)
def not_allowed_handler(request: Request, exc: NotAllowed):
def not_allowed_handler(request: Request, exc: NotAllowed) -> JSONResponse:
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={"message": exc.message})


@app.exception_handler(IntegrityError)
def integrity_error_handler(request: Request, exc: IntegrityError):
def integrity_error_handler(request: Request, exc: IntegrityError) -> JSONResponse:
return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content={"message": exc.message})


@app.exception_handler(StorageError)
def storage_error_handler(request: Request, exc: StorageError):
def storage_error_handler(request: Request, exc: StorageError) -> JSONResponse:
return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"message": exc.message})


Expand Down
10 changes: 7 additions & 3 deletions boefjes/boefjes/katalogus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,23 @@


@router.get("", response_model=dict)
def list_settings(organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service)):
def list_settings(
organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service)
) -> dict[str, str]:
return plugin_service.get_all_settings(organisation_id, plugin_id)


@router.put("")
def upsert_settings(
organisation_id: str, plugin_id: str, values: dict, plugin_service: PluginService = Depends(get_plugin_service)
):
) -> None:
with plugin_service as p:
p.upsert_settings(values, organisation_id, plugin_id)


@router.delete("")
def remove_settings(organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service)):
def remove_settings(
organisation_id: str, plugin_id: str, plugin_service: PluginService = Depends(get_plugin_service)
) -> None:
with plugin_service as p:
p.delete_settings(organisation_id, plugin_id)
2 changes: 1 addition & 1 deletion boefjes/boefjes/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy import engine_from_config, pool

from boefjes.config import settings
from boefjes.sql.db_models import SQL_BASE
from boefjes.sql.db import SQL_BASE

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def upgrade() -> None:
# ### end Alembic commands ###


def upgrade_encrypted_settings(conn: Connection):
def upgrade_encrypted_settings(conn: Connection) -> None:
encrypter = create_encrypter()

with conn.begin():
Expand Down Expand Up @@ -90,7 +90,7 @@ def downgrade() -> None:
# ### end Alembic commands ###


def downgrade_encrypted_settings(conn: Connection):
def downgrade_encrypted_settings(conn: Connection) -> None:
encrypter = create_encrypter()

with conn.begin():
Expand Down
8 changes: 7 additions & 1 deletion boefjes/boefjes/plugins/kat_crt_sh/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,13 @@
)


def request_certs(search_string, search_type="Identity", match="=", deduplicate=True, json_output=True) -> str:
def request_certs(
search_string: str,
search_type: str = "Identity",
match: str = "=",
deduplicate: bool = True,
json_output: bool = True,
) -> str:
"""Queries the public service CRT.sh for certificate information
the searchtype can be specified and defaults to Identity.
the type of sql matching can be specified and defaults to "="
Expand Down
18 changes: 8 additions & 10 deletions boefjes/boefjes/plugins/kat_cve_2023_35078/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from octopoes.models import Reference
from octopoes.models.ooi.findings import CVEFindingType, Finding
from octopoes.models.ooi.software import Software, SoftwareInstance
from packaging import version
from packaging.version import Version, parse

VULNERABLE_RANGES: list[tuple[str, str]] = [("0", "11.8.1.1"), ("11.9.0.0", "11.9.1.1"), ("11.10.0.0", "11.10.0.2")]


def extract_js_version(html_content: str) -> version.Version | bool:
def extract_js_version(html_content: str) -> Version | bool:
telltale = "/mifs/scripts/auth.js?"
telltale_position = html_content.find(telltale)
if telltale_position == -1:
Expand All @@ -20,10 +20,10 @@ def extract_js_version(html_content: str) -> version.Version | bool:
version_string = html_content[telltale_position + len(telltale) : version_end]
if not version_string:
return False
return version.parse(" ".join(strip_vsp_and_build(version_string)))
return parse(" ".join(strip_vsp_and_build(version_string)))


def extract_css_version(html_content: str) -> version.Version | bool:
def extract_css_version(html_content: str) -> Version | bool:
telltale = "/mifs/css/windowsAllAuth.css?"
telltale_position = html_content.find(telltale)
if telltale_position == -1:
Expand All @@ -34,7 +34,7 @@ def extract_css_version(html_content: str) -> version.Version | bool:
version_string = html_content[telltale_position + len(telltale) : version_end]
if not version_string:
return False
return version.parse(" ".join(strip_vsp_and_build(version_string)))
return parse(" ".join(strip_vsp_and_build(version_string)))


def strip_vsp_and_build(url: str) -> Iterable[str]:
Expand All @@ -47,9 +47,7 @@ def strip_vsp_and_build(url: str) -> Iterable[str]:
yield part


def is_vulnerable_version(
vulnerable_ranges: list[tuple[version.Version, version.Version]], detected_version: version.Version
) -> bool:
def is_vulnerable_version(vulnerable_ranges: list[tuple[Version, Version]], detected_version: Version) -> bool:
return any(start <= detected_version < end for start, end in vulnerable_ranges)


Expand All @@ -70,11 +68,11 @@ def run(input_ooi: dict, raw: bytes) -> Iterable[NormalizerOutput]:
yield software_instance
if js_detected_version:
vulnerable = is_vulnerable_version(
[(version.parse(start), version.parse(end)) for start, end in VULNERABLE_RANGES], js_detected_version
[(parse(start), parse(end)) for start, end in VULNERABLE_RANGES], js_detected_version
)
else:
# The CSS version only included the first two parts of the version number so we don't know the patch level
vulnerable = css_detected_version < version.parse("11.8")
vulnerable = css_detected_version < parse("11.8")
if vulnerable:
finding_type = CVEFindingType(id="CVE-2023-35078")
finding = Finding(
Expand Down
6 changes: 2 additions & 4 deletions boefjes/boefjes/plugins/kat_dnssec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import subprocess


def run(boefje_meta: dict):
def run(boefje_meta: dict) -> list[tuple[set, bytes | str]]:
input_ = boefje_meta["arguments"]["input"]
domain = input_["name"]

Expand All @@ -19,6 +19,4 @@ def run(boefje_meta: dict):

output.check_returncode()

results = [({"openkat/dnssec-output"}, output.stdout)]

return results
return [({"openkat/dnssec-output"}, output.stdout)]
10 changes: 5 additions & 5 deletions boefjes/boefjes/plugins/kat_manual/csv/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import ValidationError

from boefjes.job_models import NormalizerDeclaration, NormalizerOutput
from octopoes.models import Reference
from octopoes.models import OOI, Reference
from octopoes.models.ooi.dns.zone import Hostname
from octopoes.models.ooi.network import IPAddressV4, IPAddressV6, Network
from octopoes.models.ooi.web import URL
Expand All @@ -30,7 +30,7 @@ def run(input_ooi: dict, raw: bytes) -> Iterable[NormalizerOutput]:
yield from process_csv(raw, reference_cache)


def process_csv(csv_raw_data, reference_cache) -> Iterable[NormalizerOutput]:
def process_csv(csv_raw_data: bytes, reference_cache: dict) -> Iterable[NormalizerOutput]:
csv_data = io.StringIO(csv_raw_data.decode("UTF-8"))

object_type = get_object_type(csv_data)
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_object_type(csv_data: io.StringIO) -> str:


def get_ooi_from_csv(
ooi_type_name: str, values: dict[str, str], reference_cache
ooi_type_name: str, values: dict[str, str], reference_cache: dict
) -> tuple[OOIType, list[NormalizerDeclaration]]:
skip_properties = ("object_type", "scan_profile", "primary_key")

Expand All @@ -85,7 +85,7 @@ def get_ooi_from_csv(
if field not in skip_properties
]

kwargs = {}
kwargs: dict[str, Reference | str | None] = {}
extra_declarations: list[NormalizerDeclaration] = []

for field, is_reference, required in ooi_fields:
Expand All @@ -109,7 +109,7 @@ def get_ooi_from_csv(
return ooi_type(**kwargs), extra_declarations


def get_or_create_reference(ooi_type_name: str, value: str | None, reference_cache):
def get_or_create_reference(ooi_type_name: str, value: str | None, reference_cache: dict) -> OOI:
ooi_type_name = next(filter(lambda x: x.casefold() == ooi_type_name.casefold(), OOI_TYPES.keys()))

# get from cache
Expand Down
2 changes: 1 addition & 1 deletion boefjes/boefjes/plugins/kat_masscan/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
FILE_PATH = "/tmp/output.json" # noqa: S108


def run_masscan(target_ip) -> bytes:
def run_masscan(target_ip: str) -> bytes:
"""Run Masscan in Docker."""
client = docker.from_env()

Expand Down
6 changes: 2 additions & 4 deletions boefjes/boefjes/plugins/kat_nmap_tcp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
TOP_PORTS_DEFAULT = 250


def run(boefje_meta: dict):
def run(boefje_meta: dict) -> list[tuple[set, bytes | str]]:
top_ports_key = "TOP_PORTS"
if boefje_meta["boefje"]["id"] == "nmap-udp":
top_ports_key = "TOP_PORTS_UDP"
Expand All @@ -22,6 +22,4 @@ def run(boefje_meta: dict):

output.check_returncode()

results = [({"openkat/nmap-output"}, output.stdout.decode())]

return results
return [({"openkat/nmap-output"}, output.stdout.decode())]
8 changes: 6 additions & 2 deletions boefjes/boefjes/plugins/kat_security_txt_downloader/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests
from forcediphttpsadapter.adapters import ForcedIPHTTPSAdapter
from requests import Session
from requests.models import Response

from boefjes.job_models import BoefjeMeta

Expand Down Expand Up @@ -41,7 +42,10 @@ def run(boefje_meta: BoefjeMeta) -> list[tuple[set, bytes | str]]:
elif response.status_code in [301, 302, 307, 308]:
uri = response.headers["Location"]
response = requests.get(uri, stream=True, timeout=30, verify=False) # noqa: S501
ip = response.raw._connection.sock.getpeername()[0]
if response.raw._connection:
ip = response.raw._connection.sock.getpeername()[0]
else:
ip = ""
results[path] = {
"content": response.content.decode(),
"url": response.url,
Expand All @@ -53,7 +57,7 @@ def run(boefje_meta: BoefjeMeta) -> list[tuple[set, bytes | str]]:
return [(set(), json.dumps(results))]


def do_request(hostname: str, session: Session, uri: str, useragent: str):
def do_request(hostname: str, session: Session, uri: str, useragent: str) -> Response:
response = session.get(
uri, headers={"Host": hostname, "User-Agent": useragent}, verify=False, allow_redirects=False
)
Expand Down
8 changes: 5 additions & 3 deletions boefjes/boefjes/plugins/kat_snyk/check_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def check_version(version1: str, version2: str) -> VersionCheck:
return check_version(version1_split[1], version2_split[1])


def check_version_agains_versionlist(my_version: str, all_versions: list[str]):
def check_version_agains_versionlist(my_version: str, all_versions: list[str]) -> tuple[bool, list[str] | None]:
lowerbound = all_versions.pop(0).strip()
upperbound = None

Expand Down Expand Up @@ -164,10 +164,12 @@ def check_version_agains_versionlist(my_version: str, all_versions: list[str]):
return True, all_versions


def check_version_in(version: str, versions: str):
def check_version_in(version: str, versions: str) -> bool:
if not version:
return False
all_versions = versions.split(",") # Example: https://snyk.io/vuln/composer%3Awoocommerce%2Fwoocommerce-blocks
all_versions: list[str] | None = versions.split(
","
) # Example: https://snyk.io/vuln/composer%3Awoocommerce%2Fwoocommerce-blocks
in_range = False
while not in_range and all_versions:
in_range, all_versions = check_version_agains_versionlist(version, all_versions)
Expand Down
Loading

0 comments on commit ec9d80e

Please sign in to comment.