diff --git a/keep/api/core/db.py b/keep/api/core/db.py index 25c447472..56f5aaaef 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -3524,6 +3524,7 @@ def add_alerts_to_incident( alert_ids: List[UUID], is_created_by_ai: bool = False, session: Optional[Session] = None, + override_count: bool = False, ) -> Optional[Incident]: logger.info( f"Adding alerts to incident {incident.id} in database, total {len(alert_ids)} alerts", @@ -3585,8 +3586,10 @@ def add_alerts_to_incident( if incident.alerts_count else alerts_data_for_incident["max_severity"].order ) - incident.alerts_count += alerts_data_for_incident["count"] - + if not override_count: + incident.alerts_count += alerts_data_for_incident["count"] + else: + incident.alerts_count = alerts_data_for_incident["count"] alert_to_incident_entries = [ AlertToIncident( alert_id=alert_id, diff --git a/keep/api/tasks/process_event_task.py b/keep/api/tasks/process_event_task.py index 167b6932c..f4b35ea91 100644 --- a/keep/api/tasks/process_event_task.py +++ b/keep/api/tasks/process_event_task.py @@ -414,7 +414,7 @@ def __handle_formatted_events( if notify_client and incidents: pusher_client = get_pusher_client() if not pusher_client: - return + pass try: pusher_client.trigger( f"private-{tenant_id}", @@ -483,6 +483,7 @@ def __handle_formatted_events( "tenant_id": tenant_id, }, ) + return enriched_formatted_events def process_event( @@ -498,7 +499,7 @@ def process_event( ), # the event to process, either plain (generic) or from a specific provider notify_client: bool = True, timestamp_forced: datetime.datetime | None = None, -): +) -> list[Alert]: extra_dict = { "tenant_id": tenant_id, "provider_type": provider_type, @@ -557,7 +558,7 @@ def process_event( raw_event = [raw_event] __internal_prepartion(event, fingerprint, api_key_name) - __handle_formatted_events( + return __handle_formatted_events( tenant_id, provider_type, session, diff --git a/keep/api/tasks/process_incident_task.py b/keep/api/tasks/process_incident_task.py index 688dfbae9..e189526b0 100644 --- a/keep/api/tasks/process_incident_task.py +++ b/keep/api/tasks/process_incident_task.py @@ -3,12 +3,15 @@ from arq import Retry from keep.api.core.db import ( + add_alerts_to_incident, create_incident_from_dto, get_incident_by_fingerprint, get_incident_by_id, update_incident_from_dto_by_id, ) +from keep.api.core.dependencies import get_pusher_client from keep.api.models.alert import IncidentDto +from keep.api.tasks.process_event_task import process_event TIMES_TO_RETRY_JOB = 5 # the number of times to retry the job in case of failure logger = logging.getLogger(__name__) @@ -60,7 +63,7 @@ def process_incident( f"Updating incident: {incident.id}", extra={**extra, "fingerprint": incident.fingerprint}, ) - update_incident_from_dto_by_id( + incident_from_db = update_incident_from_dto_by_id( tenant_id=tenant_id, incident_id=incident_from_db.id, updated_incident_dto=incident, @@ -74,7 +77,7 @@ def process_incident( f"Creating incident: {incident.id}", extra={**extra, "fingerprint": incident.fingerprint}, ) - create_incident_from_dto( + incident_from_db = create_incident_from_dto( tenant_id=tenant_id, incident_dto=incident, ) @@ -82,7 +85,54 @@ def process_incident( f"Created incident: {incident.id}", extra={**extra, "fingerprint": incident.fingerprint}, ) + + try: + if incident.alerts: + logger.info("Adding incident alerts", extra=extra) + processed_alerts = process_event( + {}, + tenant_id, + provider_type, + provider_id, + None, + None, + trace_id, + incident.alerts, + ) + if processed_alerts: + add_alerts_to_incident( + tenant_id, + incident_from_db, + [ + processed_alert.event_id + for processed_alert in processed_alerts + ], + # Because the incident was created with the alerts count, we need to override it + # otherwise it will be the sum of the previous count + the newly attached alerts count + override_count=True, + ) + logger.info("Added incident alerts", extra=extra) + else: + logger.info( + "No alerts to add to incident, probably deduplicated", + extra=extra, + ) + except Exception: + logger.exception("Error adding incident alerts", extra=extra) logger.info("Processed incident", extra=extra) + + pusher_client = get_pusher_client() + if not pusher_client: + pass + try: + pusher_client.trigger( + f"private-{tenant_id}", + "incident-change", + {}, + ) + except Exception: + logger.exception("Failed to push incidents to the client") + logger.info("Processed all incidents", extra=extra) except Exception: logger.exception( diff --git a/keep/providers/pagerduty_provider/pagerduty_provider.py b/keep/providers/pagerduty_provider/pagerduty_provider.py index e45793367..e657cdde5 100644 --- a/keep/providers/pagerduty_provider/pagerduty_provider.py +++ b/keep/providers/pagerduty_provider/pagerduty_provider.py @@ -11,6 +11,7 @@ import requests from keep.api.models.alert import ( + AlertDto, AlertSeverity, AlertStatus, IncidentDto, @@ -490,7 +491,38 @@ def _notify( ) def _query(self, incident_id: str = None): - return self.__get_all_incidents_or_alerts(incident_id=incident_id) + incidents = self.__get_all_incidents_or_alerts() + return ( + next( + [incident for incident in incidents if incident.id == incident_id], + None, + ) + if incident_id + else incidents + ) + + def _format_alert( + self, event: dict, provider_instance: "BaseProvider" = None + ) -> AlertDto: + status = self.ALERT_STATUS_MAP.get(event.get("status", "firing")) + severity = self.ALERT_SEVERITIES_MAP.get(event.get("severity", "info")) + source = ["pagerduty"] + origin = event.get("body", {}).get("cef_details", {}).get("source_origin") + fingerprint = event.get("alert_key", event.get("id")) + if origin: + source.append(origin) + return AlertDto( + id=event.get("id"), + name=event.get("summary"), + url=event.get("html_url"), + service=event.get("service", {}).get("name"), + lastReceived=event.get("created_at"), + status=status, + severity=severity, + source=source, + original_alert=event, + fingerprint=fingerprint, + ) def __get_all_incidents_or_alerts(self, incident_id: str = None): self.logger.info( @@ -504,7 +536,7 @@ def __get_all_incidents_or_alerts(self, incident_id: str = None): include = [] resource = "incidents" if incident_id is not None: - url += f"/{incident_id}" + url += f"/{incident_id}/alerts" include = ["teams", "services"] resource = "alerts" response = requests.get( @@ -609,26 +641,44 @@ def pull_topology(self) -> list[TopologyServiceInDto]: return list(service_topology.values()) def _get_incidents(self) -> list[IncidentDto]: - incidents = self.__get_all_incidents_or_alerts() - incidents = [ - self._format_incident({"event": {"data": incident}}) - for incident in incidents - ] + raw_incidents = self.__get_all_incidents_or_alerts() + incidents = [] + for incident in raw_incidents: + incident_dto = self._format_incident({"event": {"data": incident}}) + incident_alerts = self.__get_all_incidents_or_alerts( + incident_id=incident_dto.fingerprint + ) + incident_alerts = [self._format_alert(alert) for alert in incident_alerts] + incident_dto._alerts = incident_alerts + incidents.append(incident_dto) return incidents + @staticmethod + def _get_incident_id(incident_id: str) -> str: + """ + Create a UUID from the incident id. + + Args: + incident_id (str): The original incident id + + Returns: + str: The UUID + """ + md5 = hashlib.md5() + md5.update(incident_id.encode("utf-8")) + return uuid.UUID(md5.hexdigest()) + @staticmethod def _format_incident( event: dict, provider_instance: "BaseProvider" = None ) -> IncidentDto | list[IncidentDto]: - # Creating an uuid from incident id. - m = hashlib.md5() event = event["event"]["data"] # This will be the same for the same incident - event_id = event.get("id", "ping") - m.update(event_id.encode("utf-8")) - incident_id = uuid.UUID(m.hexdigest()) + original_incident_id = event.get("id", "ping") + + incident_id = PagerdutyProvider._get_incident_id(original_incident_id) status = PagerdutyProvider.INCIDENT_STATUS_MAP.get( event.get("status", "firing"), IncidentStatus.FIRING @@ -648,7 +698,7 @@ def _format_incident( return IncidentDto( id=incident_id, creation_time=created_at, - user_generated_name=f'PD-{event.get("title", "unknown")}-{event_id}', + user_generated_name=f'PD-{event.get("title", "unknown")}-{original_incident_id}', status=status, severity=severity, alert_sources=["pagerduty"], @@ -657,7 +707,7 @@ def _format_incident( is_predicted=False, is_confirmed=True, # This is the reference to the incident in PagerDuty - fingerprint=event_id, + fingerprint=original_incident_id, )