diff --git a/keep/api/bl/incidents_bl.py b/keep/api/bl/incidents_bl.py index 5bc3d9216..a1b91a8fe 100644 --- a/keep/api/bl/incidents_bl.py +++ b/keep/api/bl/incidents_bl.py @@ -14,11 +14,12 @@ add_alerts_to_incident_by_incident_id, create_incident_from_dto, delete_incident_by_id, - get_incident_alerts_by_incident_id, get_incident_by_id, get_incident_unique_fingerprint_count, remove_alerts_to_incident_by_incident_id, update_incident_from_dto_by_id, + enrich_alerts_with_incidents, + get_all_alerts_by_fingerprints, ) from keep.api.core.elastic import ElasticClient from keep.api.models.alert import IncidentDto, IncidentDtoIn @@ -108,43 +109,30 @@ async def add_alerts_to_incident( "Alerts added to incident", extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) - self.__update_elastic(incident_id, alert_fingerprints) - self.logger.info( - "Alerts pushed to elastic", - extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, - ) - self.__update_client_on_incident_change(incident_id) - self.logger.info( - "Client updated on incident change", - extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, - ) - incident_dto = IncidentDto.from_db_incident(incident) - self.__run_workflows(incident_dto, "updated") - self.logger.info( - "Workflows run on incident", - extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, - ) + self.__postprocess_alerts_change(incident, alert_fingerprints) await self.__generate_summary(incident_id, incident) self.logger.info( "Summary generated", extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) - def __update_elastic(self, incident_id: UUID, alert_fingerprints: List[str]): + def __update_elastic(self, alert_fingerprints: List[str]): try: elastic_client = ElasticClient(self.tenant_id) if elastic_client.enabled: - db_alerts, _ = get_incident_alerts_by_incident_id( + db_alerts = get_all_alerts_by_fingerprints( tenant_id=self.tenant_id, - incident_id=incident_id, - limit=len(alert_fingerprints), + fingerprints=alert_fingerprints, + session=self.session, ) + db_alerts = enrich_alerts_with_incidents(self.tenant_id, db_alerts, session=self.session) enriched_alerts_dto = convert_db_alerts_to_dto_alerts( db_alerts, with_incidents=True ) elastic_client.index_alerts(alerts=enriched_alerts_dto) except Exception: self.logger.exception("Failed to push alert to elasticsearch") + raise def __update_client_on_incident_change(self, incident_id: Optional[UUID] = None): if self.pusher_client is not None: @@ -217,6 +205,7 @@ def delete_alerts_from_incident( raise HTTPException(status_code=404, detail="Incident not found") remove_alerts_to_incident_by_incident_id(self.tenant_id, incident_id, alert_fingerprints) + self.__postprocess_alerts_change(incident, alert_fingerprints) def delete_incident(self, incident_id: UUID) -> None: self.logger.info( @@ -255,7 +244,7 @@ def update_incident( incident_id: UUID, updated_incident_dto: IncidentDtoIn, generated_by_ai: bool, - ) -> None: + ) -> IncidentDto: self.logger.info( "Fetching incident", extra={ @@ -270,16 +259,34 @@ def update_incident( raise HTTPException(status_code=404, detail="Incident not found") new_incident_dto = IncidentDto.from_db_incident(incident) - try: - workflow_manager = WorkflowManager.get_instance() - self.logger.info("Adding incident to the workflow manager queue") - workflow_manager.insert_incident( - self.tenant_id, new_incident_dto, "updated" - ) - self.logger.info("Added incident to the workflow manager queue") - except Exception: - self.logger.exception( - "Failed to run workflows based on incident", - extra={"incident_id": new_incident_dto.id, "tenant_id": self.tenant_id}, - ) + + self.__update_client_on_incident_change(incident.id) + self.logger.info( + "Client updated on incident change", + extra={"incident_id": incident.id}, + ) + self.__run_workflows(new_incident_dto, "updated") + self.logger.info( + "Workflows run on incident", + extra={"incident_id": incident.id}, + ) return new_incident_dto + + def __postprocess_alerts_change(self, incident, alert_fingerprints): + + self.__update_elastic(alert_fingerprints) + self.logger.info( + "Alerts pushed to elastic", + extra={"incident_id": incident.id, "alert_fingerprints": alert_fingerprints}, + ) + self.__update_client_on_incident_change(incident.id) + self.logger.info( + "Client updated on incident change", + extra={"incident_id": incident.id, "alert_fingerprints": alert_fingerprints}, + ) + incident_dto = IncidentDto.from_db_incident(incident) + self.__run_workflows(incident_dto, "updated") + self.logger.info( + "Workflows run on incident", + extra={"incident_id": incident.id, "alert_fingerprints": alert_fingerprints}, + ) diff --git a/keep/api/core/db.py b/keep/api/core/db.py index eb570c19c..8949b6586 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -1428,6 +1428,18 @@ def get_alerts_by_fingerprint( return alerts +def get_all_alerts_by_fingerprints( + tenant_id: str, fingerprints: List[str], session: Optional[Session] = None +) -> List[Alert]: + with existed_or_new_session(session) as session: + query = ( + select(Alert) + .filter(Alert.tenant_id == tenant_id) + .filter(Alert.fingerprint.in_(fingerprints)) + .order_by(Alert.timestamp.desc()) + ) + return session.exec(query).all() + def get_alert_by_fingerprint_and_event_id( tenant_id: str, fingerprint: str, event_id: str @@ -3215,11 +3227,11 @@ def enrich_alerts_with_incidents( ).all() incidents_per_alert = defaultdict(list) - for alert_id, incident in alert_incidents: - incidents_per_alert[alert_id].append(incident) + for fingerprint, incident in alert_incidents: + incidents_per_alert[fingerprint].append(incident) for alert in alerts: - alert._incidents = incidents_per_alert[incident.id] + alert._incidents = incidents_per_alert[alert.fingerprint] return alerts diff --git a/keep/api/core/demo_mode.py b/keep/api/core/demo_mode.py index e438355b2..17f4b8b28 100644 --- a/keep/api/core/demo_mode.py +++ b/keep/api/core/demo_mode.py @@ -417,8 +417,11 @@ def perform_demo_ai(keep_api_key, keep_api_url): def simulate_alerts(*args, **kwargs): - asyncio.create_task(simulate_alerts_worker(0, keep_api_key, 0)) - asyncio.run(simulate_alerts_async(*args, **kwargs)) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.create_task(simulate_alerts_worker(0, kwargs.get("keep_api_key"), 0)) + loop.create_task(simulate_alerts_async(*args, **kwargs)) + loop.run_forever() async def simulate_alerts_async( diff --git a/keep/api/utils/enrichment_helpers.py b/keep/api/utils/enrichment_helpers.py index 2b0ad2b62..85897ad9e 100644 --- a/keep/api/utils/enrichment_helpers.py +++ b/keep/api/utils/enrichment_helpers.py @@ -80,10 +80,10 @@ def calculated_start_firing_time( def convert_db_alerts_to_dto_alerts( - alerts: list[Alert | tuple[Alert, LastAlertToIncident]], - with_incidents: bool = False, - session: Optional[Session] = None, - ) -> list[AlertDto | AlertWithIncidentLinkMetadataDto]: + alerts: list[Alert | tuple[Alert, LastAlertToIncident]], + with_incidents: bool = False, + session: Optional[Session] = None, +) -> list[AlertDto | AlertWithIncidentLinkMetadataDto]: """ Enriches the alerts with the enrichment data. @@ -109,8 +109,8 @@ def convert_db_alerts_to_dto_alerts( if alert.alert_enrichment: alert.event.update(alert.alert_enrichment.enrichments) if with_incidents: - if alert.incidents: - alert.event["incident"] = ",".join(str(incident.id) for incident in alert.incidents) + if alert._incidents: + alert.event["incident"] = ",".join(str(incident.id) for incident in alert._incidents) try: if alert_to_incident is not None: alert_dto = AlertWithIncidentLinkMetadataDto.from_db_instance(alert, alert_to_incident) diff --git a/keep/providers/grafana_provider/grafana_provider.py b/keep/providers/grafana_provider/grafana_provider.py index d26bf88d7..e3f958230 100644 --- a/keep/providers/grafana_provider/grafana_provider.py +++ b/keep/providers/grafana_provider/grafana_provider.py @@ -509,6 +509,8 @@ def simulate_alert(cls, **kwargs) -> dict: if not alert_type: alert_type = random.choice(list(ALERTS.keys())) + to_wrap_with_provider_type = kwargs.get("to_wrap_with_provider_type") + if "payload" in ALERTS[alert_type]: alert_payload = ALERTS[alert_type]["payload"] else: @@ -552,11 +554,15 @@ def simulate_alert(cls, **kwargs) -> dict: fingerprint = hashlib.md5(fingerprint_src.encode()).hexdigest() alert_payload["fingerprint"] = fingerprint - return { - "alerts": [alert_payload], - "severity": alert_payload.get("labels", {}).get("severity"), - "title": alert_type, - } + + final_payload = { + "alerts": [alert_payload], + "severity": alert_payload.get("labels", {}).get("severity"), + "title": alert_type, + } + if to_wrap_with_provider_type: + return {"keep_source_type": "grafana", "event": final_payload} + return final_payload if __name__ == "__main__": diff --git a/keep/providers/prometheus_provider/prometheus_provider.py b/keep/providers/prometheus_provider/prometheus_provider.py index cb3c392a7..b4dd4927e 100644 --- a/keep/providers/prometheus_provider/prometheus_provider.py +++ b/keep/providers/prometheus_provider/prometheus_provider.py @@ -233,6 +233,8 @@ def simulate_alert(cls, **kwargs) -> dict: if not alert_type: alert_type = random.choice(list(ALERTS.keys())) + to_wrap_with_provider_type = kwargs.get("to_wrap_with_provider_type") + alert_payload = ALERTS[alert_type]["payload"] alert_parameters = ALERTS[alert_type].get("parameters", []) # now generate some random data @@ -267,6 +269,9 @@ def simulate_alert(cls, **kwargs) -> dict: fingerprint_src = json.dumps(alert_payload["labels"], sort_keys=True) fingerprint = hashlib.md5(fingerprint_src.encode()).hexdigest() alert_payload["fingerprint"] = fingerprint + if to_wrap_with_provider_type: + return {"keep_source_type": "prometheus", "event": alert_payload} + return alert_payload diff --git a/keep/providers/vectordev_provider/vectordev_provider.py b/keep/providers/vectordev_provider/vectordev_provider.py index c3d6f8a9d..aeaa17aca 100644 --- a/keep/providers/vectordev_provider/vectordev_provider.py +++ b/keep/providers/vectordev_provider/vectordev_provider.py @@ -1,13 +1,20 @@ import dataclasses + +import random import json import pydantic +import logging from keep.api.models.alert import AlertDto from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig +from keep.api.models.alert import AlertDto +from keep.providers.providers_factory import ProvidersFactory +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) @pydantic.dataclasses.dataclass class VectordevProviderAuthConfig: @@ -21,6 +28,11 @@ class VectordevProvider(BaseProvider): PROVIDER_CATEGORY = ["Monitoring", "Developer Tools"] PROVIDER_COMING_SOON = True + # Mapping from vector sources to keep providers + SOURCE_TO_PROVIDER_MAP = { + "prometheus": "prometheus", + } + def __init__( self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): @@ -32,33 +44,42 @@ def validate_config(self): ) def _format_alert( - event: list[dict], provider_instance: "BaseProvider" = None + event: dict, provider_instance: "BaseProvider" = None ) -> AlertDto | list[AlertDto]: events = [] - # event is a list of events - for e in event: - event_json = None - try: - event_json = json.loads(e.get("message")) - except json.JSONDecodeError: - pass - - events.append( + if isinstance(event, list): + events = event + else: + events = [event] + alert_dtos = [] + for e in events: + if "keep_source_type" in e and e["keep_source_type"] in VectordevProvider.SOURCE_TO_PROVIDER_MAP: + provider_class = ProvidersFactory.get_provider_class(VectordevProvider.SOURCE_TO_PROVIDER_MAP[e["keep_source_type"]]) + alert_dtos.extend(provider_class._format_alert(e["message"],provider_instance)) + else: + message_str = json.dumps(e.get("message")) + alert_dtos.append( AlertDto( name="", - host=e.get("host"), - message=e.get("message"), - description=e.get("message"), + message=message_str, + description=message_str, lastReceived=e.get("timestamp"), source_type=e.get("source_type"), source=["vectordev"], - original_event=event_json, + original_event=e.get("message"), ) ) - return events + return alert_dtos def dispose(self): """ No need to dispose of anything, so just do nothing. """ pass + + @classmethod + def simulate_alert(cls, **kwargs) -> dict: + provider = random.choice(list(VectordevProvider.SOURCE_TO_PROVIDER_MAP.values())) + provider_class = ProvidersFactory.get_provider_class(provider) + return provider_class.simulate_alert(to_wrap_with_provider_type=True) + diff --git a/keep/secretmanager/kubernetessecretmanager.py b/keep/secretmanager/kubernetessecretmanager.py index 90bf30830..88e31047f 100644 --- a/keep/secretmanager/kubernetessecretmanager.py +++ b/keep/secretmanager/kubernetessecretmanager.py @@ -4,7 +4,7 @@ import kubernetes.client import kubernetes.config -from kubernetes.client.rest import ApiException +from kubernetes.client.exceptions import ApiException from keep.secretmanager.secretmanager import BaseSecretManager @@ -32,7 +32,7 @@ def write_secret(self, secret_name: str, secret_value: str) -> None: ApiException: If an error occurs while writing the secret. """ # k8s requirements: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names - secret_name = secret_name.replace("_", "-") + secret_name = secret_name.replace("_", "-").lower() self.logger.info("Writing secret", extra={"secret_name": secret_name}) body = kubernetes.client.V1Secret( @@ -70,7 +70,7 @@ def write_secret(self, secret_name: str, secret_value: str) -> None: def read_secret(self, secret_name: str, is_json: bool = False) -> str | dict: # k8s requirements: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names - secret_name = secret_name.replace("_", "-") + secret_name = secret_name.replace("_", "-").lower() self.logger.info("Getting secret", extra={"secret_name": secret_name}) try: response = self.api.read_namespaced_secret( @@ -91,6 +91,7 @@ def read_secret(self, secret_name: str, is_json: bool = False) -> str | dict: raise def delete_secret(self, secret_name: str) -> None: + secret_name = secret_name.replace("_", "-").lower() self.logger.info("Deleting secret", extra={"secret_name": secret_name}) try: self.api.delete_namespaced_secret( diff --git a/poetry.lock b/poetry.lock index 9d66c1c4e..8eaff4e8f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3952,6 +3952,24 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.21.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-asyncio-0.21.0.tar.gz", hash = "sha256:2b38a496aef56f56b0e87557ec313e11e1ab9276fc3863f6a7be0f1d0e415e1b"}, + {file = "pytest_asyncio-0.21.0-py3-none-any.whl", hash = "sha256:f2b3366b7cd501a4056858bd39349d5af19742aed2d81660b7998b6341c7eb9c"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"] + [[package]] name = "pytest-docker" version = "2.2.0" @@ -5329,4 +5347,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "a3319e110409281b9a3fa4c09300681fb82ef73bd05bbe1f93081a669b7635d7" +content-hash = "00bee7e27325f82b4779b56a87f600af1706019e765e6aa9fcdf3dd2d000dd06" diff --git a/pyproject.toml b/pyproject.toml index 2606373cf..080fd7434 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "keep" -version = "0.32.1" +version = "0.32.2" description = "Alerting. for developers, by developers." authors = ["Keep Alerting LTD"] packages = [{include = "keep"}] @@ -90,6 +90,7 @@ psycopg-binary = "^3.2.3" psycopg = "^3.2.3" prometheus-client = "^0.21.1" psycopg2-binary = "^2.9.10" +pytest-asyncio = "0.21.0" [tool.poetry.group.dev.dependencies] pre-commit = "^3.0.4" pre-commit-hooks = "^4.4.0" diff --git a/scripts/simulate_alerts.py b/scripts/simulate_alerts.py index 29a38ca4b..45db1310b 100644 --- a/scripts/simulate_alerts.py +++ b/scripts/simulate_alerts.py @@ -36,7 +36,7 @@ async def main(): SLEEP_INTERVAL = float( os.environ.get("SLEEP_INTERVAL", default_sleep_interval) ) - keep_api_key = os.environ.get("KEEP_API_KEY") + keep_api_key = os.environ.get("KEEP_API_KEY") or "keepappkey" keep_api_url = os.environ.get("KEEP_API_URL") or "http://localhost:8080" for i in range(args.workers): diff --git a/tests/test_incidents.py b/tests/test_incidents.py index 87f470836..2184649e2 100644 --- a/tests/test_incidents.py +++ b/tests/test_incidents.py @@ -1,8 +1,12 @@ from datetime import datetime +from unittest.mock import patch + +from fastapi import HTTPException from itertools import cycle +from uuid import uuid4 import pytest -from sqlalchemy import distinct, func, desc +from sqlalchemy import distinct, func, desc, and_ from keep.api.bl.incidents_bl import IncidentBl from keep.api.core.db import ( @@ -23,8 +27,10 @@ AlertStatus, IncidentSeverity, IncidentStatus, + IncidentDtoIn, + IncidentDto, ) -from keep.api.models.db.alert import Alert, LastAlertToIncident +from keep.api.models.db.alert import Alert, LastAlertToIncident, Incident, NULL_FOR_DELETED_AT from keep.api.models.db.tenant import Tenant from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts from tests.fixtures.client import client, test_app # noqa @@ -1017,3 +1023,391 @@ def test_cross_tenant_exposure_issue_2768(db_session, create_alert): assert incident_tenant_2.alerts_count == 1 assert total_incident_tenant_2_alerts == 1 assert len(incident_tenant_2_alerts) == 1 + + +class PusherMock: + + def __init__(self): + self.triggers = [] + + def trigger(self, channel, event_name, data): + self.triggers.append((channel, event_name, data)) + +class WorkflowManagerMock: + + def __init__(self): + self.events = [] + + def get_instance(self): + return self + + def insert_incident(self, tenant_id, incident_dto, action): + self.events.append((tenant_id, incident_dto, action)) + + +class ElasticClientMock: + + def __init__(self): + self.alerts = [] + self.tenant_id = None + self.enabled = True + + def __call__(self, tenant_id): + self.tenant_id = tenant_id + return self + + def index_alerts(self, alerts): + self.alerts.append((self.tenant_id, alerts)) + + +def test_incident_bl_create_incident(db_session): + + pusher = PusherMock() + workflow_manager = WorkflowManagerMock() + + with patch("keep.api.bl.incidents_bl.WorkflowManager", workflow_manager): + incident_bl = IncidentBl(tenant_id=SINGLE_TENANT_UUID, session=db_session, pusher_client=pusher) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 0 + + incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Incident name", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto = incident_bl.create_incident(incident_dto_in, generated_from_ai=False) + assert isinstance(incident_dto, IncidentDto) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + assert incident_dto.is_confirmed is True + assert incident_dto.is_predicted is False + + incident = db_session.query(Incident).get(incident_dto.id) + assert incident.user_generated_name == "Incident name" + assert incident.status == "firing" + assert incident.user_summary == "Keep: Incident description" + assert incident.is_confirmed is True + assert incident.is_predicted is False + + # Check pusher + + assert len(pusher.triggers) == 1 + channel, event_name, data = pusher.triggers[0] + assert channel == f"private-{SINGLE_TENANT_UUID}" + assert event_name == "incident-change" + assert isinstance(data, dict) + assert "incident_id" in data + assert data["incident_id"] is None # For new incidents we don't send incident.id + + # Check workflow manager + assert len(workflow_manager.events) == 1 + wf_tenant_id, wf_incident_dto, wf_action = workflow_manager.events[0] + assert wf_tenant_id == SINGLE_TENANT_UUID + assert wf_incident_dto.id == incident_dto.id + assert wf_action == "created" + + incident_dto_ai = incident_bl.create_incident(incident_dto_in, generated_from_ai=True) + assert isinstance(incident_dto_ai, IncidentDto) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 2 + + assert incident_dto_ai.is_confirmed is True + assert incident_dto_ai.is_predicted is False + + +def test_incident_bl_update_incident(db_session): + pusher = PusherMock() + workflow_manager = WorkflowManagerMock() + + with patch("keep.api.bl.incidents_bl.WorkflowManager", workflow_manager): + incident_bl = IncidentBl( + tenant_id=SINGLE_TENANT_UUID, + session=db_session, + pusher_client=pusher + ) + incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Incident name", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto = incident_bl.create_incident(incident_dto_in) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + new_incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Not an incident", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto_update = incident_bl.update_incident(incident_dto.id, new_incident_dto_in, False) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + assert incident_dto_update.name == "Not an incident" + + incident = db_session.query(Incident).get(incident_dto.id) + assert incident.user_generated_name == "Not an incident" + assert incident.status == "firing" + assert incident.user_summary == "Keep: Incident description" + + # Check error if no incident found + with pytest.raises(HTTPException, match="Incident not found"): + incident_bl.update_incident(uuid4(), incident_dto_update, False) + + # Check workflowmanager + assert len(workflow_manager.events) == 2 + wf_tenant_id, wf_incident_dto, wf_action = workflow_manager.events[-1] + assert wf_tenant_id == SINGLE_TENANT_UUID + assert wf_incident_dto.id == incident_dto.id + assert wf_action == "updated" + + # Check pusher + assert len(pusher.triggers) == 2 # 1 for create, 1 for update + channel, event_name, data = pusher.triggers[-1] + assert channel == f"private-{SINGLE_TENANT_UUID}" + assert event_name == "incident-change" + assert isinstance(data, dict) + assert "incident_id" in data + assert data["incident_id"] == str(incident_dto.id) + + +def test_incident_bl_delete_incident(db_session): + pusher = PusherMock() + workflow_manager = WorkflowManagerMock() + + with patch("keep.api.bl.incidents_bl.WorkflowManager", workflow_manager): + incident_bl = IncidentBl( + tenant_id=SINGLE_TENANT_UUID, + session=db_session, + pusher_client=pusher + ) + # Check error if no incident found + with pytest.raises(HTTPException, match="Incident not found"): + incident_bl.delete_incident(uuid4()) + + incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Incident name", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto = incident_bl.create_incident(incident_dto_in) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + incident_bl.delete_incident(incident_dto.id) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 0 + + # Check pusher + assert len(pusher.triggers) == 2 # Created, deleted + + channel, event_name, data = pusher.triggers[-1] + assert channel == f"private-{SINGLE_TENANT_UUID}" + assert event_name == "incident-change" + assert isinstance(data, dict) + assert "incident_id" in data + assert data["incident_id"] is None + + # Check workflow manager + assert len(workflow_manager.events) == 2 # Created, deleted + wf_tenant_id, wf_incident_dto, wf_action = workflow_manager.events[-1] + assert wf_tenant_id == SINGLE_TENANT_UUID + assert wf_incident_dto.id == incident_dto.id + assert wf_action == "deleted" + + +@pytest.mark.asyncio +async def test_incident_bl_add_alert_to_incident(db_session, create_alert): + pusher = PusherMock() + workflow_manager = WorkflowManagerMock() + elastic_client = ElasticClientMock() + + with patch("keep.api.bl.incidents_bl.WorkflowManager", workflow_manager): + with patch("keep.api.bl.incidents_bl.ElasticClient", elastic_client): + incident_bl = IncidentBl( + tenant_id=SINGLE_TENANT_UUID, + session=db_session, + pusher_client=pusher + ) + incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Incident name", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto = incident_bl.create_incident(incident_dto_in) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + with pytest.raises(HTTPException, match="Incident not found"): + await incident_bl.add_alerts_to_incident(uuid4(), [], False) + + create_alert( + "alert-test-1", + AlertStatus("firing"), + datetime.utcnow(), + {}, + ) + + await incident_bl.add_alerts_to_incident( + incident_dto.id, + ["alert-test-1"], + False + ) + + alerts_to_incident_count = ( + db_session + .query(LastAlertToIncident) + .where( + LastAlertToIncident.incident_id == incident_dto.id + ) + .count() + ) + assert alerts_to_incident_count == 1 + + alert_to_incident = ( + db_session + .query(LastAlertToIncident) + .where( + LastAlertToIncident.fingerprint == "alert-test-1" + ) + .first() + ) + assert alert_to_incident is not None + + # Check pusher + assert len(pusher.triggers) == 2 # Created, update + + channel, event_name, data = pusher.triggers[-1] + assert channel == f"private-{SINGLE_TENANT_UUID}" + assert event_name == "incident-change" + assert isinstance(data, dict) + assert "incident_id" in data + assert data["incident_id"] == str(incident_dto.id) + + # Check workflow manager + assert len(workflow_manager.events) == 2 # Created, update + wf_tenant_id, wf_incident_dto, wf_action = workflow_manager.events[-1] + assert wf_tenant_id == SINGLE_TENANT_UUID + assert wf_incident_dto.id == incident_dto.id + assert wf_action == "updated" + + # Check elastic + assert len(elastic_client.alerts) == 1 + el_tenant_id, el_alerts = elastic_client.alerts[-1] + assert len(el_alerts) == 1 + assert el_tenant_id == SINGLE_TENANT_UUID + assert el_alerts[-1].fingerprint == "alert-test-1" + assert el_alerts[-1].incident == str(incident_dto.id) + + +@pytest.mark.asyncio +async def test_incident_bl_delete_alerts_from_incident(db_session, create_alert): + pusher = PusherMock() + workflow_manager = WorkflowManagerMock() + elastic_client = ElasticClientMock() + + with patch("keep.api.bl.incidents_bl.WorkflowManager", workflow_manager): + with patch("keep.api.bl.incidents_bl.ElasticClient", elastic_client): + incident_bl = IncidentBl( + tenant_id=SINGLE_TENANT_UUID, + session=db_session, + pusher_client=pusher + ) + incident_dto_in = IncidentDtoIn(**{ + "user_generated_name": "Incident name", + "user_summary": "Keep: Incident description", + "status": "firing", + }) + + incident_dto = incident_bl.create_incident(incident_dto_in) + + incidents_count = db_session.query(Incident).count() + assert incidents_count == 1 + + with pytest.raises(HTTPException, match="Incident not found"): + incident_bl.delete_alerts_from_incident(uuid4(), []) + + create_alert( + "alert-test-1", + AlertStatus("firing"), + datetime.utcnow(), + {}, + ) + + await incident_bl.add_alerts_to_incident( + incident_dto.id, + ["alert-test-1"], + False + ) + + alerts_to_incident_count = ( + db_session + .query(LastAlertToIncident) + .where( + and_( + LastAlertToIncident.incident_id == incident_dto.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT + ) + ) + .count() + ) + assert alerts_to_incident_count == 1 + + incident_bl.delete_alerts_from_incident( + incident_dto.id, + ["alert-test-1"], + ) + + alerts_to_incident_count = ( + db_session + .query(LastAlertToIncident) + .where( + and_( + LastAlertToIncident.incident_id == incident_dto.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT + ) + ) + .count() + ) + assert alerts_to_incident_count == 0 + + # Check pusher + # Created, updated (added event), updated(deleted event) + assert len(pusher.triggers) == 3 + + channel, event_name, data = pusher.triggers[-1] + assert channel == f"private-{SINGLE_TENANT_UUID}" + assert event_name == "incident-change" + assert isinstance(data, dict) + assert "incident_id" in data + assert data["incident_id"] == str(incident_dto.id) + + # Check workflow manager + # Created, updated (added event), updated(deleted event) + assert len(workflow_manager.events) == 3 + wf_tenant_id, wf_incident_dto, wf_action = workflow_manager.events[-1] + assert wf_tenant_id == SINGLE_TENANT_UUID + assert wf_incident_dto.id == incident_dto.id + assert wf_action == "updated" + + # Check elastic + assert len(elastic_client.alerts) == 2 + el_tenant_id, el_alerts = elastic_client.alerts[-1] + assert len(el_alerts) == 1 + assert el_tenant_id == SINGLE_TENANT_UUID + assert el_alerts[-1].fingerprint == "alert-test-1" + assert el_alerts[-1].incident is None