From 01d29ecb89505fe66fc2575d9f44c154d0c8d953 Mon Sep 17 00:00:00 2001 From: Vladimir Filonov Date: Mon, 2 Dec 2024 19:59:42 +0400 Subject: [PATCH] feat: change relation between alerts and incidents to work with fingerprints instead of alert ids (#2473) Signed-off-by: Vladimir Filonov Co-authored-by: Kirill Chernakov --- .../alerts/alert-associate-incident-modal.tsx | 2 +- .../[id]/alerts/incident-alert-menu.tsx | 2 +- keep-ui/entities/incidents/model/models.ts | 1 + keep/api/bl/incidents_bl.py | 26 +- keep/api/core/db.py | 423 ++++++++++-------- keep/api/models/db/alert.py | 85 ++-- .../versions/2024-12-01-16-40_3ad5308e7200.py | 7 +- .../versions/2024-12-02-13-36_bdae8684d0b4.py | 161 +++++++ keep/api/routes/incidents.py | 10 +- keep/api/routes/workflows.py | 3 +- keep/api/tasks/process_event_task.py | 4 + keep/api/utils/enrichment_helpers.py | 89 ++-- keep/rulesengine/rulesengine.py | 4 +- tests/conftest.py | 28 +- tests/test_incidents.py | 183 +++++--- tests/test_metrics.py | 2 +- tests/test_rules_engine.py | 28 +- 17 files changed, 686 insertions(+), 372 deletions(-) create mode 100644 keep/api/models/db/migrations/versions/2024-12-02-13-36_bdae8684d0b4.py diff --git a/keep-ui/app/(keep)/alerts/alert-associate-incident-modal.tsx b/keep-ui/app/(keep)/alerts/alert-associate-incident-modal.tsx index e8b85f1ee..fb1d8a8bc 100644 --- a/keep-ui/app/(keep)/alerts/alert-associate-incident-modal.tsx +++ b/keep-ui/app/(keep)/alerts/alert-associate-incident-modal.tsx @@ -42,7 +42,7 @@ const AlertAssociateIncidentModal = ({ try { const response = await api.post( `/incidents/${incidentId}/alerts`, - alerts.map(({ event_id }) => event_id) + alerts.map(({ fingerprint }) => fingerprint) ); handleSuccess(); await mutate(); diff --git a/keep-ui/app/(keep)/incidents/[id]/alerts/incident-alert-menu.tsx b/keep-ui/app/(keep)/incidents/[id]/alerts/incident-alert-menu.tsx index fb3ab512a..5ef346b8c 100644 --- a/keep-ui/app/(keep)/incidents/[id]/alerts/incident-alert-menu.tsx +++ b/keep-ui/app/(keep)/incidents/[id]/alerts/incident-alert-menu.tsx @@ -18,7 +18,7 @@ export default function IncidentAlertMenu({ incidentId, alert }: Props) { if (confirm("Are you sure you want to remove correlation?")) { api .delete(`/incidents/${incidentId}/alerts`, { - body: [alert.event_id], + body: [alert.fingerprint], }) .then(() => { toast.success("Alert removed from incident successfully", { diff --git a/keep-ui/entities/incidents/model/models.ts b/keep-ui/entities/incidents/model/models.ts index 0f8e8cfb2..bdc311549 100644 --- a/keep-ui/entities/incidents/model/models.ts +++ b/keep-ui/entities/incidents/model/models.ts @@ -31,6 +31,7 @@ export interface IncidentDto { merged_into_incident_id: string; merged_by: string; merged_at: Date; + fingerprint: string; } export interface IncidentCandidateDto { diff --git a/keep/api/bl/incidents_bl.py b/keep/api/bl/incidents_bl.py index a0bd58584..5bc3d9216 100644 --- a/keep/api/bl/incidents_bl.py +++ b/keep/api/bl/incidents_bl.py @@ -93,51 +93,51 @@ def create_incident( return new_incident_dto async def add_alerts_to_incident( - self, incident_id: UUID, alert_ids: List[UUID], is_created_by_ai: bool = False + self, incident_id: UUID, alert_fingerprints: List[str], is_created_by_ai: bool = False ) -> None: self.logger.info( "Adding alerts to incident", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) incident = get_incident_by_id(tenant_id=self.tenant_id, incident_id=incident_id) if not incident: raise HTTPException(status_code=404, detail="Incident not found") - add_alerts_to_incident_by_incident_id(self.tenant_id, incident_id, alert_ids, is_created_by_ai) + add_alerts_to_incident_by_incident_id(self.tenant_id, incident_id, alert_fingerprints, is_created_by_ai) self.logger.info( "Alerts added to incident", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) - self.__update_elastic(incident_id, alert_ids) + self.__update_elastic(incident_id, alert_fingerprints) self.logger.info( "Alerts pushed to elastic", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) self.__update_client_on_incident_change(incident_id) self.logger.info( "Client updated on incident change", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) incident_dto = IncidentDto.from_db_incident(incident) self.__run_workflows(incident_dto, "updated") self.logger.info( "Workflows run on incident", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) await self.__generate_summary(incident_id, incident) self.logger.info( "Summary generated", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) - def __update_elastic(self, incident_id: UUID, alert_ids: List[UUID]): + def __update_elastic(self, incident_id: UUID, alert_fingerprints: List[str]): try: elastic_client = ElasticClient(self.tenant_id) if elastic_client.enabled: db_alerts, _ = get_incident_alerts_by_incident_id( tenant_id=self.tenant_id, incident_id=incident_id, - limit=len(alert_ids), + limit=len(alert_fingerprints), ) enriched_alerts_dto = convert_db_alerts_to_dto_alerts( db_alerts, with_incidents=True @@ -203,7 +203,7 @@ async def __generate_summary(self, incident_id: UUID, incident: Incident): ) def delete_alerts_from_incident( - self, incident_id: UUID, alert_ids: List[UUID] + self, incident_id: UUID, alert_fingerprints: List[str] ) -> None: self.logger.info( "Fetching incident", @@ -216,7 +216,7 @@ def delete_alerts_from_incident( if not incident: raise HTTPException(status_code=404, detail="Incident not found") - remove_alerts_to_incident_by_incident_id(self.tenant_id, incident_id, alert_ids) + remove_alerts_to_incident_by_incident_id(self.tenant_id, incident_id, alert_fingerprints) def delete_incident(self, incident_id: UUID) -> None: self.logger.info( diff --git a/keep/api/core/db.py b/keep/api/core/db.py index cc9294f26..c600c1214 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -16,6 +16,7 @@ from uuid import uuid4 import validators +from dateutil.tz import tz from dotenv import find_dotenv, load_dotenv from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from sqlalchemy import ( @@ -44,8 +45,8 @@ # This import is required to create the tables from keep.api.models.ai_external import ( - ExternalAIConfigAndMetadata, - ExternalAIConfigAndMetadataDto, + ExternalAIConfigAndMetadata, + ExternalAIConfigAndMetadataDto, ) from keep.api.models.alert import ( AlertStatus, @@ -1302,13 +1303,13 @@ def get_last_alerts( # SQLite version - using JSON incidents_subquery = ( session.query( - AlertToIncident.alert_id, + LastAlertToIncident.fingerprint, func.json_group_array( - cast(AlertToIncident.incident_id, String) + cast(LastAlertToIncident.incident_id, String) ).label("incidents"), ) - .filter(AlertToIncident.deleted_at == NULL_FOR_DELETED_AT) - .group_by(AlertToIncident.alert_id) + .filter(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT) + .group_by(LastAlertToIncident.fingerprint) .subquery() ) @@ -1316,13 +1317,13 @@ def get_last_alerts( # MySQL version - using GROUP_CONCAT incidents_subquery = ( session.query( - AlertToIncident.alert_id, + LastAlertToIncident.fingerprint, func.group_concat( - cast(AlertToIncident.incident_id, String) + cast(LastAlertToIncident.incident_id, String) ).label("incidents"), ) - .filter(AlertToIncident.deleted_at == NULL_FOR_DELETED_AT) - .group_by(AlertToIncident.alert_id) + .filter(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT) + .group_by(LastAlertToIncident.fingerprint) .subquery() ) @@ -1330,14 +1331,14 @@ def get_last_alerts( # PostgreSQL version - using string_agg incidents_subquery = ( session.query( - AlertToIncident.alert_id, + LastAlertToIncident.fingerprint, func.string_agg( - cast(AlertToIncident.incident_id, String), + cast(LastAlertToIncident.incident_id, String), ",", ).label("incidents"), ) - .filter(AlertToIncident.deleted_at == NULL_FOR_DELETED_AT) - .group_by(AlertToIncident.alert_id) + .filter(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT) + .group_by(LastAlertToIncident.fingerprint) .subquery() ) else: @@ -1345,7 +1346,7 @@ def get_last_alerts( query = query.add_columns(incidents_subquery.c.incidents) query = query.outerjoin( - incidents_subquery, Alert.id == incidents_subquery.c.alert_id + incidents_subquery, Alert.fingerprint == incidents_subquery.c.fingerprint ) if provider_id: @@ -1741,9 +1742,10 @@ def get_incident_for_grouping_rule( # if the last alert in the incident is older than the timeframe, create a new incident is_incident_expired = False - if incident and incident.alerts: + if incident and incident.alerts_count > 0: + enrich_incidents_with_alerts(tenant_id, [incident], session) is_incident_expired = max( - alert.timestamp for alert in incident.alerts + alert.timestamp for alert in incident._alerts ) < datetime.utcnow() - timedelta(seconds=timeframe) # if there is no incident with the rule_fingerprint, create it or existed is already expired @@ -1792,13 +1794,13 @@ def get_rule_distribution(tenant_id, minute=False): # Check the dialect if session.bind.dialect.name == "mysql": time_format = "%Y-%m-%d %H:%i" if minute else "%Y-%m-%d %H" - timestamp_format = func.date_format(AlertToIncident.timestamp, time_format) + timestamp_format = func.date_format(LastAlertToIncident.timestamp, time_format) elif session.bind.dialect.name == "postgresql": time_format = "YYYY-MM-DD HH:MI" if minute else "YYYY-MM-DD HH" - timestamp_format = func.to_char(AlertToIncident.timestamp, time_format) + timestamp_format = func.to_char(LastAlertToIncident.timestamp, time_format) elif session.bind.dialect.name == "sqlite": time_format = "%Y-%m-%d %H:%M" if minute else "%Y-%m-%d %H" - timestamp_format = func.strftime(time_format, AlertToIncident.timestamp) + timestamp_format = func.strftime(time_format, LastAlertToIncident.timestamp) else: raise ValueError("Unsupported database dialect") # Construct the query @@ -1806,20 +1808,20 @@ def get_rule_distribution(tenant_id, minute=False): session.query( Rule.id.label("rule_id"), Rule.name.label("rule_name"), - Incident.id.label("group_id"), + Incident.id.label("incident_id"), Incident.rule_fingerprint.label("rule_fingerprint"), timestamp_format.label("time"), - func.count(AlertToIncident.alert_id).label("hits"), + func.count(LastAlertToIncident.fingerprint).label("hits"), ) .join(Incident, Rule.id == Incident.rule_id) - .join(AlertToIncident, Incident.id == AlertToIncident.incident_id) + .join(LastAlertToIncident, Incident.id == LastAlertToIncident.incident_id) .filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.timestamp >= seven_days_ago, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.timestamp >= seven_days_ago, ) .filter(Rule.tenant_id == tenant_id) # Filter by tenant_id .group_by( - "rule_id", "rule_name", "incident_id", "rule_fingerprint", "time" + Rule.id, "rule_name", Incident.id, "rule_fingerprint", "time" ) # Adjusted here .order_by("time") ) @@ -2824,24 +2826,24 @@ def update_preset_options(tenant_id: str, preset_id: str, options: dict) -> Pres def assign_alert_to_incident( - alert_id: UUID | str, + fingerprint: str, incident: Incident, tenant_id: str, session: Optional[Session] = None, ): - return add_alerts_to_incident(tenant_id, incident, [alert_id], session=session) + return add_alerts_to_incident(tenant_id, incident, [fingerprint], session=session) def is_alert_assigned_to_incident( - alert_id: UUID, incident_id: UUID, tenant_id: str + fingerprint: str, incident_id: UUID, tenant_id: str ) -> bool: with Session(engine) as session: assigned = session.exec( - select(AlertToIncident) - .where(AlertToIncident.alert_id == alert_id) - .where(AlertToIncident.incident_id == incident_id) - .where(AlertToIncident.tenant_id == tenant_id) - .where(AlertToIncident.deleted_at == NULL_FOR_DELETED_AT) + select(LastAlertToIncident) + .where(LastAlertToIncident.fingerprint == fingerprint) + .where(LastAlertToIncident.incident_id == incident_id) + .where(LastAlertToIncident.tenant_id == tenant_id) + .where(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT) ).first() return assigned is not None @@ -3106,6 +3108,32 @@ def filter_query(session: Session, query, field, value): return query +def enrich_incidents_with_alerts(tenant_id: str, incidents: List[Incident], session: Optional[Session]=None): + with existed_or_new_session(session) as session: + incident_alerts = session.exec( + select(LastAlertToIncident.incident_id, Alert) + .select_from(LastAlert) + .join(LastAlertToIncident, and_( + LastAlertToIncident.fingerprint == LastAlert.fingerprint, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + )) + .join(Alert, LastAlert.alert_id == Alert.id) + .where( + LastAlert.tenant_id == tenant_id, + LastAlertToIncident.incident_id.in_([incident.id for incident in incidents]) + ) + ).all() + + alerts_per_incident = defaultdict(list) + for incident_id, alert in incident_alerts: + alerts_per_incident[incident_id].append(alert) + + for incident in incidents: + incident._alerts = alerts_per_incident[incident.id] + + return incidents + + def get_last_incidents( tenant_id: str, limit: int = 25, @@ -3147,9 +3175,6 @@ def get_last_incidents( if allowed_incident_ids: query = query.filter(Incident.id.in_(allowed_incident_ids)) - if with_alerts: - query = query.options(joinedload(Incident.alerts)) - if is_predicted is not None: query = query.filter(Incident.is_predicted == is_predicted) @@ -3181,23 +3206,30 @@ def get_last_incidents( # Execute the query incidents = query.all() + if with_alerts: + enrich_incidents_with_alerts(tenant_id, incidents, session) + return incidents, total_count def get_incident_by_id( - tenant_id: str, incident_id: str | UUID, with_alerts: bool = False + tenant_id: str, incident_id: str | UUID, with_alerts: bool = False, + session: Optional[Session] = None, ) -> Optional[Incident]: - with Session(engine) as session: + with existed_or_new_session(session) as session: query = session.query( Incident, ).filter( Incident.tenant_id == tenant_id, Incident.id == incident_id, ) + incident = query.first() if with_alerts: - query = query.options(joinedload(Incident.alerts)) + enrich_incidents_with_alerts( + tenant_id, [incident], session, + ) - return query.first() + return incident def create_incident_from_dto( @@ -3254,7 +3286,6 @@ def create_incident_from_dict( session.add(new_incident) session.commit() session.refresh(new_incident) - new_incident.alerts = [] return new_incident @@ -3271,7 +3302,6 @@ def update_incident_from_dto_by_id( Incident.tenant_id == tenant_id, Incident.id == incident_id, ) - .options(joinedload(Incident.alerts)) ).first() if not incident: @@ -3330,10 +3360,10 @@ def delete_incident_by_id( # Delete all associations with alerts: ( - session.query(AlertToIncident) + session.query(LastAlertToIncident) .where( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) .delete() ) @@ -3363,46 +3393,27 @@ def get_incident_alerts_and_links_by_incident_id( offset: Optional[int] = 0, session: Optional[Session] = None, include_unlinked: bool = False, -) -> tuple[List[tuple[Alert, AlertToIncident]], int]: +) -> tuple[List[tuple[Alert, LastAlertToIncident]], int]: with existed_or_new_session(session) as session: - last_fingerprints_subquery = ( - session.query( - Alert.fingerprint, func.max(Alert.timestamp).label("max_timestamp") - ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) - .filter( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident_id, - ) - .group_by(Alert.fingerprint) - .subquery() - ) - query = ( session.query( Alert, - AlertToIncident, - ) - .select_from(last_fingerprints_subquery) - .outerjoin( - Alert, - and_( - last_fingerprints_subquery.c.fingerprint == Alert.fingerprint, - last_fingerprints_subquery.c.max_timestamp == Alert.timestamp, - ), + LastAlertToIncident, ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .select_from(LastAlertToIncident) + .join(LastAlert, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident_id, ) - .order_by(col(Alert.timestamp).desc()) + .order_by(col(LastAlert.timestamp).desc()) .options(joinedload(Alert.alert_enrichment)) ) if not include_unlinked: query = query.filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, ) total_count = query.count() @@ -3415,7 +3426,7 @@ def get_incident_alerts_and_links_by_incident_id( def get_incident_alerts_by_incident_id(*args, **kwargs) -> tuple[List[Alert], int]: """ - Unpacking (List[(Alert, AlertToIncident)], int) to (List[Alert], int). + Unpacking (List[(Alert, LastAlertToIncident)], int) to (List[Alert], int). """ alerts_and_links, total_alerts = get_incident_alerts_and_links_by_incident_id( *args, **kwargs @@ -3464,8 +3475,7 @@ def get_all_same_alert_ids( def get_alerts_data_for_incident( tenant_id: str, - alert_ids: List[str | UUID], - existed_fingerprints: Optional[List[str]] = None, + fingerprints: Optional[List[str]] = None, session: Optional[Session] = None, ) -> dict: """ @@ -3479,8 +3489,6 @@ def get_alerts_data_for_incident( Returns: dict {sources: list[str], services: list[str], count: int} """ - existed_fingerprints = existed_fingerprints or [] - with existed_or_new_session(session) as session: fields = ( @@ -3491,16 +3499,18 @@ def get_alerts_data_for_incident( ) alerts_data = session.exec( - select(*fields).where( - Alert.tenant_id == tenant_id, - col(Alert.id).in_(alert_ids), + select(*fields) + .select_from(LastAlert) + .join(Alert, LastAlert.alert_id == Alert.id) + .where( + LastAlert.tenant_id == tenant_id, + col(LastAlert.fingerprint).in_(fingerprints), ) ).all() sources = [] services = [] severities = [] - fingerprints = set() for service, source, fingerprint, severity in alerts_data: if source: @@ -3512,21 +3522,19 @@ def get_alerts_data_for_incident( severities.append(IncidentSeverity.from_number(severity)) else: severities.append(IncidentSeverity(severity)) - if fingerprint and fingerprint not in existed_fingerprints: - fingerprints.add(fingerprint) return { "sources": set(sources), "services": set(services), "max_severity": max(severities), - "count": len(fingerprints), + "count": len(alerts_data), } def add_alerts_to_incident_by_incident_id( tenant_id: str, incident_id: str | UUID, - alert_ids: List[UUID], + fingerprints: List[str], is_created_by_ai: bool = False, session: Optional[Session] = None, ) -> Optional[Incident]: @@ -3540,62 +3548,52 @@ def add_alerts_to_incident_by_incident_id( if not incident: return None return add_alerts_to_incident( - tenant_id, incident, alert_ids, is_created_by_ai, session + tenant_id, incident, fingerprints, is_created_by_ai, session ) def add_alerts_to_incident( tenant_id: str, incident: Incident, - alert_ids: List[UUID], + fingerprints: List[str], is_created_by_ai: bool = False, session: Optional[Session] = None, override_count: bool = False, ) -> Optional[Incident]: logger.info( - f"Adding alerts to incident {incident.id} in database, total {len(alert_ids)} alerts", + f"Adding alerts to incident {incident.id} in database, total {len(fingerprints)} alerts", extra={"tags": {"tenant_id": tenant_id, "incident_id": incident.id}}, ) with existed_or_new_session(session) as session: with session.no_autoflush: - all_alert_ids = get_all_same_alert_ids(tenant_id, alert_ids, session) # Use a set for faster membership checks - existing_alert_ids = set( - session.exec( - select(AlertToIncident.alert_id).where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, - col(AlertToIncident.alert_id).in_(all_alert_ids), - ) - ).all() - ) + existing_fingerprints = set( session.exec( - select(Alert.fingerprint) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + select(LastAlert.fingerprint) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == LastAlert.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) ).all() ) - new_alert_ids = [ - alert_id - for alert_id in all_alert_ids - if alert_id not in existing_alert_ids - ] + new_fingerprints = { + fingerprint + for fingerprint in fingerprints + if fingerprint not in existing_fingerprints + } - if not new_alert_ids: + if not new_fingerprints: return incident alerts_data_for_incident = get_alerts_data_for_incident( - tenant_id, new_alert_ids, existing_fingerprints, session + tenant_id, new_fingerprints, session ) incident.sources = list( @@ -3617,13 +3615,13 @@ def add_alerts_to_incident( else: incident.alerts_count = alerts_data_for_incident["count"] alert_to_incident_entries = [ - AlertToIncident( - alert_id=alert_id, + LastAlertToIncident( + fingerprint=fingerprint, incident_id=incident.id, tenant_id=tenant_id, is_created_by_ai=is_created_by_ai, ) - for alert_id in new_alert_ids + for fingerprint in new_fingerprints ] for idx, entry in enumerate(alert_to_incident_entries): @@ -3640,11 +3638,11 @@ def add_alerts_to_incident( started_at, last_seen_at = session.exec( select(func.min(Alert.timestamp), func.max(Alert.timestamp)) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == Alert.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) ).one() @@ -3661,12 +3659,11 @@ def get_incident_unique_fingerprint_count(tenant_id: str, incident_id: str) -> i with Session(engine) as session: return session.execute( select(func.count(1)) - .select_from(AlertToIncident) - .join(Alert, AlertToIncident.alert_id == Alert.id) + .select_from(LastAlertToIncident) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - Alert.tenant_id == tenant_id, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident_id, ) ).scalar() @@ -3678,12 +3675,14 @@ def get_last_alerts_for_incidents( query = ( session.query( Alert, - AlertToIncident.incident_id, + LastAlertToIncident.incident_id, ) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id.in_(incident_ids), + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id.in_(incident_ids), ) .order_by(Alert.timestamp.desc()) ) @@ -3698,7 +3697,7 @@ def get_last_alerts_for_incidents( def remove_alerts_to_incident_by_incident_id( - tenant_id: str, incident_id: str | UUID, alert_ids: List[UUID] + tenant_id: str, incident_id: str | UUID, fingerprints: List[str] ) -> Optional[int]: with Session(engine) as session: incident = session.exec( @@ -3711,16 +3710,14 @@ def remove_alerts_to_incident_by_incident_id( if not incident: return None - all_alert_ids = get_all_same_alert_ids(tenant_id, alert_ids, session) - # Removing alerts-to-incident relation for provided alerts_ids deleted = ( - session.query(AlertToIncident) + session.query(LastAlertToIncident) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, - col(AlertToIncident.alert_id).in_(all_alert_ids), + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, + col(LastAlertToIncident.fingerprint).in_(fingerprints), ) .update( { @@ -3732,7 +3729,7 @@ def remove_alerts_to_incident_by_incident_id( # Getting aggregated data for incidents for alerts which just was removed alerts_data_for_incident = get_alerts_data_for_incident( - tenant_id, all_alert_ids, session=session + tenant_id, fingerprints, session=session ) service_field = get_json_extract_field(session, Alert.event, "service") @@ -3741,10 +3738,12 @@ def remove_alerts_to_incident_by_incident_id( # which still assigned with the incident existed_services_query = ( select(func.distinct(service_field)) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident_id, service_field.in_(alerts_data_for_incident["services"]), ) ) @@ -3754,10 +3753,12 @@ def remove_alerts_to_incident_by_incident_id( # which still assigned with the incident existed_sources_query = ( select(col(Alert.provider_type).distinct()) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident_id, col(Alert.provider_type).in_(alerts_data_for_incident["sources"]), ) ) @@ -3777,10 +3778,12 @@ def remove_alerts_to_incident_by_incident_id( started_at, last_seen_at = session.exec( select(func.min(Alert.timestamp), func.max(Alert.timestamp)) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .where( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) ).one() @@ -3822,7 +3825,6 @@ def merge_incidents_to_id( .where( Incident.tenant_id == tenant_id, Incident.id == destination_incident_id ) - .options(joinedload(Incident.alerts)) ).first() if not destination_incident: @@ -3837,12 +3839,14 @@ def merge_incidents_to_id( ) ).all() + enrich_incidents_with_alerts(tenant_id, source_incidents, session=session) + merged_incident_ids = [] skipped_incident_ids = [] failed_incident_ids = [] for source_incident in source_incidents: - source_incident_alerts_ids = [alert.id for alert in source_incident.alerts] - if not source_incident_alerts_ids: + source_incident_alerts_fingerprints = [alert.fingerprint for alert in source_incident._alerts] + if not source_incident_alerts_fingerprints: logger.info(f"Source incident {source_incident.id} doesn't have alerts") skipped_incident_ids.append(source_incident.id) continue @@ -3854,7 +3858,7 @@ def merge_incidents_to_id( remove_alerts_to_incident_by_incident_id( tenant_id, source_incident.id, - [alert.id for alert in source_incident.alerts], + [alert.fingerprint for alert in source_incident._alerts], ) except OperationalError as e: logger.error( @@ -3864,7 +3868,7 @@ def merge_incidents_to_id( add_alerts_to_incident( tenant_id, destination_incident, - source_incident_alerts_ids, + source_incident_alerts_fingerprints, session=session, ) merged_incident_ids.append(source_incident.id) @@ -4222,12 +4226,13 @@ def get_workflow_executions_for_incident_or_alert( # Query for workflow executions associated with alerts tied to the incident alert_query = ( base_query.join( - Alert, WorkflowToAlertExecution.alert_fingerprint == Alert.fingerprint + LastAlert, WorkflowToAlertExecution.alert_fingerprint == LastAlert.fingerprint ) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .join(Alert, LastAlert.alert_id == Alert.id) + .join(LastAlertToIncident, Alert.fingerprint == LastAlertToIncident.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident_id, ) ) @@ -4270,17 +4275,16 @@ def is_all_incident_alerts_resolved( enriched_status_field.label("enriched_status"), status_field.label("status"), ) - .select_from(Alert) + .select_from(LastAlert) + .join(Alert, LastAlert.alert_id == Alert.id) .outerjoin( AlertEnrichment, Alert.fingerprint == AlertEnrichment.alert_fingerprint ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == LastAlert.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident.id, ) - .group_by(Alert.fingerprint) - .having(func.max(Alert.timestamp)) ).subquery() not_resolved_exists = session.query( @@ -4337,8 +4341,8 @@ def is_edge_incident_alert_resolved( .outerjoin( AlertEnrichment, Alert.fingerprint == AlertEnrichment.alert_fingerprint ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) - .where(AlertToIncident.incident_id == incident.id) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == Alert.fingerprint) + .where(LastAlertToIncident.incident_id == incident.id) .group_by(Alert.fingerprint) .having(func.max(Alert.timestamp)) .order_by(direction(Alert.timestamp)) @@ -4379,11 +4383,12 @@ def get_alerts_metrics_by_provider( Alert.provider_id, func.count(Alert.id).label("total_alerts"), func.sum( - case([(AlertToIncident.alert_id.isnot(None), 1)], else_=0) + case([(LastAlertToIncident.fingerprint.isnot(None), 1)], else_=0) ).label("correlated_alerts"), *dynamic_field_sums, ) - .outerjoin(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .join(LastAlert, Alert.id == LastAlert.alert_id) + .outerjoin(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) .filter( Alert.tenant_id == tenant_id, ) @@ -4504,28 +4509,28 @@ def get_resource_ids_by_resource_type( result = session.exec(query) return result.all() - -def get_or_creat_posthog_instance_id(session: Optional[Session] = None): - POSTHOG_INSTANCE_ID_KEY = "posthog_instance_id" - with Session(engine) as session: - system = session.exec( - select(System).where(System.name == POSTHOG_INSTANCE_ID_KEY) - ).first() - if system: +def get_or_creat_posthog_instance_id( + session: Optional[Session] = None + ): + POSTHOG_INSTANCE_ID_KEY = "posthog_instance_id" + with Session(engine) as session: + system = session.exec(select(System).where(System.name == POSTHOG_INSTANCE_ID_KEY)).first() + if system: + return system.value + + system = System( + id=str(uuid4()), + name=POSTHOG_INSTANCE_ID_KEY, + value=str(uuid4()), + ) + session.add(system) + session.commit() + session.refresh(system) return system.value - system = System( - id=str(uuid4()), - name=POSTHOG_INSTANCE_ID_KEY, - value=str(uuid4()), - ) - session.add(system) - session.commit() - session.refresh(system) - return system.value - - -def get_activity_report(session: Optional[Session] = None): +def get_activity_report( + session: Optional[Session] = None + ): from keep.api.models.db.user import User last_24_hours = datetime.utcnow() - timedelta(hours=24) @@ -4553,8 +4558,46 @@ def get_activity_report(session: Optional[Session] = None): ) activity_report["last_24_hours_workflows_executed"] = ( session.query(WorkflowExecution) - .filter(WorkflowExecution.started >= last_24_hours) - .count() - ) - + .filter(WorkflowExecution.started >= last_24_hours).count() +) return activity_report + + +def get_last_alert_by_fingerprint( + tenant_id: str, fingerprint: str, session: Optional[Session] = None +) -> Optional[LastAlert]: + with existed_or_new_session(session) as session: + return session.exec( + select(LastAlert) + .where( + and_( + LastAlert.tenant_id == tenant_id, + LastAlert.fingerprint == fingerprint, + ) + ) + ).first() + +def set_last_alert( + tenant_id: str, alert: Alert, session: Optional[Session] = None +) -> None: + last_alert = get_last_alert_by_fingerprint(tenant_id, alert.fingerprint, session) + + # To prevent rare, but possible race condition + # For example if older alert failed to process + # and retried after new one + + if last_alert and last_alert.timestamp.replace(tzinfo=tz.UTC) < alert.timestamp.replace(tzinfo=tz.UTC): + last_alert.timestamp = alert.timestamp + last_alert.alert_id = alert.id + session.add(last_alert) + session.commit() + + elif not last_alert: + last_alert = LastAlert( + tenant_id=tenant_id, + fingerprint=alert.fingerprint, + timestamp=alert.timestamp, + alert_id=alert.id, + ) + session.add(last_alert) + session.commit() diff --git a/keep/api/models/db/alert.py b/keep/api/models/db/alert.py index 11efa9ef7..0d0bac569 100644 --- a/keep/api/models/db/alert.py +++ b/keep/api/models/db/alert.py @@ -4,7 +4,8 @@ from typing import List, Optional from uuid import UUID, uuid4 -from sqlalchemy import ForeignKey, UniqueConstraint +from pydantic import PrivateAttr +from sqlalchemy import ForeignKey, UniqueConstraint, ForeignKeyConstraint from sqlalchemy.dialects.mssql import DATETIME2 as MSSQL_DATETIME2 from sqlalchemy.dialects.mysql import DATETIME as MySQL_DATETIME from sqlalchemy.engine.url import make_url @@ -60,8 +61,6 @@ class AlertToIncident(SQLModel, table=True): primary_key=True, ) ) - alert: "Alert" = Relationship(back_populates="alert_to_incident_link") - incident: "Incident" = Relationship(back_populates="alert_to_incident_link") is_created_by_ai: bool = Field(default=False) @@ -72,6 +71,43 @@ class AlertToIncident(SQLModel, table=True): default=NULL_FOR_DELETED_AT, ) +class LastAlert(SQLModel, table=True): + + tenant_id: str = Field(foreign_key="tenant.id", nullable=False, primary_key=True) + fingerprint: str = Field(primary_key=True, index=True) + alert_id: UUID = Field(foreign_key="alert.id") + timestamp: datetime = Field(nullable=False, index=True) + + +class LastAlertToIncident(SQLModel, table=True): + tenant_id: str = Field(foreign_key="tenant.id", nullable=False, primary_key=True) + timestamp: datetime = Field(default_factory=datetime.utcnow) + + fingerprint: str = Field(primary_key=True) + incident_id: UUID = Field( + sa_column=Column( + UUIDType(binary=False), + ForeignKey("incident.id", ondelete="CASCADE"), + primary_key=True, + ) + ) + + is_created_by_ai: bool = Field(default=False) + + deleted_at: datetime = Field( + default_factory=None, + nullable=True, + primary_key=True, + default=NULL_FOR_DELETED_AT, + ) + + __table_args__ = ( + ForeignKeyConstraint( + ["tenant_id", "fingerprint"], + ["lastalert.tenant_id", "lastalert.fingerprint"]), + {} + ) + class Incident(SQLModel, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True) @@ -96,26 +132,6 @@ class Incident(SQLModel, table=True): end_time: datetime | None last_seen_time: datetime | None - # map of attributes to values - alerts: List["Alert"] = Relationship( - back_populates="incidents", - link_model=AlertToIncident, - # primaryjoin is used to filter out deleted links for various DB dialects - sa_relationship_kwargs={ - "primaryjoin": f"""and_(AlertToIncident.incident_id == Incident.id, - or_( - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S.%f')}', - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S')}' - ))""", - "uselist": True, - "overlaps": "alert,incident", - }, - ) - alert_to_incident_link: List[AlertToIncident] = Relationship( - back_populates="incident", - sa_relationship_kwargs={"overlaps": "alerts,incidents"}, - ) - is_predicted: bool = Field(default=False) is_confirmed: bool = Field(default=False) @@ -183,10 +199,7 @@ class Incident(SQLModel, table=True): ), ) - def __init__(self, **kwargs): - super().__init__(**kwargs) - if "alerts" not in kwargs: - self.alerts = [] + _alerts: List["Alert"] = PrivateAttr() class Config: arbitrary_types_allowed = True @@ -224,24 +237,6 @@ class Alert(SQLModel, table=True): } ) - incidents: List["Incident"] = Relationship( - back_populates="alerts", - link_model=AlertToIncident, - sa_relationship_kwargs={ - # primaryjoin is used to filter out deleted links for various DB dialects - "primaryjoin": f"""and_(AlertToIncident.alert_id == Alert.id, - or_( - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S.%f')}', - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S')}' - ))""", - "uselist": True, - "overlaps": "alert,incident", - }, - ) - alert_to_incident_link: List[AlertToIncident] = Relationship( - back_populates="alert", sa_relationship_kwargs={"overlaps": "alerts,incidents"} - ) - __table_args__ = ( Index( "ix_alert_tenant_fingerprint_timestamp", diff --git a/keep/api/models/db/migrations/versions/2024-12-01-16-40_3ad5308e7200.py b/keep/api/models/db/migrations/versions/2024-12-01-16-40_3ad5308e7200.py index f92d246a8..e6c2140e3 100644 --- a/keep/api/models/db/migrations/versions/2024-12-01-16-40_3ad5308e7200.py +++ b/keep/api/models/db/migrations/versions/2024-12-01-16-40_3ad5308e7200.py @@ -7,10 +7,7 @@ """ import sqlalchemy as sa -import sqlalchemy_utils -import sqlmodel from alembic import op -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "3ad5308e7200" @@ -24,13 +21,15 @@ def upgrade() -> None: with op.batch_alter_table("externalaiconfigandmetadata", schema=None) as batch_op: batch_op.alter_column( - "settings", existing_type=sa.VARCHAR(), type_=sa.JSON(), nullable=True + "settings", existing_type=sa.VARCHAR(), type_=sa.JSON(), nullable=True, + postgresql_using="settings::json" ) batch_op.alter_column( "settings_proposed_by_algorithm", existing_type=sa.VARCHAR(), type_=sa.JSON(), existing_nullable=True, + postgresql_using="settings::json" ) batch_op.alter_column( "feedback_logs", diff --git a/keep/api/models/db/migrations/versions/2024-12-02-13-36_bdae8684d0b4.py b/keep/api/models/db/migrations/versions/2024-12-02-13-36_bdae8684d0b4.py new file mode 100644 index 000000000..905748deb --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-12-02-13-36_bdae8684d0b4.py @@ -0,0 +1,161 @@ +"""add lastalert and lastalerttoincident table + +Revision ID: bdae8684d0b4 +Revises: 3ad5308e7200 +Create Date: 2024-11-05 22:48:04.733192 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +import sqlmodel +from alembic import op +from sqlalchemy.orm import Session + +# revision identifiers, used by Alembic. +revision = "bdae8684d0b4" +down_revision = "3ad5308e7200" +branch_labels = None +depends_on = None + +migration_metadata = sa.MetaData() + + +def populate_db(): + session = Session(op.get_bind()) + + if session.bind.dialect.name == "postgresql": + migrate_lastalert_query = """ + insert into lastalert (tenant_id, fingerprint, alert_id, timestamp) + select alert.tenant_id, alert.fingerprint, alert.id as alert_id, alert.timestamp + from alert + join ( + select + alert.tenant_id, alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint, tenant_id + ) as a ON alert.fingerprint = a.fingerprint and alert.timestamp = a.last_received and alert.tenant_id = a.tenant_id + on conflict + do nothing + """ + + migrate_lastalerttoincident_query = """ + insert into lastalerttoincident (incident_id, tenant_id, timestamp, fingerprint, is_created_by_ai, deleted_at) + select ati.incident_id, ati.tenant_id, ati.timestamp, lf.fingerprint, ati.is_created_by_ai, ati.deleted_at + from alerttoincident as ati + join + ( + select alert.tenant_id, alert.id, alert.fingerprint + from alert + join ( + select + alert.tenant_id, alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint, tenant_id + ) as a on alert.fingerprint = a.fingerprint and alert.timestamp = a.last_received and alert.tenant_id = a.tenant_id + ) as lf on ati.alert_id = lf.id + on conflict + do nothing + """ + + else: + migrate_lastalert_query = """ + INSERT INTO lastalert (tenant_id, fingerprint, alert_id, timestamp) + SELECT + grouped_alerts.tenant_id, + grouped_alerts.fingerprint, + MAX(grouped_alerts.alert_id) as alert_id, -- Using MAX to consistently pick one alert_id + grouped_alerts.timestamp + FROM ( + select alert.tenant_id, alert.fingerprint, alert.id as alert_id, alert.timestamp + from alert + join ( + select + alert.tenant_id, alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint, tenant_id + ) as a ON alert.fingerprint = a.fingerprint + and alert.timestamp = a.last_received + and alert.tenant_id = a.tenant_id + ) as grouped_alerts + GROUP BY grouped_alerts.tenant_id, grouped_alerts.fingerprint, grouped_alerts.timestamp; +""" + + migrate_lastalerttoincident_query = """ + REPLACE INTO lastalerttoincident (incident_id, tenant_id, timestamp, fingerprint, is_created_by_ai, deleted_at) + select ati.incident_id, ati.tenant_id, ati.timestamp, lf.fingerprint, ati.is_created_by_ai, ati.deleted_at + from alerttoincident as ati + join + ( + select alert.id, alert.fingerprint, alert.tenant_id + from alert + join ( + select + alert.tenant_id,alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint, tenant_id + ) as a on alert.fingerprint = a.fingerprint and alert.timestamp = a.last_received and alert.tenant_id = a.tenant_id + ) as lf on ati.alert_id = lf.id; + """ + + session.execute(migrate_lastalert_query) + session.execute(migrate_lastalerttoincident_query) + + +def upgrade() -> None: + op.create_table( + "lastalert", + sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("fingerprint", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("alert_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["alert_id"], + ["alert.id"], + ), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("tenant_id", "fingerprint"), + ) + with op.batch_alter_table("lastalert", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("ix_lastalert_timestamp"), ["timestamp"], unique=False + ) + # Add index for the fingerprint column that will be referenced by foreign key + batch_op.create_index("ix_lastalert_fingerprint", ["fingerprint"], unique=False) + + op.create_table( + "lastalerttoincident", + sa.Column( + "incident_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=False, + ), + sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("fingerprint", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("is_created_by_ai", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["tenant_id", "fingerprint"], + ["lastalert.tenant_id", "lastalert.fingerprint"], + ), + sa.ForeignKeyConstraint(["incident_id"], ["incident.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("tenant_id", "incident_id", "fingerprint", "deleted_at"), + ) + + populate_db() + + +def downgrade() -> None: + op.drop_table("lastalerttoincident") + with op.batch_alter_table("lastalert", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_lastalert_timestamp")) + + op.drop_table("lastalert") \ No newline at end of file diff --git a/keep/api/routes/incidents.py b/keep/api/routes/incidents.py index 4820a415a..d7e4bbd39 100644 --- a/keep/api/routes/incidents.py +++ b/keep/api/routes/incidents.py @@ -457,7 +457,7 @@ def get_incident_workflows( ) async def add_alerts_to_incident( incident_id: UUID, - alert_ids: List[UUID], + alert_fingerprints: List[str], is_created_by_ai: bool = False, authenticated_entity: AuthenticatedEntity = Depends( IdentityManagerFactory.get_auth_verifier(["write:incident"]) @@ -467,7 +467,7 @@ async def add_alerts_to_incident( ): tenant_id = authenticated_entity.tenant_id incident_bl = IncidentBl(tenant_id, session, pusher_client) - await incident_bl.add_alerts_to_incident(incident_id, alert_ids, is_created_by_ai) + await incident_bl.add_alerts_to_incident(incident_id, alert_fingerprints, is_created_by_ai) return Response(status_code=202) @@ -479,7 +479,7 @@ async def add_alerts_to_incident( ) def delete_alerts_from_incident( incident_id: UUID, - alert_ids: List[UUID], + fingerprints: List[str], authenticated_entity: AuthenticatedEntity = Depends( IdentityManagerFactory.get_auth_verifier(["write:incident"]) ), @@ -489,7 +489,7 @@ def delete_alerts_from_incident( tenant_id = authenticated_entity.tenant_id incident_bl = IncidentBl(tenant_id, session, pusher_client) incident_bl.delete_alerts_from_incident( - incident_id=incident_id, alert_ids=alert_ids + incident_id=incident_id, alert_fingerprints=fingerprints ) return Response(status_code=202) @@ -611,7 +611,7 @@ def change_incident_status( # TODO: same this change to audit table with the comment if change.status == IncidentStatus.RESOLVED: - for alert in incident.alerts: + for alert in incident._alerts: _enrich_alert( EnrichAlertRequestBody( enrichments={"status": "resolved"}, diff --git a/keep/api/routes/workflows.py b/keep/api/routes/workflows.py index 8d355d08f..6326a6b33 100644 --- a/keep/api/routes/workflows.py +++ b/keep/api/routes/workflows.py @@ -190,7 +190,8 @@ def run_workflow( event_body = body.get("body", {}) or body # if its event that was triggered by the UI with the Modal - if "test-workflow" in event_body.get("fingerprint", "") or not body: + fingerprint = event_body.get("fingerprint", "") + if (fingerprint and "test-workflow" in fingerprint) or not body: # some random event_body["id"] = event_body.get("fingerprint", "manual-run") event_body["name"] = event_body.get("fingerprint", "manual-run") diff --git a/keep/api/tasks/process_event_task.py b/keep/api/tasks/process_event_task.py index 5a5aae8b1..28ed135fa 100644 --- a/keep/api/tasks/process_event_task.py +++ b/keep/api/tasks/process_event_task.py @@ -23,6 +23,7 @@ get_all_presets_dtos, get_enrichment_with_session, get_session_sync, + set_last_alert, ) from keep.api.core.dependencies import get_pusher_client from keep.api.core.elastic import ElasticClient @@ -188,6 +189,9 @@ def __save_to_db( ) session.add(audit) alert_dto = AlertDto(**formatted_event.dict()) + + set_last_alert(tenant_id, alert, session=session) + # Mapping try: enrichments_bl.run_mapping_rules(alert_dto) diff --git a/keep/api/utils/enrichment_helpers.py b/keep/api/utils/enrichment_helpers.py index 86e9795be..2b0ad2b62 100644 --- a/keep/api/utils/enrichment_helpers.py +++ b/keep/api/utils/enrichment_helpers.py @@ -1,10 +1,13 @@ import logging from datetime import datetime +from typing import Optional from opentelemetry import trace +from sqlmodel import Session +from keep.api.core.db import existed_or_new_session from keep.api.models.alert import AlertDto, AlertStatus, AlertWithIncidentLinkMetadataDto -from keep.api.models.db.alert import Alert, AlertToIncident +from keep.api.models.db.alert import Alert, LastAlertToIncident tracer = trace.get_tracer(__name__) logger = logging.getLogger(__name__) @@ -77,8 +80,9 @@ def calculated_start_firing_time( def convert_db_alerts_to_dto_alerts( - alerts: list[Alert | tuple[Alert, AlertToIncident]], - with_incidents: bool = False + alerts: list[Alert | tuple[Alert, LastAlertToIncident]], + with_incidents: bool = False, + session: Optional[Session] = None, ) -> list[AlertDto | AlertWithIncidentLinkMetadataDto]: """ Enriches the alerts with the enrichment data. @@ -90,46 +94,47 @@ def convert_db_alerts_to_dto_alerts( Returns: list[AlertDto | AlertWithIncidentLinkMetadataDto]: The enriched alerts. """ - alerts_dto = [] - with tracer.start_as_current_span("alerts_enrichment"): - # enrich the alerts with the enrichment data - for _object in alerts: - - # We may have an Alert only or and Alert with an AlertToIncident - if isinstance(_object, Alert): - alert, alert_to_incident = _object, None - else: - alert, alert_to_incident = _object - - if alert.alert_enrichment: - alert.event.update(alert.alert_enrichment.enrichments) - if with_incidents: - if alert.incidents: - alert.event["incident"] = ",".join(str(incident.id) for incident in alert.incidents) - try: - if alert_to_incident is not None: - alert_dto = AlertWithIncidentLinkMetadataDto.from_db_instance(alert, alert_to_incident) + with existed_or_new_session(session) as session: + alerts_dto = [] + with tracer.start_as_current_span("alerts_enrichment"): + # enrich the alerts with the enrichment data + for _object in alerts: + + # We may have an Alert only or and Alert with an LastAlertToIncident + if isinstance(_object, Alert): + alert, alert_to_incident = _object, None else: - alert_dto = AlertDto(**alert.event) + alert, alert_to_incident = _object + if alert.alert_enrichment: - parse_and_enrich_deleted_and_assignees( - alert_dto, alert.alert_enrichment.enrichments + alert.event.update(alert.alert_enrichment.enrichments) + if with_incidents: + if alert.incidents: + alert.event["incident"] = ",".join(str(incident.id) for incident in alert.incidents) + try: + if alert_to_incident is not None: + alert_dto = AlertWithIncidentLinkMetadataDto.from_db_instance(alert, alert_to_incident) + else: + alert_dto = AlertDto(**alert.event) + if alert.alert_enrichment: + parse_and_enrich_deleted_and_assignees( + alert_dto, alert.alert_enrichment.enrichments + ) + except Exception: + # should never happen but just in case + logger.exception( + "Failed to parse alert", + extra={ + "alert": alert, + }, ) - except Exception: - # should never happen but just in case - logger.exception( - "Failed to parse alert", - extra={ - "alert": alert, - }, - ) - continue - - alert_dto.event_id = str(alert.id) - - # enrich provider id when it's possible - if alert_dto.providerId is None: - alert_dto.providerId = alert.provider_id - alert_dto.providerType = alert.provider_type - alerts_dto.append(alert_dto) + continue + + alert_dto.event_id = str(alert.id) + + # enrich provider id when it's possible + if alert_dto.providerId is None: + alert_dto.providerId = alert.provider_id + alert_dto.providerType = alert.provider_type + alerts_dto.append(alert_dto) return alerts_dto diff --git a/keep/rulesengine/rulesengine.py b/keep/rulesengine/rulesengine.py index 363901c16..03538b8e3 100644 --- a/keep/rulesengine/rulesengine.py +++ b/keep/rulesengine/rulesengine.py @@ -81,7 +81,7 @@ def run_rules( ) incident = assign_alert_to_incident( - alert_id=event.event_id, + fingerprint=event.fingerprint, incident=incident, tenant_id=self.tenant_id, session=session, @@ -101,7 +101,7 @@ def run_rules( ): should_resolve = True - if ( + elif ( rule.resolve_on == ResolveOn.LAST.value and is_last_incident_alert_resolved(incident, session=session) ): diff --git a/tests/conftest.py b/tests/conftest.py index 06725ca5b..6a353aa31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -333,6 +333,7 @@ def is_elastic_responsive(host, port, user, password): basic_auth=(user, password), ) info = elastic_client._client.info() + print("Elastic still up now") return True if info else False except Exception: print("Elastic still not up") @@ -547,6 +548,32 @@ def _setup_stress_alerts_no_elastic(num_alerts): db_session.add_all(alerts) db_session.commit() + existed_last_alerts = db_session.query(LastAlert).all() + existed_last_alerts_dict = { + last_alert.fingerprint: last_alert + for last_alert in existed_last_alerts + } + last_alerts = [] + for alert in alerts: + if alert.fingerprint in existed_last_alerts_dict: + last_alert = existed_last_alerts_dict[alert.fingerprint] + last_alert.alert_id = alert.id + last_alert.timestamp=alert.timestamp + last_alerts.append( + last_alert + ) + else: + last_alerts.append( + LastAlert( + tenant_id=SINGLE_TENANT_UUID, + fingerprint=alert.fingerprint, + timestamp=alert.timestamp, + alert_id=alert.id, + ) + ) + db_session.add_all(last_alerts) + db_session.commit() + return alerts return _setup_stress_alerts_no_elastic @@ -559,7 +586,6 @@ def setup_stress_alerts( num_alerts = request.param.get( "num_alerts", 1000 ) # Default to 1000 alerts if not specified - alerts = setup_stress_alerts_no_elastic(num_alerts) # add all to elasticsearch alerts_dto = convert_db_alerts_to_dto_alerts(alerts) diff --git a/tests/test_incidents.py b/tests/test_incidents.py index e675c9075..6002058f6 100644 --- a/tests/test_incidents.py +++ b/tests/test_incidents.py @@ -2,8 +2,7 @@ from itertools import cycle import pytest -from sqlalchemy import distinct, func -from sqlalchemy.orm.exc import DetachedInstanceError +from sqlalchemy import distinct, func, desc from keep.api.core.db import ( IncidentSorting, @@ -24,7 +23,7 @@ IncidentSeverity, IncidentStatus, ) -from keep.api.models.db.alert import Alert, AlertToIncident +from keep.api.models.db.alert import Alert, LastAlertToIncident from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts from tests.fixtures.client import client, test_app # noqa @@ -50,7 +49,7 @@ def test_get_alerts_data_for_incident(db_session, create_alert): assert 100 == db_session.query(func.count(Alert.id)).scalar() assert 10 == unique_fingerprints - data = get_alerts_data_for_incident(SINGLE_TENANT_UUID, [a.id for a in alerts]) + data = get_alerts_data_for_incident(SINGLE_TENANT_UUID, [a.fingerprint for a in alerts]) assert data["sources"] == set([f"source_{i}" for i in range(10)]) assert data["services"] == set([f"service_{i}" for i in range(10)]) assert data["count"] == unique_fingerprints @@ -64,16 +63,27 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} ) - assert len(incident.alerts) == 0 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 0 + assert total_incident_alerts == 0 add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [a.id for a in alerts] + SINGLE_TENANT_UUID, incident.id, [a.fingerprint for a in alerts] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - # 110 alerts - assert len(incident.alerts) == 110 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 100 + assert total_incident_alerts == 100 # But 100 unique fingerprints assert incident.alerts_count == 100 @@ -86,7 +96,7 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti service_field = get_json_extract_field(db_session, Alert.event, "service") - service_0 = db_session.query(Alert.id).filter(service_field == "service_0").all() + service_0 = db_session.query(Alert.fingerprint).filter(service_field == "service_0").all() # Testing unique fingerprints more_alerts_with_same_fingerprints = setup_stress_alerts_no_elastic(10) @@ -94,26 +104,32 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti add_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, - [a.id for a in more_alerts_with_same_fingerprints], + [a.fingerprint for a in more_alerts_with_same_fingerprints], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) assert incident.alerts_count == 100 - assert db_session.query(func.count(AlertToIncident.alert_id)).scalar() == 120 + assert db_session.query(func.count(LastAlertToIncident.fingerprint)).scalar() == 100 remove_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, [ - service_0[0].id, + service_0[0].fingerprint, ], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - # 117 because we removed multiple alerts with service_0 - assert len(incident.alerts) == 117 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 99 + assert total_incident_alerts == 99 + assert "service_0" in incident.affected_services assert len(incident.affected_services) == 10 assert sorted(incident.affected_services) == sorted( @@ -121,7 +137,7 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti ) remove_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [a.id for a in service_0] + SINGLE_TENANT_UUID, incident.id, [a.fingerprint for a in service_0] ) # Removing shouldn't impact links between alert and incident if include_unlinked=True @@ -138,8 +154,14 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - # 108 because we removed multiple alert with same fingerprints - assert len(incident.alerts) == 108 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 90 + assert total_incident_alerts == 90 + assert "service_0" not in incident.affected_services assert len(incident.affected_services) == 9 assert sorted(incident.affected_services) == sorted( @@ -147,29 +169,36 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti ) source_1 = ( - db_session.query(Alert.id).filter(Alert.provider_type == "source_1").all() + db_session.query(Alert.fingerprint).filter(Alert.provider_type == "source_1").all() ) remove_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, [ - source_1[0].id, + source_1[0].fingerprint, ], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 105 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 89 + assert total_incident_alerts == 89 + assert "source_1" in incident.sources - # source_0 was removed together with service_0 + # source_0 was removed together with service_1 assert len(incident.sources) == 9 assert sorted(incident.sources) == sorted( ["source_{}".format(i) for i in range(1, 10)] ) remove_alerts_to_incident_by_incident_id( - "keep", incident.id, [a.id for a in source_1] + "keep", incident.id, [a.fingerprint for a in source_1] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) @@ -213,8 +242,19 @@ def test_get_last_incidents(db_session, create_alert): ) alert = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() + create_alert( + f"alert-test-2-{i}", + AlertStatus(status), + datetime.utcnow(), + { + "severity": AlertSeverity.from_number(severity), + "service": service, + }, + ) + alert2 = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() + add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [alert.id] + SINGLE_TENANT_UUID, incident.id, [alert.fingerprint, alert2.fingerprint] ) incidents_default, incidents_default_count = get_last_incidents(SINGLE_TENANT_UUID) @@ -246,19 +286,14 @@ def test_get_last_incidents(db_session, create_alert): for i, j in enumerate(range(5, 10)): assert incidents_limit_5_page_2[i].user_generated_name == f"test-{j}" - # If alerts not preloaded, we will have detached session issue during attempt to get them - # Background on this error at: https://sqlalche.me/e/14/bhk3 - with pytest.raises(DetachedInstanceError): - alerts = incidents_confirmed[0].alerts # noqa - incidents_with_alerts, _ = get_last_incidents( SINGLE_TENANT_UUID, is_confirmed=True, with_alerts=True ) for i in range(25): if incidents_with_alerts[i].status == IncidentStatus.MERGED.value: - assert len(incidents_with_alerts[i].alerts) == 0 + assert len(incidents_with_alerts[i]._alerts) == 0 else: - assert len(incidents_with_alerts[i].alerts) == 1 + assert len(incidents_with_alerts[i]._alerts) == 2 # Test sorting @@ -319,11 +354,16 @@ def test_incident_status_change( "keep", {"name": "test", "description": "test"} ) - add_alerts_to_incident_by_incident_id("keep", incident.id, [a.id for a in alerts]) + add_alerts_to_incident_by_incident_id( + "keep", + incident.id, + [a.fingerprint for a in alerts], + session=db_session + ) - incident = get_incident_by_id("keep", incident.id, with_alerts=True) + incident = get_incident_by_id("keep", incident.id, with_alerts=True, session=db_session) - alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) + alerts_dtos = convert_db_alerts_to_dto_alerts(incident._alerts, session=db_session) assert ( len( [ @@ -348,10 +388,11 @@ def test_incident_status_change( assert data["id"] == str(incident.id) assert data["status"] == IncidentStatus.ACKNOWLEDGED.value - incident = get_incident_by_id("keep", incident.id, with_alerts=True) + db_session.expire_all() + incident = get_incident_by_id("keep", incident.id, with_alerts=True, session=db_session) assert incident.status == IncidentStatus.ACKNOWLEDGED.value - alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) + alerts_dtos = convert_db_alerts_to_dto_alerts(incident._alerts) assert ( len( [ @@ -376,11 +417,12 @@ def test_incident_status_change( assert data["id"] == str(incident.id) assert data["status"] == IncidentStatus.RESOLVED.value - incident = get_incident_by_id("keep", incident.id, with_alerts=True) + db_session.expire_all() + incident = get_incident_by_id("keep", incident.id, with_alerts=True, session=db_session) assert incident.status == IncidentStatus.RESOLVED.value # All alerts are resolved as well - alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) + alerts_dtos = convert_db_alerts_to_dto_alerts(incident._alerts, session=db_session) assert ( len( [ @@ -475,23 +517,48 @@ def test_add_alerts_with_same_fingerprint_to_incident(db_session, create_alert): SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} ) - assert len(incident.alerts) == 0 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 0 + assert total_incident_alerts == 0 add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].id] + SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].fingerprint] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 2 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 1 + last_fp1_alert = ( + db_session + .query(Alert.timestamp) + .where(Alert.fingerprint == "fp1") + .order_by(desc(Alert.timestamp)).first() + ) + assert incident_alerts[0].timestamp == last_fp1_alert.timestamp + assert total_incident_alerts == 1 remove_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].id] + SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].fingerprint] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 0 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 0 + assert total_incident_alerts == 0 def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elastic): @@ -522,7 +589,7 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti ) alerts_1 = db_session.query(Alert).all() add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident_1.id, [a.id for a in alerts_1] + SINGLE_TENANT_UUID, incident_1.id, [a.fingerprint for a in alerts_1] ) incident_2 = create_incident_from_dict( SINGLE_TENANT_UUID, @@ -532,19 +599,19 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti }, ) create_alert( - "fp20", + "fp20-0", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.CRITICAL.value}, ) create_alert( - "fp20", + "fp20-1", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.CRITICAL.value}, ) create_alert( - "fp20", + "fp20-2", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.CRITICAL.value}, @@ -553,7 +620,7 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti db_session.query(Alert).filter(Alert.fingerprint.startswith("fp20")).all() ) add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident_2.id, [a.id for a in alerts_2] + SINGLE_TENANT_UUID, incident_2.id, [a.fingerprint for a in alerts_2] ) incident_3 = create_incident_from_dict( SINGLE_TENANT_UUID, @@ -563,28 +630,28 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti }, ) create_alert( - "fp30", + "fp30-0", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.WARNING.value}, ) create_alert( - "fp30", + "fp30-1", AlertStatus.FIRING, datetime.utcnow(), - {"severity": AlertSeverity.WARNING.value}, + {"severity": AlertSeverity.INFO.value}, ) create_alert( - "fp30", + "fp30-2", AlertStatus.FIRING, datetime.utcnow(), - {"severity": AlertSeverity.INFO.value}, + {"severity": AlertSeverity.WARNING.value}, ) alerts_3 = ( db_session.query(Alert).filter(Alert.fingerprint.startswith("fp30")).all() ) add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident_3.id, [a.id for a in alerts_3] + SINGLE_TENANT_UUID, incident_3.id, [a.fingerprint for a in alerts_3] ) # before merge @@ -602,19 +669,21 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti "test-user-email", ) + db_session.expire_all() + incident_1 = get_incident_by_id(SINGLE_TENANT_UUID, incident_1.id, with_alerts=True) - assert len(incident_1.alerts) == 9 + assert len(incident_1._alerts) == 8 assert incident_1.severity == IncidentSeverity.CRITICAL.order incident_2 = get_incident_by_id(SINGLE_TENANT_UUID, incident_2.id, with_alerts=True) - assert len(incident_2.alerts) == 0 + assert len(incident_2._alerts) == 0 assert incident_2.status == IncidentStatus.MERGED.value assert incident_2.merged_into_incident_id == incident_1.id assert incident_2.merged_at is not None assert incident_2.merged_by == "test-user-email" - incident_3 = get_incident_by_id(SINGLE_TENANT_UUID, incident_3.id, with_alerts=True) - assert len(incident_3.alerts) == 0 + incident_3 = get_incident_by_id(SINGLE_TENANT_UUID, incident_3.id, with_alerts=True, session=db_session) + assert len(incident_3._alerts) == 0 assert incident_3.status == IncidentStatus.MERGED.value assert incident_3.merged_into_incident_id == incident_1.id assert incident_3.merged_at is not None diff --git a/tests/test_metrics.py b/tests/test_metrics.py index fcd4306d6..8908896b6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -18,7 +18,7 @@ def test_add_remove_alert_to_incidents( valid_api_key = "valid_api_key" setup_api_key(db_session, valid_api_key) - add_alerts_to_incident_by_incident_id("keep", incident.id, [a.id for a in alerts]) + add_alerts_to_incident_by_incident_id("keep", incident.id, [a.fingerprint for a in alerts]) response = client.get("/metrics?labels=a.b", headers={"X-API-KEY": "valid_api_key"}) diff --git a/tests/test_rules_engine.py b/tests/test_rules_engine.py index 40627b372..07ad6cf70 100644 --- a/tests/test_rules_engine.py +++ b/tests/test_rules_engine.py @@ -7,7 +7,7 @@ import pytest from keep.api.core.db import create_rule as create_rule_db -from keep.api.core.db import get_incident_alerts_by_incident_id, get_last_incidents +from keep.api.core.db import get_incident_alerts_by_incident_id, get_last_incidents, set_last_alert from keep.api.core.db import get_rules as get_rules_db from keep.api.core.dependencies import SINGLE_TENANT_UUID from keep.api.models.alert import ( @@ -63,10 +63,13 @@ def test_sanity(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) + db_session.add(alert) db_session.commit() + + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -110,10 +113,11 @@ def test_sanity_2(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) db_session.add(alert) db_session.commit() + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -158,10 +162,11 @@ def test_sanity_3(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) db_session.add(alert) db_session.commit() + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -206,10 +211,11 @@ def test_sanity_4(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) db_session.add(alert) db_session.commit() + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -223,7 +229,7 @@ def test_incident_attributes(db_session): AlertDto( id=str(uuid.uuid4()), source=["grafana"], - name="grafana-test-alert", + name=f"grafana-test-alert-{i}", status=AlertStatus.FIRING, severity=AlertSeverity.CRITICAL, lastReceived=datetime.datetime.now().isoformat(), @@ -255,13 +261,15 @@ def test_incident_attributes(db_session): provider_type="test", provider_id="test", event=alert.dict(), - fingerprint=hashlib.sha256(json.dumps(alert.dict()).encode()).hexdigest(), + fingerprint=alert.fingerprint, timestamp=alert.lastReceived, ) for alert in alerts_dto ] db_session.add_all(alerts) db_session.commit() + for alert in alerts: + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) for i, alert in enumerate(alerts_dto): alert.event_id = alerts[i].id @@ -283,7 +291,7 @@ def test_incident_severity(db_session): AlertDto( id=str(uuid.uuid4()), source=["grafana"], - name="grafana-test-alert", + name=f"grafana-test-alert-{i}", status=AlertStatus.FIRING, severity=AlertSeverity.INFO, lastReceived=datetime.datetime.now().isoformat(), @@ -315,13 +323,15 @@ def test_incident_severity(db_session): provider_type="test", provider_id="test", event=alert.dict(), - fingerprint=hashlib.sha256(json.dumps(alert.dict()).encode()).hexdigest(), + fingerprint=alert.fingerprint, timestamp=alert.lastReceived, ) for alert in alerts_dto ] db_session.add_all(alerts) db_session.commit() + for alert in alerts: + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) for i, alert in enumerate(alerts_dto): alert.event_id = alerts[i].id