diff --git a/keep/api/bl/incidents_bl.py b/keep/api/bl/incidents_bl.py index 5bc3d9216..a1b91a8fe 100644 --- a/keep/api/bl/incidents_bl.py +++ b/keep/api/bl/incidents_bl.py @@ -14,11 +14,12 @@ add_alerts_to_incident_by_incident_id, create_incident_from_dto, delete_incident_by_id, - get_incident_alerts_by_incident_id, get_incident_by_id, get_incident_unique_fingerprint_count, remove_alerts_to_incident_by_incident_id, update_incident_from_dto_by_id, + enrich_alerts_with_incidents, + get_all_alerts_by_fingerprints, ) from keep.api.core.elastic import ElasticClient from keep.api.models.alert import IncidentDto, IncidentDtoIn @@ -108,43 +109,30 @@ async def add_alerts_to_incident( "Alerts added to incident", extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) - self.__update_elastic(incident_id, alert_fingerprints) - self.logger.info( - "Alerts pushed to elastic", - extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, - ) - self.__update_client_on_incident_change(incident_id) - self.logger.info( - "Client updated on incident change", - extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, - ) - incident_dto = IncidentDto.from_db_incident(incident) - self.__run_workflows(incident_dto, "updated") - self.logger.info( - "Workflows run on incident", - extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, - ) + self.__postprocess_alerts_change(incident, alert_fingerprints) await self.__generate_summary(incident_id, incident) self.logger.info( "Summary generated", extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) - def __update_elastic(self, incident_id: UUID, alert_fingerprints: List[str]): + def __update_elastic(self, alert_fingerprints: List[str]): try: elastic_client = ElasticClient(self.tenant_id) if elastic_client.enabled: - db_alerts, _ = get_incident_alerts_by_incident_id( + db_alerts = get_all_alerts_by_fingerprints( tenant_id=self.tenant_id, - incident_id=incident_id, - limit=len(alert_fingerprints), + fingerprints=alert_fingerprints, + session=self.session, ) + db_alerts = enrich_alerts_with_incidents(self.tenant_id, db_alerts, session=self.session) enriched_alerts_dto = convert_db_alerts_to_dto_alerts( db_alerts, with_incidents=True ) elastic_client.index_alerts(alerts=enriched_alerts_dto) except Exception: self.logger.exception("Failed to push alert to elasticsearch") + raise def __update_client_on_incident_change(self, incident_id: Optional[UUID] = None): if self.pusher_client is not None: @@ -217,6 +205,7 @@ def delete_alerts_from_incident( raise HTTPException(status_code=404, detail="Incident not found") remove_alerts_to_incident_by_incident_id(self.tenant_id, incident_id, alert_fingerprints) + self.__postprocess_alerts_change(incident, alert_fingerprints) def delete_incident(self, incident_id: UUID) -> None: self.logger.info( @@ -255,7 +244,7 @@ def update_incident( incident_id: UUID, updated_incident_dto: IncidentDtoIn, generated_by_ai: bool, - ) -> None: + ) -> IncidentDto: self.logger.info( "Fetching incident", extra={ @@ -270,16 +259,34 @@ def update_incident( raise HTTPException(status_code=404, detail="Incident not found") new_incident_dto = IncidentDto.from_db_incident(incident) - try: - workflow_manager = WorkflowManager.get_instance() - self.logger.info("Adding incident to the workflow manager queue") - workflow_manager.insert_incident( - self.tenant_id, new_incident_dto, "updated" - ) - self.logger.info("Added incident to the workflow manager queue") - except Exception: - self.logger.exception( - "Failed to run workflows based on incident", - extra={"incident_id": new_incident_dto.id, "tenant_id": self.tenant_id}, - ) + + self.__update_client_on_incident_change(incident.id) + self.logger.info( + "Client updated on incident change", + extra={"incident_id": incident.id}, + ) + self.__run_workflows(new_incident_dto, "updated") + self.logger.info( + "Workflows run on incident", + extra={"incident_id": incident.id}, + ) return new_incident_dto + + def __postprocess_alerts_change(self, incident, alert_fingerprints): + + self.__update_elastic(alert_fingerprints) + self.logger.info( + "Alerts pushed to elastic", + extra={"incident_id": incident.id, "alert_fingerprints": alert_fingerprints}, + ) + self.__update_client_on_incident_change(incident.id) + self.logger.info( + "Client updated on incident change", + extra={"incident_id": incident.id, "alert_fingerprints": alert_fingerprints}, + ) + incident_dto = IncidentDto.from_db_incident(incident) + self.__run_workflows(incident_dto, "updated") + self.logger.info( + "Workflows run on incident", + extra={"incident_id": incident.id, "alert_fingerprints": alert_fingerprints}, + ) diff --git a/keep/api/core/db.py b/keep/api/core/db.py index 321f38500..c8ff1f885 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -1428,6 +1428,18 @@ def get_alerts_by_fingerprint( return alerts +def get_all_alerts_by_fingerprints( + tenant_id: str, fingerprints: List[str], session: Optional[Session] = None +) -> List[Alert]: + with existed_or_new_session(session) as session: + query = ( + select(Alert) + .filter(Alert.tenant_id == tenant_id) + .filter(Alert.fingerprint.in_(fingerprints)) + .order_by(Alert.timestamp.desc()) + ) + return session.exec(query).all() + def get_alert_by_fingerprint_and_event_id( tenant_id: str, fingerprint: str, event_id: str @@ -3215,11 +3227,11 @@ def enrich_alerts_with_incidents( ).all() incidents_per_alert = defaultdict(list) - for alert_id, incident in alert_incidents: - incidents_per_alert[alert_id].append(incident) + for fingerprint, incident in alert_incidents: + incidents_per_alert[fingerprint].append(incident) for alert in alerts: - alert._incidents = incidents_per_alert[incident.id] + alert._incidents = incidents_per_alert[alert.fingerprint] return alerts diff --git a/keep/api/utils/enrichment_helpers.py b/keep/api/utils/enrichment_helpers.py index 2b0ad2b62..85897ad9e 100644 --- a/keep/api/utils/enrichment_helpers.py +++ b/keep/api/utils/enrichment_helpers.py @@ -80,10 +80,10 @@ def calculated_start_firing_time( def convert_db_alerts_to_dto_alerts( - alerts: list[Alert | tuple[Alert, LastAlertToIncident]], - with_incidents: bool = False, - session: Optional[Session] = None, - ) -> list[AlertDto | AlertWithIncidentLinkMetadataDto]: + alerts: list[Alert | tuple[Alert, LastAlertToIncident]], + with_incidents: bool = False, + session: Optional[Session] = None, +) -> list[AlertDto | AlertWithIncidentLinkMetadataDto]: """ Enriches the alerts with the enrichment data. @@ -109,8 +109,8 @@ def convert_db_alerts_to_dto_alerts( if alert.alert_enrichment: alert.event.update(alert.alert_enrichment.enrichments) if with_incidents: - if alert.incidents: - alert.event["incident"] = ",".join(str(incident.id) for incident in alert.incidents) + if alert._incidents: + alert.event["incident"] = ",".join(str(incident.id) for incident in alert._incidents) try: if alert_to_incident is not None: alert_dto = AlertWithIncidentLinkMetadataDto.from_db_instance(alert, alert_to_incident) diff --git a/poetry.lock b/poetry.lock index 9d66c1c4e..8eaff4e8f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3952,6 +3952,24 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.21.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.0.tar.gz", hash = "sha256:2b38a496aef56f56b0e87557ec313e11e1ab9276fc3863f6a7be0f1d0e415e1b"}, + {file = "pytest_asyncio-0.21.0-py3-none-any.whl", hash = "sha256:f2b3366b7cd501a4056858bd39349d5af19742aed2d81660b7998b6341c7eb9c"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "pytest-docker" version = "2.2.0" @@ -5329,4 +5347,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "a3319e110409281b9a3fa4c09300681fb82ef73bd05bbe1f93081a669b7635d7" +content-hash = "00bee7e27325f82b4779b56a87f600af1706019e765e6aa9fcdf3dd2d000dd06" diff --git a/pyproject.toml b/pyproject.toml index 413d0e05f..dbc9066b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ psycopg-binary = "^3.2.3" psycopg = "^3.2.3" prometheus-client = "^0.21.1" psycopg2-binary = "^2.9.10" +pytest-asyncio = "0.21.0" [tool.poetry.group.dev.dependencies] pre-commit = "^3.0.4" pre-commit-hooks = "^4.4.0" diff --git a/tests/test_incidents.py b/tests/test_incidents.py index 4ecf89b89..a8c035fc9 100644 --- a/tests/test_incidents.py +++ b/tests/test_incidents.py @@ -1,9 +1,14 @@ from datetime import datetime +from unittest.mock import patch + +from fastapi import HTTPException from itertools import cycle +from uuid import uuid4 import pytest -from sqlalchemy import distinct, func, desc +from sqlalchemy import distinct, func, desc, and_ +from keep.api.bl.incidents_bl import IncidentBl from keep.api.core.db import ( IncidentSorting, add_alerts_to_incident_by_incident_id, @@ -22,8 +27,10 @@ AlertStatus, IncidentSeverity, IncidentStatus, + IncidentDtoIn, + IncidentDto, ) -from keep.api.models.db.alert import Alert, LastAlertToIncident +from keep.api.models.db.alert import Alert, LastAlertToIncident, Incident, NULL_FOR_DELETED_AT from keep.api.models.db.tenant import Tenant from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts from tests.fixtures.client import client, test_app # noqa @@ -860,3 +867,391 @@ def test_cross_tenant_exposure_issue_2768(db_session, create_alert): assert incident_tenant_2.alerts_count == 1 assert total_incident_tenant_2_alerts == 1 assert len(incident_tenant_2_alerts) == 1 + + +class PusherMock: + + def __init__(self): + self.triggers = [] + + def trigger(self, channel, event_name, data): + self.triggers.append((channel, event_name, data)) + +class WorkflowManagerMock: + + def __init__(self): + self.events = [] + + def get_instance(self): + return self + + def insert_incident(self, tenant_id, incident_dto, action): + self.events.append((tenant_id, incident_dto, action)) + + +class ElasticClientMock: + + def __init__(self): + self.alerts = [] + self.tenant_id = None + self.enabled = True + + def __call__(self, tenant_id): + self.tenant_id = tenant_id + return self + + def index_alerts(self, alerts): + self.alerts.append((self.tenant_id, alerts)) + + +def test_incident_bl_create_incident(db_session): + + pusher = PusherMock() + workflow_manager = WorkflowManagerMock() + + with patch("keep.api.bl.incidents_bl.WorkflowManager", workflow_manager): + incident_bl = IncidentBl(tenant_id=SINGLE_TENANT_UUID, session=db_session, pusher_client=pusher) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 0 + + incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Incident name", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto = incident_bl.create_incident(incident_dto_in, generated_from_ai=False) + assert isinstance(incident_dto, IncidentDto) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + assert incident_dto.is_confirmed is True + assert incident_dto.is_predicted is False + + incident = db_session.query(Incident).get(incident_dto.id) + assert incident.user_generated_name == "Incident name" + assert incident.status == "firing" + assert incident.user_summary == "Keep: Incident description" + assert incident.is_confirmed is True + assert incident.is_predicted is False + + # Check pusher + + assert len(pusher.triggers) == 1 + channel, event_name, data = pusher.triggers[0] + assert channel == f"private-{SINGLE_TENANT_UUID}" + assert event_name == "incident-change" + assert isinstance(data, dict) + assert "incident_id" in data + assert data["incident_id"] is None # For new incidents we don't send incident.id + + # Check workflow manager + assert len(workflow_manager.events) == 1 + wf_tenant_id, wf_incident_dto, wf_action = workflow_manager.events[0] + assert wf_tenant_id == SINGLE_TENANT_UUID + assert wf_incident_dto.id == incident_dto.id + assert wf_action == "created" + + incident_dto_ai = incident_bl.create_incident(incident_dto_in, generated_from_ai=True) + assert isinstance(incident_dto_ai, IncidentDto) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 2 + + assert incident_dto_ai.is_confirmed is True + assert incident_dto_ai.is_predicted is False + + +def test_incident_bl_update_incident(db_session): + pusher = PusherMock() + workflow_manager = WorkflowManagerMock() + + with patch("keep.api.bl.incidents_bl.WorkflowManager", workflow_manager): + incident_bl = IncidentBl( + tenant_id=SINGLE_TENANT_UUID, + session=db_session, + pusher_client=pusher + ) + incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Incident name", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto = incident_bl.create_incident(incident_dto_in) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + new_incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Not an incident", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto_update = incident_bl.update_incident(incident_dto.id, new_incident_dto_in, False) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + assert incident_dto_update.name == "Not an incident" + + incident = db_session.query(Incident).get(incident_dto.id) + assert incident.user_generated_name == "Not an incident" + assert incident.status == "firing" + assert incident.user_summary == "Keep: Incident description" + + # Check error if no incident found + with pytest.raises(HTTPException, match="Incident not found"): + incident_bl.update_incident(uuid4(), incident_dto_update, False) + + # Check workflowmanager + assert len(workflow_manager.events) == 2 + wf_tenant_id, wf_incident_dto, wf_action = workflow_manager.events[-1] + assert wf_tenant_id == SINGLE_TENANT_UUID + assert wf_incident_dto.id == incident_dto.id + assert wf_action == "updated" + + # Check pusher + assert len(pusher.triggers) == 2 # 1 for create, 1 for update + channel, event_name, data = pusher.triggers[-1] + assert channel == f"private-{SINGLE_TENANT_UUID}" + assert event_name == "incident-change" + assert isinstance(data, dict) + assert "incident_id" in data + assert data["incident_id"] == str(incident_dto.id) + + +def test_incident_bl_delete_incident(db_session): + pusher = PusherMock() + workflow_manager = WorkflowManagerMock() + + with patch("keep.api.bl.incidents_bl.WorkflowManager", workflow_manager): + incident_bl = IncidentBl( + tenant_id=SINGLE_TENANT_UUID, + session=db_session, + pusher_client=pusher + ) + # Check error if no incident found + with pytest.raises(HTTPException, match="Incident not found"): + incident_bl.delete_incident(uuid4()) + + incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Incident name", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto = incident_bl.create_incident(incident_dto_in) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + incident_bl.delete_incident(incident_dto.id) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 0 + + # Check pusher + assert len(pusher.triggers) == 2 # Created, deleted + + channel, event_name, data = pusher.triggers[-1] + assert channel == f"private-{SINGLE_TENANT_UUID}" + assert event_name == "incident-change" + assert isinstance(data, dict) + assert "incident_id" in data + assert data["incident_id"] is None + + # Check workflow manager + assert len(workflow_manager.events) == 2 # Created, deleted + wf_tenant_id, wf_incident_dto, wf_action = workflow_manager.events[-1] + assert wf_tenant_id == SINGLE_TENANT_UUID + assert wf_incident_dto.id == incident_dto.id + assert wf_action == "deleted" + + +@pytest.mark.asyncio +async def test_incident_bl_add_alert_to_incident(db_session, create_alert): + pusher = PusherMock() + workflow_manager = WorkflowManagerMock() + elastic_client = ElasticClientMock() + + with patch("keep.api.bl.incidents_bl.WorkflowManager", workflow_manager): + with patch("keep.api.bl.incidents_bl.ElasticClient", elastic_client): + incident_bl = IncidentBl( + tenant_id=SINGLE_TENANT_UUID, + session=db_session, + pusher_client=pusher + ) + incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Incident name", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto = incident_bl.create_incident(incident_dto_in) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + with pytest.raises(HTTPException, match="Incident not found"): + await incident_bl.add_alerts_to_incident(uuid4(), [], False) + + create_alert( + "alert-test-1", + AlertStatus("firing"), + datetime.utcnow(), + {}, + ) + + await incident_bl.add_alerts_to_incident( + incident_dto.id, + ["alert-test-1"], + False + ) + + alerts_to_incident_count = ( + db_session + .query(LastAlertToIncident) + .where( + LastAlertToIncident.incident_id == incident_dto.id + ) + .count() + ) + assert alerts_to_incident_count == 1 + + alert_to_incident = ( + db_session + .query(LastAlertToIncident) + .where( + LastAlertToIncident.fingerprint == "alert-test-1" + ) + .first() + ) + assert alert_to_incident is not None + + # Check pusher + assert len(pusher.triggers) == 2 # Created, update + + channel, event_name, data = pusher.triggers[-1] + assert channel == f"private-{SINGLE_TENANT_UUID}" + assert event_name == "incident-change" + assert isinstance(data, dict) + assert "incident_id" in data + assert data["incident_id"] == str(incident_dto.id) + + # Check workflow manager + assert len(workflow_manager.events) == 2 # Created, update + wf_tenant_id, wf_incident_dto, wf_action = workflow_manager.events[-1] + assert wf_tenant_id == SINGLE_TENANT_UUID + assert wf_incident_dto.id == incident_dto.id + assert wf_action == "updated" + + # Check elastic + assert len(elastic_client.alerts) == 1 + el_tenant_id, el_alerts = elastic_client.alerts[-1] + assert len(el_alerts) == 1 + assert el_tenant_id == SINGLE_TENANT_UUID + assert el_alerts[-1].fingerprint == "alert-test-1" + assert el_alerts[-1].incident == str(incident_dto.id) + + +@pytest.mark.asyncio +async def test_incident_bl_delete_alerts_from_incident(db_session, create_alert): + pusher = PusherMock() + workflow_manager = WorkflowManagerMock() + elastic_client = ElasticClientMock() + + with patch("keep.api.bl.incidents_bl.WorkflowManager", workflow_manager): + with patch("keep.api.bl.incidents_bl.ElasticClient", elastic_client): + incident_bl = IncidentBl( + tenant_id=SINGLE_TENANT_UUID, + session=db_session, + pusher_client=pusher + ) + incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Incident name", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto = incident_bl.create_incident(incident_dto_in) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + with pytest.raises(HTTPException, match="Incident not found"): + incident_bl.delete_alerts_from_incident(uuid4(), []) + + create_alert( + "alert-test-1", + AlertStatus("firing"), + datetime.utcnow(), + {}, + ) + + await incident_bl.add_alerts_to_incident( + incident_dto.id, + ["alert-test-1"], + False + ) + + alerts_to_incident_count = ( + db_session + .query(LastAlertToIncident) + .where( + and_( + LastAlertToIncident.incident_id == incident_dto.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT + ) + ) + .count() + ) + assert alerts_to_incident_count == 1 + + incident_bl.delete_alerts_from_incident( + incident_dto.id, + ["alert-test-1"], + ) + + alerts_to_incident_count = ( + db_session + .query(LastAlertToIncident) + .where( + and_( + LastAlertToIncident.incident_id == incident_dto.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT + ) + ) + .count() + ) + assert alerts_to_incident_count == 0 + + # Check pusher + # Created, updated (added event), updated(deleted event) + assert len(pusher.triggers) == 3 + + channel, event_name, data = pusher.triggers[-1] + assert channel == f"private-{SINGLE_TENANT_UUID}" + assert event_name == "incident-change" + assert isinstance(data, dict) + assert "incident_id" in data + assert data["incident_id"] == str(incident_dto.id) + + # Check workflow manager + # Created, updated (added event), updated(deleted event) + assert len(workflow_manager.events) == 3 + wf_tenant_id, wf_incident_dto, wf_action = workflow_manager.events[-1] + assert wf_tenant_id == SINGLE_TENANT_UUID + assert wf_incident_dto.id == incident_dto.id + assert wf_action == "updated" + + # Check elastic + assert len(elastic_client.alerts) == 2 + el_tenant_id, el_alerts = elastic_client.alerts[-1] + assert len(el_alerts) == 1 + assert el_tenant_id == SINGLE_TENANT_UUID + assert el_alerts[-1].fingerprint == "alert-test-1" + assert el_alerts[-1].incident is None