diff --git a/tests/test_incidents.py b/tests/test_incidents.py index 91df0c916..87f470836 100644 --- a/tests/test_incidents.py +++ b/tests/test_incidents.py @@ -4,6 +4,7 @@ import pytest from sqlalchemy import distinct, func, desc +from keep.api.bl.incidents_bl import IncidentBl from keep.api.core.db import ( IncidentSorting, add_alerts_to_incident_by_incident_id, @@ -810,9 +811,18 @@ def test_merge_incidents_app( assert incident_3_via_api["status"] == IncidentStatus.MERGED.value assert incident_3_via_api["merged_into_incident_id"] == str(incident_1.id) """ - -@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) -def test_split_incident_app(db_session, client, test_app, create_alert): +@pytest.mark.asyncio +async def test_split_incident(db_session, create_alert): + # Create source incident with multiple alerts + incident_source = create_incident_from_dict( + SINGLE_TENANT_UUID, + { + "user_generated_name": "Source incident with mixed severity", + "user_summary": "Source incident with mixed severity", + }, + ) + + # Create alerts with different severities create_alert( "fp1", AlertStatus.FIRING, @@ -823,47 +833,121 @@ def test_split_incident_app(db_session, client, test_app, create_alert): "fp2", AlertStatus.FIRING, datetime.utcnow(), - {"severity": AlertSeverity.CRITICAL.value}, + {"severity": AlertSeverity.WARNING.value}, ) create_alert( "fp3", AlertStatus.FIRING, datetime.utcnow(), - {"severity": AlertSeverity.CRITICAL.value}, + {"severity": AlertSeverity.INFO.value}, ) + alerts = db_session.query(Alert).all() - incident_1 = create_incident_from_dict( - SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} - ) add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident_1.id, [a.fingerprint for a in alerts] - ) - - incident_1 = get_incident_by_id(SINGLE_TENANT_UUID, incident_1.id, with_alerts=True) - assert len(incident_1._alerts) == 3 - - incident_2 = create_incident_from_dict( - SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} + SINGLE_TENANT_UUID, incident_source.id, [a.fingerprint for a in alerts] ) - incident_2 = get_incident_by_id(SINGLE_TENANT_UUID, incident_2.id, with_alerts=True) - assert len(incident_2._alerts) == 0 - response = client.post( - f"/incidents/{str(incident_1.id)}/split", - headers={"x-api-key": "some-key"}, - json={ - "alert_fingerprints": [alerts[2].fingerprint], - "destination_incident_id": str(incident_2.id), + # Create destination incident + incident_dest = create_incident_from_dict( + SINGLE_TENANT_UUID, + { + "user_generated_name": "Destination incident", + "user_summary": "Destination incident", }, ) - assert response.status_code == 200 + # Verify initial state + incident_source = get_incident_by_id(SINGLE_TENANT_UUID, incident_source.id, with_alerts=True) + assert len(incident_source._alerts) == 3 + assert incident_source.severity == IncidentSeverity.CRITICAL.order - incident_2 = get_incident_by_id(SINGLE_TENANT_UUID, incident_2.id, with_alerts=True) - assert incident_2._alerts[0].fingerprint == alerts[2].fingerprint + incident_dest = get_incident_by_id(SINGLE_TENANT_UUID, incident_dest.id, with_alerts=True) + assert len(incident_dest._alerts) == 0 - incident_1 = get_incident_by_id(SINGLE_TENANT_UUID, incident_1.id, with_alerts=True) - assert len(incident_1._alerts) == 2 + # Split the critical alert using IncidentBl + critical_alert = next(a for a in alerts if a.event["severity"] == AlertSeverity.CRITICAL.value) + incident_bl = IncidentBl(SINGLE_TENANT_UUID, db_session, pusher_client=None) + + # Move alert to destination incident + await incident_bl.add_alerts_to_incident( + incident_id=incident_dest.id, + alert_fingerprints=[critical_alert.fingerprint] + ) + + # Remove alert from source incident + incident_bl.delete_alerts_from_incident( + incident_id=incident_source.id, + alert_fingerprints=[critical_alert.fingerprint] + ) + + db_session.expire_all() + + # Verify final state + incident_source = get_incident_by_id(SINGLE_TENANT_UUID, incident_source.id, with_alerts=True) + assert len(incident_source._alerts) == 2 + assert incident_source.severity == IncidentSeverity.WARNING.order + + incident_dest = get_incident_by_id(SINGLE_TENANT_UUID, incident_dest.id, with_alerts=True) + assert len(incident_dest._alerts) == 1 + assert incident_dest.severity == IncidentSeverity.CRITICAL.order + assert incident_dest._alerts[0].fingerprint == critical_alert.fingerprint + assert len(incident_dest._alerts) == 1 + assert incident_dest.severity == IncidentSeverity.CRITICAL.order + assert incident_dest._alerts[0].fingerprint == critical_alert.fingerprint + +# @pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) +# def test_split_incident_app(db_session, client, test_app, create_alert): +# create_alert( +# "fp1", +# AlertStatus.FIRING, +# datetime.utcnow(), +# {"severity": AlertSeverity.CRITICAL.value}, +# ) +# create_alert( +# "fp2", +# AlertStatus.FIRING, +# datetime.utcnow(), +# {"severity": AlertSeverity.CRITICAL.value}, +# ) +# create_alert( +# "fp3", +# AlertStatus.FIRING, +# datetime.utcnow(), +# {"severity": AlertSeverity.CRITICAL.value}, +# ) +# alerts = db_session.query(Alert).all() +# incident_1 = create_incident_from_dict( +# SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} +# ) +# add_alerts_to_incident_by_incident_id( +# SINGLE_TENANT_UUID, incident_1.id, [a.fingerprint for a in alerts] +# ) + +# incident_1 = get_incident_by_id(SINGLE_TENANT_UUID, incident_1.id, with_alerts=True) +# assert len(incident_1._alerts) == 3 + +# incident_2 = create_incident_from_dict( +# SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} +# ) +# incident_2 = get_incident_by_id(SINGLE_TENANT_UUID, incident_2.id, with_alerts=True) +# assert len(incident_2._alerts) == 0 + +# response = client.post( +# f"/incidents/{str(incident_1.id)}/split", +# headers={"x-api-key": "some-key"}, +# json={ +# "alert_fingerprints": [alerts[2].fingerprint], +# "destination_incident_id": str(incident_2.id), +# }, +# ) + +# assert response.status_code == 200 + +# incident_2 = get_incident_by_id(SINGLE_TENANT_UUID, incident_2.id, with_alerts=True) +# assert incident_2._alerts[0].fingerprint == alerts[2].fingerprint + +# incident_1 = get_incident_by_id(SINGLE_TENANT_UUID, incident_1.id, with_alerts=True) +# assert len(incident_1._alerts) == 2 def test_cross_tenant_exposure_issue_2768(db_session, create_alert):