diff --git a/keep-ui/app/(keep)/topology/api/index.ts b/keep-ui/app/(keep)/topology/api/index.ts index 954f07d76..924259b5b 100644 --- a/keep-ui/app/(keep)/topology/api/index.ts +++ b/keep-ui/app/(keep)/topology/api/index.ts @@ -47,3 +47,7 @@ export async function getTopology( const url = buildTopologyUrl({ providerIds, services, environment }); return await api.get(url); } + +export async function pullTopology(api: ApiClient) { + return await api.post("/topology/pull"); +} diff --git a/keep-ui/app/(keep)/topology/topology-client.tsx b/keep-ui/app/(keep)/topology/topology-client.tsx index 27bbb98eb..27e60dd46 100644 --- a/keep-ui/app/(keep)/topology/topology-client.tsx +++ b/keep-ui/app/(keep)/topology/topology-client.tsx @@ -6,6 +6,11 @@ import { ApplicationsList } from "./ui/applications/applications-list"; import { useContext, useEffect, useState } from "react"; import { TopologySearchContext } from "./TopologySearchContext"; import { TopologyApplication, TopologyService } from "./model"; +import { Button } from "@/components/ui"; +import { ArrowPathIcon } from "@heroicons/react/24/outline"; +import { useApi } from "@/shared/lib/hooks/useApi"; +import { pullTopology } from "./api"; +import { toast } from "react-toastify"; export function TopologyPageClient({ applications, @@ -16,6 +21,18 @@ export function TopologyPageClient({ }) { const [tabIndex, setTabIndex] = useState(0); const { selectedObjectId } = useContext(TopologySearchContext); + const api = useApi(); + + const handlePullTopology = async (e: React.MouseEvent) => { + e.stopPropagation(); + try { + await pullTopology(api); + toast.success("Topology pull initiated"); + } catch (error) { + toast.error("Failed to pull topology"); + console.error("Failed to pull topology:", error); + } + }; useEffect(() => { if (!selectedObjectId) { @@ -32,8 +49,26 @@ export function TopologyPageClient({ onIndexChange={setTabIndex} > - Topology Map - Applications + +
+ + + + Topology Map +
+
+ +
+ Applications +
+
diff --git a/keep-ui/app/(keep)/topology/ui/map/getNodesAndEdgesFromTopologyData.ts b/keep-ui/app/(keep)/topology/ui/map/getNodesAndEdgesFromTopologyData.ts index beb147c2f..2bca3c13e 100644 --- a/keep-ui/app/(keep)/topology/ui/map/getNodesAndEdgesFromTopologyData.ts +++ b/keep-ui/app/(keep)/topology/ui/map/getNodesAndEdgesFromTopologyData.ts @@ -20,11 +20,13 @@ export function getNodesAndEdgesFromTopologyData( ) { const nodeMap = new Map(); const edgeMap = new Map(); - const allServices = topologyData.map((data) => data.display_name); + // Create nodes from service definitions for (const service of topologyData) { - const numIncidentsToService = allIncidents.filter((incident) => - incident.services.includes(service.display_name) + const numIncidentsToService = allIncidents.filter( + (incident) => + incident.services.includes(service.display_name) || + incident.services.includes(service.service) ); const node: ServiceNodeType = { id: service.service.toString(), diff --git a/keep/api/api.py b/keep/api/api.py index ae1a44ac9..f2cb8792a 100644 --- a/keep/api/api.py +++ b/keep/api/api.py @@ -64,6 +64,7 @@ IdentityManagerFactory, IdentityManagerTypes, ) +from keep.topologies.topology_processor import TopologyProcessor # load all providers into cache from keep.workflowmanager.workflowmanager import WorkflowManager @@ -76,6 +77,7 @@ PORT = config("PORT", default=8080, cast=int) SCHEDULER = config("SCHEDULER", default="true", cast=bool) CONSUMER = config("CONSUMER", default="true", cast=bool) +TOPOLOGY = config("KEEP_TOPOLOGY_PROCESSOR", default="false", cast=bool) KEEP_DEBUG_TASKS = config("KEEP_DEBUG_TASKS", default="false", cast=bool) AUTH_TYPE = config("AUTH_TYPE", default=IdentityManagerTypes.NOAUTH.value).lower() @@ -142,6 +144,15 @@ async def startup(): logger.info("Consumer started successfully") except Exception: logger.exception("Failed to start the consumer") + # Start the topology processor + if TOPOLOGY: + try: + logger.info("Starting the topology processor") + topology_processor = TopologyProcessor.get_instance() + await topology_processor.start() + logger.info("Topology processor started successfully") + except Exception: + logger.exception("Failed to start the topology processor") if KEEP_ARQ_TASK_POOL != KEEP_ARQ_TASK_POOL_NONE: event_loop = asyncio.get_event_loop() diff --git a/keep/api/core/db.py b/keep/api/core/db.py index 158972d6a..3078ee966 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -1816,13 +1816,43 @@ def create_incident_for_grouping_rule( rule_id=rule.id, rule_fingerprint=rule_fingerprint, is_predicted=False, - is_confirmed=rule.create_on == CreateIncidentOn.ANY.value and not rule.require_approve, + is_confirmed=rule.create_on == CreateIncidentOn.ANY.value + and not rule.require_approve, + incident_type=IncidentType.RULE.value, ) session.add(incident) session.commit() session.refresh(incident) return incident + +def create_incident_for_topology( + tenant_id: str, alert_group: list[Alert], session: Session +) -> Incident: + """Create a new incident from topology-connected alerts""" + # Get highest severity from alerts + severity = max(alert.severity for alert in alert_group) + + # Get all services + services = set() + service_names = set() + for alert in alert_group: + services.update(alert.service_ids) + service_names.update(alert.service_names) + + incident = Incident( + tenant_id=tenant_id, + user_generated_name=f"Topology incident: Multiple alerts across {', '.join(service_names)}", + severity=severity.value, + status=IncidentStatus.FIRING.value, + is_confirmed=True, + incident_type=IncidentType.TOPOLOGY.value, # Set incident type for topology + data={"services": list(services), "alert_count": len(alert_group)}, + ) + + return incident + + def get_rule(tenant_id, rule_id): with Session(engine) as session: rule = session.exec( @@ -1915,6 +1945,7 @@ def get_all_deduplication_rules(tenant_id): ).all() return rules + def get_deduplication_rule_by_id(tenant_id, rule_id: str): rule_uuid = __convert_to_uuid(rule_id) if not rule_uuid: @@ -1952,7 +1983,7 @@ def create_deduplication_rule( full_deduplication: bool = False, ignore_fields: list[str] = [], priority: int = 0, - is_provisioned: bool = False + is_provisioned: bool = False, ): with Session(engine) as session: new_rule = AlertDeduplicationRule( @@ -3403,15 +3434,20 @@ def create_incident_from_dto( "assignee": incident_dto.assignee, "is_predicted": False, # its not a prediction, but an AI generation "is_confirmed": True, # confirmed by the user :) + "incident_type": IncidentType.AI.value, } elif issubclass(type(incident_dto), IncidentDto): # we will reach this block when incident is pulled from a provider incident_dict = incident_dto.to_db_incident().dict() - + if "incident_type" not in incident_dict: + incident_dict["incident_type"] = IncidentType.MANUAL.value else: # We'll reach this block when a user creates an incident incident_dict = incident_dto.dict() + # Keep existing incident_type if present, default to MANUAL if not + if "incident_type" not in incident_dict: + incident_dict["incident_type"] = IncidentType.MANUAL.value return create_incident_from_dict(tenant_id, incident_dict) @@ -3815,7 +3851,9 @@ def add_alerts_to_incident( return incident -def get_incident_unique_fingerprint_count(tenant_id: str, incident_id: str | UUID) -> int: +def get_incident_unique_fingerprint_count( + tenant_id: str, incident_id: str | UUID +) -> int: with Session(engine) as session: return session.execute( select(func.count(1)) @@ -4488,19 +4526,22 @@ def get_workflow_executions_for_incident_or_alert( results = session.execute(final_query).all() return results, total_count + def is_all_alerts_resolved( fingerprints: Optional[List[str]] = None, incident: Optional[Incident] = None, - session: Optional[Session] = None + session: Optional[Session] = None, ): - return is_all_alerts_in_status(fingerprints, incident, AlertStatus.RESOLVED, session) + return is_all_alerts_in_status( + fingerprints, incident, AlertStatus.RESOLVED, session + ) def is_all_alerts_in_status( fingerprints: Optional[List[str]] = None, incident: Optional[Incident] = None, status: AlertStatus = AlertStatus.RESOLVED, - session: Optional[Session] = None + session: Optional[Session] = None, ): if incident and incident.alerts_count == 0: @@ -4533,19 +4574,15 @@ def is_all_alerts_in_status( subquery = subquery.where(LastAlert.fingerprint.in_(fingerprints)) if incident: - subquery = ( - subquery - .join( + subquery = subquery.join( LastAlertToIncident, and_( LastAlertToIncident.tenant_id == LastAlert.tenant_id, LastAlertToIncident.fingerprint == LastAlert.fingerprint, ), - ) - .where( - LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, - LastAlertToIncident.incident_id == incident.id, - ) + ).where( + LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT, + LastAlertToIncident.incident_id == incident.id, ) subquery = subquery.subquery() @@ -4920,8 +4957,8 @@ def set_last_alert( timestamp=alert.timestamp, first_timestamp=alert.timestamp, alert_id=alert.id, - alert_hash=alert.alert_hash, - ) + alert_hash=alert.alert_hash, + ) session.add(last_alert) session.commit() diff --git a/keep/api/models/db/alert.py b/keep/api/models/db/alert.py index 1d9bbeb52..e522e2884 100644 --- a/keep/api/models/db/alert.py +++ b/keep/api/models/db/alert.py @@ -49,6 +49,13 @@ NULL_FOR_DELETED_AT = datetime(1000, 1, 1, 0, 0) +class IncidentType(str, enum.Enum): + MANUAL = "manual" # Created manually by users + AI = "ai" # Created by AI + RULE = "rule" # Created by rules engine + TOPOLOGY = "topology" # Created by topology processor + + class AlertToIncident(SQLModel, table=True): tenant_id: str = Field(foreign_key="tenant.id") timestamp: datetime = Field(default_factory=datetime.utcnow) @@ -157,6 +164,10 @@ class Incident(SQLModel, table=True): # It's not a unique identifier in the DB (constraint), but when we have the same incident from some tools, we can use it to detect duplicates fingerprint: str | None = Field(default=None, sa_column=Column(TEXT)) + incident_type: str = Field(default=IncidentType.MANUAL.value) + # for topology incidents + incident_application: str | None = Field(default=None) + same_incident_in_the_past_id: UUID | None = Field( sa_column=Column( UUIDType(binary=False), diff --git a/keep/api/models/db/topology.py b/keep/api/models/db/topology.py index 971b51820..545928d52 100644 --- a/keep/api/models/db/topology.py +++ b/keep/api/models/db/topology.py @@ -39,6 +39,7 @@ class TopologyService(SQLModel, table=True): mac_address: Optional[str] = None category: Optional[str] = None manufacturer: Optional[str] = None + namespace: Optional[str] = None updated_at: Optional[datetime] = Field( sa_column=Column( @@ -53,7 +54,7 @@ class TopologyService(SQLModel, table=True): back_populates="service", sa_relationship_kwargs={ "foreign_keys": "[TopologyServiceDependency.service_id]", - "cascade": "all, delete-orphan" + "cascade": "all, delete-orphan", }, ) @@ -112,6 +113,7 @@ class TopologyServiceDtoBase(BaseModel, extra="ignore"): mac_address: Optional[str] = None category: Optional[str] = None manufacturer: Optional[str] = None + namespace: Optional[str] = None class TopologyServiceInDto(TopologyServiceDtoBase): @@ -212,4 +214,5 @@ def from_orm( ], application_ids=application_ids, updated_at=service.updated_at, + namespace=service.namespace, ) diff --git a/keep/api/routes/preset.py b/keep/api/routes/preset.py index bbe3af59c..64f5017fd 100644 --- a/keep/api/routes/preset.py +++ b/keep/api/routes/preset.py @@ -151,7 +151,7 @@ def pull_data_from_providers( try: if isinstance(provider_class, BaseTopologyProvider): logger.info("Pulling topology data", extra=extra) - topology_data = provider_class.pull_topology() + topology_data, _ = provider_class.pull_topology() logger.info( "Pulling topology data finished, processing", extra={**extra, "topology_length": len(topology_data)}, diff --git a/keep/api/routes/topology.py b/keep/api/routes/topology.py index 86dfd0462..13dc21841 100644 --- a/keep/api/routes/topology.py +++ b/keep/api/routes/topology.py @@ -6,20 +6,24 @@ from fastapi.responses import JSONResponse from sqlmodel import Session -from keep.api.core.db import get_session +from keep.api.core.db import get_session, get_session_sync from keep.api.models.db.topology import ( TopologyApplicationDtoIn, TopologyApplicationDtoOut, + TopologyServiceDtoIn, TopologyServiceDtoOut, ) +from keep.api.tasks.process_topology_task import process_topology from keep.identitymanager.authenticatedentity import AuthenticatedEntity from keep.identitymanager.identitymanagerfactory import IdentityManagerFactory +from keep.providers.base.base_provider import BaseTopologyProvider +from keep.providers.providers_factory import ProvidersFactory from keep.topologies.topologies_service import ( - TopologiesService, ApplicationNotFoundException, + ApplicationParseException, InvalidApplicationDataException, ServiceNotFoundException, - ApplicationParseException, + TopologiesService, ) logger = logging.getLogger(__name__) @@ -138,3 +142,141 @@ def delete_application( ) except ApplicationNotFoundException as e: raise HTTPException(status_code=404, detail=str(e)) + + +@router.post( + "/pull", + description="Pull topology data on demand from providers", + response_model=List[TopologyServiceDtoOut], +) +def pull_topology_data( + provider_ids: Optional[str] = None, + authenticated_entity: AuthenticatedEntity = Depends( + IdentityManagerFactory.get_auth_verifier(["write:topology"]) + ), + session: Session = Depends(get_session), +): + tenant_id = authenticated_entity.tenant_id + logger.info( + "Pulling topology data on demand", + extra={"tenant_id": tenant_id, "provider_ids": provider_ids}, + ) + + try: + providers = ProvidersFactory.get_installed_providers( + tenant_id=tenant_id, include_details=False + ) + + # Filter providers if provider_ids is specified + if provider_ids: + provider_id_list = provider_ids.split(",") + providers = [p for p in providers if str(p.id) in provider_id_list] + + for provider in providers: + extra = { + "provider_type": provider.type, + "provider_id": provider.id, + "tenant_id": tenant_id, + } + + try: + provider_class = ProvidersFactory.get_installed_provider( + tenant_id=tenant_id, + provider_id=provider.id, + provider_type=provider.type, + ) + + if isinstance(provider_class, BaseTopologyProvider): + logger.info("Pulling topology data", extra=extra) + topology_data, applications_to_create = ( + provider_class.pull_topology() + ) + logger.info( + "Pulling topology data finished, processing", + extra={**extra, "topology_length": len(topology_data)}, + ) + process_topology( + tenant_id, topology_data, provider.id, provider.type + ) + new_session = get_session_sync() + # now we want to create the applications + topology_data = TopologiesService.get_all_topology_data( + tenant_id, new_session, provider_ids=[provider.id] + ) + for app in applications_to_create: + _app = TopologyApplicationDtoIn( + name=app, + services=[], + ) + try: + # replace service name with service id + services = applications_to_create[app].get("services", []) + for service in services: + service_id = next( + ( + s.id + for s in topology_data + if s.service == service + ), + None, + ) + if not service_id: + raise ServiceNotFoundException(service.service) + _app.services.append( + TopologyServiceDtoIn(id=service_id) + ) + + # if the application already exists, update it + existing_apps = ( + TopologiesService.get_applications_by_tenant_id( + tenant_id, new_session + ) + ) + if any(a.name == app for a in existing_apps): + app_id = next( + (a.id for a in existing_apps if a.name == app), + None, + ) + TopologiesService.update_application_by_id( + tenant_id, app_id, _app, new_session + ) + else: + TopologiesService.create_application_by_tenant_id( + tenant_id, _app, session + ) + except InvalidApplicationDataException as e: + logger.error( + f"Error creating application {app.name}: {str(e)}", + extra=extra, + ) + + logger.info("Finished processing topology data", extra=extra) + else: + logger.debug( + f"Provider {provider.type} ({provider.id}) does not implement pulling topology data", + extra=extra, + ) + except NotImplementedError: + logger.debug( + f"Provider {provider.type} ({provider.id}) does not implement pulling topology data", + extra=extra, + ) + except Exception as e: + logger.exception( + f"Error pulling topology from provider {provider.type} ({provider.id})", + extra={**extra, "error": str(e)}, + ) + + # Return the updated topology data + return TopologiesService.get_all_topology_data( + tenant_id, session, provider_ids=provider_ids + ) + + except Exception as e: + logger.exception( + "Error during on-demand topology pull", + extra={"tenant_id": tenant_id, "error": str(e)}, + ) + raise HTTPException( + status_code=500, detail=f"Failed to pull topology data: {str(e)}" + ) diff --git a/keep/api/tasks/process_event_task.py b/keep/api/tasks/process_event_task.py index 09415a90d..8de4bd626 100644 --- a/keep/api/tasks/process_event_task.py +++ b/keep/api/tasks/process_event_task.py @@ -444,16 +444,10 @@ def __handle_formatted_events( if KEEP_CORRELATION_ENABLED: try: rules_engine = RulesEngine(tenant_id=tenant_id) + # handle incidents, also handle workflow execution as incidents: List[IncidentDto] = rules_engine.run_rules( enriched_formatted_events, session=session ) - - # TODO: Replace with incidents workflow triggers. Ticket: https://github.com/keephq/keep/issues/1527 - # if new grouped incidents were created, we need to push them to the client - # if incidents: - # logger.info("Adding group alerts to the workflow manager queue") - # workflow_manager.insert_events(tenant_id, grouped_alerts) - # logger.info("Added group alerts to the workflow manager queue") except Exception: logger.exception( "Failed to run rules engine", @@ -485,10 +479,7 @@ def __handle_formatted_events( logger.exception("Failed to tell client to poll alerts") pass - if ( - incidents - and pusher_cache.should_notify(tenant_id, "incident-change") - ): + if incidents and pusher_cache.should_notify(tenant_id, "incident-change"): try: pusher_client.trigger( f"private-{tenant_id}", @@ -519,7 +510,9 @@ def __handle_formatted_events( pusher_client.trigger( f"private-{tenant_id}", "poll-presets", - json.dumps([p.name.lower() for p in presets_do_update], default=str), + json.dumps( + [p.name.lower() for p in presets_do_update], default=str + ), ) except Exception: logger.exception("Failed to send presets via pusher") diff --git a/keep/api/tasks/process_topology_task.py b/keep/api/tasks/process_topology_task.py index e39b5c7c0..835570c22 100644 --- a/keep/api/tasks/process_topology_task.py +++ b/keep/api/tasks/process_topology_task.py @@ -1,6 +1,8 @@ import copy import logging +from sqlalchemy import and_ + from keep.api.core.db import get_session_sync from keep.api.core.dependencies import get_pusher_client from keep.api.models.db.topology import ( @@ -37,6 +39,17 @@ def process_topology( extra=extra, ) + # delete dependencies + session.query(TopologyServiceDependency).filter( + TopologyServiceDependency.service.has( + and_( + TopologyService.source_provider_id == provider_id, + TopologyService.tenant_id == tenant_id, + ) + ) + ).delete(synchronize_session=False) + + # delete services session.query(TopologyService).filter( TopologyService.source_provider_id == provider_id, TopologyService.tenant_id == tenant_id, @@ -72,7 +85,19 @@ def process_topology( for service in topology_data: for dependency in service.dependencies: service_id = service_to_keep_service_id_map.get(service.service) + if service_id > len(topology_data): + logger.debug( + "Found a dangling service, skipping", + extra={"service": service.service, "service_id": service_id}, + ) + continue depends_on_service_id = service_to_keep_service_id_map.get(dependency) + if depends_on_service_id > len(topology_data): + logger.debug( + "Found a dangling service, skipping", + extra={"service": service.service, "dependency": dependency}, + ) + continue if not service_id or not depends_on_service_id: logger.debug( "Found a dangling service, skipping", diff --git a/keep/providers/base/base_provider.py b/keep/providers/base/base_provider.py index f92588d2e..46ec93125 100644 --- a/keep/providers/base/base_provider.py +++ b/keep/providers/base/base_provider.py @@ -780,7 +780,7 @@ def is_provisioned(self) -> bool: class BaseTopologyProvider(BaseProvider): - def pull_topology(self) -> list[TopologyServiceInDto]: + def pull_topology(self) -> tuple[list[TopologyServiceInDto], dict]: raise NotImplementedError("get_topology() method not implemented") diff --git a/keep/providers/cilium_provider/cilium_provider.py b/keep/providers/cilium_provider/cilium_provider.py index 400133827..a7e4396ff 100644 --- a/keep/providers/cilium_provider/cilium_provider.py +++ b/keep/providers/cilium_provider/cilium_provider.py @@ -78,7 +78,7 @@ def _get_service_name(self, endpoint) -> str: return "unknown" return service - def pull_topology(self) -> list[TopologyServiceInDto]: + def pull_topology(self) -> tuple[list[TopologyServiceInDto], dict]: # for some providers that depends on grpc like cilium provider, this might fail on imports not from Keep (such as the docs script) from keep.providers.cilium_provider.grpc.observer_pb2 import ( # noqa FlowFilter, @@ -102,6 +102,12 @@ def pull_topology(self) -> list[TopologyServiceInDto]: # Process the responses service_map = defaultdict(lambda: {"dependencies": set(), "namespace": ""}) # https://docs.cilium.io/en/stable/_api/v1/flow/README/#flow-FlowFilter + + # get the responses as list + responses = list(responses) + + application_to_create = {} + for response in responses: flow = response.flow if not flow.source: @@ -111,27 +117,123 @@ def pull_topology(self) -> list[TopologyServiceInDto]: source = self._get_service_name(flow.source) destination = self._get_service_name(flow.destination) + source_namespace = flow.source.namespace + destination_namespace = flow.destination.namespace + + node_labels = list(flow.node_labels) + + destination_port = flow.l4.TCP.destination_port + # source_port = flow.l4.TCP.source_port + + category = "http" + + if destination_port == 5432: + category = "postgres" + + application = None + + try: + application_label = [ + label + for label in flow.source.labels + if label.startswith("k8s:keepapp=") + ][0] + application = application_label.split("=")[1] + + if application not in application_to_create: + application_to_create[application] = {"services": set()} + application_to_create[application]["services"].add(source) + except Exception: + pass + + service_map[source]["dependencies"].add(destination) + service_map[source]["namespace"] = source_namespace + service_map[source]["tags"] = list(flow.source.labels) + service_map[source]["tags"].append(flow.source.pod_name) + service_map[source]["tags"].append(flow.source.cluster_name) + service_map[source]["tags"] += node_labels + + if destination not in service_map: + service_map[destination] = { + "dependencies": set(), + "namespace": destination_namespace or "internet", + } + service_map[destination]["dependencies"].add(source) + service_map[destination]["tags"] = list(flow.destination.labels) + service_map[destination]["category"] = category + else: + service_map[destination]["dependencies"].add(source) + service_map[destination]["tags"] = list(flow.destination.labels) + # if its outside the cluster + elif ( + flow.destination + and flow.destination.labels + and "reserved:world" in flow.destination.labels + ): + source = self._get_service_name(flow.source) + destination = flow.IP.destination source_namespace = flow.source.namespace + node_labels = list(flow.node_labels) + + destination_port = flow.l4.TCP.destination_port + # source_port = flow.l4.TCP.source_port + + category = "http" + + if destination_port == 5432: + category = "postgres" + service_map[source]["dependencies"].add(destination) service_map[source]["namespace"] = source_namespace - service_map[source]["tags"] = flow.source.labels + service_map[source]["tags"] = list(flow.source.labels) service_map[source]["tags"].append(flow.source.pod_name) service_map[source]["tags"].append(flow.source.cluster_name) - # service_map[destination]["namespace"] = destination_namespace + service_map[source]["tags"] += node_labels + + # look for the application + for application in application_to_create: + if source in application_to_create[application]["services"]: + self.logger.debug(f"Adding {destination} to {application}") + application_to_create[application]["services"].add(destination) + break + if destination not in service_map: + service_map[destination] = { + "dependencies": set(), + "namespace": "internet", + } # destination_namespace is external + service_map[destination]["dependencies"].add(source) + service_map[destination]["tags"] = list(flow.destination.labels) + service_map[destination]["category"] = category + else: + service_map[destination]["dependencies"].add(source) + service_map[destination]["tags"] = list(flow.destination.labels) # Convert to TopologyServiceInDto topology = [] for service, data in service_map.items(): - topology_service = TopologyServiceInDto( - source_provider_id=self.provider_id, - service=service, - display_name=service, - environment=data["namespace"], - dependencies={dep: "network" for dep in data["dependencies"]}, - tags=list(data["tags"]), - ) - topology.append(topology_service) + try: + topology_service = TopologyServiceInDto( + source_provider_id=self.provider_id, + service=service, + display_name=service, + environment=data["namespace"], + dependencies={dep: "network" for dep in data["dependencies"]}, + tags=list(data["tags"]), + category=data.get("category", "http"), + namespace=data["namespace"], + ) + topology.append(topology_service) + except Exception as e: + self.logger.error( + "Error processing service", + extra={ + "service": service, + "data": data, + "error": str(e), + }, + ) + pass self.logger.info( "Topology pulling completed", @@ -140,7 +242,11 @@ def pull_topology(self) -> list[TopologyServiceInDto]: "len_of_topology": len(topology), }, ) - return topology + return topology, application_to_create + + def get_existing_services(self, all_services): + """Helper function to create a set of all valid service names""" + return {service for service in all_services} def dispose(self): """ diff --git a/keep/providers/datadog_provider/datadog_provider.py b/keep/providers/datadog_provider/datadog_provider.py index e4cce8e33..dabc09c72 100644 --- a/keep/providers/datadog_provider/datadog_provider.py +++ b/keep/providers/datadog_provider/datadog_provider.py @@ -970,7 +970,7 @@ def simulate_alert(cls) -> dict: ).hexdigest() return simulated_alert - def pull_topology(self) -> list[TopologyServiceInDto]: + def pull_topology(self) -> tuple[list[TopologyServiceInDto], dict]: services = {} with ApiClient(self.configuration) as api_client: api_instance = ServiceDefinitionApi(api_client) @@ -1017,7 +1017,7 @@ def pull_topology(self) -> list[TopologyServiceInDto]: dependency: "unknown" for dependency in dependencies } services[service_dep] = service - return list(services.values()) + return list(services.values()), {} if __name__ == "__main__": diff --git a/keep/providers/pagerduty_provider/pagerduty_provider.py b/keep/providers/pagerduty_provider/pagerduty_provider.py index c2a69b438..34f60ec2e 100644 --- a/keep/providers/pagerduty_provider/pagerduty_provider.py +++ b/keep/providers/pagerduty_provider/pagerduty_provider.py @@ -475,16 +475,18 @@ def _trigger_incident( ) # This will give us a better error message in Keep workflows raise Exception(r.text) from e - + def clean_up(self): """ Clean up the provider. It will remove the webhook from PagerDuty if it exists. """ - self.logger.info("Cleaning up %s provider with id %s", self.PROVIDER_DISPLAY_NAME, self.provider_id) - keep_webhook_incidents_api_url = ( - f"{self.context_manager.api_url}/incidents/event/{self.provider_type}?provider_id={self.provider_id}" - ) + self.logger.info( + "Cleaning up %s provider with id %s", + self.PROVIDER_DISPLAY_NAME, + self.provider_id, + ) + keep_webhook_incidents_api_url = f"{self.context_manager.api_url}/incidents/event/{self.provider_type}?provider_id={self.provider_id}" headers = self.__get_headers() request = requests.get(self.SUBSCRIPTION_API_URL, headers=headers) if not request.ok: @@ -495,7 +497,8 @@ def clean_up(self): [ webhook for webhook in existing_webhooks - if keep_webhook_incidents_api_url == webhook.get("delivery_method", {}).get("url", "") + if keep_webhook_incidents_api_url + == webhook.get("delivery_method", {}).get("url", "") ] ), False, @@ -510,7 +513,6 @@ def clean_up(self): raise Exception("Could not remove existing webhook") self.logger.info("Webhook removed", extra={"webhook_id": webhook_id}) - def dispose(self): """ No need to dispose of anything, so just do nothing. @@ -824,10 +826,10 @@ def __get_all_services(self, business_services: bool = False): all_services.extend(services_response.get(endpoint, [])) return all_services - def pull_topology(self) -> list[TopologyServiceInDto]: + def pull_topology(self) -> tuple[list[TopologyServiceInDto], dict]: # Skipping topology pulling when we're installed with routing_key if self.authentication_config.routing_key: - return [] + return [], {} all_services = self.__get_all_services() all_business_services = self.__get_all_services(business_services=True) @@ -879,7 +881,7 @@ def pull_topology(self) -> list[TopologyServiceInDto]: ), ) service_topology[dependent["id"]].dependencies[supporting["id"]] = "unknown" - return list(service_topology.values()) + return list(service_topology.values()), {} def _get_incidents(self) -> list[IncidentDto]: # Skipping incidents pulling when we're installed with routing_key diff --git a/keep/providers/servicenow_provider/servicenow_provider.py b/keep/providers/servicenow_provider/servicenow_provider.py index 473ab7f7f..5c3c677fb 100644 --- a/keep/providers/servicenow_provider/servicenow_provider.py +++ b/keep/providers/servicenow_provider/servicenow_provider.py @@ -27,7 +27,7 @@ class ServicenowProviderAuthConfig: "description": "The base URL of the ServiceNow instance", "sensitive": False, "hint": "https://dev12345.service-now.com", - "validation": "https_url" + "validation": "https_url", } ) @@ -232,7 +232,7 @@ def _query( return response.json().get("result", []) - def pull_topology(self) -> list[TopologyServiceInDto]: + def pull_topology(self) -> tuple[list[TopologyServiceInDto], dict]: # TODO: in scale, we'll need to use pagination around here headers = {"Content-Type": "application/json", "Accept": "application/json"} auth = ( @@ -282,7 +282,7 @@ def pull_topology(self) -> list[TopologyServiceInDto]: "status_code": cmdb_response.status_code, }, ) - return topology + return topology, {} cmdb_data = cmdb_response.json().get("result", []) self.logger.info( @@ -349,7 +349,7 @@ def pull_topology(self) -> list[TopologyServiceInDto]: "len_of_topology": len(topology), }, ) - return topology + return topology, {} def dispose(self): """ diff --git a/keep/rulesengine/rulesengine.py b/keep/rulesengine/rulesengine.py index 31f77b09e..5bc2a6d76 100644 --- a/keep/rulesengine/rulesengine.py +++ b/keep/rulesengine/rulesengine.py @@ -1,6 +1,6 @@ import json import logging -from typing import Optional, List +from typing import List, Optional import celpy import celpy.c7nlib @@ -12,16 +12,25 @@ from keep.api.bl.incidents_bl import IncidentBl from keep.api.core.db import ( assign_alert_to_incident, - get_incident_for_grouping_rule, create_incident_for_grouping_rule, + enrich_incidents_with_alerts, + get_incident_for_grouping_rule, +) +from keep.api.core.db import get_rules as get_rules_db +from keep.api.core.db import ( + is_all_alerts_in_status, is_all_alerts_resolved, is_first_incident_alert_resolved, is_last_incident_alert_resolved, - is_all_alerts_in_status, enrich_incidents_with_alerts, ) -from keep.api.core.db import get_rules as get_rules_db from keep.api.core.dependencies import get_pusher_client -from keep.api.models.alert import AlertDto, AlertSeverity, IncidentDto, IncidentStatus, AlertStatus +from keep.api.models.alert import ( + AlertDto, + AlertSeverity, + AlertStatus, + IncidentDto, + IncidentStatus, +) from keep.api.models.db.alert import Incident from keep.api.models.db.rule import ResolveOn, Rule from keep.api.utils.cel_utils import preprocess_cel_expression @@ -54,6 +63,27 @@ def __init__(self, tenant_id=None): def run_rules( self, events: list[AlertDto], session: Optional[Session] = None ) -> list[IncidentDto]: + """ + Evaluate the rules on the events and create incidents if needed + Args: + events: list of events + session: db session + """ + self.logger.info("Running CEL rules") + cel_incidents = self._run_cel_rules(events, session) + self.logger.info("CEL rules ran successfully") + + return cel_incidents + + def _run_cel_rules( + self, events: list[AlertDto], session: Optional[Session] = None + ) -> list[IncidentDto]: + """ + Evaluate the rules on the events and create incidents if needed + Args: + events: list of events + session: db session + """ self.logger.info("Running rules") rules = get_rules_db(tenant_id=self.tenant_id) @@ -101,8 +131,13 @@ def run_rules( rule_groups = self._extract_subrules(rule.definition_cel) - if rule.create_on == "any" or (rule.create_on == "all" and len(rule_groups) == len(matched_rules)): - self.logger.info("Single event is enough, so creating incident") + if rule.create_on == "any" or ( + rule.create_on == "all" + and len(rule_groups) == len(matched_rules) + ): + self.logger.info( + "Single event is enough, so creating incident" + ) incident.is_confirmed = True elif rule.create_on == "all": incident = self._process_event_for_history_based_rule( @@ -111,7 +146,9 @@ def run_rules( send_created_event = incident.is_confirmed - incident = self._resolve_incident_if_require(rule, incident, session) + incident = self._resolve_incident_if_require( + rule, incident, session + ) session.add(incident) session.commit() @@ -137,7 +174,6 @@ def run_rules( return list(incidents_dto.values()) - def _get_or_create_incident(self, rule, rule_fingerprint, session): incident = get_incident_for_grouping_rule( self.tenant_id, @@ -155,14 +191,9 @@ def _get_or_create_incident(self, rule, rule_fingerprint, session): return incident def _process_event_for_history_based_rule( - self, - incident: Incident, - rule: Rule, - session: Session + self, incident: Incident, rule: Rule, session: Session ) -> Incident: - self.logger.info( - "Multiple events required for the incident to start" - ) + self.logger.info("Multiple events required for the incident to start") enrich_incidents_with_alerts( tenant_id=self.tenant_id, @@ -178,7 +209,9 @@ def _process_event_for_history_based_rule( matched_sub_rules = set() for alert in incident.alerts: - matched_sub_rules = matched_sub_rules.union(self._check_if_rule_apply(rule, AlertDto(**alert.event))) + matched_sub_rules = matched_sub_rules.union( + self._check_if_rule_apply(rule, AlertDto(**alert.event)) + ) if all_sub_rules == matched_sub_rules: is_all_conditions_met = True break @@ -193,13 +226,14 @@ def _process_event_for_history_based_rule( return incident @staticmethod - def _resolve_incident_if_require(rule: Rule, incident: Incident, session: Session) -> Incident: + def _resolve_incident_if_require( + rule: Rule, incident: Incident, session: Session + ) -> Incident: should_resolve = False - if ( - rule.resolve_on == ResolveOn.ALL.value - and is_all_alerts_resolved(incident=incident, session=session) + if rule.resolve_on == ResolveOn.ALL.value and is_all_alerts_resolved( + incident=incident, session=session ): should_resolve = True @@ -386,10 +420,11 @@ def filter_alerts( return filtered_alerts - def _send_workflow_event(self, session: Session, incident_dto: IncidentDto, action: str): + def _send_workflow_event( + self, session: Session, incident_dto: IncidentDto, action: str + ): pusher_client = get_pusher_client() incident_bl = IncidentBl(self.tenant_id, session, pusher_client) incident_bl.send_workflow_event(incident_dto, action) incident_bl.update_client_on_incident_change(incident_dto.id) - diff --git a/keep/topologies/topologies_service.py b/keep/topologies/topologies_service.py index 502320960..a403e7fea 100644 --- a/keep/topologies/topologies_service.py +++ b/keep/topologies/topologies_service.py @@ -1,10 +1,10 @@ +import json import logging from typing import List, Optional -from pydantic import ValidationError -from sqlalchemy.orm import joinedload, selectinload from uuid import UUID -import json +from pydantic import ValidationError +from sqlalchemy.orm import joinedload, selectinload from sqlmodel import Session, select from keep.api.core.db_utils import get_aggreated_field diff --git a/keep/topologies/topology_processor.py b/keep/topologies/topology_processor.py new file mode 100644 index 000000000..f54b73059 --- /dev/null +++ b/keep/topologies/topology_processor.py @@ -0,0 +1,268 @@ +import logging +import os +import threading +from collections import defaultdict +from typing import Dict, Optional, Set + +from sqlmodel import select + +from keep.api.core.db import existed_or_new_session, get_last_alerts +from keep.api.core.dependencies import SINGLE_TENANT_UUID +from keep.api.core.tenant_configuration import TenantConfiguration +from keep.api.models.alert import AlertDto +from keep.api.models.db.alert import Incident +from keep.api.models.db.topology import TopologyServiceApplication +from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts +from keep.topologies.topologies_service import TopologiesService + + +class TopologyProcessor: + + @staticmethod + def get_instance() -> "TopologyProcessor": + if not hasattr(TopologyProcessor, "_instance"): + TopologyProcessor._instance = TopologyProcessor() + return TopologyProcessor._instance + + def __init__(self): + self.logger = logging.getLogger(__name__) + self.started = False + self.thread = None + self._stop_event = threading.Event() + self._topology_cache = {} + self._cache_lock = threading.Lock() + self.enabled = ( + os.environ.get("KEEP_TOPOLOGY_PROCESSOR", "false").lower() == "true" + ) + # get enabled tenants + self.tenant_configuration = TenantConfiguration() + self.enabled_tenants = { + tenant_id: self.tenant_configuration.get_configuration( + tenant_id, "topology_processor" + ) + for tenant_id in self.tenant_configuration.configurations + } + # for the single tenant, use the global configuration + self.enabled_tenants[SINGLE_TENANT_UUID] = self.enabled + # Configuration + self.process_interval = 60 # seconds + self.look_back_window = 15 # minutes + + async def start(self): + """Runs the topology processor in server mode""" + if not self.enabled: + self.logger.info("Topology processor is disabled") + return + + if self.started: + self.logger.info("Topology processor already started") + return + + self.logger.info("Starting topology processor") + self._stop_event.clear() + self.thread = threading.Thread( + target=self._start_processing, name="topology-processing", daemon=True + ) + self.thread.start() + self.started = True + self.logger.info("Started topology processor") + + def _start_processing(self): + """Starts processing the topology""" + self.logger.info("Starting topology processing") + + while not self._stop_event.is_set(): + try: + self.logger.info("Processing topology for all tenants") + self._process_all_tenants() + self.logger.info( + "Finished processing topology for all tenants will wait for next interval [{}]".format( + self.process_interval + ) + ) + except Exception as e: + self.logger.exception("Error in topology processing: %s", str(e)) + + # Wait for the next interval or until stopped + self._stop_event.wait(self.process_interval) + + self.logger.info("Topology processing stopped") + + def stop(self): + """Stops the topology processor""" + if not self.started: + return + + self.logger.info("Stopping topology processor") + self._stop_event.set() + + if self.thread and self.thread.is_alive(): + self.thread.join(timeout=30) # Wait up to 30 seconds + if self.thread.is_alive(): + self.logger.warning("Topology processor thread did not stop gracefully") + + self.started = False + self.thread = None + self.logger.info("Stopped topology processor") + + def _process_all_tenants(self): + """Process topology for all tenants""" + tenants = self.enabled_tenants.keys() + for tenant_id in tenants: + try: + self.logger.info(f"Processing topology for tenant {tenant_id}") + self._process_tenant(tenant_id) + self.logger.info(f"Finished processing topology for tenant {tenant_id}") + except Exception as e: + self.logger.exception(f"Error processing tenant {tenant_id}: {str(e)}") + + def _process_tenant(self, tenant_id: str): + """Process topology for a single tenant""" + self.logger.debug(f"Processing topology for tenant {tenant_id}") + + # 1. Get last alerts for the tenant + topology_data = self._get_topology_data(tenant_id) + applications = self._get_applications_data(tenant_id) + services = [t.service for t in topology_data] + if not topology_data: + self.logger.debug(f"No topology data found for tenant {tenant_id}") + return + + # Currently topology-based incidents are created for applications only + # SHAHAR: this is harder to implement service-related incidents without applications + # TODO: add support for service-related incidents + if not applications: + self.logger.debug(f"No applications found for tenant {tenant_id}") + return + + db_last_alerts = get_last_alerts(tenant_id, with_incidents=True) + last_alerts = convert_db_alerts_to_dto_alerts(db_last_alerts) + + services_with_alerts = defaultdict(list) + # group by service + for alert in last_alerts: + if alert.service: + if alert.service not in services: + # ignore alerts for services not in topology data + self.logger.debug( + f"Alert service {alert.service} not in topology data" + ) + continue + services_with_alerts[alert.service].append(alert) + + for application in applications: + # check if there is an incident for the application + incident = self._get_application_based_incident(tenant_id, application) + print(incident) + application_services = [t.service for t in application.services] + # if more than one service in the application has alerts, create an incident + services_with_alerts = [ + service + for service in application_services + if service in services_with_alerts + ] + if len(services_with_alerts) > 1: + self._create_or_update_application_based_incident( + application, services_with_alerts, services_with_alerts[0] + ) + + def _get_topology_based_incidents(self, tenant_id: str) -> Dict[str, Incident]: + """Get all topology-based incidents for a tenant""" + with existed_or_new_session() as session: + incidents = session.exec( + select(Incident).where( + Incident.tenant_id == tenant_id + and Incident.incident_type == "topology" + ) + ).all() + return incidents + + def _check_topology_for_incidents( + self, + last_alerts: Dict[str, AlertDto], + topology_based_incidents: Dict[str, Incident], + ) -> Set[Incident]: + """Check if the topology should create incidents""" + incidents = [] + # get all alerts within the same application: + + # get all alerts within services that have dependencies: + return incidents + + def _get_application_based_incident( + self, tenant_id, application: TopologyServiceApplication + ) -> Optional[Incident]: + """Get the incident for an application""" + with existed_or_new_session() as session: + incident = session.exec( + select(Incident).where(Incident.incident_application == application.id) + ).first() + return incident + + def _get_topology_data(self, tenant_id: str): + """Get topology data for a tenant""" + with existed_or_new_session() as session: + topology_data = TopologiesService.get_all_topology_data( + tenant_id=tenant_id, session=session + ) + return topology_data + + def _get_applications_data(self, tenant_id: str): + """Get applications data for a tenant""" + with existed_or_new_session() as session: + applications = TopologiesService.get_applications_by_tenant_id( + tenant_id=tenant_id, session=session + ) + return applications + + def _get_nested_dependencies(self, topology_data): + """ + Get nested dependencies for each service including all sub-dependencies. + Returns a dict mapping service name to list of all dependencies (direct and indirect). + """ + # First, build a map of service_id to service and its dependencies + service_deps = {} + for service in topology_data: + service_deps[service.service] = { + "deps": list(service.dependencies), # Use list instead of set + "processed": False, + } + + def get_all_deps(service_name: str, visited: set): + """Recursively get all dependencies for a service""" + if service_name in visited: + # Avoid circular dependencies + return [] + + visited.add(service_name) + + if service_name not in service_deps: + # Service not found in our data + return [] + + # Start with direct dependencies + all_deps = service_deps[service_name]["deps"].copy() + + # For each direct dependency, get its dependencies + for dep in service_deps[service_name]["deps"]: + # Find the service object for this dependency + for service in topology_data: + if service.service == dep.serviceName: + # Get nested dependencies recursively + nested_deps = get_all_deps(dep.serviceName, visited.copy()) + # Add nested deps if they're not already in all_deps + for nested_dep in nested_deps: + if not any( + d.serviceId == nested_dep.serviceId for d in all_deps + ): + all_deps.append(nested_dep) + break + + return all_deps + + # Build complete dependency map + nested_dependencies = {} + for service in topology_data: + nested_dependencies[service.service] = get_all_deps(service.service, set()) + + return nested_dependencies