From 72d2c6558cd12622818912a8994498d23f3c33d0 Mon Sep 17 00:00:00 2001 From: Tal Borenstein Date: Mon, 9 Dec 2024 16:55:35 +0200 Subject: [PATCH] fix: tests --- keep/api/core/db.py | 55 +++++++++++++------------- poetry.lock | 16 +++++++- pyproject.toml | 1 + tests/conftest.py | 68 ++++++++++++++++++++----------- tests/fixtures/client.py | 3 +- tests/test_deduplications.py | 77 +++++++++++++++++++++++++++++++++++- tests/test_enrichments.py | 11 ++++++ 7 files changed, 174 insertions(+), 57 deletions(-) diff --git a/keep/api/core/db.py b/keep/api/core/db.py index 2301162b6..4f414bba2 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -4688,35 +4688,34 @@ def set_last_alert( ) -> None: logger.info(f"Set last alert for `{alert.fingerprint}`") with existed_or_new_session(session) as session: - with session.begin_nested() as transaction: - last_alert = get_last_alert_by_fingerprint( - tenant_id, alert.fingerprint, session, for_update=True - ) + last_alert = get_last_alert_by_fingerprint( + tenant_id, alert.fingerprint, session, for_update=True + ) - # To prevent rare, but possible race condition - # For example if older alert failed to process - # and retried after new one - if last_alert and last_alert.timestamp.replace( - tzinfo=tz.UTC - ) < alert.timestamp.replace(tzinfo=tz.UTC): + # To prevent rare, but possible race condition + # For example if older alert failed to process + # and retried after new one + if last_alert and last_alert.timestamp.replace( + tzinfo=tz.UTC + ) < alert.timestamp.replace(tzinfo=tz.UTC): - logger.info( - f"Update last alert for `{alert.fingerprint}`: {last_alert.alert_id} -> {alert.id}" - ) - last_alert.timestamp = alert.timestamp - last_alert.alert_id = alert.id - session.add(last_alert) + logger.info( + f"Update last alert for `{alert.fingerprint}`: {last_alert.alert_id} -> {alert.id}" + ) + last_alert.timestamp = alert.timestamp + last_alert.alert_id = alert.id + session.add(last_alert) - elif not last_alert: - logger.info(f"No last alert for `{alert.fingerprint}`, creating new") - last_alert = LastAlert( - tenant_id=tenant_id, - fingerprint=alert.fingerprint, - timestamp=alert.timestamp, - first_timestamp=alert.timestamp, - alert_id=alert.id, - alert_hash=alert.alert_hash, - ) + elif not last_alert: + logger.info(f"No last alert for `{alert.fingerprint}`, creating new") + last_alert = LastAlert( + tenant_id=tenant_id, + fingerprint=alert.fingerprint, + timestamp=alert.timestamp, + first_timestamp=alert.timestamp, + alert_id=alert.id, + alert_hash=alert.alert_hash, + ) - session.add(last_alert) - transaction.commit() + session.add(last_alert) + session.commit() diff --git a/poetry.lock b/poetry.lock index a35f23948..a9cc055e0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3878,6 +3878,20 @@ pytest = ">=6.2.5" [package.extras] dev = ["pre-commit", "pytest-asyncio", "tox"] +[[package]] +name = "pytest-timeout" +version = "2.3.1" +description = "pytest plugin to abort hanging tests" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-timeout-2.3.1.tar.gz", hash = "sha256:12397729125c6ecbdaca01035b9e5239d4db97352320af155b3f5de1ba5165d9"}, + {file = "pytest_timeout-2.3.1-py3-none-any.whl", hash = "sha256:68188cb703edfc6a18fad98dc25a3c61e9f24d644b0b70f33af545219fc7813e"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + [[package]] name = "pytest-xdist" version = "3.6.1" @@ -5199,4 +5213,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "4c7d5e35b4f4b687fed8f4604a39eaadf05902d5551bdb1f36f6db646f6c52ce" +content-hash = "d47e2fb172413ac623ee2e4a6b695b9d206b80d21ba0c39485df64436f626c30" diff --git a/pyproject.toml b/pyproject.toml index 0eb60699d..121e132c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ pytest-docker = "^2.0.1" playwright = "^1.44.0" freezegun = "^1.5.1" +pytest-timeout = "^2.3.1" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/tests/conftest.py b/tests/conftest.py index dbab94aa1..aeaaf101d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,12 +11,12 @@ import requests from dotenv import find_dotenv, load_dotenv from pytest_docker.plugin import get_docker_services +from sqlalchemy import event, text from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool -from sqlmodel import SQLModel, Session, create_engine +from sqlmodel import Session, SQLModel, create_engine from starlette_context import context, request_cycle_context -from keep.api.core.db import set_last_alert # This import is required to create the tables from keep.api.core.dependencies import SINGLE_TENANT_UUID from keep.api.core.elastic import ElasticClient @@ -177,6 +177,23 @@ def db_session(request): connect_args={"check_same_thread": False}, poolclass=StaticPool, ) + + # @tb: leaving this here if anybody else gets to problem with nested transactions + # https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + @event.listens_for(mock_engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before any DDL. + dbapi_connection.isolation_level = None + + @event.listens_for(mock_engine, "begin") + def do_begin(conn): + # emit our own BEGIN + try: + conn.exec_driver_sql(text("BEGIN EXCLUSIVE")) + except Exception: + pass + SQLModel.metadata.create_all(mock_engine) # Mock the environment variables so db.py will use it @@ -184,7 +201,9 @@ def db_session(request): # Create a session # Passing class_=Session to use the Session class from sqlmodel (https://github.com/fastapi/sqlmodel/issues/75#issuecomment-2109911909) - SessionLocal = sessionmaker(class_=Session, autocommit=False, autoflush=False, bind=mock_engine) + SessionLocal = sessionmaker( + class_=Session, autocommit=False, autoflush=False, bind=mock_engine + ) session = SessionLocal() # Prepopulate the database with test data @@ -510,8 +529,7 @@ def setup_alerts(elastic_client, db_session, request): existed_last_alerts = db_session.query(LastAlert).all() existed_last_alerts_dict = { - last_alert.fingerprint: last_alert - for last_alert in existed_last_alerts + last_alert.fingerprint: last_alert for last_alert in existed_last_alerts } last_alerts = [] @@ -520,9 +538,7 @@ def setup_alerts(elastic_client, db_session, request): last_alert = existed_last_alerts_dict[alert.fingerprint] last_alert.alert_id = alert.id last_alert.timestamp = alert.timestamp - last_alerts.append( - last_alert - ) + last_alerts.append(last_alert) else: last_alerts.append( LastAlert( @@ -580,18 +596,15 @@ def _setup_stress_alerts_no_elastic(num_alerts): existed_last_alerts = db_session.query(LastAlert).all() existed_last_alerts_dict = { - last_alert.fingerprint: last_alert - for last_alert in existed_last_alerts + last_alert.fingerprint: last_alert for last_alert in existed_last_alerts } last_alerts = [] for alert in alerts: if alert.fingerprint in existed_last_alerts_dict: last_alert = existed_last_alerts_dict[alert.fingerprint] last_alert.alert_id = alert.id - last_alert.timestamp=alert.timestamp - last_alerts.append( - last_alert - ) + last_alert.timestamp = alert.timestamp + last_alerts.append(last_alert) else: last_alerts.append( LastAlert( @@ -625,7 +638,9 @@ def setup_stress_alerts( @pytest.fixture def create_alert(db_session): - def _create_alert(fingerprint, status, timestamp, details=None, tenant_id=SINGLE_TENANT_UUID): + def _create_alert( + fingerprint, status, timestamp, details=None, tenant_id=SINGLE_TENANT_UUID + ): details = details or {} random_name = "test-{}".format(fingerprint) process_event( @@ -634,7 +649,9 @@ def _create_alert(fingerprint, status, timestamp, details=None, tenant_id=SINGLE tenant_id=tenant_id, provider_id="test", provider_type=( - details["source"][0] if details and "source" in details and details["source"] else None + details["source"][0] + if details and "source" in details and details["source"] + else None ), fingerprint=fingerprint, api_key_name="test", @@ -658,11 +675,14 @@ def pytest_addoption(parser): """ parser.addoption( - "--integration", action="store_const", const=True, - dest="run_integration") + "--integration", action="store_const", const=True, dest="run_integration" + ) parser.addoption( - "--non-integration", action="store_const", const=True, - dest="run_non_integration") + "--non-integration", + action="store_const", + const=True, + dest="run_non_integration", + ) def pytest_configure(config): @@ -706,9 +726,9 @@ def pytest_collection_modifyitems(items): elif "keycloak_client" in fixturenames: item.add_marker("integration") elif ( - hasattr(item, "callspec") - and "db_session" in item.callspec.params - and item.callspec.params["db_session"] - and "db" in item.callspec.params["db_session"] + hasattr(item, "callspec") + and "db_session" in item.callspec.params + and item.callspec.params["db_session"] + and "db" in item.callspec.params["db_session"] ): item.add_marker("integration") diff --git a/tests/fixtures/client.py b/tests/fixtures/client.py index a8fd97837..393d2ebbc 100644 --- a/tests/fixtures/client.py +++ b/tests/fixtures/client.py @@ -46,8 +46,7 @@ def test_app(monkeypatch, request): provision_resources() app = get_app() - - yield app + return app # Fixture for TestClient using the test_app fixture diff --git a/tests/test_deduplications.py b/tests/test_deduplications.py index f87e971ca..1667fea17 100644 --- a/tests/test_deduplications.py +++ b/tests/test_deduplications.py @@ -1,4 +1,5 @@ import random +import time import uuid import pytest @@ -7,6 +8,15 @@ from tests.fixtures.client import client, setup_api_key, test_app # noqa +def wait_for_alerts(client, num_alerts): + alerts = client.get("/alerts", headers={"x-api-key": "some-api-key"}).json() + print(f"------------- Total alerts: {len(alerts)}") + while len(alerts) != num_alerts: + time.sleep(1) + alerts = client.get("/alerts", headers={"x-api-key": "some-api-key"}).json() + print(f"------------- Total alerts: {len(alerts)}") + + @pytest.mark.parametrize( "test_app", [ @@ -16,7 +26,7 @@ ], indirect=True, ) -def test_default_deduplication_rule(db_session, client): +def test_default_deduplication_rule(db_session, client, test_app): # insert an alert with some provider_id and make sure that the default deduplication rule is working provider_classes = { provider: ProvidersFactory.get_provider_class(provider) @@ -30,6 +40,8 @@ def test_default_deduplication_rule(db_session, client): headers={"x-api-key": "some-api-key"}, ) + wait_for_alerts(client, 2) + deduplication_rules = client.get( "/deduplications", headers={"x-api-key": "some-api-key"} ).json() @@ -71,6 +83,8 @@ def test_deduplication_sanity(db_session, client, test_app): "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} ) + wait_for_alerts(client, 1) + deduplication_rules = client.get( "/deduplications", headers={"x-api-key": "some-api-key"} ).json() @@ -120,6 +134,8 @@ def test_deduplication_sanity_2(db_session, client, test_app): headers={"x-api-key": "some-api-key"}, ) + wait_for_alerts(client, 2) + deduplication_rules = client.get( "/deduplications", headers={"x-api-key": "some-api-key"} ).json() @@ -133,6 +149,7 @@ def test_deduplication_sanity_2(db_session, client, test_app): assert dedup_rule.get("default") +@pytest.mark.timeout(20) @pytest.mark.parametrize( "test_app", [ @@ -157,6 +174,8 @@ def test_deduplication_sanity_3(db_session, client, test_app): "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} ) + wait_for_alerts(client, 10) + deduplication_rules = client.get( "/deduplications", headers={"x-api-key": "some-api-key"} ).json() @@ -186,6 +205,9 @@ def test_custom_deduplication_rule(db_session, client, test_app): "/alerts/event/datadog", json=alert1, headers={"x-api-key": "some-api-key"} ) + # wait for the background tasks to finish + wait_for_alerts(client, 1) + # create a custom deduplication rule and insert alerts that should be deduplicated by this custom_rule = { "name": "Custom Rule", @@ -209,6 +231,8 @@ def test_custom_deduplication_rule(db_session, client, test_app): "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} ) + wait_for_alerts(client, 2) + deduplication_rules = client.get( "/deduplications", headers={"x-api-key": "some-api-key"} ).json() @@ -224,6 +248,7 @@ def test_custom_deduplication_rule(db_session, client, test_app): assert custom_rule_found +@pytest.mark.timeout(10) @pytest.mark.parametrize( "test_app", [ @@ -240,6 +265,10 @@ def test_custom_deduplication_rule_behaviour(db_session, client, test_app): client.post( "/alerts/event/datadog", json=alert1, headers={"x-api-key": "some-api-key"} ) + + # wait for the background tasks to finish + wait_for_alerts(client, 1) + custom_rule = { "name": "Custom Rule", "description": "Custom Rule Description", @@ -269,6 +298,14 @@ def test_custom_deduplication_rule_behaviour(db_session, client, test_app): "/deduplications", headers={"x-api-key": "some-api-key"} ).json() + while not any( + [rule for rule in deduplication_rules if rule.get("dedup_ratio") == 50.0] + ): + time.sleep(1) + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + custom_rule_found = False for dedup_rule in deduplication_rules: if dedup_rule.get("name") == "Custom Rule": @@ -332,6 +369,12 @@ def test_custom_deduplication_rule_2(db_session, client, test_app): headers={"x-api-key": "some-api-key"}, ) + # wait for the background tasks to finish + alerts = client.get("/alerts", headers={"x-api-key": "some-api-key"}).json() + while len(alerts) < 2: + time.sleep(1) + alerts = client.get("/alerts", headers={"x-api-key": "some-api-key"}).json() + deduplication_rules = client.get( "/deduplications", headers={"x-api-key": "some-api-key"} ).json() @@ -455,6 +498,8 @@ def test_update_deduplication_rule_linked_provider(db_session, client, test_app) response = client.post( "/alerts/event/datadog", json=alert1, headers={"x-api-key": "some-api-key"} ) + + time.sleep(2) custom_rule = { "name": "Custom Rule", "description": "Custom Rule Description", @@ -557,6 +602,11 @@ def test_delete_deduplication_rule_default(db_session, client, test_app): "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} ) + alerts = client.get("/alerts", headers={"x-api-key": "some-api-key"}).json() + while len(alerts) != 1: + time.sleep(1) + alerts = client.get("/alerts", headers={"x-api-key": "some-api-key"}).json() + # try to delete a default deduplication rule deduplication_rules = client.get( "/deduplications", headers={"x-api-key": "some-api-key"} @@ -627,6 +677,7 @@ def test_full_deduplication(db_session, client, test_app): """ +@pytest.mark.timeout(15) @pytest.mark.parametrize( "test_app", [ @@ -652,10 +703,18 @@ def test_partial_deduplication(db_session, client, test_app): "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} ) + wait_for_alerts(client, 1) + deduplication_rules = client.get( "/deduplications", headers={"x-api-key": "some-api-key"} ).json() + while not any([rule for rule in deduplication_rules if rule.get("ingested") == 3]): + time.sleep(1) + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + datadog_rule_found = False for dedup_rule in deduplication_rules: if dedup_rule.get("provider_type") == "datadog" and dedup_rule.get("default"): @@ -690,6 +749,11 @@ def test_ingesting_alert_without_fingerprint_fields(db_session, client, test_app "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} ) + alerts = client.get("/alerts", headers={"x-api-key": "some-api-key"}).json() + while len(alerts) != 1: + time.sleep(1) + alerts = client.get("/alerts", headers={"x-api-key": "some-api-key"}).json() + deduplication_rules = client.get( "/deduplications", headers={"x-api-key": "some-api-key"} ).json() @@ -704,6 +768,7 @@ def test_ingesting_alert_without_fingerprint_fields(db_session, client, test_app assert datadog_rule_found +@pytest.mark.timeout(15) @pytest.mark.parametrize( "test_app", [ @@ -729,15 +794,23 @@ def test_deduplication_fields(db_session, client, test_app): "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} ) + wait_for_alerts(client, 1) + deduplication_rules = client.get( "/deduplications", headers={"x-api-key": "some-api-key"} ).json() + while not any([rule for rule in deduplication_rules if rule.get("ingested") == 3]): + print("Waiting for deduplication rules to be ingested") + time.sleep(1) + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + datadog_rule_found = False for dedup_rule in deduplication_rules: if dedup_rule.get("provider_type") == "datadog" and dedup_rule.get("default"): datadog_rule_found = True assert dedup_rule.get("ingested") == 3 assert 66.667 - dedup_rule.get("dedup_ratio") < 0.1 # 0.66666666....7 - assert datadog_rule_found diff --git a/tests/test_enrichments.py b/tests/test_enrichments.py index 943d3a420..59a76519d 100644 --- a/tests/test_enrichments.py +++ b/tests/test_enrichments.py @@ -1,4 +1,5 @@ # test_enrichments.py +import time from unittest.mock import MagicMock, Mock, patch import pytest @@ -46,6 +47,7 @@ def mock_alert_dto(): severity="high", lastReceived="2021-01-01T00:00:00Z", source=["test_source"], + fingerprint="mock_fingerprint", labels={}, ) @@ -519,6 +521,15 @@ def test_disposable_enrichment(db_session, client, test_app, mock_alert_dto): json=mock_alert_dto.dict(), ) + while ( + client.get( + f"/alerts/{mock_alert_dto.fingerprint}", + headers={"x-api-key": "some-key"}, + ).status_code + != 200 + ): + time.sleep(0.1) + # 2. enrich with disposable alert response = client.post( "/alerts/enrich?dispose_on_new_alert=true",