diff --git a/keep/api/core/db.py b/keep/api/core/db.py index 735c0bec1..5e1020369 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -17,6 +17,7 @@ import numpy as np import validators +from dateutil.tz import tz from dotenv import find_dotenv, load_dotenv from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from sqlalchemy import ( @@ -30,7 +31,7 @@ null, select, union, - update, + update, asc, ) from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -1722,9 +1723,10 @@ def get_incident_for_grouping_rule( # if the last alert in the incident is older than the timeframe, create a new incident is_incident_expired = False - if incident and incident.alerts: + if incident and incident.alerts_count > 0: + enrich_incidents_with_alerts(tenant_id, [incident], session) is_incident_expired = max( - alert.timestamp for alert in incident.alerts + alert.timestamp for alert in incident._alerts ) < datetime.utcnow() - timedelta(seconds=timeframe) # if there is no incident with the rule_fingerprint, create it or existed is already expired @@ -2802,12 +2804,12 @@ def update_preset_options(tenant_id: str, preset_id: str, options: dict) -> Pres def assign_alert_to_incident( - alert_id: UUID | str, + fingerprint: str, incident: Incident, tenant_id: str, session: Optional[Session] = None, ): - return add_alerts_to_incident(tenant_id, incident, [alert_id], session=session) + return add_alerts_to_incident(tenant_id, incident, [fingerprint], session=session) def is_alert_assigned_to_incident( @@ -3084,6 +3086,32 @@ def filter_query(session: Session, query, field, value): return query +def enrich_incidents_with_alerts(tenant_id: str, incidents: List[Incident], session: Optional[Session]=None): + with existed_or_new_session(session) as session: + incident_alerts = session.exec( + select(LastAlertToIncident.incident_id, Alert) + .select_from(LastAlert) + .join(LastAlertToIncident, and_( + LastAlertToIncident.fingerprint == LastAlert.fingerprint, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + )) + .join(Alert, LastAlert.alert_id == Alert.id) + .where( + LastAlert.tenant_id == tenant_id, + LastAlertToIncident.incident_id.in_([incident.id for incident in incidents]) + ) + ).all() + + alerts_per_incident = defaultdict(list) + for incident_id, alert in incident_alerts: + alerts_per_incident[incident_id].append(alert) + + for incident in incidents: + incident._alerts = alerts_per_incident[incident.id] + + return incidents + + def get_last_incidents( tenant_id: str, limit: int = 25, @@ -3125,9 +3153,6 @@ def get_last_incidents( if allowed_incident_ids: query = query.filter(Incident.id.in_(allowed_incident_ids)) - if with_alerts: - query = query.options(joinedload(Incident.alerts)) - if is_predicted is not None: query = query.filter(Incident.is_predicted == is_predicted) @@ -3159,23 +3184,30 @@ def get_last_incidents( # Execute the query incidents = query.all() + if with_alerts: + enrich_incidents_with_alerts(tenant_id, incidents, session) + return incidents, total_count def get_incident_by_id( - tenant_id: str, incident_id: str | UUID, with_alerts: bool = False + tenant_id: str, incident_id: str | UUID, with_alerts: bool = False, + session: Optional[Session] = None, ) -> Optional[Incident]: - with Session(engine) as session: + with existed_or_new_session(session) as session: query = session.query( Incident, ).filter( Incident.tenant_id == tenant_id, Incident.id == incident_id, ) + incident = query.first() if with_alerts: - query = query.options(joinedload(Incident.alerts)) + enrich_incidents_with_alerts( + tenant_id, [incident], session, + ) - return query.first() + return incident def create_incident_from_dto( @@ -3232,7 +3264,6 @@ def create_incident_from_dict( session.add(new_incident) session.commit() session.refresh(new_incident) - new_incident.alerts = [] return new_incident @@ -3344,43 +3375,24 @@ def get_incident_alerts_and_links_by_incident_id( ) -> tuple[List[tuple[Alert, AlertToIncident]], int]: with existed_or_new_session(session) as session: - last_fingerprints_subquery = ( - session.query( - Alert.fingerprint, func.max(Alert.timestamp).label("max_timestamp") - ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) - .filter( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident_id, - ) - .group_by(Alert.fingerprint) - .subquery() - ) - query = ( session.query( Alert, - AlertToIncident, - ) - .select_from(last_fingerprints_subquery) - .outerjoin( - Alert, - and_( - last_fingerprints_subquery.c.fingerprint == Alert.fingerprint, - last_fingerprints_subquery.c.max_timestamp == Alert.timestamp, - ), + LastAlertToIncident, ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .select_from(LastAlertToIncident) + .join(LastAlert, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident_id, ) - .order_by(col(Alert.timestamp).desc()) + .order_by(col(LastAlert.timestamp).desc()) .options(joinedload(Alert.alert_enrichment)) ) if not include_unlinked: query = query.filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, ) total_count = query.count() @@ -3442,8 +3454,7 @@ def get_all_same_alert_ids( def get_alerts_data_for_incident( tenant_id: str, - alert_ids: List[str | UUID], - existed_fingerprints: Optional[List[str]] = None, + fingerprints: Optional[List[str]] = None, session: Optional[Session] = None, ) -> dict: """ @@ -3457,8 +3468,6 @@ def get_alerts_data_for_incident( Returns: dict {sources: list[str], services: list[str], count: int} """ - existed_fingerprints = existed_fingerprints or [] - with existed_or_new_session(session) as session: fields = ( @@ -3469,16 +3478,18 @@ def get_alerts_data_for_incident( ) alerts_data = session.exec( - select(*fields).where( - Alert.tenant_id == tenant_id, - col(Alert.id).in_(alert_ids), + select(*fields) + .select_from(LastAlert) + .join(Alert, LastAlert.alert_id == Alert.id) + .where( + LastAlert.tenant_id == tenant_id, + col(LastAlert.fingerprint).in_(fingerprints), ) ).all() sources = [] services = [] severities = [] - fingerprints = set() for service, source, fingerprint, severity in alerts_data: if source: @@ -3490,21 +3501,19 @@ def get_alerts_data_for_incident( severities.append(IncidentSeverity.from_number(severity)) else: severities.append(IncidentSeverity(severity)) - if fingerprint and fingerprint not in existed_fingerprints: - fingerprints.add(fingerprint) return { "sources": set(sources), "services": set(services), "max_severity": max(severities), - "count": len(fingerprints), + "count": len(alerts_data), } def add_alerts_to_incident_by_incident_id( tenant_id: str, incident_id: str | UUID, - alert_ids: List[UUID], + fingerprints: List[str], is_created_by_ai: bool = False, session: Optional[Session] = None, ) -> Optional[Incident]: @@ -3518,62 +3527,52 @@ def add_alerts_to_incident_by_incident_id( if not incident: return None return add_alerts_to_incident( - tenant_id, incident, alert_ids, is_created_by_ai, session + tenant_id, incident, fingerprints, is_created_by_ai, session ) def add_alerts_to_incident( tenant_id: str, incident: Incident, - alert_ids: List[UUID], + fingerprints: List[str], is_created_by_ai: bool = False, session: Optional[Session] = None, override_count: bool = False, ) -> Optional[Incident]: logger.info( - f"Adding alerts to incident {incident.id} in database, total {len(alert_ids)} alerts", + f"Adding alerts to incident {incident.id} in database, total {len(fingerprints)} alerts", extra={"tags": {"tenant_id": tenant_id, "incident_id": incident.id}}, ) with existed_or_new_session(session) as session: with session.no_autoflush: - all_alert_ids = get_all_same_alert_ids(tenant_id, alert_ids, session) # Use a set for faster membership checks - existing_alert_ids = set( - session.exec( - select(AlertToIncident.alert_id).where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, - col(AlertToIncident.alert_id).in_(all_alert_ids), - ) - ).all() - ) + existing_fingerprints = set( session.exec( - select(Alert.fingerprint) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + select(LastAlert.fingerprint) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == LastAlert.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) ).all() ) - new_alert_ids = [ - alert_id - for alert_id in all_alert_ids - if alert_id not in existing_alert_ids - ] + new_fingerprints = { + fingerprint + for fingerprint in fingerprints + if fingerprint not in existing_fingerprints + } - if not new_alert_ids: + if not new_fingerprints: return incident alerts_data_for_incident = get_alerts_data_for_incident( - tenant_id, new_alert_ids, existing_fingerprints, session + tenant_id, new_fingerprints, session ) incident.sources = list( @@ -3595,13 +3594,13 @@ def add_alerts_to_incident( else: incident.alerts_count = alerts_data_for_incident["count"] alert_to_incident_entries = [ - AlertToIncident( - alert_id=alert_id, + LastAlertToIncident( + fingerprint=fingerprint, incident_id=incident.id, tenant_id=tenant_id, is_created_by_ai=is_created_by_ai, ) - for alert_id in new_alert_ids + for fingerprint in new_fingerprints ] for idx, entry in enumerate(alert_to_incident_entries): @@ -3618,11 +3617,11 @@ def add_alerts_to_incident( started_at, last_seen_at = session.exec( select(func.min(Alert.timestamp), func.max(Alert.timestamp)) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == Alert.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) ).one() @@ -3639,12 +3638,11 @@ def get_incident_unique_fingerprint_count(tenant_id: str, incident_id: str) -> i with Session(engine) as session: return session.execute( select(func.count(1)) - .select_from(AlertToIncident) - .join(Alert, AlertToIncident.alert_id == Alert.id) + .select_from(LastAlertToIncident) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - Alert.tenant_id == tenant_id, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident_id, ) ).scalar() @@ -3656,12 +3654,14 @@ def get_last_alerts_for_incidents( query = ( session.query( Alert, - AlertToIncident.incident_id, + LastAlertToIncident.incident_id, ) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id.in_(incident_ids), + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id.in_(incident_ids), ) .order_by(Alert.timestamp.desc()) ) @@ -3676,7 +3676,7 @@ def get_last_alerts_for_incidents( def remove_alerts_to_incident_by_incident_id( - tenant_id: str, incident_id: str | UUID, alert_ids: List[UUID] + tenant_id: str, incident_id: str | UUID, fingerprints: List[str] ) -> Optional[int]: with Session(engine) as session: incident = session.exec( @@ -3689,16 +3689,14 @@ def remove_alerts_to_incident_by_incident_id( if not incident: return None - all_alert_ids = get_all_same_alert_ids(tenant_id, alert_ids, session) - # Removing alerts-to-incident relation for provided alerts_ids deleted = ( - session.query(AlertToIncident) + session.query(LastAlertToIncident) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, - col(AlertToIncident.alert_id).in_(all_alert_ids), + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, + col(LastAlertToIncident.fingerprint).in_(fingerprints), ) .update( { @@ -3710,7 +3708,7 @@ def remove_alerts_to_incident_by_incident_id( # Getting aggregated data for incidents for alerts which just was removed alerts_data_for_incident = get_alerts_data_for_incident( - tenant_id, all_alert_ids, session=session + tenant_id, fingerprints, session=session ) service_field = get_json_extract_field(session, Alert.event, "service") @@ -3719,10 +3717,12 @@ def remove_alerts_to_incident_by_incident_id( # which still assigned with the incident existed_services_query = ( select(func.distinct(service_field)) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident_id, service_field.in_(alerts_data_for_incident["services"]), ) ) @@ -3732,7 +3732,9 @@ def remove_alerts_to_incident_by_incident_id( # which still assigned with the incident existed_sources_query = ( select(col(Alert.provider_type).distinct()) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, AlertToIncident.incident_id == incident_id, @@ -3755,10 +3757,12 @@ def remove_alerts_to_incident_by_incident_id( started_at, last_seen_at = session.exec( select(func.min(Alert.timestamp), func.max(Alert.timestamp)) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .where( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) ).one() @@ -3800,7 +3804,6 @@ def merge_incidents_to_id( .where( Incident.tenant_id == tenant_id, Incident.id == destination_incident_id ) - .options(joinedload(Incident.alerts)) ).first() if not destination_incident: @@ -3815,12 +3818,14 @@ def merge_incidents_to_id( ) ).all() + enrich_incidents_with_alerts(tenant_id, source_incidents, session=session) + merged_incident_ids = [] skipped_incident_ids = [] failed_incident_ids = [] for source_incident in source_incidents: - source_incident_alerts_ids = [alert.id for alert in source_incident.alerts] - if not source_incident_alerts_ids: + source_incident_alerts_fingerprints = [alert.fingerprint for alert in source_incident._alerts] + if not source_incident_alerts_fingerprints: logger.info(f"Source incident {source_incident.id} doesn't have alerts") skipped_incident_ids.append(source_incident.id) continue @@ -3832,7 +3837,7 @@ def merge_incidents_to_id( remove_alerts_to_incident_by_incident_id( tenant_id, source_incident.id, - [alert.id for alert in source_incident.alerts], + [alert.fingerprint for alert in source_incident._alerts], ) except OperationalError as e: logger.error( @@ -3842,7 +3847,7 @@ def merge_incidents_to_id( add_alerts_to_incident( tenant_id, destination_incident, - source_incident_alerts_ids, + source_incident_alerts_fingerprints, session=session, ) merged_incident_ids.append(source_incident.id) @@ -4267,17 +4272,16 @@ def is_all_incident_alerts_resolved( enriched_status_field.label("enriched_status"), status_field.label("status"), ) - .select_from(Alert) + .select_from(LastAlert) + .join(Alert, LastAlert.alert_id == Alert.id) .outerjoin( AlertEnrichment, Alert.fingerprint == AlertEnrichment.alert_fingerprint ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == LastAlert.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident.id, ) - .group_by(Alert.fingerprint) - .having(func.max(Alert.timestamp)) ).subquery() not_resolved_exists = session.query( @@ -4334,8 +4338,8 @@ def is_edge_incident_alert_resolved( .outerjoin( AlertEnrichment, Alert.fingerprint == AlertEnrichment.alert_fingerprint ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) - .where(AlertToIncident.incident_id == incident.id) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == Alert.fingerprint) + .where(LastAlertToIncident.incident_id == incident.id) .group_by(Alert.fingerprint) .having(func.max(Alert.timestamp)) .order_by(direction(Alert.timestamp)) @@ -4469,3 +4473,43 @@ def get_resource_ids_by_resource_type( # Execute the query and return results result = session.exec(query) return result.all() + + +def get_last_alert_by_fingerprint( + tenant_id: str, fingerprint: str, session: Optional[Session] = None +) -> Optional[Alert]: + with existed_or_new_session(session) as session: + return session.exec( + select(LastAlert) + .where( + and_( + LastAlert.tenant_id == tenant_id, + LastAlert.fingerprint == fingerprint, + ) + ) + ).first() + +def set_last_alert( + tenant_id: str, alert: Alert, session: Optional[Session] = None +) -> None: + last_alert = get_last_alert_by_fingerprint(tenant_id, alert.fingerprint, session) + + # To prevent rare, but possible race condition + # For example if older alert failed to process + # and retried after new one + + if last_alert and last_alert.timestamp.replace(tzinfo=tz.UTC) < alert.timestamp.replace(tzinfo=tz.UTC): + last_alert.timestamp = alert.timestamp + last_alert.alert_id = alert.id + session.add(last_alert) + session.commit() + + elif not last_alert: + last_alert = LastAlert( + tenant_id=tenant_id, + fingerprint=alert.fingerprint, + timestamp=alert.timestamp, + alert_id=alert.id, + ) + session.add(last_alert) + session.commit() diff --git a/keep/api/models/db/alert.py b/keep/api/models/db/alert.py index 11efa9ef7..86449a5a5 100644 --- a/keep/api/models/db/alert.py +++ b/keep/api/models/db/alert.py @@ -4,12 +4,13 @@ from typing import List, Optional from uuid import UUID, uuid4 +from pydantic import PrivateAttr from sqlalchemy import ForeignKey, UniqueConstraint from sqlalchemy.dialects.mssql import DATETIME2 as MSSQL_DATETIME2 from sqlalchemy.dialects.mysql import DATETIME as MySQL_DATETIME from sqlalchemy.engine.url import make_url from sqlalchemy_utils import UUIDType -from sqlmodel import JSON, TEXT, Column, DateTime, Field, Index, Relationship, SQLModel +from sqlmodel import JSON, TEXT, Column, DateTime, Field, Index, Relationship, SQLModel, Session from keep.api.consts import RUNNING_IN_CLOUD_RUN from keep.api.core.config import config @@ -60,8 +61,8 @@ class AlertToIncident(SQLModel, table=True): primary_key=True, ) ) - alert: "Alert" = Relationship(back_populates="alert_to_incident_link") - incident: "Incident" = Relationship(back_populates="alert_to_incident_link") + # alert: "Alert" = Relationship(back_populates="alert_to_incident_link") + # incident: "Incident" = Relationship(back_populates="alert_to_incident_link") is_created_by_ai: bool = Field(default=False) @@ -72,6 +73,49 @@ class AlertToIncident(SQLModel, table=True): default=NULL_FOR_DELETED_AT, ) +class LastAlert(SQLModel, table=True): + + tenant_id: str = Field(foreign_key="tenant.id", nullable=False) + fingerprint: str = Field(primary_key=True) + alert_id: UUID = Field(foreign_key="alert.id") + timestamp: datetime = Field(nullable=False, index=True) + + +class LastAlertToIncident(SQLModel, table=True): + tenant_id: str = Field(foreign_key="tenant.id", nullable=False) + timestamp: datetime = Field(default_factory=datetime.utcnow) + + fingerprint: str = Field(foreign_key="lastalert.fingerprint", primary_key=True) + incident_id: UUID = Field( + sa_column=Column( + UUIDType(binary=False), + ForeignKey("incident.id", ondelete="CASCADE"), + primary_key=True, + ) + ) + + is_created_by_ai: bool = Field(default=False) + + deleted_at: datetime = Field( + default_factory=None, + nullable=True, + primary_key=True, + default=NULL_FOR_DELETED_AT, + ) + + # alert: "Alert" = Relationship( + # back_populates="alert_to_incident_link", + # sa_relationship = relationship( + # "Alert", + # secondary="lastalert", + # primaryjoin=f"""LastAlertToIncident.fingerprint == LastAlert.fingerprint""", + # secondaryjoin="LastAlert.alert_id == Alert.id", + # overlaps="alert,lastalert", + # viewonly=True, + # ), + # ) + # incident: "Incident" = Relationship(back_populates="alert_to_incident_link") + class Incident(SQLModel, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True) @@ -96,25 +140,10 @@ class Incident(SQLModel, table=True): end_time: datetime | None last_seen_time: datetime | None - # map of attributes to values - alerts: List["Alert"] = Relationship( - back_populates="incidents", - link_model=AlertToIncident, - # primaryjoin is used to filter out deleted links for various DB dialects - sa_relationship_kwargs={ - "primaryjoin": f"""and_(AlertToIncident.incident_id == Incident.id, - or_( - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S.%f')}', - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S')}' - ))""", - "uselist": True, - "overlaps": "alert,incident", - }, - ) - alert_to_incident_link: List[AlertToIncident] = Relationship( - back_populates="incident", - sa_relationship_kwargs={"overlaps": "alerts,incidents"}, - ) + # alert_to_incident_link: List[LastAlertToIncident] = Relationship( + # back_populates="incident", + # sa_relationship_kwargs={"overlaps": "alerts,incidents"}, + # ) is_predicted: bool = Field(default=False) is_confirmed: bool = Field(default=False) @@ -183,10 +212,7 @@ class Incident(SQLModel, table=True): ), ) - def __init__(self, **kwargs): - super().__init__(**kwargs) - if "alerts" not in kwargs: - self.alerts = [] + _alerts: List["Alert"] = PrivateAttr() class Config: arbitrary_types_allowed = True @@ -224,24 +250,6 @@ class Alert(SQLModel, table=True): } ) - incidents: List["Incident"] = Relationship( - back_populates="alerts", - link_model=AlertToIncident, - sa_relationship_kwargs={ - # primaryjoin is used to filter out deleted links for various DB dialects - "primaryjoin": f"""and_(AlertToIncident.alert_id == Alert.id, - or_( - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S.%f')}', - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S')}' - ))""", - "uselist": True, - "overlaps": "alert,incident", - }, - ) - alert_to_incident_link: List[AlertToIncident] = Relationship( - back_populates="alert", sa_relationship_kwargs={"overlaps": "alerts,incidents"} - ) - __table_args__ = ( Index( "ix_alert_tenant_fingerprint_timestamp", diff --git a/keep/api/models/db/migrations/versions/2024-11-05-22-48_bdae8684d0b4.py b/keep/api/models/db/migrations/versions/2024-11-05-22-48_bdae8684d0b4.py new file mode 100644 index 000000000..30d63bd3c --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-11-05-22-48_bdae8684d0b4.py @@ -0,0 +1,180 @@ +"""add lastalert and lastalerttoincident table + +Revision ID: bdae8684d0b4 +Revises: ef0b5b0df41c +Create Date: 2024-11-05 22:48:04.733192 + +""" +import warnings + +import sqlalchemy as sa +import sqlalchemy_utils +import sqlmodel +from alembic import op +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Session +from sqlalchemy.sql import expression +from sqlalchemy import exc as sa_exc + +# revision identifiers, used by Alembic. +revision = "bdae8684d0b4" +down_revision = "ef0b5b0df41c" +branch_labels = None +depends_on = None + +migration_metadata = sa.MetaData() +# +# alert_to_incident_table = sa.Table( +# 'alerttoincident', +# migration_metadata, +# sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), +# sa.Column('alert_id', UUID(as_uuid=False), sa.ForeignKey('alert.id', ondelete='CASCADE'), primary_key=True), +# sa.Column('incident_id', UUID(as_uuid=False), sa.ForeignKey('incident.id', ondelete='CASCADE'), primary_key=True), +# sa.Column("timestamp", sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), +# sa.Column("is_created_by_ai", sa.Boolean(), nullable=False, server_default=expression.false()), +# sa.Column("deleted_at", sa.DateTime(), nullable=False, server_default="1000-01-01 00:00:00"), +# +# ) +# +# # The following code will shoow SA warning about dialect, so we suppress it. +# with warnings.catch_warnings(): +# warnings.simplefilter("ignore", category=sa_exc.SAWarning) +# incident_table = sa.Table( +# 'incident', +# migration_metadata, +# sa.Column('id', UUID(as_uuid=False), primary_key=True), +# sa.Column('alerts_count', sa.Integer, default=0), +# sa.Column('affected_services', sa.JSON, default_factory=list), +# sa.Column('sources', sa.JSON, default_factory=list) +# ) +# +# alert_table = sa.Table( +# 'alert', +# migration_metadata, +# sa.Column('id', UUID(as_uuid=False), primary_key=True), +# sa.Column('fingerprint', sa.String), +# sa.Column('provider_type', sa.String), +# sa.Column('event', sa.JSON) +# ) + +# +def populate_db(): + session = Session(op.get_bind()) + + if session.bind.dialect.name == "postgresql": + migrate_lastalert_query = """ + insert into lastalert (fingerprint, alert_id, timestamp) + select alert.fingerprint, alert.id as alert_id, alert.timestamp + from alert + join ( + select + alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint + ) as a ON alert.fingerprint = a.fingerprint and alert.timestamp = a.last_received + on conflict + do nothing + """ + + migrate_lastalerttoincodent_query = """ + insert into lastalerttoincident (incident_id, tenant_id, timestamp, fingerprint, is_created_by_ai, deleted_at) + select ati.incident_id, ati.tenant_id, ati.timestamp, lf.fingerprint, ati.is_created_by_ai, ati.deleted_at + from alerttoincident as ati + join + ( + select alert.id, alert.fingerprint + from alert + join ( + select + alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint + ) as a on alert.fingerprint = a.fingerprint and alert.timestamp = a.last_received + ) as lf on ati.alert_id = lf.id + on conflict + do nothing + """ + + else: + migrate_lastalert_query = """ + replace into lastalert (fingerprint, alert_id, timestamp) + select alert.fingerprint, alert.id as alert_id, alert.timestamp + from alert + join ( + select + alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint + ) as a ON alert.fingerprint = a.fingerprint and alert.timestamp = a.last_received; + """ + + migrate_lastalerttoincodent_query = """ + replace into lastalerttoincident (incident_id, tenant_id, timestamp, fingerprint, is_created_by_ai, deleted_at) + select ati.incident_id, ati.tenant_id, ati.timestamp, lf.fingerprint, ati.is_created_by_ai, ati.deleted_at + from alerttoincident as ati + join + ( + select alert.id, alert.fingerprint + from alert + join ( + select + alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint + ) as a on alert.fingerprint = a.fingerprint and alert.timestamp = a.last_received + ) as lf on ati.alert_id = lf.id + """ + + session.execute(migrate_lastalert_query) + session.execute(migrate_lastalerttoincodent_query) + + +def upgrade() -> None: + op.create_table( + "lastalert", + sa.Column("fingerprint", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("alert_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["alert_id"], + ["alert.id"], + ), + sa.PrimaryKeyConstraint("fingerprint"), + ) + with op.batch_alter_table("lastalert", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("ix_lastalert_timestamp"), ["timestamp"], unique=False + ) + + op.create_table( + "lastalerttoincident", + sa.Column( + "incident_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=False, + ), + sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("fingerprint", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("is_created_by_ai", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["fingerprint"], + ["lastalert.fingerprint"], + ), + sa.ForeignKeyConstraint(["incident_id"], ["incident.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("incident_id", "fingerprint", "deleted_at"), + ) + + populate_db() + +def downgrade() -> None: + op.drop_table("lastalerttoincident") + with op.batch_alter_table("lastalert", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_lastalert_timestamp")) + + op.drop_table("lastalert") diff --git a/keep/api/routes/incidents.py b/keep/api/routes/incidents.py index d670a6ac6..7aa9ca5c8 100644 --- a/keep/api/routes/incidents.py +++ b/keep/api/routes/incidents.py @@ -614,7 +614,7 @@ def change_incident_status( # TODO: same this change to audit table with the comment if change.status == IncidentStatus.RESOLVED: - for alert in incident.alerts: + for alert in incident._alerts: _enrich_alert( EnrichAlertRequestBody( enrichments={"status": "resolved"}, diff --git a/keep/api/tasks/process_event_task.py b/keep/api/tasks/process_event_task.py index 53a064dcc..e442ebc12 100644 --- a/keep/api/tasks/process_event_task.py +++ b/keep/api/tasks/process_event_task.py @@ -22,7 +22,7 @@ get_alerts_by_fingerprint, get_all_presets, get_enrichment_with_session, - get_session_sync, + get_session_sync, set_last_alert, ) from keep.api.core.dependencies import get_pusher_client from keep.api.core.elastic import ElasticClient @@ -184,6 +184,9 @@ def __save_to_db( ) session.add(audit) alert_dto = AlertDto(**formatted_event.dict()) + + set_last_alert(tenant_id, alert, session=session) + # Mapping try: enrichments_bl.run_mapping_rules(alert_dto) @@ -400,7 +403,7 @@ def __handle_formatted_events( # logger.info("Adding group alerts to the workflow manager queue") # workflow_manager.insert_events(tenant_id, grouped_alerts) # logger.info("Added group alerts to the workflow manager queue") - except Exception: + except Exception as ex: logger.exception( "Failed to run rules engine", extra={ diff --git a/keep/api/utils/enrichment_helpers.py b/keep/api/utils/enrichment_helpers.py index 86e9795be..c3af0321a 100644 --- a/keep/api/utils/enrichment_helpers.py +++ b/keep/api/utils/enrichment_helpers.py @@ -1,8 +1,12 @@ import logging from datetime import datetime +from optparse import Option +from typing import Optional from opentelemetry import trace +from sqlmodel import Session +from keep.api.core.db import existed_or_new_session from keep.api.models.alert import AlertDto, AlertStatus, AlertWithIncidentLinkMetadataDto from keep.api.models.db.alert import Alert, AlertToIncident @@ -78,7 +82,8 @@ def calculated_start_firing_time( def convert_db_alerts_to_dto_alerts( alerts: list[Alert | tuple[Alert, AlertToIncident]], - with_incidents: bool = False + with_incidents: bool = False, + session: Optional[Session] = None, ) -> list[AlertDto | AlertWithIncidentLinkMetadataDto]: """ Enriches the alerts with the enrichment data. @@ -90,46 +95,47 @@ def convert_db_alerts_to_dto_alerts( Returns: list[AlertDto | AlertWithIncidentLinkMetadataDto]: The enriched alerts. """ - alerts_dto = [] - with tracer.start_as_current_span("alerts_enrichment"): - # enrich the alerts with the enrichment data - for _object in alerts: - - # We may have an Alert only or and Alert with an AlertToIncident - if isinstance(_object, Alert): - alert, alert_to_incident = _object, None - else: - alert, alert_to_incident = _object - - if alert.alert_enrichment: - alert.event.update(alert.alert_enrichment.enrichments) - if with_incidents: - if alert.incidents: - alert.event["incident"] = ",".join(str(incident.id) for incident in alert.incidents) - try: - if alert_to_incident is not None: - alert_dto = AlertWithIncidentLinkMetadataDto.from_db_instance(alert, alert_to_incident) + with existed_or_new_session(session) as session: + alerts_dto = [] + with tracer.start_as_current_span("alerts_enrichment"): + # enrich the alerts with the enrichment data + for _object in alerts: + + # We may have an Alert only or and Alert with an AlertToIncident + if isinstance(_object, Alert): + alert, alert_to_incident = _object, None else: - alert_dto = AlertDto(**alert.event) + alert, alert_to_incident = _object + if alert.alert_enrichment: - parse_and_enrich_deleted_and_assignees( - alert_dto, alert.alert_enrichment.enrichments + alert.event.update(alert.alert_enrichment.enrichments) + if with_incidents: + if alert.incidents: + alert.event["incident"] = ",".join(str(incident.id) for incident in alert.incidents) + try: + if alert_to_incident is not None: + alert_dto = AlertWithIncidentLinkMetadataDto.from_db_instance(alert, alert_to_incident) + else: + alert_dto = AlertDto(**alert.event) + if alert.alert_enrichment: + parse_and_enrich_deleted_and_assignees( + alert_dto, alert.alert_enrichment.enrichments + ) + except Exception: + # should never happen but just in case + logger.exception( + "Failed to parse alert", + extra={ + "alert": alert, + }, ) - except Exception: - # should never happen but just in case - logger.exception( - "Failed to parse alert", - extra={ - "alert": alert, - }, - ) - continue - - alert_dto.event_id = str(alert.id) - - # enrich provider id when it's possible - if alert_dto.providerId is None: - alert_dto.providerId = alert.provider_id - alert_dto.providerType = alert.provider_type - alerts_dto.append(alert_dto) + continue + + alert_dto.event_id = str(alert.id) + + # enrich provider id when it's possible + if alert_dto.providerId is None: + alert_dto.providerId = alert.provider_id + alert_dto.providerType = alert.provider_type + alerts_dto.append(alert_dto) return alerts_dto diff --git a/keep/rulesengine/rulesengine.py b/keep/rulesengine/rulesengine.py index 363901c16..03538b8e3 100644 --- a/keep/rulesengine/rulesengine.py +++ b/keep/rulesengine/rulesengine.py @@ -81,7 +81,7 @@ def run_rules( ) incident = assign_alert_to_incident( - alert_id=event.event_id, + fingerprint=event.fingerprint, incident=incident, tenant_id=self.tenant_id, session=session, @@ -101,7 +101,7 @@ def run_rules( ): should_resolve = True - if ( + elif ( rule.resolve_on == ResolveOn.LAST.value and is_last_incident_alert_resolved(incident, session=session) ): diff --git a/tests/conftest.py b/tests/conftest.py index e1e21649d..134db6a0f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ from sqlmodel import SQLModel, Session, create_engine from starlette_context import context, request_cycle_context +from keep.api.core.db import set_last_alert # This import is required to create the tables from keep.api.core.dependencies import SINGLE_TENANT_UUID from keep.api.core.elastic import ElasticClient @@ -547,6 +548,10 @@ def _setup_stress_alerts_no_elastic(num_alerts): db_session.add_all(alerts) db_session.commit() + last_alerts = [] + for alert in alerts: + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) + return alerts return _setup_stress_alerts_no_elastic diff --git a/tests/test_incidents.py b/tests/test_incidents.py index e675c9075..cc1e32a90 100644 --- a/tests/test_incidents.py +++ b/tests/test_incidents.py @@ -2,8 +2,7 @@ from itertools import cycle import pytest -from sqlalchemy import distinct, func -from sqlalchemy.orm.exc import DetachedInstanceError +from sqlalchemy import distinct, func, desc from keep.api.core.db import ( IncidentSorting, @@ -14,7 +13,7 @@ get_incident_by_id, get_last_incidents, merge_incidents_to_id, - remove_alerts_to_incident_by_incident_id, + remove_alerts_to_incident_by_incident_id, enrich_incidents_with_alerts, ) from keep.api.core.db_utils import get_json_extract_field from keep.api.core.dependencies import SINGLE_TENANT_UUID @@ -24,7 +23,7 @@ IncidentSeverity, IncidentStatus, ) -from keep.api.models.db.alert import Alert, AlertToIncident +from keep.api.models.db.alert import Alert, AlertToIncident, LastAlertToIncident from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts from tests.fixtures.client import client, test_app # noqa @@ -50,7 +49,7 @@ def test_get_alerts_data_for_incident(db_session, create_alert): assert 100 == db_session.query(func.count(Alert.id)).scalar() assert 10 == unique_fingerprints - data = get_alerts_data_for_incident(SINGLE_TENANT_UUID, [a.id for a in alerts]) + data = get_alerts_data_for_incident(SINGLE_TENANT_UUID, [a.fingerprint for a in alerts]) assert data["sources"] == set([f"source_{i}" for i in range(10)]) assert data["services"] == set([f"service_{i}" for i in range(10)]) assert data["count"] == unique_fingerprints @@ -64,16 +63,27 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} ) - assert len(incident.alerts) == 0 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 0 + assert total_incident_alerts == 0 add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [a.id for a in alerts] + SINGLE_TENANT_UUID, incident.id, [a.fingerprint for a in alerts] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - # 110 alerts - assert len(incident.alerts) == 110 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 100 + assert total_incident_alerts == 100 # But 100 unique fingerprints assert incident.alerts_count == 100 @@ -86,7 +96,7 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti service_field = get_json_extract_field(db_session, Alert.event, "service") - service_0 = db_session.query(Alert.id).filter(service_field == "service_0").all() + service_0 = db_session.query(Alert.fingerprint).filter(service_field == "service_0").all() # Testing unique fingerprints more_alerts_with_same_fingerprints = setup_stress_alerts_no_elastic(10) @@ -94,26 +104,32 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti add_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, - [a.id for a in more_alerts_with_same_fingerprints], + [a.fingerprint for a in more_alerts_with_same_fingerprints], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) assert incident.alerts_count == 100 - assert db_session.query(func.count(AlertToIncident.alert_id)).scalar() == 120 + assert db_session.query(func.count(LastAlertToIncident.fingerprint)).scalar() == 100 remove_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, [ - service_0[0].id, + service_0[0].fingerprint, ], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - # 117 because we removed multiple alerts with service_0 - assert len(incident.alerts) == 117 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 99 + assert total_incident_alerts == 99 + assert "service_0" in incident.affected_services assert len(incident.affected_services) == 10 assert sorted(incident.affected_services) == sorted( @@ -121,7 +137,7 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti ) remove_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [a.id for a in service_0] + SINGLE_TENANT_UUID, incident.id, [a.fingerprint for a in service_0] ) # Removing shouldn't impact links between alert and incident if include_unlinked=True @@ -138,8 +154,14 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - # 108 because we removed multiple alert with same fingerprints - assert len(incident.alerts) == 108 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 90 + assert total_incident_alerts == 90 + assert "service_0" not in incident.affected_services assert len(incident.affected_services) == 9 assert sorted(incident.affected_services) == sorted( @@ -147,29 +169,36 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti ) source_1 = ( - db_session.query(Alert.id).filter(Alert.provider_type == "source_1").all() + db_session.query(Alert.fingerprint).filter(Alert.provider_type == "source_1").all() ) remove_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, [ - source_1[0].id, + source_1[0].fingerprint, ], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 105 - assert "source_1" in incident.sources - # source_0 was removed together with service_0 - assert len(incident.sources) == 9 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 89 + assert total_incident_alerts == 89 + + assert "source_1" not in incident.sources + # source_0 was removed together with service_1 + assert len(incident.sources) == 8 assert sorted(incident.sources) == sorted( - ["source_{}".format(i) for i in range(1, 10)] + ["source_{}".format(i) for i in range(2, 10)] ) remove_alerts_to_incident_by_incident_id( - "keep", incident.id, [a.id for a in source_1] + "keep", incident.id, [a.fingerprint for a in source_1] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) @@ -213,8 +242,19 @@ def test_get_last_incidents(db_session, create_alert): ) alert = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() + create_alert( + f"alert-test-2-{i}", + AlertStatus(status), + datetime.utcnow(), + { + "severity": AlertSeverity.from_number(severity), + "service": service, + }, + ) + alert2 = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() + add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [alert.id] + SINGLE_TENANT_UUID, incident.id, [alert.fingerprint, alert2.fingerprint] ) incidents_default, incidents_default_count = get_last_incidents(SINGLE_TENANT_UUID) @@ -246,19 +286,14 @@ def test_get_last_incidents(db_session, create_alert): for i, j in enumerate(range(5, 10)): assert incidents_limit_5_page_2[i].user_generated_name == f"test-{j}" - # If alerts not preloaded, we will have detached session issue during attempt to get them - # Background on this error at: https://sqlalche.me/e/14/bhk3 - with pytest.raises(DetachedInstanceError): - alerts = incidents_confirmed[0].alerts # noqa - incidents_with_alerts, _ = get_last_incidents( SINGLE_TENANT_UUID, is_confirmed=True, with_alerts=True ) for i in range(25): if incidents_with_alerts[i].status == IncidentStatus.MERGED.value: - assert len(incidents_with_alerts[i].alerts) == 0 + assert len(incidents_with_alerts[i]._alerts) == 0 else: - assert len(incidents_with_alerts[i].alerts) == 1 + assert len(incidents_with_alerts[i]._alerts) == 2 # Test sorting @@ -319,11 +354,16 @@ def test_incident_status_change( "keep", {"name": "test", "description": "test"} ) - add_alerts_to_incident_by_incident_id("keep", incident.id, [a.id for a in alerts]) + add_alerts_to_incident_by_incident_id( + "keep", + incident.id, + [a.fingerprint for a in alerts], + session=db_session + ) - incident = get_incident_by_id("keep", incident.id, with_alerts=True) + incident = get_incident_by_id("keep", incident.id, with_alerts=True, session=db_session) - alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) + alerts_dtos = convert_db_alerts_to_dto_alerts(incident._alerts, session=db_session) assert ( len( [ @@ -348,10 +388,11 @@ def test_incident_status_change( assert data["id"] == str(incident.id) assert data["status"] == IncidentStatus.ACKNOWLEDGED.value - incident = get_incident_by_id("keep", incident.id, with_alerts=True) + db_session.expire_all() + incident = get_incident_by_id("keep", incident.id, with_alerts=True, session=db_session) assert incident.status == IncidentStatus.ACKNOWLEDGED.value - alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) + alerts_dtos = convert_db_alerts_to_dto_alerts(incident._alerts) assert ( len( [ @@ -376,11 +417,12 @@ def test_incident_status_change( assert data["id"] == str(incident.id) assert data["status"] == IncidentStatus.RESOLVED.value - incident = get_incident_by_id("keep", incident.id, with_alerts=True) + db_session.expire_all() + incident = get_incident_by_id("keep", incident.id, with_alerts=True, session=db_session) assert incident.status == IncidentStatus.RESOLVED.value # All alerts are resolved as well - alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) + alerts_dtos = convert_db_alerts_to_dto_alerts(incident._alerts, session=db_session) assert ( len( [ @@ -475,23 +517,48 @@ def test_add_alerts_with_same_fingerprint_to_incident(db_session, create_alert): SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} ) - assert len(incident.alerts) == 0 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 0 + assert total_incident_alerts == 0 add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].id] + SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].fingerprint] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 2 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 1 + last_fp1_alert = ( + db_session + .query(Alert.timestamp) + .where(Alert.fingerprint == "fp1") + .order_by(desc(Alert.timestamp)).first() + ) + assert incident_alerts[0].timestamp == last_fp1_alert.timestamp + assert total_incident_alerts == 1 remove_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].id] + SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].fingerprint] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 0 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 0 + assert total_incident_alerts == 0 def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elastic): @@ -522,7 +589,7 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti ) alerts_1 = db_session.query(Alert).all() add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident_1.id, [a.id for a in alerts_1] + SINGLE_TENANT_UUID, incident_1.id, [a.fingerprint for a in alerts_1] ) incident_2 = create_incident_from_dict( SINGLE_TENANT_UUID, @@ -532,19 +599,19 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti }, ) create_alert( - "fp20", + "fp20-0", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.CRITICAL.value}, ) create_alert( - "fp20", + "fp20-1", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.CRITICAL.value}, ) create_alert( - "fp20", + "fp20-2", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.CRITICAL.value}, @@ -553,7 +620,7 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti db_session.query(Alert).filter(Alert.fingerprint.startswith("fp20")).all() ) add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident_2.id, [a.id for a in alerts_2] + SINGLE_TENANT_UUID, incident_2.id, [a.fingerprint for a in alerts_2] ) incident_3 = create_incident_from_dict( SINGLE_TENANT_UUID, @@ -563,28 +630,28 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti }, ) create_alert( - "fp30", + "fp30-0", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.WARNING.value}, ) create_alert( - "fp30", + "fp30-1", AlertStatus.FIRING, datetime.utcnow(), - {"severity": AlertSeverity.WARNING.value}, + {"severity": AlertSeverity.INFO.value}, ) create_alert( - "fp30", + "fp30-2", AlertStatus.FIRING, datetime.utcnow(), - {"severity": AlertSeverity.INFO.value}, + {"severity": AlertSeverity.WARNING.value}, ) alerts_3 = ( db_session.query(Alert).filter(Alert.fingerprint.startswith("fp30")).all() ) add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident_3.id, [a.id for a in alerts_3] + SINGLE_TENANT_UUID, incident_3.id, [a.fingerprint for a in alerts_3] ) # before merge @@ -602,19 +669,21 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti "test-user-email", ) + db_session.expire_all() + incident_1 = get_incident_by_id(SINGLE_TENANT_UUID, incident_1.id, with_alerts=True) - assert len(incident_1.alerts) == 9 + assert len(incident_1._alerts) == 8 assert incident_1.severity == IncidentSeverity.CRITICAL.order incident_2 = get_incident_by_id(SINGLE_TENANT_UUID, incident_2.id, with_alerts=True) - assert len(incident_2.alerts) == 0 + assert len(incident_2._alerts) == 0 assert incident_2.status == IncidentStatus.MERGED.value assert incident_2.merged_into_incident_id == incident_1.id assert incident_2.merged_at is not None assert incident_2.merged_by == "test-user-email" - incident_3 = get_incident_by_id(SINGLE_TENANT_UUID, incident_3.id, with_alerts=True) - assert len(incident_3.alerts) == 0 + incident_3 = get_incident_by_id(SINGLE_TENANT_UUID, incident_3.id, with_alerts=True, session=db_session) + assert len(incident_3._alerts) == 0 assert incident_3.status == IncidentStatus.MERGED.value assert incident_3.merged_into_incident_id == incident_1.id assert incident_3.merged_at is not None diff --git a/tests/test_metrics.py b/tests/test_metrics.py index fcd4306d6..8908896b6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -18,7 +18,7 @@ def test_add_remove_alert_to_incidents( valid_api_key = "valid_api_key" setup_api_key(db_session, valid_api_key) - add_alerts_to_incident_by_incident_id("keep", incident.id, [a.id for a in alerts]) + add_alerts_to_incident_by_incident_id("keep", incident.id, [a.fingerprint for a in alerts]) response = client.get("/metrics?labels=a.b", headers={"X-API-KEY": "valid_api_key"}) diff --git a/tests/test_rules_engine.py b/tests/test_rules_engine.py index 5abf1f5a9..d075a43c8 100644 --- a/tests/test_rules_engine.py +++ b/tests/test_rules_engine.py @@ -8,7 +8,8 @@ import pytest from sqlalchemy import desc, asc -from keep.api.core.db import create_rule as create_rule_db, get_last_incidents, get_incident_alerts_by_incident_id +from keep.api.core.db import create_rule as create_rule_db, get_last_incidents, get_incident_alerts_by_incident_id, \ + set_last_alert from keep.api.core.db import get_rules as get_rules_db from keep.api.core.dependencies import SINGLE_TENANT_UUID from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus, IncidentSeverity, IncidentStatus @@ -58,10 +59,13 @@ def test_sanity(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) + db_session.add(alert) db_session.commit() + + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -105,10 +109,11 @@ def test_sanity_2(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) db_session.add(alert) db_session.commit() + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -153,10 +158,11 @@ def test_sanity_3(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) db_session.add(alert) db_session.commit() + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -201,10 +207,11 @@ def test_sanity_4(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) db_session.add(alert) db_session.commit() + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -218,7 +225,7 @@ def test_incident_attributes(db_session): AlertDto( id=str(uuid.uuid4()), source=["grafana"], - name="grafana-test-alert", + name=f"grafana-test-alert-{i}", status=AlertStatus.FIRING, severity=AlertSeverity.CRITICAL, lastReceived=datetime.datetime.now().isoformat(), @@ -250,13 +257,15 @@ def test_incident_attributes(db_session): provider_type="test", provider_id="test", event=alert.dict(), - fingerprint=hashlib.sha256(json.dumps(alert.dict()).encode()).hexdigest(), + fingerprint=alert.fingerprint, timestamp=alert.lastReceived, ) for alert in alerts_dto ] db_session.add_all(alerts) db_session.commit() + for alert in alerts: + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) for i, alert in enumerate(alerts_dto): alert.event_id = alerts[i].id @@ -275,7 +284,7 @@ def test_incident_severity(db_session): AlertDto( id=str(uuid.uuid4()), source=["grafana"], - name="grafana-test-alert", + name=f"grafana-test-alert-{i}", status=AlertStatus.FIRING, severity=AlertSeverity.INFO, lastReceived=datetime.datetime.now().isoformat(), @@ -307,13 +316,15 @@ def test_incident_severity(db_session): provider_type="test", provider_id="test", event=alert.dict(), - fingerprint=hashlib.sha256(json.dumps(alert.dict()).encode()).hexdigest(), + fingerprint=alert.fingerprint, timestamp=alert.lastReceived, ) for alert in alerts_dto ] db_session.add_all(alerts) db_session.commit() + for alert in alerts: + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) for i, alert in enumerate(alerts_dto): alert.event_id = alerts[i].id