From 2d21e57c29cf90f3e8231ae5315761ff7b5e7497 Mon Sep 17 00:00:00 2001 From: Vladimir Filonov Date: Wed, 16 Oct 2024 16:48:32 +0400 Subject: [PATCH] incident.alerts_count should return count of unique fingerprints instead of alerts --- keep/api/core/db.py | 30 ++++++++++---- keep/api/tasks/process_event_task.py | 5 ++- tests/test_incidents.py | 59 ++++++++++++++++++++++------ 3 files changed, 74 insertions(+), 20 deletions(-) diff --git a/keep/api/core/db.py b/keep/api/core/db.py index 0d1c0e3c5..9d6bc7380 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -2950,7 +2950,9 @@ def get_all_same_alert_ids( def get_alerts_data_for_incident( - alert_ids: List[str | UUID], session: Optional[Session] = None + alert_ids: List[str | UUID], + existed_fingerprints: Optional[List[str]] = None, + session: Optional[Session] = None ) -> dict: """ Function to prepare aggregated data for incidents from the given list of alert_ids @@ -2962,12 +2964,14 @@ def get_alerts_data_for_incident( Returns: dict {sources: list[str], services: list[str], count: int} """ + existed_fingerprints = existed_fingerprints or [] with existed_or_new_session(session) as session: fields = ( get_json_extract_field(session, Alert.event, "service"), Alert.provider_type, + Alert.fingerprint, get_json_extract_field(session, Alert.event, "severity"), ) @@ -2980,8 +2984,9 @@ def get_alerts_data_for_incident( sources = [] services = [] severities = [] + fingerprints = set() - for service, source, severity in alerts_data: + for service, source, fingerprint, severity in alerts_data: if source: sources.append(source) if service: @@ -2991,12 +2996,14 @@ def get_alerts_data_for_incident( severities.append(IncidentSeverity.from_number(severity)) else: severities.append(IncidentSeverity(severity)) + if fingerprint and fingerprint not in existed_fingerprints: + fingerprints.add(fingerprint) return { "sources": set(sources), "services": set(services), "max_severity": max(severities), - "count": len(alerts_data), + "count": len(fingerprints), } @@ -3047,6 +3054,17 @@ def add_alerts_to_incident( ) ).all() ) + existing_fingerprints = set( + session.exec( + select(Alert.fingerprint) + .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .where( + AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + AlertToIncident.tenant_id == tenant_id, + AlertToIncident.incident_id == incident.id, + ) + ).all() + ) new_alert_ids = [ alert_id for alert_id in all_alert_ids if alert_id not in existing_alert_ids @@ -3055,9 +3073,7 @@ def add_alerts_to_incident( if not new_alert_ids: return incident - alerts_data_for_incident = get_alerts_data_for_incident( - new_alert_ids, session - ) + alerts_data_for_incident = get_alerts_data_for_incident(new_alert_ids, existing_fingerprints, session) incident.sources = list( set(incident.sources if incident.sources else []) | set(alerts_data_for_incident["sources"]) @@ -3177,7 +3193,7 @@ def remove_alerts_to_incident_by_incident_id( session.commit() # Getting aggregated data for incidents for alerts which just was removed - alerts_data_for_incident = get_alerts_data_for_incident(all_alert_ids, session) + alerts_data_for_incident = get_alerts_data_for_incident(all_alert_ids, session=session) service_field = get_json_extract_field(session, Alert.event, "service") diff --git a/keep/api/tasks/process_event_task.py b/keep/api/tasks/process_event_task.py index 5e2070743..9347fe49b 100644 --- a/keep/api/tasks/process_event_task.py +++ b/keep/api/tasks/process_event_task.py @@ -527,7 +527,10 @@ def process_event( and isinstance(event, dict) or isinstance(event, FormData) ): - provider_class = ProvidersFactory.get_provider_class(provider_type) + try: + provider_class = ProvidersFactory.get_provider_class(provider_type) + except Exception: + provider_class = ProvidersFactory.get_provider_class("keep") event = provider_class.format_alert( tenant_id=tenant_id, event=event, diff --git a/tests/test_incidents.py b/tests/test_incidents.py index 5cd78fbd2..b9758b638 100644 --- a/tests/test_incidents.py +++ b/tests/test_incidents.py @@ -2,7 +2,7 @@ from itertools import cycle import pytest -from sqlalchemy import func +from sqlalchemy import func, distinct from sqlalchemy.orm.exc import DetachedInstanceError from keep.api.core.db import ( @@ -23,24 +23,40 @@ IncidentSeverity, IncidentStatus, ) -from keep.api.models.db.alert import Alert +from keep.api.models.db.alert import Alert, AlertToIncident from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts from tests.fixtures.client import client, test_app # noqa +def test_get_alerts_data_for_incident(db_session, create_alert): + for i in range(100): + create_alert( + f"alert-test-{i % 10}", + AlertStatus.FIRING, + datetime.utcnow(), + { + "source": [f"source_{i % 10}"], + "service": f"service_{i % 10}", + } + ) + + alerts = db_session.query(Alert).all() + + unique_fingerprints = db_session.query(func.count(distinct(Alert.fingerprint))).scalar() -def test_get_alerts_data_for_incident(db_session, setup_stress_alerts_no_elastic): - alerts = setup_stress_alerts_no_elastic(100) assert 100 == db_session.query(func.count(Alert.id)).scalar() + assert 10 == unique_fingerprints data = get_alerts_data_for_incident([a.id for a in alerts]) - assert data["sources"] == set(["source_{}".format(i) for i in range(10)]) - assert data["services"] == set(["service_{}".format(i) for i in range(10)]) - assert data["count"] == 100 + assert data["sources"] == set([f"source_{i}" for i in range(10)]) + assert data["services"] == set([f"service_{i}" for i in range(10)]) + assert data["count"] == unique_fingerprints def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elastic): alerts = setup_stress_alerts_no_elastic(100) + # Adding 10 non-unique fingerprints + alerts.extend(setup_stress_alerts_no_elastic(10)) incident = create_incident_from_dict( SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} ) @@ -53,7 +69,10 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 100 + # 110 alerts + assert len(incident.alerts) == 110 + # But 100 unique fingerprints + assert incident.alerts_count == 100 assert sorted(incident.affected_services) == sorted( ["service_{}".format(i) for i in range(10)] @@ -66,6 +85,18 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti service_0 = db_session.query(Alert.id).filter(service_field == "service_0").all() + # Testing unique fingerprints + more_alerts_with_same_fingerprints = setup_stress_alerts_no_elastic(10) + + add_alerts_to_incident_by_incident_id( + SINGLE_TENANT_UUID, incident.id, [a.id for a in more_alerts_with_same_fingerprints] + ) + + incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) + + assert incident.alerts_count == 100 + assert db_session.query(func.count(AlertToIncident.alert_id)).scalar() == 120 + remove_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, @@ -76,7 +107,8 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 99 + # 117 because we removed multiple alerts with service_0 + assert len(incident.alerts) == 117 assert "service_0" in incident.affected_services assert len(incident.affected_services) == 10 assert sorted(incident.affected_services) == sorted( @@ -92,11 +124,12 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti incident_id=incident.id, tenant_id=incident.tenant_id, include_unlinked=True - )[0]) == 100 + )[0]) == 120 incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 90 + # 108 because we removed multiple alert with same fingerprints + assert len(incident.alerts) == 108 assert "service_0" not in incident.affected_services assert len(incident.affected_services) == 9 assert sorted(incident.affected_services) == sorted( @@ -117,7 +150,7 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 89 + assert len(incident.alerts) == 105 assert "source_1" in incident.sources # source_0 was removed together with service_0 assert len(incident.sources) == 9 @@ -137,6 +170,8 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti ) + + def test_get_last_incidents(db_session, create_alert): severity_cycle = cycle([s.order for s in IncidentSeverity])