From 5d0e5288c0ed31f80299c76a30b36a19a3216e6b Mon Sep 17 00:00:00 2001 From: Robin5605 Date: Thu, 22 Feb 2024 15:22:57 -0600 Subject: [PATCH] Fix existing tests so they work with new models --- tests/conftest.py | 23 +-- .../test_data/first_safe_second_malicious.py | 79 ++++----- tests/test_data/sample.py | 156 ++++++++++-------- tests/test_job.py | 12 +- tests/test_package.py | 20 ++- tests/test_report.py | 20 ++- tests/test_stats.py | 5 +- 7 files changed, 171 insertions(+), 144 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 3aa83972..aa70996f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,4 @@ import logging -from copy import deepcopy from datetime import datetime, timedelta from typing import Generator from unittest.mock import MagicMock @@ -7,7 +6,7 @@ import pytest import requests from letsbuilda.pypi import PyPIServices -from letsbuilda.pypi.models import Package +from letsbuilda.pypi.models import Package as PyPIPackage from letsbuilda.pypi.models.models_package import Distribution, Release from msgraph.core import GraphClient from sqlalchemy import Engine, create_engine @@ -15,7 +14,7 @@ from mainframe.constants import mainframe_settings from mainframe.json_web_token import AuthenticationData -from mainframe.models.orm import Base, Scan +from mainframe.models.orm import Base, Package from mainframe.rules import Rules from .test_data import data @@ -34,19 +33,15 @@ def engine() -> Engine: return create_engine(mainframe_settings.db_url) -@pytest.fixture(params=data, scope="session") -def test_data(request: pytest.FixtureRequest) -> list[Scan]: - return request.param - - -@pytest.fixture(scope="session", autouse=True) -def initial_populate_db(engine: Engine, sm: sessionmaker[Session], test_data: list[Scan]): +@pytest.fixture(params=data, scope="session", autouse=True) +def initial_populate_db(request: pytest.FixtureRequest, engine: Engine, sm: sessionmaker[Session]): Base.metadata.drop_all(engine) Base.metadata.create_all(engine) + packages: list[Package] = request.param + session = sm() - for scan in test_data: - session.add(deepcopy(scan)) + session.add_all(packages) session.commit() @@ -86,8 +81,8 @@ def pypi_client() -> PyPIServices: session = requests.Session() pypi_client = PyPIServices(session) - def side_effect(name: str, version: str) -> Package: - return Package( + def side_effect(name: str, version: str) -> PyPIPackage: + return PyPIPackage( title=name, releases=[Release(version=version, distributions=[Distribution(filename="test", url="test")])], ) diff --git a/tests/test_data/first_safe_second_malicious.py b/tests/test_data/first_safe_second_malicious.py index 6ddbb893..4431edb8 100644 --- a/tests/test_data/first_safe_second_malicious.py +++ b/tests/test_data/first_safe_second_malicious.py @@ -1,44 +1,49 @@ from datetime import datetime -from mainframe.models.orm import Rule, Scan, Status +from mainframe.models.orm import Package, Rule, Scan, Status data = [ - Scan( + Package( name="a", - version="0.1.0", - status=Status.FINISHED, - score=0, - inspector_url=None, - queued_at=datetime.fromisoformat("2023-05-12T18:00:00+00:00"), - queued_by="remmy", - pending_at=datetime.fromisoformat("2023-05-12T18:30:00+00:00"), - pending_by="remmy", - finished_at=datetime.fromisoformat("2023-05-12T19:00:00+00:00"), - finished_by="remmy", - reported_at=None, - reported_by=None, - rules=[], - download_urls=[], - commit_hash=None, - fail_reason=None, - ), - Scan( - name="a", - version="0.2.0", - status=Status.FINISHED, - score=10, - inspector_url="some inspector URL", - queued_at=datetime.fromisoformat("2023-05-12T19:00:00+00:00"), - queued_by="remmy", - pending_at=datetime.fromisoformat("2023-05-12T19:30:00+00:00"), - pending_by="remmy", - finished_at=datetime.fromisoformat("2023-05-12T20:00:00+00:00"), - finished_by="remmy", - reported_at=None, - reported_by=None, - download_urls=[], - rules=[Rule(name="test rule 1")], - commit_hash="test commit hash", - fail_reason=None, + scans=[ + Scan( + name="a", + version="0.1.0", + status=Status.FINISHED, + score=0, + inspector_url=None, + queued_at=datetime.fromisoformat("2023-05-12T18:00:00+00:00"), + queued_by="remmy", + pending_at=datetime.fromisoformat("2023-05-12T18:30:00+00:00"), + pending_by="remmy", + finished_at=datetime.fromisoformat("2023-05-12T19:00:00+00:00"), + finished_by="remmy", + reported_at=None, + reported_by=None, + rules=[], + download_urls=[], + commit_hash=None, + fail_reason=None, + ), + Scan( + name="a", + version="0.2.0", + status=Status.FINISHED, + score=10, + inspector_url="some inspector URL", + queued_at=datetime.fromisoformat("2023-05-12T19:00:00+00:00"), + queued_by="remmy", + pending_at=datetime.fromisoformat("2023-05-12T19:30:00+00:00"), + pending_by="remmy", + finished_at=datetime.fromisoformat("2023-05-12T20:00:00+00:00"), + finished_by="remmy", + reported_at=None, + reported_by=None, + download_urls=[], + rules=[Rule(name="test rule 1")], + commit_hash="test commit hash", + fail_reason=None, + ), + ], ), ] diff --git a/tests/test_data/sample.py b/tests/test_data/sample.py index 8f6ceaed..6f344f77 100644 --- a/tests/test_data/sample.py +++ b/tests/test_data/sample.py @@ -1,82 +1,92 @@ from datetime import datetime -from mainframe.models.orm import Scan, Status +from mainframe.models.orm import Package, Scan, Status data = [ - Scan( + Package( name="a", - version="0.1.0", - status=Status.FINISHED, - score=0, - inspector_url=None, - queued_at=datetime.fromisoformat("2023-05-12T18:00:00+00:00"), - queued_by="remmy", - pending_at=datetime.fromisoformat("2023-05-12T18:30:00+00:00"), - pending_by="remmy", - finished_at=datetime.fromisoformat("2023-05-12T19:00:00+00:00"), - finished_by="remmy", - reported_at=None, - reported_by=None, - rules=[], - download_urls=[], - commit_hash=None, - fail_reason=None, + scans=[ + Scan( + name="a", + version="0.1.0", + status=Status.FINISHED, + score=0, + inspector_url=None, + queued_at=datetime.fromisoformat("2023-05-12T18:00:00+00:00"), + queued_by="remmy", + pending_at=datetime.fromisoformat("2023-05-12T18:30:00+00:00"), + pending_by="remmy", + finished_at=datetime.fromisoformat("2023-05-12T19:00:00+00:00"), + finished_by="remmy", + reported_at=None, + reported_by=None, + rules=[], + download_urls=[], + commit_hash=None, + fail_reason=None, + ), + Scan( + name="a", + version="0.2.0", + status=Status.QUEUED, + score=None, + inspector_url=None, + queued_at=datetime.fromisoformat("2023-05-12T17:00:00+00:00"), + queued_by="remmy", + pending_at=None, + pending_by=None, + finished_at=None, + finished_by=None, + reported_at=None, + reported_by=None, + rules=[], + download_urls=[], + commit_hash=None, + fail_reason=None, + ), + ], ), - Scan( + Package( name="b", - version="0.1.0", - status=Status.FINISHED, - score=0, - inspector_url=None, - queued_at=datetime.fromisoformat("2023-05-12T15:00:00+00:00"), - queued_by="remmy", - pending_at=datetime.fromisoformat("2023-05-12T16:00:00+00:00"), - pending_by="remmy", - finished_at=datetime.fromisoformat("2023-05-12T16:30:00+00:00"), - finished_by="remmy", - reported_at=None, - reported_by=None, - rules=[], - download_urls=[], - commit_hash="test commit hash", - fail_reason=None, - ), - Scan( - name="a", - version="0.2.0", - status=Status.QUEUED, - score=None, - inspector_url=None, - queued_at=datetime.fromisoformat("2023-05-12T17:00:00+00:00"), - queued_by="remmy", - pending_at=None, - pending_by=None, - finished_at=None, - finished_by=None, - reported_at=None, - reported_by=None, - rules=[], - download_urls=[], - commit_hash=None, - fail_reason=None, - ), - Scan( - name="b", - version="0.2.0", - status=Status.PENDING, - score=None, - inspector_url=None, - queued_at=datetime.fromisoformat("2023-05-12T19:00:00+00:00"), - queued_by="remmy", - pending_at=datetime.fromisoformat("2023-05-12T20:00:00+00:00"), - pending_by="remmy", - finished_at=None, - finished_by=None, - reported_at=None, - reported_by=None, - rules=[], - download_urls=[], - commit_hash=None, - fail_reason=None, + scans=[ + Scan( + name="b", + version="0.1.0", + status=Status.FINISHED, + score=0, + inspector_url=None, + queued_at=datetime.fromisoformat("2023-05-12T15:00:00+00:00"), + queued_by="remmy", + pending_at=datetime.fromisoformat("2023-05-12T16:00:00+00:00"), + pending_by="remmy", + finished_at=datetime.fromisoformat("2023-05-12T16:30:00+00:00"), + finished_by="remmy", + reported_at=None, + reported_by=None, + rules=[], + download_urls=[], + commit_hash="test commit hash", + fail_reason=None, + ), + Scan( + name="b", + version="0.2.0", + status=Status.PENDING, + score=None, + inspector_url=None, + queued_at=datetime.fromisoformat("2023-05-12T19:00:00+00:00"), + queued_by="remmy", + pending_at=datetime.fromisoformat("2023-05-12T20:00:00+00:00"), + pending_by="remmy", + finished_at=None, + finished_by=None, + reported_at=None, + reported_by=None, + rules=[], + download_urls=[], + commit_hash=None, + fail_reason=None, + ), + ], ), ] diff --git a/tests/test_job.py b/tests/test_job.py index dffb596c..e9f62e23 100644 --- a/tests/test_job.py +++ b/tests/test_job.py @@ -13,7 +13,8 @@ def oldest_queued_package(db_session: Session): return db_session.scalar(select(func.min(Scan.queued_at)).where(Scan.status == Status.QUEUED)) -def test_min_queue_date_of_queued_rows(test_data: list[Scan], db_session: Session): +def test_min_queue_date_of_queued_rows(db_session: Session): + test_data = db_session.scalars(select(Scan)).all() queued_at_times = [ scan.queued_at for scan in test_data if scan.status is Status.QUEUED and scan.queued_at is not None ] @@ -28,12 +29,14 @@ def fetch_queue_time(name: str, version: str, db_session: Session) -> dt.datetim return db_session.scalar(select(Scan.queued_at).where(Scan.name == name).where(Scan.version == version)) -def test_fetch_queue_time(test_data: list[Scan], db_session: Session): +def test_fetch_queue_time(db_session: Session): + test_data = db_session.scalars(select(Scan)).all() for scan in test_data: assert scan.queued_at == fetch_queue_time(scan.name, scan.version, db_session) -def test_job(test_data: list[Scan], db_session: Session, auth: AuthenticationData, rules_state: Rules): +def test_job(db_session: Session, auth: AuthenticationData, rules_state: Rules): + test_data = db_session.scalars(select(Scan)).all() job = get_jobs(db_session, auth, rules_state, batch=1) if job: job = job[0] @@ -47,7 +50,8 @@ def test_job(test_data: list[Scan], db_session: Session, auth: AuthenticationDat assert all(scan.status != Status.QUEUED for scan in test_data) -def test_batch_job(test_data: list[Scan], db_session: Session, auth: AuthenticationData, rules_state: Rules): +def test_batch_job(db_session: Session, auth: AuthenticationData, rules_state: Rules): + test_data = db_session.scalars(select(Scan)).all() jobs = {(job.name, job.version) for job in get_jobs(db_session, auth, rules_state, batch=len(test_data))} # check if each returned job should have actually been returned diff --git a/tests/test_package.py b/tests/test_package.py index cac1cef6..5cfee239 100644 --- a/tests/test_package.py +++ b/tests/test_package.py @@ -1,3 +1,4 @@ +import itertools from typing import Optional import pytest @@ -38,9 +39,9 @@ def test_package_lookup( since: Optional[int], name: Optional[str], version: Optional[str], - test_data: list[Scan], db_session: Session, ): + test_data = db_session.scalars(select(Scan)).all() exp: set[tuple[str, str]] = set() for scan in test_data: if since is not None and (scan.finished_at is None or since > int(scan.finished_at.timestamp())): @@ -51,7 +52,8 @@ def test_package_lookup( continue exp.add((scan.name, scan.version)) - scans = lookup_package_info(db_session, since, name, version) + packages = lookup_package_info(db_session, since, name, version) + scans = itertools.chain.from_iterable(package.scans for package in packages) assert exp == {(scan.name, scan.version) for scan in scans} @@ -77,8 +79,9 @@ def test_package_lookup_rejects_invalid_combinations( assert e.value.status_code == 400 -def test_handle_success(db_session: Session, test_data: list[Scan], auth: AuthenticationData, rules_state: Rules): +def test_handle_success(db_session: Session, auth: AuthenticationData, rules_state: Rules): job = get_jobs(db_session, auth, rules_state, batch=1) + test_data = db_session.scalars(select(Scan)).all() if job: job = job[0] @@ -105,8 +108,9 @@ def test_handle_success(db_session: Session, test_data: list[Scan], auth: Authen assert all(scan.status != Status.QUEUED for scan in test_data) -def test_handle_fail(db_session: Session, test_data: list[Scan], auth: AuthenticationData, rules_state: Rules): +def test_handle_fail(db_session: Session, auth: AuthenticationData, rules_state: Rules): job = get_jobs(db_session, auth, rules_state, batch=1) + test_data = db_session.scalars(select(Scan)).all() if job: job = job[0] @@ -129,7 +133,8 @@ def test_handle_fail(db_session: Session, test_data: list[Scan], auth: Authentic assert all(scan.status != Status.QUEUED for scan in test_data) -def test_batch_queue(db_session: Session, test_data: list[Scan], pypi_client: PyPIServices, auth: AuthenticationData): +def test_batch_queue(db_session: Session, pypi_client: PyPIServices, auth: AuthenticationData): + test_data = db_session.scalars(select(Scan)).all() packages_to_add = [PackageSpecifier(name=scan.name, version=scan.version) for scan in test_data] packages_to_add.append(PackageSpecifier(name="c", version="1.0.0")) batch_queue_package(packages_to_add, db_session, auth, pypi_client) @@ -210,10 +215,9 @@ def test_submit_nonexistent_package(db_session: Session, auth: AuthenticationDat assert e.value.status_code == 404 -def test_submit_duplicate_package( - db_session: Session, test_data: list[Scan], auth: AuthenticationData, rules_state: Rules -): +def test_submit_duplicate_package(db_session: Session, auth: AuthenticationData, rules_state: Rules): job = get_jobs(db_session, auth, rules_state, batch=1) + test_data = db_session.scalars(select(Scan)).all() if job: job = job[0] diff --git a/tests/test_report.py b/tests/test_report.py index baac81df..153188c1 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -13,7 +13,7 @@ from mainframe.endpoints.report import report_package from mainframe.json_web_token import AuthenticationData -from mainframe.models.orm import DownloadURL, Rule, Scan, Status +from mainframe.models.orm import DownloadURL, Package, Rule, Scan, Status from mainframe.models.schemas import ReportPackageBody @@ -38,7 +38,9 @@ def test_report(db_session: Session, auth: AuthenticationData, pypi_client: PyPI commit_hash="test commit hash", ) - db_session.add(result) + package = Package(name="c", people=[], scans=[result]) + + db_session.add(package) db_session.commit() body = ReportPackageBody( @@ -149,7 +151,9 @@ def test_report_multi_versions( commit_hash="test commit hash", ) - db_session.add_all((version1, version2)) + package = Package(name="c", people=[], scans=[version1, version2]) + + db_session.add(package) db_session.commit() body = ReportPackageBody( @@ -187,7 +191,9 @@ def test_report_invalid_version( fail_reason=None, commit_hash="test commit hash", ) - db_session.add(scan) + + package = Package(name="c", people=[], scans=[scan]) + db_session.add(package) db_session.commit() body = ReportPackageBody( @@ -225,7 +231,8 @@ def test_report_missing_inspector_url( fail_reason=None, commit_hash="test commit hash", ) - db_session.add(scan) + package = Package(name="c", people=[], scans=[scan]) + db_session.add(package) db_session.commit() body = ReportPackageBody( @@ -263,7 +270,8 @@ def test_report_missing_additional_information( fail_reason=None, commit_hash="test commit hash", ) - db_session.add(scan) + package = Package(name="c", people=[], scans=[scan]) + db_session.add(package) db_session.commit() body = ReportPackageBody( diff --git a/tests/test_stats.py b/tests/test_stats.py index bd3a8308..4d2b4232 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Session from mainframe.endpoints.stats import get_stats -from mainframe.models.orm import DownloadURL, Rule, Scan, Status +from mainframe.models.orm import DownloadURL, Package, Rule, Scan, Status def test_stats(db_session: Session): @@ -27,8 +27,9 @@ def test_stats(db_session: Session): fail_reason=None, commit_hash="test commit hash", ) + package = Package(name="c", people=[], scans=[scan]) - db_session.add(scan) + db_session.add_all((package, scan)) db_session.commit() stats = get_stats(db_session)