From 70d9caf9a4e9b50906fe84e06dc51e331468328b Mon Sep 17 00:00:00 2001 From: Tal Date: Mon, 2 Dec 2024 16:44:48 +0200 Subject: [PATCH 01/12] fix: pagerduty to get single incident (#2726) --- .../pagerduty_provider/pagerduty_provider.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/keep/providers/pagerduty_provider/pagerduty_provider.py b/keep/providers/pagerduty_provider/pagerduty_provider.py index 991b146e0..c1ada8245 100644 --- a/keep/providers/pagerduty_provider/pagerduty_provider.py +++ b/keep/providers/pagerduty_provider/pagerduty_provider.py @@ -593,15 +593,10 @@ def _notify( ) def _query(self, incident_id: str = None): - incidents = self.__get_all_incidents_or_alerts() - return ( - next( - [incident for incident in incidents if incident.id == incident_id], - None, - ) - if incident_id - else incidents - ) + if incident_id: + return self._get_specific_incident(incident_id) + else: + return self.__get_all_incidents_or_alerts() @staticmethod def _format_alert( @@ -694,6 +689,28 @@ def _format_alert_old(event: dict) -> AlertDto: labels=metadata, ) + def _get_specific_incident(self, incident_id: str): + self.logger.info("Getting Incident", extra={"incident_id": incident_id}) + url = f"{self.BASE_API_URL}/incidents/{incident_id}" + params = { + "include[]": [ + "acknowledgers", + "agents", + "assignees", + "conference_bridge", + "custom_fields", + "escalation_policies", + "first_trigger_log_entries", + "priorities", + "services", + "teams", + "users", + ] + } + response = requests.get(url, headers=self.__get_headers(), params=params) + response.raise_for_status() + return response.json() + def __get_all_incidents_or_alerts(self, incident_id: str = None): self.logger.info( "Getting incidents or alerts", extra={"incident_id": incident_id} From 01d29ecb89505fe66fc2575d9f44c154d0c8d953 Mon Sep 17 00:00:00 2001 From: Vladimir Filonov Date: Mon, 2 Dec 2024 19:59:42 +0400 Subject: [PATCH 02/12] feat: change relation between alerts and incidents to work with fingerprints instead of alert ids (#2473) Signed-off-by: Vladimir Filonov Co-authored-by: Kirill Chernakov --- .../alerts/alert-associate-incident-modal.tsx | 2 +- .../[id]/alerts/incident-alert-menu.tsx | 2 +- keep-ui/entities/incidents/model/models.ts | 1 + keep/api/bl/incidents_bl.py | 26 +- keep/api/core/db.py | 423 ++++++++++-------- keep/api/models/db/alert.py | 85 ++-- .../versions/2024-12-01-16-40_3ad5308e7200.py | 7 +- .../versions/2024-12-02-13-36_bdae8684d0b4.py | 161 +++++++ keep/api/routes/incidents.py | 10 +- keep/api/routes/workflows.py | 3 +- keep/api/tasks/process_event_task.py | 4 + keep/api/utils/enrichment_helpers.py | 89 ++-- keep/rulesengine/rulesengine.py | 4 +- tests/conftest.py | 28 +- tests/test_incidents.py | 183 +++++--- tests/test_metrics.py | 2 +- tests/test_rules_engine.py | 28 +- 17 files changed, 686 insertions(+), 372 deletions(-) create mode 100644 keep/api/models/db/migrations/versions/2024-12-02-13-36_bdae8684d0b4.py diff --git a/keep-ui/app/(keep)/alerts/alert-associate-incident-modal.tsx b/keep-ui/app/(keep)/alerts/alert-associate-incident-modal.tsx index e8b85f1ee..fb1d8a8bc 100644 --- a/keep-ui/app/(keep)/alerts/alert-associate-incident-modal.tsx +++ b/keep-ui/app/(keep)/alerts/alert-associate-incident-modal.tsx @@ -42,7 +42,7 @@ const AlertAssociateIncidentModal = ({ try { const response = await api.post( `/incidents/${incidentId}/alerts`, - alerts.map(({ event_id }) => event_id) + alerts.map(({ fingerprint }) => fingerprint) ); handleSuccess(); await mutate(); diff --git a/keep-ui/app/(keep)/incidents/[id]/alerts/incident-alert-menu.tsx b/keep-ui/app/(keep)/incidents/[id]/alerts/incident-alert-menu.tsx index fb3ab512a..5ef346b8c 100644 --- a/keep-ui/app/(keep)/incidents/[id]/alerts/incident-alert-menu.tsx +++ b/keep-ui/app/(keep)/incidents/[id]/alerts/incident-alert-menu.tsx @@ -18,7 +18,7 @@ export default function IncidentAlertMenu({ incidentId, alert }: Props) { if (confirm("Are you sure you want to remove correlation?")) { api .delete(`/incidents/${incidentId}/alerts`, { - body: [alert.event_id], + body: [alert.fingerprint], }) .then(() => { toast.success("Alert removed from incident successfully", { diff --git a/keep-ui/entities/incidents/model/models.ts b/keep-ui/entities/incidents/model/models.ts index 0f8e8cfb2..bdc311549 100644 --- a/keep-ui/entities/incidents/model/models.ts +++ b/keep-ui/entities/incidents/model/models.ts @@ -31,6 +31,7 @@ export interface IncidentDto { merged_into_incident_id: string; merged_by: string; merged_at: Date; + fingerprint: string; } export interface IncidentCandidateDto { diff --git a/keep/api/bl/incidents_bl.py b/keep/api/bl/incidents_bl.py index a0bd58584..5bc3d9216 100644 --- a/keep/api/bl/incidents_bl.py +++ b/keep/api/bl/incidents_bl.py @@ -93,51 +93,51 @@ def create_incident( return new_incident_dto async def add_alerts_to_incident( - self, incident_id: UUID, alert_ids: List[UUID], is_created_by_ai: bool = False + self, incident_id: UUID, alert_fingerprints: List[str], is_created_by_ai: bool = False ) -> None: self.logger.info( "Adding alerts to incident", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) incident = get_incident_by_id(tenant_id=self.tenant_id, incident_id=incident_id) if not incident: raise HTTPException(status_code=404, detail="Incident not found") - add_alerts_to_incident_by_incident_id(self.tenant_id, incident_id, alert_ids, is_created_by_ai) + add_alerts_to_incident_by_incident_id(self.tenant_id, incident_id, alert_fingerprints, is_created_by_ai) self.logger.info( "Alerts added to incident", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) - self.__update_elastic(incident_id, alert_ids) + self.__update_elastic(incident_id, alert_fingerprints) self.logger.info( "Alerts pushed to elastic", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) self.__update_client_on_incident_change(incident_id) self.logger.info( "Client updated on incident change", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) incident_dto = IncidentDto.from_db_incident(incident) self.__run_workflows(incident_dto, "updated") self.logger.info( "Workflows run on incident", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) await self.__generate_summary(incident_id, incident) self.logger.info( "Summary generated", - extra={"incident_id": incident_id, "alert_ids": alert_ids}, + extra={"incident_id": incident_id, "alert_fingerprints": alert_fingerprints}, ) - def __update_elastic(self, incident_id: UUID, alert_ids: List[UUID]): + def __update_elastic(self, incident_id: UUID, alert_fingerprints: List[str]): try: elastic_client = ElasticClient(self.tenant_id) if elastic_client.enabled: db_alerts, _ = get_incident_alerts_by_incident_id( tenant_id=self.tenant_id, incident_id=incident_id, - limit=len(alert_ids), + limit=len(alert_fingerprints), ) enriched_alerts_dto = convert_db_alerts_to_dto_alerts( db_alerts, with_incidents=True @@ -203,7 +203,7 @@ async def __generate_summary(self, incident_id: UUID, incident: Incident): ) def delete_alerts_from_incident( - self, incident_id: UUID, alert_ids: List[UUID] + self, incident_id: UUID, alert_fingerprints: List[str] ) -> None: self.logger.info( "Fetching incident", @@ -216,7 +216,7 @@ def delete_alerts_from_incident( if not incident: raise HTTPException(status_code=404, detail="Incident not found") - remove_alerts_to_incident_by_incident_id(self.tenant_id, incident_id, alert_ids) + remove_alerts_to_incident_by_incident_id(self.tenant_id, incident_id, alert_fingerprints) def delete_incident(self, incident_id: UUID) -> None: self.logger.info( diff --git a/keep/api/core/db.py b/keep/api/core/db.py index cc9294f26..c600c1214 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -16,6 +16,7 @@ from uuid import uuid4 import validators +from dateutil.tz import tz from dotenv import find_dotenv, load_dotenv from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from sqlalchemy import ( @@ -44,8 +45,8 @@ # This import is required to create the tables from keep.api.models.ai_external import ( - ExternalAIConfigAndMetadata, - ExternalAIConfigAndMetadataDto, + ExternalAIConfigAndMetadata, + ExternalAIConfigAndMetadataDto, ) from keep.api.models.alert import ( AlertStatus, @@ -1302,13 +1303,13 @@ def get_last_alerts( # SQLite version - using JSON incidents_subquery = ( session.query( - AlertToIncident.alert_id, + LastAlertToIncident.fingerprint, func.json_group_array( - cast(AlertToIncident.incident_id, String) + cast(LastAlertToIncident.incident_id, String) ).label("incidents"), ) - .filter(AlertToIncident.deleted_at == NULL_FOR_DELETED_AT) - .group_by(AlertToIncident.alert_id) + .filter(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT) + .group_by(LastAlertToIncident.fingerprint) .subquery() ) @@ -1316,13 +1317,13 @@ def get_last_alerts( # MySQL version - using GROUP_CONCAT incidents_subquery = ( session.query( - AlertToIncident.alert_id, + LastAlertToIncident.fingerprint, func.group_concat( - cast(AlertToIncident.incident_id, String) + cast(LastAlertToIncident.incident_id, String) ).label("incidents"), ) - .filter(AlertToIncident.deleted_at == NULL_FOR_DELETED_AT) - .group_by(AlertToIncident.alert_id) + .filter(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT) + .group_by(LastAlertToIncident.fingerprint) .subquery() ) @@ -1330,14 +1331,14 @@ def get_last_alerts( # PostgreSQL version - using string_agg incidents_subquery = ( session.query( - AlertToIncident.alert_id, + LastAlertToIncident.fingerprint, func.string_agg( - cast(AlertToIncident.incident_id, String), + cast(LastAlertToIncident.incident_id, String), ",", ).label("incidents"), ) - .filter(AlertToIncident.deleted_at == NULL_FOR_DELETED_AT) - .group_by(AlertToIncident.alert_id) + .filter(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT) + .group_by(LastAlertToIncident.fingerprint) .subquery() ) else: @@ -1345,7 +1346,7 @@ def get_last_alerts( query = query.add_columns(incidents_subquery.c.incidents) query = query.outerjoin( - incidents_subquery, Alert.id == incidents_subquery.c.alert_id + incidents_subquery, Alert.fingerprint == incidents_subquery.c.fingerprint ) if provider_id: @@ -1741,9 +1742,10 @@ def get_incident_for_grouping_rule( # if the last alert in the incident is older than the timeframe, create a new incident is_incident_expired = False - if incident and incident.alerts: + if incident and incident.alerts_count > 0: + enrich_incidents_with_alerts(tenant_id, [incident], session) is_incident_expired = max( - alert.timestamp for alert in incident.alerts + alert.timestamp for alert in incident._alerts ) < datetime.utcnow() - timedelta(seconds=timeframe) # if there is no incident with the rule_fingerprint, create it or existed is already expired @@ -1792,13 +1794,13 @@ def get_rule_distribution(tenant_id, minute=False): # Check the dialect if session.bind.dialect.name == "mysql": time_format = "%Y-%m-%d %H:%i" if minute else "%Y-%m-%d %H" - timestamp_format = func.date_format(AlertToIncident.timestamp, time_format) + timestamp_format = func.date_format(LastAlertToIncident.timestamp, time_format) elif session.bind.dialect.name == "postgresql": time_format = "YYYY-MM-DD HH:MI" if minute else "YYYY-MM-DD HH" - timestamp_format = func.to_char(AlertToIncident.timestamp, time_format) + timestamp_format = func.to_char(LastAlertToIncident.timestamp, time_format) elif session.bind.dialect.name == "sqlite": time_format = "%Y-%m-%d %H:%M" if minute else "%Y-%m-%d %H" - timestamp_format = func.strftime(time_format, AlertToIncident.timestamp) + timestamp_format = func.strftime(time_format, LastAlertToIncident.timestamp) else: raise ValueError("Unsupported database dialect") # Construct the query @@ -1806,20 +1808,20 @@ def get_rule_distribution(tenant_id, minute=False): session.query( Rule.id.label("rule_id"), Rule.name.label("rule_name"), - Incident.id.label("group_id"), + Incident.id.label("incident_id"), Incident.rule_fingerprint.label("rule_fingerprint"), timestamp_format.label("time"), - func.count(AlertToIncident.alert_id).label("hits"), + func.count(LastAlertToIncident.fingerprint).label("hits"), ) .join(Incident, Rule.id == Incident.rule_id) - .join(AlertToIncident, Incident.id == AlertToIncident.incident_id) + .join(LastAlertToIncident, Incident.id == LastAlertToIncident.incident_id) .filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.timestamp >= seven_days_ago, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.timestamp >= seven_days_ago, ) .filter(Rule.tenant_id == tenant_id) # Filter by tenant_id .group_by( - "rule_id", "rule_name", "incident_id", "rule_fingerprint", "time" + Rule.id, "rule_name", Incident.id, "rule_fingerprint", "time" ) # Adjusted here .order_by("time") ) @@ -2824,24 +2826,24 @@ def update_preset_options(tenant_id: str, preset_id: str, options: dict) -> Pres def assign_alert_to_incident( - alert_id: UUID | str, + fingerprint: str, incident: Incident, tenant_id: str, session: Optional[Session] = None, ): - return add_alerts_to_incident(tenant_id, incident, [alert_id], session=session) + return add_alerts_to_incident(tenant_id, incident, [fingerprint], session=session) def is_alert_assigned_to_incident( - alert_id: UUID, incident_id: UUID, tenant_id: str + fingerprint: str, incident_id: UUID, tenant_id: str ) -> bool: with Session(engine) as session: assigned = session.exec( - select(AlertToIncident) - .where(AlertToIncident.alert_id == alert_id) - .where(AlertToIncident.incident_id == incident_id) - .where(AlertToIncident.tenant_id == tenant_id) - .where(AlertToIncident.deleted_at == NULL_FOR_DELETED_AT) + select(LastAlertToIncident) + .where(LastAlertToIncident.fingerprint == fingerprint) + .where(LastAlertToIncident.incident_id == incident_id) + .where(LastAlertToIncident.tenant_id == tenant_id) + .where(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT) ).first() return assigned is not None @@ -3106,6 +3108,32 @@ def filter_query(session: Session, query, field, value): return query +def enrich_incidents_with_alerts(tenant_id: str, incidents: List[Incident], session: Optional[Session]=None): + with existed_or_new_session(session) as session: + incident_alerts = session.exec( + select(LastAlertToIncident.incident_id, Alert) + .select_from(LastAlert) + .join(LastAlertToIncident, and_( + LastAlertToIncident.fingerprint == LastAlert.fingerprint, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + )) + .join(Alert, LastAlert.alert_id == Alert.id) + .where( + LastAlert.tenant_id == tenant_id, + LastAlertToIncident.incident_id.in_([incident.id for incident in incidents]) + ) + ).all() + + alerts_per_incident = defaultdict(list) + for incident_id, alert in incident_alerts: + alerts_per_incident[incident_id].append(alert) + + for incident in incidents: + incident._alerts = alerts_per_incident[incident.id] + + return incidents + + def get_last_incidents( tenant_id: str, limit: int = 25, @@ -3147,9 +3175,6 @@ def get_last_incidents( if allowed_incident_ids: query = query.filter(Incident.id.in_(allowed_incident_ids)) - if with_alerts: - query = query.options(joinedload(Incident.alerts)) - if is_predicted is not None: query = query.filter(Incident.is_predicted == is_predicted) @@ -3181,23 +3206,30 @@ def get_last_incidents( # Execute the query incidents = query.all() + if with_alerts: + enrich_incidents_with_alerts(tenant_id, incidents, session) + return incidents, total_count def get_incident_by_id( - tenant_id: str, incident_id: str | UUID, with_alerts: bool = False + tenant_id: str, incident_id: str | UUID, with_alerts: bool = False, + session: Optional[Session] = None, ) -> Optional[Incident]: - with Session(engine) as session: + with existed_or_new_session(session) as session: query = session.query( Incident, ).filter( Incident.tenant_id == tenant_id, Incident.id == incident_id, ) + incident = query.first() if with_alerts: - query = query.options(joinedload(Incident.alerts)) + enrich_incidents_with_alerts( + tenant_id, [incident], session, + ) - return query.first() + return incident def create_incident_from_dto( @@ -3254,7 +3286,6 @@ def create_incident_from_dict( session.add(new_incident) session.commit() session.refresh(new_incident) - new_incident.alerts = [] return new_incident @@ -3271,7 +3302,6 @@ def update_incident_from_dto_by_id( Incident.tenant_id == tenant_id, Incident.id == incident_id, ) - .options(joinedload(Incident.alerts)) ).first() if not incident: @@ -3330,10 +3360,10 @@ def delete_incident_by_id( # Delete all associations with alerts: ( - session.query(AlertToIncident) + session.query(LastAlertToIncident) .where( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) .delete() ) @@ -3363,46 +3393,27 @@ def get_incident_alerts_and_links_by_incident_id( offset: Optional[int] = 0, session: Optional[Session] = None, include_unlinked: bool = False, -) -> tuple[List[tuple[Alert, AlertToIncident]], int]: +) -> tuple[List[tuple[Alert, LastAlertToIncident]], int]: with existed_or_new_session(session) as session: - last_fingerprints_subquery = ( - session.query( - Alert.fingerprint, func.max(Alert.timestamp).label("max_timestamp") - ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) - .filter( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident_id, - ) - .group_by(Alert.fingerprint) - .subquery() - ) - query = ( session.query( Alert, - AlertToIncident, - ) - .select_from(last_fingerprints_subquery) - .outerjoin( - Alert, - and_( - last_fingerprints_subquery.c.fingerprint == Alert.fingerprint, - last_fingerprints_subquery.c.max_timestamp == Alert.timestamp, - ), + LastAlertToIncident, ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .select_from(LastAlertToIncident) + .join(LastAlert, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident_id, ) - .order_by(col(Alert.timestamp).desc()) + .order_by(col(LastAlert.timestamp).desc()) .options(joinedload(Alert.alert_enrichment)) ) if not include_unlinked: query = query.filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, ) total_count = query.count() @@ -3415,7 +3426,7 @@ def get_incident_alerts_and_links_by_incident_id( def get_incident_alerts_by_incident_id(*args, **kwargs) -> tuple[List[Alert], int]: """ - Unpacking (List[(Alert, AlertToIncident)], int) to (List[Alert], int). + Unpacking (List[(Alert, LastAlertToIncident)], int) to (List[Alert], int). """ alerts_and_links, total_alerts = get_incident_alerts_and_links_by_incident_id( *args, **kwargs @@ -3464,8 +3475,7 @@ def get_all_same_alert_ids( def get_alerts_data_for_incident( tenant_id: str, - alert_ids: List[str | UUID], - existed_fingerprints: Optional[List[str]] = None, + fingerprints: Optional[List[str]] = None, session: Optional[Session] = None, ) -> dict: """ @@ -3479,8 +3489,6 @@ def get_alerts_data_for_incident( Returns: dict {sources: list[str], services: list[str], count: int} """ - existed_fingerprints = existed_fingerprints or [] - with existed_or_new_session(session) as session: fields = ( @@ -3491,16 +3499,18 @@ def get_alerts_data_for_incident( ) alerts_data = session.exec( - select(*fields).where( - Alert.tenant_id == tenant_id, - col(Alert.id).in_(alert_ids), + select(*fields) + .select_from(LastAlert) + .join(Alert, LastAlert.alert_id == Alert.id) + .where( + LastAlert.tenant_id == tenant_id, + col(LastAlert.fingerprint).in_(fingerprints), ) ).all() sources = [] services = [] severities = [] - fingerprints = set() for service, source, fingerprint, severity in alerts_data: if source: @@ -3512,21 +3522,19 @@ def get_alerts_data_for_incident( severities.append(IncidentSeverity.from_number(severity)) else: severities.append(IncidentSeverity(severity)) - if fingerprint and fingerprint not in existed_fingerprints: - fingerprints.add(fingerprint) return { "sources": set(sources), "services": set(services), "max_severity": max(severities), - "count": len(fingerprints), + "count": len(alerts_data), } def add_alerts_to_incident_by_incident_id( tenant_id: str, incident_id: str | UUID, - alert_ids: List[UUID], + fingerprints: List[str], is_created_by_ai: bool = False, session: Optional[Session] = None, ) -> Optional[Incident]: @@ -3540,62 +3548,52 @@ def add_alerts_to_incident_by_incident_id( if not incident: return None return add_alerts_to_incident( - tenant_id, incident, alert_ids, is_created_by_ai, session + tenant_id, incident, fingerprints, is_created_by_ai, session ) def add_alerts_to_incident( tenant_id: str, incident: Incident, - alert_ids: List[UUID], + fingerprints: List[str], is_created_by_ai: bool = False, session: Optional[Session] = None, override_count: bool = False, ) -> Optional[Incident]: logger.info( - f"Adding alerts to incident {incident.id} in database, total {len(alert_ids)} alerts", + f"Adding alerts to incident {incident.id} in database, total {len(fingerprints)} alerts", extra={"tags": {"tenant_id": tenant_id, "incident_id": incident.id}}, ) with existed_or_new_session(session) as session: with session.no_autoflush: - all_alert_ids = get_all_same_alert_ids(tenant_id, alert_ids, session) # Use a set for faster membership checks - existing_alert_ids = set( - session.exec( - select(AlertToIncident.alert_id).where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, - col(AlertToIncident.alert_id).in_(all_alert_ids), - ) - ).all() - ) + existing_fingerprints = set( session.exec( - select(Alert.fingerprint) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + select(LastAlert.fingerprint) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == LastAlert.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) ).all() ) - new_alert_ids = [ - alert_id - for alert_id in all_alert_ids - if alert_id not in existing_alert_ids - ] + new_fingerprints = { + fingerprint + for fingerprint in fingerprints + if fingerprint not in existing_fingerprints + } - if not new_alert_ids: + if not new_fingerprints: return incident alerts_data_for_incident = get_alerts_data_for_incident( - tenant_id, new_alert_ids, existing_fingerprints, session + tenant_id, new_fingerprints, session ) incident.sources = list( @@ -3617,13 +3615,13 @@ def add_alerts_to_incident( else: incident.alerts_count = alerts_data_for_incident["count"] alert_to_incident_entries = [ - AlertToIncident( - alert_id=alert_id, + LastAlertToIncident( + fingerprint=fingerprint, incident_id=incident.id, tenant_id=tenant_id, is_created_by_ai=is_created_by_ai, ) - for alert_id in new_alert_ids + for fingerprint in new_fingerprints ] for idx, entry in enumerate(alert_to_incident_entries): @@ -3640,11 +3638,11 @@ def add_alerts_to_incident( started_at, last_seen_at = session.exec( select(func.min(Alert.timestamp), func.max(Alert.timestamp)) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == Alert.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) ).one() @@ -3661,12 +3659,11 @@ def get_incident_unique_fingerprint_count(tenant_id: str, incident_id: str) -> i with Session(engine) as session: return session.execute( select(func.count(1)) - .select_from(AlertToIncident) - .join(Alert, AlertToIncident.alert_id == Alert.id) + .select_from(LastAlertToIncident) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - Alert.tenant_id == tenant_id, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident_id, ) ).scalar() @@ -3678,12 +3675,14 @@ def get_last_alerts_for_incidents( query = ( session.query( Alert, - AlertToIncident.incident_id, + LastAlertToIncident.incident_id, ) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id.in_(incident_ids), + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id.in_(incident_ids), ) .order_by(Alert.timestamp.desc()) ) @@ -3698,7 +3697,7 @@ def get_last_alerts_for_incidents( def remove_alerts_to_incident_by_incident_id( - tenant_id: str, incident_id: str | UUID, alert_ids: List[UUID] + tenant_id: str, incident_id: str | UUID, fingerprints: List[str] ) -> Optional[int]: with Session(engine) as session: incident = session.exec( @@ -3711,16 +3710,14 @@ def remove_alerts_to_incident_by_incident_id( if not incident: return None - all_alert_ids = get_all_same_alert_ids(tenant_id, alert_ids, session) - # Removing alerts-to-incident relation for provided alerts_ids deleted = ( - session.query(AlertToIncident) + session.query(LastAlertToIncident) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, - col(AlertToIncident.alert_id).in_(all_alert_ids), + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, + col(LastAlertToIncident.fingerprint).in_(fingerprints), ) .update( { @@ -3732,7 +3729,7 @@ def remove_alerts_to_incident_by_incident_id( # Getting aggregated data for incidents for alerts which just was removed alerts_data_for_incident = get_alerts_data_for_incident( - tenant_id, all_alert_ids, session=session + tenant_id, fingerprints, session=session ) service_field = get_json_extract_field(session, Alert.event, "service") @@ -3741,10 +3738,12 @@ def remove_alerts_to_incident_by_incident_id( # which still assigned with the incident existed_services_query = ( select(func.distinct(service_field)) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident_id, service_field.in_(alerts_data_for_incident["services"]), ) ) @@ -3754,10 +3753,12 @@ def remove_alerts_to_incident_by_incident_id( # which still assigned with the incident existed_sources_query = ( select(col(Alert.provider_type).distinct()) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .filter( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident_id, col(Alert.provider_type).in_(alerts_data_for_incident["sources"]), ) ) @@ -3777,10 +3778,12 @@ def remove_alerts_to_incident_by_incident_id( started_at, last_seen_at = session.exec( select(func.min(Alert.timestamp), func.max(Alert.timestamp)) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .select_from(LastAlert) + .join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) + .join(Alert, LastAlert.alert_id == Alert.id) .where( - AlertToIncident.tenant_id == tenant_id, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.tenant_id == tenant_id, + LastAlertToIncident.incident_id == incident.id, ) ).one() @@ -3822,7 +3825,6 @@ def merge_incidents_to_id( .where( Incident.tenant_id == tenant_id, Incident.id == destination_incident_id ) - .options(joinedload(Incident.alerts)) ).first() if not destination_incident: @@ -3837,12 +3839,14 @@ def merge_incidents_to_id( ) ).all() + enrich_incidents_with_alerts(tenant_id, source_incidents, session=session) + merged_incident_ids = [] skipped_incident_ids = [] failed_incident_ids = [] for source_incident in source_incidents: - source_incident_alerts_ids = [alert.id for alert in source_incident.alerts] - if not source_incident_alerts_ids: + source_incident_alerts_fingerprints = [alert.fingerprint for alert in source_incident._alerts] + if not source_incident_alerts_fingerprints: logger.info(f"Source incident {source_incident.id} doesn't have alerts") skipped_incident_ids.append(source_incident.id) continue @@ -3854,7 +3858,7 @@ def merge_incidents_to_id( remove_alerts_to_incident_by_incident_id( tenant_id, source_incident.id, - [alert.id for alert in source_incident.alerts], + [alert.fingerprint for alert in source_incident._alerts], ) except OperationalError as e: logger.error( @@ -3864,7 +3868,7 @@ def merge_incidents_to_id( add_alerts_to_incident( tenant_id, destination_incident, - source_incident_alerts_ids, + source_incident_alerts_fingerprints, session=session, ) merged_incident_ids.append(source_incident.id) @@ -4222,12 +4226,13 @@ def get_workflow_executions_for_incident_or_alert( # Query for workflow executions associated with alerts tied to the incident alert_query = ( base_query.join( - Alert, WorkflowToAlertExecution.alert_fingerprint == Alert.fingerprint + LastAlert, WorkflowToAlertExecution.alert_fingerprint == LastAlert.fingerprint ) - .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .join(Alert, LastAlert.alert_id == Alert.id) + .join(LastAlertToIncident, Alert.fingerprint == LastAlertToIncident.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id == incident_id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident_id, ) ) @@ -4270,17 +4275,16 @@ def is_all_incident_alerts_resolved( enriched_status_field.label("enriched_status"), status_field.label("status"), ) - .select_from(Alert) + .select_from(LastAlert) + .join(Alert, LastAlert.alert_id == Alert.id) .outerjoin( AlertEnrichment, Alert.fingerprint == AlertEnrichment.alert_fingerprint ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == LastAlert.fingerprint) .where( - AlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - AlertToIncident.incident_id == incident.id, + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident.id, ) - .group_by(Alert.fingerprint) - .having(func.max(Alert.timestamp)) ).subquery() not_resolved_exists = session.query( @@ -4337,8 +4341,8 @@ def is_edge_incident_alert_resolved( .outerjoin( AlertEnrichment, Alert.fingerprint == AlertEnrichment.alert_fingerprint ) - .join(AlertToIncident, AlertToIncident.alert_id == Alert.id) - .where(AlertToIncident.incident_id == incident.id) + .join(LastAlertToIncident, LastAlertToIncident.fingerprint == Alert.fingerprint) + .where(LastAlertToIncident.incident_id == incident.id) .group_by(Alert.fingerprint) .having(func.max(Alert.timestamp)) .order_by(direction(Alert.timestamp)) @@ -4379,11 +4383,12 @@ def get_alerts_metrics_by_provider( Alert.provider_id, func.count(Alert.id).label("total_alerts"), func.sum( - case([(AlertToIncident.alert_id.isnot(None), 1)], else_=0) + case([(LastAlertToIncident.fingerprint.isnot(None), 1)], else_=0) ).label("correlated_alerts"), *dynamic_field_sums, ) - .outerjoin(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .join(LastAlert, Alert.id == LastAlert.alert_id) + .outerjoin(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint) .filter( Alert.tenant_id == tenant_id, ) @@ -4504,28 +4509,28 @@ def get_resource_ids_by_resource_type( result = session.exec(query) return result.all() - -def get_or_creat_posthog_instance_id(session: Optional[Session] = None): - POSTHOG_INSTANCE_ID_KEY = "posthog_instance_id" - with Session(engine) as session: - system = session.exec( - select(System).where(System.name == POSTHOG_INSTANCE_ID_KEY) - ).first() - if system: +def get_or_creat_posthog_instance_id( + session: Optional[Session] = None + ): + POSTHOG_INSTANCE_ID_KEY = "posthog_instance_id" + with Session(engine) as session: + system = session.exec(select(System).where(System.name == POSTHOG_INSTANCE_ID_KEY)).first() + if system: + return system.value + + system = System( + id=str(uuid4()), + name=POSTHOG_INSTANCE_ID_KEY, + value=str(uuid4()), + ) + session.add(system) + session.commit() + session.refresh(system) return system.value - system = System( - id=str(uuid4()), - name=POSTHOG_INSTANCE_ID_KEY, - value=str(uuid4()), - ) - session.add(system) - session.commit() - session.refresh(system) - return system.value - - -def get_activity_report(session: Optional[Session] = None): +def get_activity_report( + session: Optional[Session] = None + ): from keep.api.models.db.user import User last_24_hours = datetime.utcnow() - timedelta(hours=24) @@ -4553,8 +4558,46 @@ def get_activity_report(session: Optional[Session] = None): ) activity_report["last_24_hours_workflows_executed"] = ( session.query(WorkflowExecution) - .filter(WorkflowExecution.started >= last_24_hours) - .count() - ) - + .filter(WorkflowExecution.started >= last_24_hours).count() +) return activity_report + + +def get_last_alert_by_fingerprint( + tenant_id: str, fingerprint: str, session: Optional[Session] = None +) -> Optional[LastAlert]: + with existed_or_new_session(session) as session: + return session.exec( + select(LastAlert) + .where( + and_( + LastAlert.tenant_id == tenant_id, + LastAlert.fingerprint == fingerprint, + ) + ) + ).first() + +def set_last_alert( + tenant_id: str, alert: Alert, session: Optional[Session] = None +) -> None: + last_alert = get_last_alert_by_fingerprint(tenant_id, alert.fingerprint, session) + + # To prevent rare, but possible race condition + # For example if older alert failed to process + # and retried after new one + + if last_alert and last_alert.timestamp.replace(tzinfo=tz.UTC) < alert.timestamp.replace(tzinfo=tz.UTC): + last_alert.timestamp = alert.timestamp + last_alert.alert_id = alert.id + session.add(last_alert) + session.commit() + + elif not last_alert: + last_alert = LastAlert( + tenant_id=tenant_id, + fingerprint=alert.fingerprint, + timestamp=alert.timestamp, + alert_id=alert.id, + ) + session.add(last_alert) + session.commit() diff --git a/keep/api/models/db/alert.py b/keep/api/models/db/alert.py index 11efa9ef7..0d0bac569 100644 --- a/keep/api/models/db/alert.py +++ b/keep/api/models/db/alert.py @@ -4,7 +4,8 @@ from typing import List, Optional from uuid import UUID, uuid4 -from sqlalchemy import ForeignKey, UniqueConstraint +from pydantic import PrivateAttr +from sqlalchemy import ForeignKey, UniqueConstraint, ForeignKeyConstraint from sqlalchemy.dialects.mssql import DATETIME2 as MSSQL_DATETIME2 from sqlalchemy.dialects.mysql import DATETIME as MySQL_DATETIME from sqlalchemy.engine.url import make_url @@ -60,8 +61,6 @@ class AlertToIncident(SQLModel, table=True): primary_key=True, ) ) - alert: "Alert" = Relationship(back_populates="alert_to_incident_link") - incident: "Incident" = Relationship(back_populates="alert_to_incident_link") is_created_by_ai: bool = Field(default=False) @@ -72,6 +71,43 @@ class AlertToIncident(SQLModel, table=True): default=NULL_FOR_DELETED_AT, ) +class LastAlert(SQLModel, table=True): + + tenant_id: str = Field(foreign_key="tenant.id", nullable=False, primary_key=True) + fingerprint: str = Field(primary_key=True, index=True) + alert_id: UUID = Field(foreign_key="alert.id") + timestamp: datetime = Field(nullable=False, index=True) + + +class LastAlertToIncident(SQLModel, table=True): + tenant_id: str = Field(foreign_key="tenant.id", nullable=False, primary_key=True) + timestamp: datetime = Field(default_factory=datetime.utcnow) + + fingerprint: str = Field(primary_key=True) + incident_id: UUID = Field( + sa_column=Column( + UUIDType(binary=False), + ForeignKey("incident.id", ondelete="CASCADE"), + primary_key=True, + ) + ) + + is_created_by_ai: bool = Field(default=False) + + deleted_at: datetime = Field( + default_factory=None, + nullable=True, + primary_key=True, + default=NULL_FOR_DELETED_AT, + ) + + __table_args__ = ( + ForeignKeyConstraint( + ["tenant_id", "fingerprint"], + ["lastalert.tenant_id", "lastalert.fingerprint"]), + {} + ) + class Incident(SQLModel, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True) @@ -96,26 +132,6 @@ class Incident(SQLModel, table=True): end_time: datetime | None last_seen_time: datetime | None - # map of attributes to values - alerts: List["Alert"] = Relationship( - back_populates="incidents", - link_model=AlertToIncident, - # primaryjoin is used to filter out deleted links for various DB dialects - sa_relationship_kwargs={ - "primaryjoin": f"""and_(AlertToIncident.incident_id == Incident.id, - or_( - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S.%f')}', - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S')}' - ))""", - "uselist": True, - "overlaps": "alert,incident", - }, - ) - alert_to_incident_link: List[AlertToIncident] = Relationship( - back_populates="incident", - sa_relationship_kwargs={"overlaps": "alerts,incidents"}, - ) - is_predicted: bool = Field(default=False) is_confirmed: bool = Field(default=False) @@ -183,10 +199,7 @@ class Incident(SQLModel, table=True): ), ) - def __init__(self, **kwargs): - super().__init__(**kwargs) - if "alerts" not in kwargs: - self.alerts = [] + _alerts: List["Alert"] = PrivateAttr() class Config: arbitrary_types_allowed = True @@ -224,24 +237,6 @@ class Alert(SQLModel, table=True): } ) - incidents: List["Incident"] = Relationship( - back_populates="alerts", - link_model=AlertToIncident, - sa_relationship_kwargs={ - # primaryjoin is used to filter out deleted links for various DB dialects - "primaryjoin": f"""and_(AlertToIncident.alert_id == Alert.id, - or_( - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S.%f')}', - AlertToIncident.deleted_at == '{NULL_FOR_DELETED_AT.strftime('%Y-%m-%d %H:%M:%S')}' - ))""", - "uselist": True, - "overlaps": "alert,incident", - }, - ) - alert_to_incident_link: List[AlertToIncident] = Relationship( - back_populates="alert", sa_relationship_kwargs={"overlaps": "alerts,incidents"} - ) - __table_args__ = ( Index( "ix_alert_tenant_fingerprint_timestamp", diff --git a/keep/api/models/db/migrations/versions/2024-12-01-16-40_3ad5308e7200.py b/keep/api/models/db/migrations/versions/2024-12-01-16-40_3ad5308e7200.py index f92d246a8..e6c2140e3 100644 --- a/keep/api/models/db/migrations/versions/2024-12-01-16-40_3ad5308e7200.py +++ b/keep/api/models/db/migrations/versions/2024-12-01-16-40_3ad5308e7200.py @@ -7,10 +7,7 @@ """ import sqlalchemy as sa -import sqlalchemy_utils -import sqlmodel from alembic import op -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "3ad5308e7200" @@ -24,13 +21,15 @@ def upgrade() -> None: with op.batch_alter_table("externalaiconfigandmetadata", schema=None) as batch_op: batch_op.alter_column( - "settings", existing_type=sa.VARCHAR(), type_=sa.JSON(), nullable=True + "settings", existing_type=sa.VARCHAR(), type_=sa.JSON(), nullable=True, + postgresql_using="settings::json" ) batch_op.alter_column( "settings_proposed_by_algorithm", existing_type=sa.VARCHAR(), type_=sa.JSON(), existing_nullable=True, + postgresql_using="settings::json" ) batch_op.alter_column( "feedback_logs", diff --git a/keep/api/models/db/migrations/versions/2024-12-02-13-36_bdae8684d0b4.py b/keep/api/models/db/migrations/versions/2024-12-02-13-36_bdae8684d0b4.py new file mode 100644 index 000000000..905748deb --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-12-02-13-36_bdae8684d0b4.py @@ -0,0 +1,161 @@ +"""add lastalert and lastalerttoincident table + +Revision ID: bdae8684d0b4 +Revises: 3ad5308e7200 +Create Date: 2024-11-05 22:48:04.733192 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +import sqlmodel +from alembic import op +from sqlalchemy.orm import Session + +# revision identifiers, used by Alembic. +revision = "bdae8684d0b4" +down_revision = "3ad5308e7200" +branch_labels = None +depends_on = None + +migration_metadata = sa.MetaData() + + +def populate_db(): + session = Session(op.get_bind()) + + if session.bind.dialect.name == "postgresql": + migrate_lastalert_query = """ + insert into lastalert (tenant_id, fingerprint, alert_id, timestamp) + select alert.tenant_id, alert.fingerprint, alert.id as alert_id, alert.timestamp + from alert + join ( + select + alert.tenant_id, alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint, tenant_id + ) as a ON alert.fingerprint = a.fingerprint and alert.timestamp = a.last_received and alert.tenant_id = a.tenant_id + on conflict + do nothing + """ + + migrate_lastalerttoincident_query = """ + insert into lastalerttoincident (incident_id, tenant_id, timestamp, fingerprint, is_created_by_ai, deleted_at) + select ati.incident_id, ati.tenant_id, ati.timestamp, lf.fingerprint, ati.is_created_by_ai, ati.deleted_at + from alerttoincident as ati + join + ( + select alert.tenant_id, alert.id, alert.fingerprint + from alert + join ( + select + alert.tenant_id, alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint, tenant_id + ) as a on alert.fingerprint = a.fingerprint and alert.timestamp = a.last_received and alert.tenant_id = a.tenant_id + ) as lf on ati.alert_id = lf.id + on conflict + do nothing + """ + + else: + migrate_lastalert_query = """ + INSERT INTO lastalert (tenant_id, fingerprint, alert_id, timestamp) + SELECT + grouped_alerts.tenant_id, + grouped_alerts.fingerprint, + MAX(grouped_alerts.alert_id) as alert_id, -- Using MAX to consistently pick one alert_id + grouped_alerts.timestamp + FROM ( + select alert.tenant_id, alert.fingerprint, alert.id as alert_id, alert.timestamp + from alert + join ( + select + alert.tenant_id, alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint, tenant_id + ) as a ON alert.fingerprint = a.fingerprint + and alert.timestamp = a.last_received + and alert.tenant_id = a.tenant_id + ) as grouped_alerts + GROUP BY grouped_alerts.tenant_id, grouped_alerts.fingerprint, grouped_alerts.timestamp; +""" + + migrate_lastalerttoincident_query = """ + REPLACE INTO lastalerttoincident (incident_id, tenant_id, timestamp, fingerprint, is_created_by_ai, deleted_at) + select ati.incident_id, ati.tenant_id, ati.timestamp, lf.fingerprint, ati.is_created_by_ai, ati.deleted_at + from alerttoincident as ati + join + ( + select alert.id, alert.fingerprint, alert.tenant_id + from alert + join ( + select + alert.tenant_id,alert.fingerprint, max(alert.timestamp) as last_received + from alert + group by fingerprint, tenant_id + ) as a on alert.fingerprint = a.fingerprint and alert.timestamp = a.last_received and alert.tenant_id = a.tenant_id + ) as lf on ati.alert_id = lf.id; + """ + + session.execute(migrate_lastalert_query) + session.execute(migrate_lastalerttoincident_query) + + +def upgrade() -> None: + op.create_table( + "lastalert", + sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("fingerprint", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("alert_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["alert_id"], + ["alert.id"], + ), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("tenant_id", "fingerprint"), + ) + with op.batch_alter_table("lastalert", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("ix_lastalert_timestamp"), ["timestamp"], unique=False + ) + # Add index for the fingerprint column that will be referenced by foreign key + batch_op.create_index("ix_lastalert_fingerprint", ["fingerprint"], unique=False) + + op.create_table( + "lastalerttoincident", + sa.Column( + "incident_id", + sqlalchemy_utils.types.uuid.UUIDType(binary=False), + nullable=False, + ), + sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("fingerprint", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("is_created_by_ai", sa.Boolean(), nullable=False), + sa.Column("deleted_at", sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint( + ["tenant_id", "fingerprint"], + ["lastalert.tenant_id", "lastalert.fingerprint"], + ), + sa.ForeignKeyConstraint(["incident_id"], ["incident.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("tenant_id", "incident_id", "fingerprint", "deleted_at"), + ) + + populate_db() + + +def downgrade() -> None: + op.drop_table("lastalerttoincident") + with op.batch_alter_table("lastalert", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_lastalert_timestamp")) + + op.drop_table("lastalert") \ No newline at end of file diff --git a/keep/api/routes/incidents.py b/keep/api/routes/incidents.py index 4820a415a..d7e4bbd39 100644 --- a/keep/api/routes/incidents.py +++ b/keep/api/routes/incidents.py @@ -457,7 +457,7 @@ def get_incident_workflows( ) async def add_alerts_to_incident( incident_id: UUID, - alert_ids: List[UUID], + alert_fingerprints: List[str], is_created_by_ai: bool = False, authenticated_entity: AuthenticatedEntity = Depends( IdentityManagerFactory.get_auth_verifier(["write:incident"]) @@ -467,7 +467,7 @@ async def add_alerts_to_incident( ): tenant_id = authenticated_entity.tenant_id incident_bl = IncidentBl(tenant_id, session, pusher_client) - await incident_bl.add_alerts_to_incident(incident_id, alert_ids, is_created_by_ai) + await incident_bl.add_alerts_to_incident(incident_id, alert_fingerprints, is_created_by_ai) return Response(status_code=202) @@ -479,7 +479,7 @@ async def add_alerts_to_incident( ) def delete_alerts_from_incident( incident_id: UUID, - alert_ids: List[UUID], + fingerprints: List[str], authenticated_entity: AuthenticatedEntity = Depends( IdentityManagerFactory.get_auth_verifier(["write:incident"]) ), @@ -489,7 +489,7 @@ def delete_alerts_from_incident( tenant_id = authenticated_entity.tenant_id incident_bl = IncidentBl(tenant_id, session, pusher_client) incident_bl.delete_alerts_from_incident( - incident_id=incident_id, alert_ids=alert_ids + incident_id=incident_id, alert_fingerprints=fingerprints ) return Response(status_code=202) @@ -611,7 +611,7 @@ def change_incident_status( # TODO: same this change to audit table with the comment if change.status == IncidentStatus.RESOLVED: - for alert in incident.alerts: + for alert in incident._alerts: _enrich_alert( EnrichAlertRequestBody( enrichments={"status": "resolved"}, diff --git a/keep/api/routes/workflows.py b/keep/api/routes/workflows.py index 8d355d08f..6326a6b33 100644 --- a/keep/api/routes/workflows.py +++ b/keep/api/routes/workflows.py @@ -190,7 +190,8 @@ def run_workflow( event_body = body.get("body", {}) or body # if its event that was triggered by the UI with the Modal - if "test-workflow" in event_body.get("fingerprint", "") or not body: + fingerprint = event_body.get("fingerprint", "") + if (fingerprint and "test-workflow" in fingerprint) or not body: # some random event_body["id"] = event_body.get("fingerprint", "manual-run") event_body["name"] = event_body.get("fingerprint", "manual-run") diff --git a/keep/api/tasks/process_event_task.py b/keep/api/tasks/process_event_task.py index 5a5aae8b1..28ed135fa 100644 --- a/keep/api/tasks/process_event_task.py +++ b/keep/api/tasks/process_event_task.py @@ -23,6 +23,7 @@ get_all_presets_dtos, get_enrichment_with_session, get_session_sync, + set_last_alert, ) from keep.api.core.dependencies import get_pusher_client from keep.api.core.elastic import ElasticClient @@ -188,6 +189,9 @@ def __save_to_db( ) session.add(audit) alert_dto = AlertDto(**formatted_event.dict()) + + set_last_alert(tenant_id, alert, session=session) + # Mapping try: enrichments_bl.run_mapping_rules(alert_dto) diff --git a/keep/api/utils/enrichment_helpers.py b/keep/api/utils/enrichment_helpers.py index 86e9795be..2b0ad2b62 100644 --- a/keep/api/utils/enrichment_helpers.py +++ b/keep/api/utils/enrichment_helpers.py @@ -1,10 +1,13 @@ import logging from datetime import datetime +from typing import Optional from opentelemetry import trace +from sqlmodel import Session +from keep.api.core.db import existed_or_new_session from keep.api.models.alert import AlertDto, AlertStatus, AlertWithIncidentLinkMetadataDto -from keep.api.models.db.alert import Alert, AlertToIncident +from keep.api.models.db.alert import Alert, LastAlertToIncident tracer = trace.get_tracer(__name__) logger = logging.getLogger(__name__) @@ -77,8 +80,9 @@ def calculated_start_firing_time( def convert_db_alerts_to_dto_alerts( - alerts: list[Alert | tuple[Alert, AlertToIncident]], - with_incidents: bool = False + alerts: list[Alert | tuple[Alert, LastAlertToIncident]], + with_incidents: bool = False, + session: Optional[Session] = None, ) -> list[AlertDto | AlertWithIncidentLinkMetadataDto]: """ Enriches the alerts with the enrichment data. @@ -90,46 +94,47 @@ def convert_db_alerts_to_dto_alerts( Returns: list[AlertDto | AlertWithIncidentLinkMetadataDto]: The enriched alerts. """ - alerts_dto = [] - with tracer.start_as_current_span("alerts_enrichment"): - # enrich the alerts with the enrichment data - for _object in alerts: - - # We may have an Alert only or and Alert with an AlertToIncident - if isinstance(_object, Alert): - alert, alert_to_incident = _object, None - else: - alert, alert_to_incident = _object - - if alert.alert_enrichment: - alert.event.update(alert.alert_enrichment.enrichments) - if with_incidents: - if alert.incidents: - alert.event["incident"] = ",".join(str(incident.id) for incident in alert.incidents) - try: - if alert_to_incident is not None: - alert_dto = AlertWithIncidentLinkMetadataDto.from_db_instance(alert, alert_to_incident) + with existed_or_new_session(session) as session: + alerts_dto = [] + with tracer.start_as_current_span("alerts_enrichment"): + # enrich the alerts with the enrichment data + for _object in alerts: + + # We may have an Alert only or and Alert with an LastAlertToIncident + if isinstance(_object, Alert): + alert, alert_to_incident = _object, None else: - alert_dto = AlertDto(**alert.event) + alert, alert_to_incident = _object + if alert.alert_enrichment: - parse_and_enrich_deleted_and_assignees( - alert_dto, alert.alert_enrichment.enrichments + alert.event.update(alert.alert_enrichment.enrichments) + if with_incidents: + if alert.incidents: + alert.event["incident"] = ",".join(str(incident.id) for incident in alert.incidents) + try: + if alert_to_incident is not None: + alert_dto = AlertWithIncidentLinkMetadataDto.from_db_instance(alert, alert_to_incident) + else: + alert_dto = AlertDto(**alert.event) + if alert.alert_enrichment: + parse_and_enrich_deleted_and_assignees( + alert_dto, alert.alert_enrichment.enrichments + ) + except Exception: + # should never happen but just in case + logger.exception( + "Failed to parse alert", + extra={ + "alert": alert, + }, ) - except Exception: - # should never happen but just in case - logger.exception( - "Failed to parse alert", - extra={ - "alert": alert, - }, - ) - continue - - alert_dto.event_id = str(alert.id) - - # enrich provider id when it's possible - if alert_dto.providerId is None: - alert_dto.providerId = alert.provider_id - alert_dto.providerType = alert.provider_type - alerts_dto.append(alert_dto) + continue + + alert_dto.event_id = str(alert.id) + + # enrich provider id when it's possible + if alert_dto.providerId is None: + alert_dto.providerId = alert.provider_id + alert_dto.providerType = alert.provider_type + alerts_dto.append(alert_dto) return alerts_dto diff --git a/keep/rulesengine/rulesengine.py b/keep/rulesengine/rulesengine.py index 363901c16..03538b8e3 100644 --- a/keep/rulesengine/rulesengine.py +++ b/keep/rulesengine/rulesengine.py @@ -81,7 +81,7 @@ def run_rules( ) incident = assign_alert_to_incident( - alert_id=event.event_id, + fingerprint=event.fingerprint, incident=incident, tenant_id=self.tenant_id, session=session, @@ -101,7 +101,7 @@ def run_rules( ): should_resolve = True - if ( + elif ( rule.resolve_on == ResolveOn.LAST.value and is_last_incident_alert_resolved(incident, session=session) ): diff --git a/tests/conftest.py b/tests/conftest.py index 06725ca5b..6a353aa31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -333,6 +333,7 @@ def is_elastic_responsive(host, port, user, password): basic_auth=(user, password), ) info = elastic_client._client.info() + print("Elastic still up now") return True if info else False except Exception: print("Elastic still not up") @@ -547,6 +548,32 @@ def _setup_stress_alerts_no_elastic(num_alerts): db_session.add_all(alerts) db_session.commit() + existed_last_alerts = db_session.query(LastAlert).all() + existed_last_alerts_dict = { + last_alert.fingerprint: last_alert + for last_alert in existed_last_alerts + } + last_alerts = [] + for alert in alerts: + if alert.fingerprint in existed_last_alerts_dict: + last_alert = existed_last_alerts_dict[alert.fingerprint] + last_alert.alert_id = alert.id + last_alert.timestamp=alert.timestamp + last_alerts.append( + last_alert + ) + else: + last_alerts.append( + LastAlert( + tenant_id=SINGLE_TENANT_UUID, + fingerprint=alert.fingerprint, + timestamp=alert.timestamp, + alert_id=alert.id, + ) + ) + db_session.add_all(last_alerts) + db_session.commit() + return alerts return _setup_stress_alerts_no_elastic @@ -559,7 +586,6 @@ def setup_stress_alerts( num_alerts = request.param.get( "num_alerts", 1000 ) # Default to 1000 alerts if not specified - alerts = setup_stress_alerts_no_elastic(num_alerts) # add all to elasticsearch alerts_dto = convert_db_alerts_to_dto_alerts(alerts) diff --git a/tests/test_incidents.py b/tests/test_incidents.py index e675c9075..6002058f6 100644 --- a/tests/test_incidents.py +++ b/tests/test_incidents.py @@ -2,8 +2,7 @@ from itertools import cycle import pytest -from sqlalchemy import distinct, func -from sqlalchemy.orm.exc import DetachedInstanceError +from sqlalchemy import distinct, func, desc from keep.api.core.db import ( IncidentSorting, @@ -24,7 +23,7 @@ IncidentSeverity, IncidentStatus, ) -from keep.api.models.db.alert import Alert, AlertToIncident +from keep.api.models.db.alert import Alert, LastAlertToIncident from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts from tests.fixtures.client import client, test_app # noqa @@ -50,7 +49,7 @@ def test_get_alerts_data_for_incident(db_session, create_alert): assert 100 == db_session.query(func.count(Alert.id)).scalar() assert 10 == unique_fingerprints - data = get_alerts_data_for_incident(SINGLE_TENANT_UUID, [a.id for a in alerts]) + data = get_alerts_data_for_incident(SINGLE_TENANT_UUID, [a.fingerprint for a in alerts]) assert data["sources"] == set([f"source_{i}" for i in range(10)]) assert data["services"] == set([f"service_{i}" for i in range(10)]) assert data["count"] == unique_fingerprints @@ -64,16 +63,27 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} ) - assert len(incident.alerts) == 0 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 0 + assert total_incident_alerts == 0 add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [a.id for a in alerts] + SINGLE_TENANT_UUID, incident.id, [a.fingerprint for a in alerts] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - # 110 alerts - assert len(incident.alerts) == 110 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 100 + assert total_incident_alerts == 100 # But 100 unique fingerprints assert incident.alerts_count == 100 @@ -86,7 +96,7 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti service_field = get_json_extract_field(db_session, Alert.event, "service") - service_0 = db_session.query(Alert.id).filter(service_field == "service_0").all() + service_0 = db_session.query(Alert.fingerprint).filter(service_field == "service_0").all() # Testing unique fingerprints more_alerts_with_same_fingerprints = setup_stress_alerts_no_elastic(10) @@ -94,26 +104,32 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti add_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, - [a.id for a in more_alerts_with_same_fingerprints], + [a.fingerprint for a in more_alerts_with_same_fingerprints], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) assert incident.alerts_count == 100 - assert db_session.query(func.count(AlertToIncident.alert_id)).scalar() == 120 + assert db_session.query(func.count(LastAlertToIncident.fingerprint)).scalar() == 100 remove_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, [ - service_0[0].id, + service_0[0].fingerprint, ], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - # 117 because we removed multiple alerts with service_0 - assert len(incident.alerts) == 117 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 99 + assert total_incident_alerts == 99 + assert "service_0" in incident.affected_services assert len(incident.affected_services) == 10 assert sorted(incident.affected_services) == sorted( @@ -121,7 +137,7 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti ) remove_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [a.id for a in service_0] + SINGLE_TENANT_UUID, incident.id, [a.fingerprint for a in service_0] ) # Removing shouldn't impact links between alert and incident if include_unlinked=True @@ -138,8 +154,14 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - # 108 because we removed multiple alert with same fingerprints - assert len(incident.alerts) == 108 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 90 + assert total_incident_alerts == 90 + assert "service_0" not in incident.affected_services assert len(incident.affected_services) == 9 assert sorted(incident.affected_services) == sorted( @@ -147,29 +169,36 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti ) source_1 = ( - db_session.query(Alert.id).filter(Alert.provider_type == "source_1").all() + db_session.query(Alert.fingerprint).filter(Alert.provider_type == "source_1").all() ) remove_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, [ - source_1[0].id, + source_1[0].fingerprint, ], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 105 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 89 + assert total_incident_alerts == 89 + assert "source_1" in incident.sources - # source_0 was removed together with service_0 + # source_0 was removed together with service_1 assert len(incident.sources) == 9 assert sorted(incident.sources) == sorted( ["source_{}".format(i) for i in range(1, 10)] ) remove_alerts_to_incident_by_incident_id( - "keep", incident.id, [a.id for a in source_1] + "keep", incident.id, [a.fingerprint for a in source_1] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) @@ -213,8 +242,19 @@ def test_get_last_incidents(db_session, create_alert): ) alert = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() + create_alert( + f"alert-test-2-{i}", + AlertStatus(status), + datetime.utcnow(), + { + "severity": AlertSeverity.from_number(severity), + "service": service, + }, + ) + alert2 = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() + add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [alert.id] + SINGLE_TENANT_UUID, incident.id, [alert.fingerprint, alert2.fingerprint] ) incidents_default, incidents_default_count = get_last_incidents(SINGLE_TENANT_UUID) @@ -246,19 +286,14 @@ def test_get_last_incidents(db_session, create_alert): for i, j in enumerate(range(5, 10)): assert incidents_limit_5_page_2[i].user_generated_name == f"test-{j}" - # If alerts not preloaded, we will have detached session issue during attempt to get them - # Background on this error at: https://sqlalche.me/e/14/bhk3 - with pytest.raises(DetachedInstanceError): - alerts = incidents_confirmed[0].alerts # noqa - incidents_with_alerts, _ = get_last_incidents( SINGLE_TENANT_UUID, is_confirmed=True, with_alerts=True ) for i in range(25): if incidents_with_alerts[i].status == IncidentStatus.MERGED.value: - assert len(incidents_with_alerts[i].alerts) == 0 + assert len(incidents_with_alerts[i]._alerts) == 0 else: - assert len(incidents_with_alerts[i].alerts) == 1 + assert len(incidents_with_alerts[i]._alerts) == 2 # Test sorting @@ -319,11 +354,16 @@ def test_incident_status_change( "keep", {"name": "test", "description": "test"} ) - add_alerts_to_incident_by_incident_id("keep", incident.id, [a.id for a in alerts]) + add_alerts_to_incident_by_incident_id( + "keep", + incident.id, + [a.fingerprint for a in alerts], + session=db_session + ) - incident = get_incident_by_id("keep", incident.id, with_alerts=True) + incident = get_incident_by_id("keep", incident.id, with_alerts=True, session=db_session) - alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) + alerts_dtos = convert_db_alerts_to_dto_alerts(incident._alerts, session=db_session) assert ( len( [ @@ -348,10 +388,11 @@ def test_incident_status_change( assert data["id"] == str(incident.id) assert data["status"] == IncidentStatus.ACKNOWLEDGED.value - incident = get_incident_by_id("keep", incident.id, with_alerts=True) + db_session.expire_all() + incident = get_incident_by_id("keep", incident.id, with_alerts=True, session=db_session) assert incident.status == IncidentStatus.ACKNOWLEDGED.value - alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) + alerts_dtos = convert_db_alerts_to_dto_alerts(incident._alerts) assert ( len( [ @@ -376,11 +417,12 @@ def test_incident_status_change( assert data["id"] == str(incident.id) assert data["status"] == IncidentStatus.RESOLVED.value - incident = get_incident_by_id("keep", incident.id, with_alerts=True) + db_session.expire_all() + incident = get_incident_by_id("keep", incident.id, with_alerts=True, session=db_session) assert incident.status == IncidentStatus.RESOLVED.value # All alerts are resolved as well - alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) + alerts_dtos = convert_db_alerts_to_dto_alerts(incident._alerts, session=db_session) assert ( len( [ @@ -475,23 +517,48 @@ def test_add_alerts_with_same_fingerprint_to_incident(db_session, create_alert): SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} ) - assert len(incident.alerts) == 0 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 0 + assert total_incident_alerts == 0 add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].id] + SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].fingerprint] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 2 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 1 + last_fp1_alert = ( + db_session + .query(Alert.timestamp) + .where(Alert.fingerprint == "fp1") + .order_by(desc(Alert.timestamp)).first() + ) + assert incident_alerts[0].timestamp == last_fp1_alert.timestamp + assert total_incident_alerts == 1 remove_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].id] + SINGLE_TENANT_UUID, incident.id, [fp1_alerts[0].fingerprint] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert len(incident.alerts) == 0 + incident_alerts, total_incident_alerts = get_incident_alerts_by_incident_id( + tenant_id=SINGLE_TENANT_UUID, + incident_id=incident.id, + ) + + assert len(incident_alerts) == 0 + assert total_incident_alerts == 0 def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elastic): @@ -522,7 +589,7 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti ) alerts_1 = db_session.query(Alert).all() add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident_1.id, [a.id for a in alerts_1] + SINGLE_TENANT_UUID, incident_1.id, [a.fingerprint for a in alerts_1] ) incident_2 = create_incident_from_dict( SINGLE_TENANT_UUID, @@ -532,19 +599,19 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti }, ) create_alert( - "fp20", + "fp20-0", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.CRITICAL.value}, ) create_alert( - "fp20", + "fp20-1", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.CRITICAL.value}, ) create_alert( - "fp20", + "fp20-2", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.CRITICAL.value}, @@ -553,7 +620,7 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti db_session.query(Alert).filter(Alert.fingerprint.startswith("fp20")).all() ) add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident_2.id, [a.id for a in alerts_2] + SINGLE_TENANT_UUID, incident_2.id, [a.fingerprint for a in alerts_2] ) incident_3 = create_incident_from_dict( SINGLE_TENANT_UUID, @@ -563,28 +630,28 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti }, ) create_alert( - "fp30", + "fp30-0", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.WARNING.value}, ) create_alert( - "fp30", + "fp30-1", AlertStatus.FIRING, datetime.utcnow(), - {"severity": AlertSeverity.WARNING.value}, + {"severity": AlertSeverity.INFO.value}, ) create_alert( - "fp30", + "fp30-2", AlertStatus.FIRING, datetime.utcnow(), - {"severity": AlertSeverity.INFO.value}, + {"severity": AlertSeverity.WARNING.value}, ) alerts_3 = ( db_session.query(Alert).filter(Alert.fingerprint.startswith("fp30")).all() ) add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, incident_3.id, [a.id for a in alerts_3] + SINGLE_TENANT_UUID, incident_3.id, [a.fingerprint for a in alerts_3] ) # before merge @@ -602,19 +669,21 @@ def test_merge_incidents(db_session, create_alert, setup_stress_alerts_no_elasti "test-user-email", ) + db_session.expire_all() + incident_1 = get_incident_by_id(SINGLE_TENANT_UUID, incident_1.id, with_alerts=True) - assert len(incident_1.alerts) == 9 + assert len(incident_1._alerts) == 8 assert incident_1.severity == IncidentSeverity.CRITICAL.order incident_2 = get_incident_by_id(SINGLE_TENANT_UUID, incident_2.id, with_alerts=True) - assert len(incident_2.alerts) == 0 + assert len(incident_2._alerts) == 0 assert incident_2.status == IncidentStatus.MERGED.value assert incident_2.merged_into_incident_id == incident_1.id assert incident_2.merged_at is not None assert incident_2.merged_by == "test-user-email" - incident_3 = get_incident_by_id(SINGLE_TENANT_UUID, incident_3.id, with_alerts=True) - assert len(incident_3.alerts) == 0 + incident_3 = get_incident_by_id(SINGLE_TENANT_UUID, incident_3.id, with_alerts=True, session=db_session) + assert len(incident_3._alerts) == 0 assert incident_3.status == IncidentStatus.MERGED.value assert incident_3.merged_into_incident_id == incident_1.id assert incident_3.merged_at is not None diff --git a/tests/test_metrics.py b/tests/test_metrics.py index fcd4306d6..8908896b6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -18,7 +18,7 @@ def test_add_remove_alert_to_incidents( valid_api_key = "valid_api_key" setup_api_key(db_session, valid_api_key) - add_alerts_to_incident_by_incident_id("keep", incident.id, [a.id for a in alerts]) + add_alerts_to_incident_by_incident_id("keep", incident.id, [a.fingerprint for a in alerts]) response = client.get("/metrics?labels=a.b", headers={"X-API-KEY": "valid_api_key"}) diff --git a/tests/test_rules_engine.py b/tests/test_rules_engine.py index 40627b372..07ad6cf70 100644 --- a/tests/test_rules_engine.py +++ b/tests/test_rules_engine.py @@ -7,7 +7,7 @@ import pytest from keep.api.core.db import create_rule as create_rule_db -from keep.api.core.db import get_incident_alerts_by_incident_id, get_last_incidents +from keep.api.core.db import get_incident_alerts_by_incident_id, get_last_incidents, set_last_alert from keep.api.core.db import get_rules as get_rules_db from keep.api.core.dependencies import SINGLE_TENANT_UUID from keep.api.models.alert import ( @@ -63,10 +63,13 @@ def test_sanity(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) + db_session.add(alert) db_session.commit() + + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -110,10 +113,11 @@ def test_sanity_2(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) db_session.add(alert) db_session.commit() + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -158,10 +162,11 @@ def test_sanity_3(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) db_session.add(alert) db_session.commit() + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -206,10 +211,11 @@ def test_sanity_4(db_session): provider_type="test", provider_id="test", event=alerts[0].dict(), - fingerprint="test", + fingerprint=alerts[0].fingerprint, ) db_session.add(alert) db_session.commit() + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id results = rules_engine.run_rules(alerts) @@ -223,7 +229,7 @@ def test_incident_attributes(db_session): AlertDto( id=str(uuid.uuid4()), source=["grafana"], - name="grafana-test-alert", + name=f"grafana-test-alert-{i}", status=AlertStatus.FIRING, severity=AlertSeverity.CRITICAL, lastReceived=datetime.datetime.now().isoformat(), @@ -255,13 +261,15 @@ def test_incident_attributes(db_session): provider_type="test", provider_id="test", event=alert.dict(), - fingerprint=hashlib.sha256(json.dumps(alert.dict()).encode()).hexdigest(), + fingerprint=alert.fingerprint, timestamp=alert.lastReceived, ) for alert in alerts_dto ] db_session.add_all(alerts) db_session.commit() + for alert in alerts: + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) for i, alert in enumerate(alerts_dto): alert.event_id = alerts[i].id @@ -283,7 +291,7 @@ def test_incident_severity(db_session): AlertDto( id=str(uuid.uuid4()), source=["grafana"], - name="grafana-test-alert", + name=f"grafana-test-alert-{i}", status=AlertStatus.FIRING, severity=AlertSeverity.INFO, lastReceived=datetime.datetime.now().isoformat(), @@ -315,13 +323,15 @@ def test_incident_severity(db_session): provider_type="test", provider_id="test", event=alert.dict(), - fingerprint=hashlib.sha256(json.dumps(alert.dict()).encode()).hexdigest(), + fingerprint=alert.fingerprint, timestamp=alert.lastReceived, ) for alert in alerts_dto ] db_session.add_all(alerts) db_session.commit() + for alert in alerts: + set_last_alert(SINGLE_TENANT_UUID, alert, db_session) for i, alert in enumerate(alerts_dto): alert.event_id = alerts[i].id From 8e8f086ec42d69dc4794421e3a9a775a863eb330 Mon Sep 17 00:00:00 2001 From: Matvey Kukuy Date: Mon, 2 Dec 2024 20:42:18 +0200 Subject: [PATCH 03/12] fix: incident linking (#2730) --- .../ui/change-same-incident-in-the-past-form.tsx | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keep-ui/features/same-incidents-in-the-past/ui/change-same-incident-in-the-past-form.tsx b/keep-ui/features/same-incidents-in-the-past/ui/change-same-incident-in-the-past-form.tsx index 77c0fb026..f97c165a4 100644 --- a/keep-ui/features/same-incidents-in-the-past/ui/change-same-incident-in-the-past-form.tsx +++ b/keep-ui/features/same-incidents-in-the-past/ui/change-same-incident-in-the-past-form.tsx @@ -35,8 +35,9 @@ export function ChangeSameIncidentInThePastForm({ await updateIncident( incident.id, { - // TODO: remove this once the backend supports partial updates - ...incident, + user_generated_name: incident.user_generated_name, + user_summary: incident.user_summary, + assignee: incident.assignee, same_incident_in_the_past_id: selectedIncidentId, }, false From 641420d37aa8d7bd448a5b08bb3154563fafa41a Mon Sep 17 00:00:00 2001 From: Vladimir Filonov Date: Tue, 3 Dec 2024 12:01:51 +0400 Subject: [PATCH 04/12] fix: Handle Duplication entry error in set_last_alert (#2733) --- keep/api/core/db.py | 57 +++++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/keep/api/core/db.py b/keep/api/core/db.py index c600c1214..0d360659f 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -4580,24 +4580,45 @@ def get_last_alert_by_fingerprint( def set_last_alert( tenant_id: str, alert: Alert, session: Optional[Session] = None ) -> None: - last_alert = get_last_alert_by_fingerprint(tenant_id, alert.fingerprint, session) + logger.info( + f"Set last alert for `{alert.fingerprint}`" + ) + with existed_or_new_session(session) as session: + last_alert = get_last_alert_by_fingerprint(tenant_id, alert.fingerprint, session) - # To prevent rare, but possible race condition - # For example if older alert failed to process - # and retried after new one + # To prevent rare, but possible race condition + # For example if older alert failed to process + # and retried after new one - if last_alert and last_alert.timestamp.replace(tzinfo=tz.UTC) < alert.timestamp.replace(tzinfo=tz.UTC): - last_alert.timestamp = alert.timestamp - last_alert.alert_id = alert.id - session.add(last_alert) - session.commit() + if last_alert and last_alert.timestamp.replace(tzinfo=tz.UTC) < alert.timestamp.replace(tzinfo=tz.UTC): - elif not last_alert: - last_alert = LastAlert( - tenant_id=tenant_id, - fingerprint=alert.fingerprint, - timestamp=alert.timestamp, - alert_id=alert.id, - ) - session.add(last_alert) - session.commit() + logger.info( + f"Update last alert for `{alert.fingerprint}`: {last_alert.alert_id} -> {alert.id}" + ) + last_alert.timestamp = alert.timestamp + last_alert.alert_id = alert.id + session.add(last_alert) + session.commit() + + elif not last_alert: + logger.info( + f"No last alert for `{alert.fingerprint}`, creating new" + ) + last_alert = LastAlert( + tenant_id=tenant_id, + fingerprint=alert.fingerprint, + timestamp=alert.timestamp, + alert_id=alert.id, + ) + + try: + session.add(last_alert) + session.commit() + except IntegrityError as ex: + reason = ex.args[0] + if "Duplicate entry" in reason: + logger.info( + f"Duplicate primary key for `{alert.fingerprint}`. Retrying." + ) + session.rollback() + return set_last_alert(tenant_id, alert, session) From cd2eeb0784b1d4ae0eec29204d84c9066c39098c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 10:24:44 +0200 Subject: [PATCH 05/12] chore(deps): bump python-multipart from 0.0.7 to 0.0.18 (#2734) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Tal --- poetry.lock | 13 +++++-------- pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/poetry.lock b/poetry.lock index 2a7833d9c..3bcbc2d9c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4009,18 +4009,15 @@ requests-toolbelt = ">=0.6.0" [[package]] name = "python-multipart" -version = "0.0.7" +version = "0.0.18" description = "A streaming multipart parser for Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "python_multipart-0.0.7-py3-none-any.whl", hash = "sha256:b1fef9a53b74c795e2347daac8c54b252d9e0df9c619712691c1cc8021bd3c49"}, - {file = "python_multipart-0.0.7.tar.gz", hash = "sha256:288a6c39b06596c1b988bb6794c6fbc80e6c369e35e5062637df256bee0c9af9"}, + {file = "python_multipart-0.0.18-py3-none-any.whl", hash = "sha256:efe91480f485f6a361427a541db4796f9e1591afc0fb8e7a4ba06bfbc6708996"}, + {file = "python_multipart-0.0.18.tar.gz", hash = "sha256:7a68db60c8bfb82e460637fa4750727b45af1d5e2ed215593f917f64694d34fe"}, ] -[package.extras] -dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==2.2.0)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pyyaml (==5.1)"] - [[package]] name = "python-socketio" version = "5.11.2" @@ -5202,4 +5199,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "ba01130745d4f3b915f1f6c509b06e5a681244cac5129b900d03a9bd4bde62c6" +content-hash = "e866f8f4cf8210e17e03248ad91f473c777ecdd3405773f0d23a7c93210c9196" diff --git a/pyproject.toml b/pyproject.toml index 59646de6c..29b099c2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ posthog = "^3.0.1" google-cloud-storage = "^2.10.0" auth0-python = "^4.4.1" asyncio = "^3.4.3" -python-multipart = "^0.0.7" +python-multipart = "^0.0.18" kubernetes = "^27.2.0" opentelemetry-exporter-otlp-proto-grpc = "^1.20.0" opentelemetry-instrumentation-sqlalchemy = "^0.41b0" From 036f40613c81480ec58ac6c94a4fa29d5dba243e Mon Sep 17 00:00:00 2001 From: Tal Date: Tue, 3 Dec 2024 13:07:57 +0200 Subject: [PATCH 06/12] feat: microsoft teams provider add support for adaptivecards (#2736) --- .../documentation/teams-provider.mdx | 170 ++++++++++++++++-- .../workflows/keep-teams-adaptive-cards.yaml | 23 +++ keep-ui/package-lock.json | 8 +- keep-ui/package.json | 2 +- .../teams_provider/teams_provider.py | 69 +++++-- pyproject.toml | 2 +- 6 files changed, 237 insertions(+), 37 deletions(-) create mode 100644 examples/workflows/keep-teams-adaptive-cards.yaml diff --git a/docs/providers/documentation/teams-provider.mdx b/docs/providers/documentation/teams-provider.mdx index 673b9dc07..345fc3a40 100644 --- a/docs/providers/documentation/teams-provider.mdx +++ b/docs/providers/documentation/teams-provider.mdx @@ -11,14 +11,19 @@ The `notify` function in the `TeamsProvider` class takes the following parameter ```python kwargs (dict): message (str): The message to send. *Required* - typeCard (str): The card type. (MessageCard is default) - themeColor (str): Hexadecimal color. - sections (array): Array of custom informations + typeCard (str): The card type. Can be "MessageCard" (legacy) or "message" (for Adaptive Cards). Default is "message" + themeColor (str): Hexadecimal color (only used with MessageCard type) + sections (array/str): For MessageCard: Array of custom information sections + For Adaptive Cards: Array of card elements following the Adaptive Card schema + Can be provided as a JSON string or array + attachments (array/str): Custom attachments array for Adaptive Cards (overrides default attachment structure) + Can be provided as a JSON string or array + schema (str): Schema URL for Adaptive Cards. Default is "http://adaptivecards.io/schemas/adaptive-card.json" ``` ## Outputs -_No information yet, feel free to contribute it using the "Edit this page" link the bottom of the page_ +The response as JSON, which is the response from the Microsoft Teams API. ## Authentication Parameters @@ -28,26 +33,159 @@ The TeamsProviderAuthConfig class takes the following parameters: ## Connecting with the Provider -1. Open the Microsoft Teams application or website and select the team or channel where you want to add the webhook. + + + 1. In the New Teams client, select Teams and navigate to the channel where + you want to add an Incoming Webhook. 2. Select More options ••• on the right + side of the channel name. 3. Select Manage Channel + + + + + For members who aren't admins of the channel, the Manage channel option is + available under the Open channel details option in the upper-right corner + of a channel. + + 4. Select Edit + + + + 5. Search for Incoming Webhook and select Add. + + + + 6. Select Add + + + + 7. Provide a name for the webhook and upload an image if necessary. 8. Select + Create. + + + + 9. Copy and save the unique webhook URL present in the dialog. The URL maps to + the channel and you can use it to send information to Teams. 10. Select Done. + The webhook is now available in the Teams channel. + + + + + + 1. In the Classic Teams client, select Teams and navigate to the channel + where you want to add an Incoming Webhook. 2. Select More options ••• from + the upper-right corner. 3. Select Connectors from the dropdown menu. + + + + 4. Search for Incoming Webhook and select Add. + + + + 5. Select Add. + + + + 6. Provide a name for the webhook and upload an image if necessary. 7. + Select Create. + + + + 8. Copy and save the unique webhook URL present in the dialog. The URL maps + to the channel and you can use it to send information to Teams. 9. Select + Done. + + + + + -2. Click on the three-dot icon next to the team or channel name and select "Connectors" from the dropdown menu. - -3. Search for "Incoming Webhook" and click on the "Add" button. - -4. Give your webhook a name and an optional icon, then click on the "Create" button. +## Notes -5. Copy the webhook URL that is generated and save it for later use. +When using Adaptive Cards (`typeCard="message"`): + +- The `sections` parameter should follow the [Adaptive Cards schema](https://adaptivecards.io/explorer/) +- `themeColor` is ignored for Adaptive Cards +- If no sections are provided, the message will be displayed as a simple text block +- Both `sections` and `attachments` can be provided as JSON strings or arrays + +### Workflow Example + +You can also find this example in our [examples](https://github.com/keephq/keep/tree/main/examples/workflows/keep-teams-adaptive-cards.yaml) folder in the Keep GitHub repository. + +```yaml +id: 6bc7c72e-ab3d-4913-84dd-08b9323195ae +description: Teams Adaptive Cards Example +disabled: false +triggers: + - type: manual + - filters: + - key: source + value: r".*" + type: alert +consts: {} +name: Keep Teams Adaptive Cards +owners: [] +services: [] +steps: [] +actions: + - name: teams-action + provider: + config: "{{ providers.teams }}" + type: teams + with: + message: "" + sections: '[{"type": "TextBlock", "text": "{{alert.name}}"}, {"type": "TextBlock", "text": "Tal from Keep"}]' + typeCard: message +``` -6. Select the options that you want to configure for your webhook, such as the default name and avatar that will be used when posting messages. + + The sections parameter is a JSON string that follows the Adaptive Cards schema, but can also be an object. + If it's a string, it will be parsed as a JSON string. + -7. Click on the "Save" button to save your webhook settings. +### Using Sections -You can now use the webhook URL to send messages to the selected channel or team in Microsoft Teams. +```python +provider.notify( + message="Fallback text", + typeCard="message", + sections=[ + { + "type": "TextBlock", + "text": "Hello from Adaptive Card!" + }, + { + "type": "Image", + "url": "https://example.com/image.jpg" + } + ] +) +``` -## Notes +### Using Custom Attachments -_No information yet, feel free to contribute it using the "Edit this page" link the bottom of the page_ +```python +provider.notify( + typeCard="message", + attachments=[{ + "contentType": "application/vnd.microsoft.card.adaptive", + "content": { + "type": "AdaptiveCard", + "version": "1.2", + "body": [ + { + "type": "TextBlock", + "text": "Custom Attachment Example" + } + ] + } + }] +) +``` ## Useful Links - https://learn.microsoft.com/pt-br/microsoftteams/platform/webhooks-and-connectors/how-to/add-incoming-webhook +- https://learn.microsoft.com/en-us/microsoftteams/platform/webhooks-and-connectors/how-to/connectors-using +- https://adaptivecards.io/explorer/ +- https://adaptivecards.io/schemas/adaptive-card.json diff --git a/examples/workflows/keep-teams-adaptive-cards.yaml b/examples/workflows/keep-teams-adaptive-cards.yaml new file mode 100644 index 000000000..ad2f64509 --- /dev/null +++ b/examples/workflows/keep-teams-adaptive-cards.yaml @@ -0,0 +1,23 @@ +id: 6bc7c72e-ab3d-4913-84dd-08b9323195ae +description: Teams Adaptive Cards Example +disabled: false +triggers: + - type: manual + - filters: + - key: source + value: r".*" + type: alert +consts: {} +name: Keep Teams Adaptive Cards +owners: [] +services: [] +steps: [] +actions: + - name: teams-action + provider: + config: "{{ providers.teams }}" + type: teams + with: + message: "" + sections: '[{"type": "TextBlock", "text": "{{alert.name}}"}, {"type": "TextBlock", "text": "Tal from Keep"}]' + typeCard: message diff --git a/keep-ui/package-lock.json b/keep-ui/package-lock.json index d72714d7d..0a3fd9436 100644 --- a/keep-ui/package-lock.json +++ b/keep-ui/package-lock.json @@ -63,7 +63,7 @@ "postcss-nested": "^6.0.1", "postcss-selector-parser": "^6.0.12", "postcss-value-parser": "^4.2.0", - "posthog-js": "^1.194.1", + "posthog-js": "^1.194.2", "posthog-node": "^3.1.1", "pusher-js": "^8.3.0", "react": "^18.3.1", @@ -16753,9 +16753,9 @@ } }, "node_modules/posthog-js": { - "version": "1.194.1", - "resolved": "https://registry.npmjs.org/posthog-js/-/posthog-js-1.194.1.tgz", - "integrity": "sha512-d68hmU9DY4iPe3WneBlnglERhimRhXuF7Lx0Au6OTmOL+IFdFUxB3Qf5LaLqJc1QLt3NUolMq1HiXOaIULe3kQ==", + "version": "1.194.2", + "resolved": "https://registry.npmjs.org/posthog-js/-/posthog-js-1.194.2.tgz", + "integrity": "sha512-UVFVvx6iJMEjHo+N/HmPDK4zjkVY8m+G13jTQmvHMtByfyn/fH6JhOz/ph+gtmvXPI03130y1qrwwgPIZ3ty8A==", "dependencies": { "core-js": "^3.38.1", "fflate": "^0.4.8", diff --git a/keep-ui/package.json b/keep-ui/package.json index 58420dd53..84ec9dad5 100644 --- a/keep-ui/package.json +++ b/keep-ui/package.json @@ -64,7 +64,7 @@ "postcss-nested": "^6.0.1", "postcss-selector-parser": "^6.0.12", "postcss-value-parser": "^4.2.0", - "posthog-js": "^1.194.1", + "posthog-js": "^1.194.2", "posthog-node": "^3.1.1", "pusher-js": "^8.3.0", "react": "^18.3.1", diff --git a/keep/providers/teams_provider/teams_provider.py b/keep/providers/teams_provider/teams_provider.py index d7898e9a5..3c709f71a 100644 --- a/keep/providers/teams_provider/teams_provider.py +++ b/keep/providers/teams_provider/teams_provider.py @@ -4,6 +4,7 @@ import dataclasses +import json5 as json import pydantic import requests @@ -51,37 +52,78 @@ def dispose(self): def _notify( self, message="", - typeCard="MessageCard", + typeCard="message", themeColor=None, sections=[], + schema="http://adaptivecards.io/schemas/adaptive-card.json", + attachments=[], **kwargs: dict, ): """ Notify alert message to Teams using the Teams Incoming Webhook API - https://learn.microsoft.com/pt-br/microsoftteams/platform/webhooks-and-connectors/how-to/connectors-using?tabs=cURL Args: - kwargs (dict): The providers with context + message (str): The message to send + typeCard (str): Type of card to send ("MessageCard" or "message" for Adaptive Cards) + themeColor (str): Color theme for MessageCard + sections (list): Sections for MessageCard or Adaptive Card content + attachments (list): Attachments for Adaptive Card + **kwargs (dict): Additional arguments """ self.logger.debug("Notifying alert message to Teams") - webhook_url = self.authentication_config.webhook_url - response = requests.post( - webhook_url, - json={ + if isinstance(sections, str): + try: + sections = json.loads(sections) + except json.JSONDecodeError as e: + self.logger.error(f"Failed to decode sections string to JSON: {e}") + + if attachments and isinstance(attachments, str): + try: + attachments = json.loads(attachments) + except json.JSONDecodeError as e: + self.logger.error(f"Failed to decode attachments string to JSON: {e}") + + if typeCard == "message": + # Adaptive Card format + payload = { + "type": "message", + "attachments": attachments + or [ + { + "contentType": "application/vnd.microsoft.card.adaptive", + "contentUrl": None, + "content": { + "$schema": schema, + "type": "AdaptiveCard", + "version": "1.2", + "body": ( + sections + if sections + else [{"type": "TextBlock", "text": message}] + ), + }, + } + ], + } + else: + # Standard MessageCard format + payload = { "@type": typeCard, "themeColor": themeColor, "text": message, "sections": sections, - }, - ) + } + + response = requests.post(webhook_url, json=payload) if not response.ok: raise ProviderException( f"{self.__class__.__name__} failed to notify alert message to Teams: {response.text}" ) self.logger.debug("Alert message notified to Teams") + return response.json() if __name__ == "__main__": @@ -106,12 +148,9 @@ def _notify( ) provider = TeamsProvider(context_manager, provider_id="teams", config=config) provider.notify( - typeCard="MessageCard", - themeColor="0076D7", - message="Microsoft Teams alert", + typeCard="message", sections=[ - {"name": "Assigned to", "value": "Danilo Vaz"}, - {"name": "Sum", "value": 10}, - {"name": "Count", "value": 100}, + {"type": "TextBlock", "text": "Danilo Vaz"}, + {"type": "TextBlock", "text": "Tal from Keep"}, ], ) diff --git a/pyproject.toml b/pyproject.toml index 29b099c2e..3d6a966d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "keep" -version = "0.30.7" +version = "0.31.0" description = "Alerting. for developers, by developers." authors = ["Keep Alerting LTD"] packages = [{include = "keep"}] From 07f0856fbaf72a405c0af095ef522e29c07636b5 Mon Sep 17 00:00:00 2001 From: Posi Adedeji <39467790+theedigerati@users.noreply.github.com> Date: Tue, 3 Dec 2024 13:16:40 +0100 Subject: [PATCH 07/12] fix: add provider input validation & type safety (#2430) Co-authored-by: Matvey Kukuy Co-authored-by: Kirill Chernakov --- keep-ui/app/(keep)/providers/form-fields.tsx | 522 ++++++++++ .../app/(keep)/providers/form-validation.ts | 343 +++++++ .../(keep)/providers/provider-form-scopes.tsx | 25 +- .../app/(keep)/providers/provider-form.tsx | 907 ++++++------------ .../app/(keep)/providers/providers-tiles.tsx | 42 +- keep-ui/app/(keep)/providers/providers.tsx | 28 +- .../app/(keep)/workflows/workflow-tile.tsx | 33 - keep-ui/package-lock.json | 1 + keep-ui/package.json | 1 + keep-ui/shared/lib/encodings.ts | 25 + keep/parser/parser.py | 3 +- .../appdynamics_provider.py | 3 +- .../auth0_provider/auth0_provider.py | 23 +- keep/providers/base/base_provider.py | 34 +- .../bigquery_provider/bigquery_provider.py | 2 +- .../centreon_provider/centreon_provider.py | 13 +- .../cilium_provider/cilium_provider.py | 7 +- .../clickhouse_provider.py | 18 +- .../datadog_provider/datadog_provider.py | 22 +- .../discord_provider/discord_provider.py | 4 +- .../elastic_provider/elastic_provider.py | 39 +- .../gcpmonitoring_provider.py | 2 +- .../gitlab_provider/gitlab_provider.py | 5 +- keep/providers/gke_provider/gke_provider.py | 2 +- .../google_chat_provider.py | 8 +- .../grafana_incident_provider.py | 5 +- .../grafana_oncall_provider.py | 12 +- .../grafana_provider/grafana_provider.py | 15 +- .../graylog_provider/graylog_provider.py | 3 +- .../ilert_provider/ilert_provider.py | 25 +- keep/providers/jira_provider/jira_provider.py | 18 +- .../jiraonprem_provider.py | 3 +- .../kafka_provider/kafka_provider.py | 8 +- .../kibana_provider/kibana_provider.py | 25 +- .../kubernetes_provider.py | 6 +- .../mattermost_provider.py | 5 +- .../mongodb_provider/mongodb_provider.py | 15 +- .../mysql_provider/mysql_provider.py | 9 +- .../newrelic_provider/newrelic_provider.py | 18 +- keep/providers/ntfy_provider/ntfy_provider.py | 3 +- .../openobserve_provider.py | 28 +- .../openshift_provider/openshift_provider.py | 8 +- .../opsgenie_provider/opsgenie_provider.py | 3 +- .../postgres_provider/postgres_provider.py | 18 +- .../prometheus_provider.py | 3 +- .../redmine_provider/redmine_provider.py | 3 +- .../sentry_provider/sentry_provider.py | 4 +- .../servicenow_provider.py | 11 +- .../site24x7_provider/site24x7_provider.py | 7 +- .../slack_provider/slack_provider.py | 7 +- keep/providers/smtp_provider/smtp_provider.py | 7 +- .../splunk_provider/splunk_provider.py | 25 +- .../squadcast_provider/squadcast_provider.py | 14 +- keep/providers/ssh_provider/ssh_provider.py | 18 +- .../teams_provider/teams_provider.py | 4 +- .../uptimekuma_provider.py | 13 +- .../victoriametrics_provider.py | 15 +- .../webhook_provider/webhook_provider.py | 3 +- .../zabbix_provider/zabbix_provider.py | 3 +- keep/validation/__init__.py | 0 keep/validation/fields.py | 163 ++++ tests/e2e_tests/test_end_to_end.py | 154 ++- tests/test_provider_validation_fields.py | 171 ++++ 63 files changed, 2030 insertions(+), 934 deletions(-) create mode 100644 keep-ui/app/(keep)/providers/form-fields.tsx create mode 100644 keep-ui/app/(keep)/providers/form-validation.ts create mode 100644 keep-ui/shared/lib/encodings.ts create mode 100644 keep/validation/__init__.py create mode 100644 keep/validation/fields.py create mode 100644 tests/test_provider_validation_fields.py diff --git a/keep-ui/app/(keep)/providers/form-fields.tsx b/keep-ui/app/(keep)/providers/form-fields.tsx new file mode 100644 index 000000000..6ab5090b7 --- /dev/null +++ b/keep-ui/app/(keep)/providers/form-fields.tsx @@ -0,0 +1,522 @@ +import { useMemo, useRef, useState } from "react"; +import { + Provider, + ProviderAuthConfig, + ProviderFormData, + ProviderFormKVData, + ProviderFormValue, + ProviderInputErrors, +} from "./providers"; +import { + Title, + Text, + Button, + Callout, + Icon, + Subtitle, + Divider, + TextInput, + Select, + SelectItem, + Card, + Tab, + TabList, + TabGroup, + TabPanel, + TabPanels, + Accordion, + AccordionHeader, + AccordionBody, + Badge, + Switch, +} from "@tremor/react"; +import { + QuestionMarkCircleIcon, + ArrowLongRightIcon, + ArrowLongLeftIcon, + ArrowTopRightOnSquareIcon, + ArrowDownOnSquareIcon, + GlobeAltIcon, + DocumentTextIcon, + PlusIcon, + TrashIcon, +} from "@heroicons/react/24/outline"; + +export function getRequiredConfigs( + config: Provider["config"] +): Provider["config"] { + const configs = Object.entries(config).filter( + ([_, config]) => config.required && !config.config_main_group + ); + return Object.fromEntries(configs); +} + +export function getOptionalConfigs( + config: Provider["config"] +): Provider["config"] { + const configs = Object.entries(config).filter( + ([_, config]) => + !config.required && !config.hidden && !config.config_main_group + ); + return Object.fromEntries(configs); +} + +function getConfigGroup(type: "config_main_group" | "config_sub_group") { + return (configs: Provider["config"]) => { + return Object.entries(configs).reduce( + (acc: Record, [key, config]) => { + const group = config[type]; + if (!group) return acc; + acc[group] ??= {}; + acc[group][key] = config; + return acc; + }, + {} + ); + }; +} + +export const getConfigByMainGroup = getConfigGroup("config_main_group"); +export const getConfigBySubGroup = getConfigGroup("config_sub_group"); + +export function GroupFields({ + groupName, + fields, + data, + errors, + disabled, + onChange, +}: { + groupName: string; + fields: Provider["config"]; + data: ProviderFormData; + errors: ProviderInputErrors; + disabled: boolean; + onChange: (key: string, value: ProviderFormValue) => void; +}) { + const subGroups = useMemo(() => getConfigBySubGroup(fields), [fields]); + + if (Object.keys(subGroups).length === 0) { + // If no subgroups, render fields directly + return ( + + {groupName} + {Object.entries(fields).map(([field, config]) => ( +
+ +
+ ))} +
+ ); + } + + return ( + + {groupName} + + + {Object.keys(subGroups).map((name) => ( + + {name} + + ))} + + + {Object.entries(subGroups).map(([name, subGroup]) => ( + + {Object.entries(subGroup).map(([field, config]) => ( +
+ +
+ ))} +
+ ))} +
+
+
+ ); +} + +export function FormField({ + id, + config, + value, + error, + disabled, + title, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + value: ProviderFormValue; + error?: string; + disabled: boolean; + title?: string; + onChange: (key: string, value: ProviderFormValue) => void; +}) { + function handleInputChange(event: React.ChangeEvent) { + let value; + const files = event.target.files; + const name = event.target.name; + + // If the input is a file, retrieve the file object, otherwise retrieve the value + if (files && files.length > 0) { + value = files[0]; // Assumes single file upload + } else { + value = event.target.value; + } + + onChange(name, value); + } + + switch (config.type) { + case "select": + return ( + onChange(id, value)} + /> + ); + case "form": + return ( + onChange(id, data)} + onChange={(value) => onChange(id, value)} + /> + ); + case "file": + return ( + + ); + case "switch": + return ( + onChange(id, value)} + /> + ); + default: + return ( + + ); + } +} + +export function TextField({ + id, + config, + value, + error, + disabled, + title, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + value: ProviderFormValue; + error?: string; + disabled: boolean; + title?: string; + onChange: (e: React.ChangeEvent) => void; +}) { + return ( + <> + + + + ); +} + +export function SelectField({ + id, + config, + value, + error, + disabled, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + value: ProviderFormValue; + error?: string; + disabled: boolean; + onChange: (value: string) => void; +}) { + return ( + <> + + + + ); +} + +export function FileField({ + id, + config, + disabled, + error, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + disabled: boolean; + error?: string; + onChange: (e: React.ChangeEvent) => void; +}) { + const [selected, setSelected] = useState(); + const ref = useRef(null); + + function handleClick(e: React.MouseEvent) { + e.preventDefault(); + if (ref.current) ref.current.click(); + } + + function handleChange(e: React.ChangeEvent) { + if (e.target.files && e.target.files[0]) { + setSelected(e.target.files[0].name); + } + onChange(e); + } + + return ( + <> + + + + {error && error?.length > 0 && ( +

{error}

+ )} + + ); +} + +export function KVForm({ + id, + config, + value, + error, + disabled, + onAdd, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + value: ProviderFormValue; + error?: string; + disabled: boolean; + onAdd: (data: ProviderFormKVData) => void; + onChange: (value: ProviderFormKVData) => void; +}) { + function handleAdd() { + const newData = Array.isArray(value) + ? [...value, { key: "", value: "" }] + : [{ key: "", value: "" }]; + onAdd(newData); + } + + return ( +
+
+ + +
+ {Array.isArray(value) && } + {error && error?.length > 0 && ( +

{error}

+ )} +
+ ); +} + +export const KVInput = ({ + data, + onChange, +}: { + data: ProviderFormKVData; + onChange: (entries: ProviderFormKVData) => void; +}) => { + const handleEntryChange = (index: number, name: string, value: string) => { + const newEntries = data.map((entry, i) => + i === index ? { ...entry, [name]: value } : entry + ); + onChange(newEntries); + }; + + const removeEntry = (index: number) => { + const newEntries = data.filter((_, i) => i !== index); + onChange(newEntries); + }; + + return ( +
+ {data.map((entry, index) => ( +
+ handleEntryChange(index, "key", e.target.value)} + placeholder="Key" + className="mr-2" + /> + handleEntryChange(index, "value", e.target.value)} + placeholder="Value" + className="mr-2" + /> +
+ ))} +
+ ); +}; + +export function SwitchInput({ + id, + config, + value, + disabled, + onChange, +}: { + id: string; + config: ProviderAuthConfig; + value: ProviderFormValue; + disabled?: boolean; + onChange: (value: boolean) => void; +}) { + if (typeof value !== "boolean") return null; + + return ( +
+ + +
+ ); +} + +export function FieldLabel({ + id, + config, +}: { + id: string; + config: ProviderAuthConfig; +}) { + return ( + + ); +} diff --git a/keep-ui/app/(keep)/providers/form-validation.ts b/keep-ui/app/(keep)/providers/form-validation.ts new file mode 100644 index 000000000..cf1bf5e1b --- /dev/null +++ b/keep-ui/app/(keep)/providers/form-validation.ts @@ -0,0 +1,343 @@ +import { z } from "zod"; +import { Provider } from "./providers"; + +type URLOptions = { + protocols: string[]; + requireTld: boolean; + requireProtocol: boolean; + requirePort: boolean; + alllowMultihost: boolean; + validateLength: boolean; + maxLength: number; +}; + +type ValidatorRes = { success: true } | { success: false; msg: string }; + +const defaultURLOptions: URLOptions = { + protocols: [], + requireTld: false, + requireProtocol: true, + requirePort: false, + alllowMultihost: false, + validateLength: true, + maxLength: 2 ** 16, +}; + +function mergeOptions>( + defaults: T, + opts?: Partial +): T { + if (!opts) return defaults; + return { ...defaults, ...opts }; +} + +const error = (msg: string) => ({ success: false, msg }); +const urlError = error("Please provide a valid URL"); +const protocolError = error("A valid URL protocol is required"); +const relProtocolError = error("A protocol-relavie URL is not allowed"); +const multiProtocolError = error("URL cannot have more than one protocol"); +const missingPortError = error("A URL with a port number is required"); +const portError = error("Invalid port number"); +const hostError = error("Invalid URL host"); +const hostWildcardError = error("Wildcard in URL host is not allowed"); +const multihostError = error("Multiple hosts are not allowed"); +const multihostProtocolError = error("Invalid multihost protocol"); +const tldError = error( + "URL must contain a valid TLD e.g .com, .io, .dev, .net" +); + +function getProtocolError(protocols: URLOptions["protocols"]) { + if (protocols.length === 0) return protocolError; + if (protocols.length === 1) + return error(`A URL with \`${protocols[0]}\` protocol is required`); + if (protocols.length === 2) + return error( + `A URL with \`${protocols[0]}\` or \`${protocols[1]}\` protocol is required` + ); + const lst = protocols.length - 1; + const wrap = (acc: string, p: string) => acc + `\`${p}\``; + const optsStr = protocols.reduce( + (acc, p, i) => + i === lst + ? wrap(acc, p) + : i === lst - 1 + ? wrap(acc, p) + " or " + : wrap(acc, p) + ", ", + "" + ); + return error(`A URL with one of ${optsStr} protocols is required`); +} + +function isFQDN(str: string, options?: Partial): ValidatorRes { + const opts = mergeOptions(defaultURLOptions, options); + + if (str[str.length - 1] === ".") return hostError; // trailing dot not allowed + if (str.indexOf("*.") === 0) return hostWildcardError; // wildcard not allowed + + const parts = str.split("."); + const tld = parts[parts.length - 1]; + const tldRegex = + /^([a-z\u00A1-\u00A8\u00AA-\uD7FF\uF900-\uFDCF\uFDF0-\uFFEF]{2,}|xn[a-z0-9-]{2,})$/i; + + if ( + opts.requireTld && + (parts.length < 2 || !tldRegex.test(tld) || /\s/.test(tld)) + ) + return tldError; + + const partsValid = parts.every((part) => { + if (!/^[a-z_\u00a1-\uffff0-9-]+$/i.test(part)) { + return false; + } + + // disallow full-width chars + if (/[\uff01-\uff5e]/.test(part)) { + return false; + } + + // disallow parts starting or ending with hyphen + if (/^-|-$/.test(part)) { + return false; + } + + return true; + }); + + return partsValid ? { success: true } : hostError; +} + +function isIP(str: string) { + const validation = z.string().ip().safeParse(str); + return validation.success; +} + +function validateHost(hostname: string, opts: URLOptions): ValidatorRes { + let host: string; + let port: number; + let portStr: string = ""; + let split: string[]; + + // extract ipv6 & port + const wrapped_ipv6 = /^\[([^\]]+)\](?::([0-9]+))?$/; + const ipv6Match = hostname.match(wrapped_ipv6); + if (ipv6Match) { + host = ipv6Match[1]; + portStr = ipv6Match[2]; + } else { + split = hostname.split(":"); + host = split.shift() ?? ""; + if (split.length) portStr = split.join(":"); + } + + if (portStr.length) { + port = parseInt(portStr, 10); + if (Number.isNaN(port)) return urlError; + if (port <= 0 || port > 65_535) return portError; + } else if (opts.requirePort) return missingPortError; + + if (!host) return hostError; + if (isIP(host)) return { success: true }; + return isFQDN(host, opts); +} + +function isURL(str: string, options?: Partial): ValidatorRes { + const opts = mergeOptions(defaultURLOptions, options); + + if (str.length === 0 || /[\s<>]/.test(str)) return urlError; + if (opts.validateLength && str.length > opts.maxLength) { + return error(`Invalid url length, max of ${opts.maxLength} expected.`); + } + + let url = str; + let split: string[]; + + split = url.split("#"); + url = split.shift() ?? ""; + + split = url.split("?"); + url = split.shift() ?? ""; + + if (url.slice(0, 2) === "//") return relProtocolError; + + // extract protocol & validate + split = url.split("://"); + if (split.length > 2) return multiProtocolError; + if (split.length > 1) { + const protocol = split.shift()?.toLowerCase() ?? ""; + if (opts.protocols.length && opts.protocols.indexOf(protocol) === -1) + return getProtocolError(opts.protocols); + if (protocol.includes(",")) return multihostProtocolError; + url = split.join("://"); + } else if (opts.requireProtocol) { + return getProtocolError(opts.protocols); + } + + split = url.split("/"); + url = split.shift() ?? ""; + if (!url.length) return urlError; + + // extract auth details & validate + split = url.split("@"); + if (split.length > 1 && !split[0]) return urlError; + if (split.length > 1) { + const auth = split.shift() ?? ""; + if (auth.split(":").length > 2) return urlError; + const [user, pass] = auth.split(":"); + if (!user && !pass) return urlError; + } + const hostname = split.join("@"); + + // validate multihost + split = hostname.split(","); + if (split.length > 1 && !opts.alllowMultihost) return multihostError; + if (split.length > 1) { + for (const host of split) { + const res = validateHost(host, opts); + if (!res.success) return res; + } + return { success: true }; + } + return validateHost(hostname, opts); +} + +const required_error = "This field is required"; + +function getBaseUrlSchema(options?: Partial) { + const urlStr = z.string({ required_error }); + const schema = urlStr.superRefine((url, ctx) => { + const valdn = isURL(url, options); + if (valdn.success) return; + ctx.addIssue({ + code: z.ZodIssueCode.custom, + message: valdn.msg, + }); + }); + return schema; +} + +export function getZodSchema(fields: Provider["config"], installed: boolean) { + const portError = "Invalid port number"; + + const kvPairs = Object.entries(fields).map(([field, config]) => { + if (config.type === "form") { + const baseSchema = z.record(z.string(), z.string()).array(); + const schema = config.required + ? baseSchema.nonempty({ + message: "At least one key-value entry should be provided.", + }) + : baseSchema.optional(); + return [field, schema]; + } + + if (config.type === "file") { + const baseSchema = z + .instanceof(File, { message: "Please upload a file here." }) + .or(z.string()) + .refine( + (file) => { + if (config.file_type == undefined) return true; + if (config.file_type.length <= 1) return true; + if (typeof file === "string" && installed) return true; + return ( + typeof file !== "string" && config.file_type.includes(file.type) + ); + }, + { + message: + config.file_type && config.file_type?.split(",").length > 1 + ? `File type should be one of ${config.file_type}.` + : `File should be of type ${config.file_type}.`, + } + ); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + + if (config.type === "switch") { + const schema = config.required ? z.boolean() : z.boolean().optional(); + return [field, schema]; + } + + if (config.validation === "any_url") { + const baseSchema = getBaseUrlSchema(); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + + if (config.validation === "any_http_url") { + const baseSchema = getBaseUrlSchema({ protocols: ["http", "https"] }); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + + if (config.validation === "https_url") { + const baseSchema = getBaseUrlSchema({ + protocols: ["https"], + requireTld: true, + maxLength: 2083, + }); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + + if (config.validation === "no_scheme_url") { + const baseSchema = getBaseUrlSchema({ requireProtocol: false }); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + + if (config.validation === "multihost_url") { + const baseSchema = getBaseUrlSchema({ alllowMultihost: true }); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + + if (config.validation === "no_scheme_multihost_url") { + const baseSchema = getBaseUrlSchema({ + alllowMultihost: true, + requireProtocol: false, + }); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + + if (config.validation === "tld") { + const baseSchema = z + .string({ required_error }) + .regex(new RegExp(/\.[a-z]{2,63}$/), { + message: "Please provide a valid TLD e.g .com, .io, .dev, .net", + }); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + + if (config.validation === "port") { + const baseSchema = z + .string({ required_error }) + .pipe( + z.coerce + .number({ invalid_type_error: portError }) + .min(1, { message: portError }) + .max(65_535, { message: portError }) + ); + const schema = config.required ? baseSchema : baseSchema.optional(); + return [field, schema]; + } + return [ + field, + config.required + ? z + .string({ required_error }) + .trim() + .min(1, { message: required_error }) + : z.string().optional(), + ]; + }); + return z.object({ + provider_name: z + .string({ required_error }) + .trim() + .min(1, { message: required_error }), + ...Object.fromEntries(kvPairs), + }); +} diff --git a/keep-ui/app/(keep)/providers/provider-form-scopes.tsx b/keep-ui/app/(keep)/providers/provider-form-scopes.tsx index 8457eec0d..d5b60130e 100644 --- a/keep-ui/app/(keep)/providers/provider-form-scopes.tsx +++ b/keep-ui/app/(keep)/providers/provider-form-scopes.tsx @@ -22,29 +22,24 @@ import "./provider-form-scopes.css"; const ProviderFormScopes = ({ provider, validatedScopes, - installedProvidersMode = false, refreshLoading, - triggerRevalidateScope, + onRevalidate, }: { provider: Provider; validatedScopes: { [key: string]: string | boolean }; - installedProvidersMode?: boolean; refreshLoading: boolean; - triggerRevalidateScope: any; + onRevalidate: () => void; }) => { return ( Scopes - {installedProvidersMode && ( + {provider.installed && ( - - handleDictInputChange(configKey, value)} - error={Object.keys(inputErrors).includes(configKey)} - disabled={provider.provisioned} - /> - - ); - case "file": - return ( - <> - {renderFieldHeader()} - - { - if (e.target.files && e.target.files[0]) { - setSelectedFile(e.target.files[0].name); - } - handleInputChange(e); - }} - disabled={provider.provisioned} - /> - - ); - default: - return ( - <> - {renderFieldHeader()} - - - ); + function setApiError(error: string) { + if (error.includes("SyntaxError")) { + setFormErrors( + "Bad response from API: Check the backend logs for more details" + ); + } else if (error.includes("Failed to fetch")) { + setFormErrors( + "Failed to connect to API: Check provider settings and your internet connection" + ); + } else { + setFormErrors(error); } - }; - - const requiredConfigs = Object.entries(provider.config) - .filter(([_, config]) => config.required && !config.config_main_group) - .reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {}); + } - const optionalConfigs = Object.entries(provider.config) - .filter( - ([_, config]) => - !config.required && !config.hidden && !config.config_main_group - ) - .reduce((acc, [key, value]) => ({ ...acc, [key]: value }), {}); + async function handleUpdateClick() { + if (provider.webhook_required) callInstallWebhook(); + if (!validate()) return; + setIsLoading(true); + submit(`/providers/${provider.id}`, "PUT") + .then(() => { + setIsLoading(false); + toast.success("Updated provider successfully", { + position: "top-left", + }); + mutate(); + }) + .catch((error) => { + showErrorToast("Failed to update provider"); + handleSubmitError(error); + setIsLoading(false); + }); + } - const groupConfigsByMainGroup = (configs) => { - return Object.entries(configs).reduce((acc, [key, config]) => { - const mainGroup = config.config_main_group; - if (mainGroup) { - if (!acc[mainGroup]) { - acc[mainGroup] = {}; + async function handleConnectClick() { + if (!validate()) return; + setIsLoading(true); + onConnectChange?.(true, false); + submit(`/providers/install`) + .then(async (data) => { + console.log("Connect Result:", data); + setIsLoading(false); + onConnectChange?.(false, true); + if ( + formValues.install_webhook && + provider.can_setup_webhook && + !isLocalhost + ) { + // mutate after webhook installation + await installWebhook(data as Provider); } - acc[mainGroup][key] = config; - } - return acc; - }, {}); - }; - - const groupConfigsBySubGroup = (configs) => { - return Object.entries(configs).reduce((acc, [key, config]) => { - const subGroup = config.config_sub_group || "default"; - if (!acc[subGroup]) { - acc[subGroup] = {}; - } - acc[subGroup][key] = config; - return acc; - }, {}); - }; - - const getSubGroups = (configs) => { - return [ - ...new Set( - Object.values(configs).map((config) => config.config_sub_group) - ), - ].filter(Boolean); - }; - - const renderGroupFields = (groupName, groupConfigs) => { - const subGroups = groupConfigsBySubGroup(groupConfigs); - const subGroupNames = getSubGroups(groupConfigs); - - if (subGroupNames.length === 0) { - // If no subgroups, render fields directly - return ( - - - {groupName.charAt(0).toUpperCase() + groupName.slice(1)} - - {Object.entries(groupConfigs).map(([configKey, config]) => ( -
- {renderFormField(configKey, config)} -
- ))} -
- ); - } + mutate(); + }) + .catch((error) => { + handleSubmitError(error); + setIsLoading(false); + onConnectChange?.(false, false); + }); + } - return ( - - {groupName.charAt(0).toUpperCase() + groupName.slice(1)} - - setActiveTabsState((prev) => ({ - ...prev, - [groupName]: subGroupNames[index], - })) - } - > - - {subGroupNames.map((subGroup) => ( - - {subGroup.replace("_", " ").toUpperCase()} - - ))} - - - {subGroupNames.map((subGroup) => ( - - {Object.entries(subGroups[subGroup] || {}).map( - ([configKey, config]) => ( -
- {renderFormField(configKey, config)} -
- ) - )} -
- ))} -
-
-
- ); - }; + const installOrUpdateWebhookEnabled = provider.scopes + ?.filter((scope) => scope.mandatory_for_webhook) + .every((scope) => providerValidatedScopes[scope.name] === true); - const groupedConfigs = groupConfigsByMainGroup(provider.config); - console.log("ProviderForm component loaded"); return (
@@ -797,6 +451,7 @@ const ProviderForm = ({ {provider.provisioned && (
)} - {provider.scopes?.length > 0 && ( + {provider.scopes && provider.scopes.length > 0 && ( )}
@@ -857,6 +511,7 @@ const ProviderForm = ({ {provider.oauth2_url && !provider.installed ? ( <>
{/* Render required fields */} - {Object.entries(requiredConfigs).map(([configKey, config]) => ( -
- {renderFormField(configKey, config)} + {Object.entries(requiredConfigs).map(([field, config]) => ( +
+
))} {/* Render grouped fields */} - {Object.entries(groupedConfigs).map(([groupName, groupConfigs]) => ( - - {renderGroupFields(groupName, groupConfigs)} + {Object.entries(groupedConfigs).map(([name, fields]) => ( + + ))} @@ -915,13 +574,18 @@ const ProviderForm = ({ Provider Optional Settings - {Object.entries(optionalConfigs).map( - ([configKey, config]) => ( -
- {renderFormField(configKey, config)} -
- ) - )} + {Object.entries(optionalConfigs).map(([field, config]) => ( +
+ +
+ ))}
@@ -937,7 +601,10 @@ const ProviderForm = ({ className="mr-2.5" onChange={handleWebhookChange} checked={ - (formValues["install_webhook"] || false) && !isLocalhost + "install_webhook" in formValues && + typeof formValues["install_webhook"] === "boolean" && + formValues["install_webhook"] && + !isLocalhost } disabled={isLocalhost || provider.webhook_required} /> @@ -963,7 +630,7 @@ const ProviderForm = ({ name="pulling_enabled" className="mr-2.5" onChange={handlePullingEnabledChange} - checked={formValues["pulling_enabled"] || false} + checked={Boolean(formValues["pulling_enabled"])} />