diff --git a/README.md b/README.md index 40fdd2d16..d06b1ec80 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,6 @@
- - Join Slack diff --git a/docker/Dockerfile.dev.ui b/docker/Dockerfile.dev.ui index 0e963bb59..0fd7ae4c8 100644 --- a/docker/Dockerfile.dev.ui +++ b/docker/Dockerfile.dev.ui @@ -19,7 +19,7 @@ COPY ./keep-ui/ /app RUN npm install # Install next globally and create a symlink RUN npm install -g next -RUN ln -s /usr/local/lib/node_modules/next/dist/bin/next /usr/local/bin/next +RUN ln -s /usr/local/lib/node_modules/next/dist/bin/next /usr/local/bin/next || echo "next binary already linked to bin" # Ensure port 3000 is accessible to our system EXPOSE 3000 diff --git a/docs/development/getting-started.mdx b/docs/development/getting-started.mdx index 998f339cf..342b00062 100644 --- a/docs/development/getting-started.mdx +++ b/docs/development/getting-started.mdx @@ -13,7 +13,7 @@ git clone https://github.com/keephq/keep.git && cd keep Next, run ``` -docker-compose -f docker-compose.dev.yml up +docker compose -f docker-compose.dev.yml up ``` ### Testing @@ -33,7 +33,7 @@ poetry run coverage run --branch -m pytest -s tests/e2e_tests/ Migrations are automatically executed on a server startup. To create a migration: ```bash -cd keep && alembic revision --autogenerate -m "Your message" +alembic -c keep/alembic.ini revision --autogenerate -m "Your message" ``` Hint: make sure your models are imported at `./api/models/db/migrations/env.py` for autogenerator to pick them up. @@ -52,6 +52,7 @@ You can run Keep from your VSCode (after cloning the repo) by adding this config "program": "keep/cli/cli.py", "console": "integratedTerminal", "justMyCode": false, + "python": "venv/bin/python", "args": ["--json", "api","--multi-tenant"], "env": { "PYDEVD_DISABLE_FILE_VALIDATION": "1", @@ -72,6 +73,7 @@ You can run Keep from your VSCode (after cloning the repo) by adding this config "program": "scripts/simulate_alerts.py", "console": "integratedTerminal", "justMyCode": false, + "python": "venv/bin/python", "env": { "PYDEVD_DISABLE_FILE_VALIDATION": "1", "PYTHONPATH": "${workspaceFolder}/", @@ -92,9 +94,17 @@ You can run Keep from your VSCode (after cloning the repo) by adding this config Install dependencies: ``` +python3.11 -m venv venv; +source venv/bin/activate; pip install poetry; poetry install; -cd keep-ui && npm i; +cd keep-ui && npm i && cd ..; +``` + +Set frontend envs: +``` +cp keep-ui/.env.local.example keep-ui/.env.local; +echo "\n\n\n\nNEXTAUTH_SECRET="$(openssl rand -hex 32) >> keep-ui/.env.local; ``` Launch Pusher ([soketi](https://soketi.app/)) container in parallel: diff --git a/docs/mint.json b/docs/mint.json index 51a2b654c..c845ee5ee 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -46,7 +46,6 @@ "group": "Development", "pages": [ "development/getting-started", - "development/authentication", "development/external-url" ] }, diff --git a/keep-ui/.env.local.example b/keep-ui/.env.local.example index 87ec162aa..a9f72e248 100644 --- a/keep-ui/.env.local.example +++ b/keep-ui/.env.local.example @@ -1,5 +1,8 @@ NEXTAUTH_URL=http://localhost:3000 -NEXTAUTH_SECRET= # Linux: `openssl rand -hex 32` or go to https://generate-secret.now.sh/32 + +# Required: +# NEXTAUTH_SECRET= # Linux: `openssl rand -hex 32` or go to https://generate-secret.now.sh/32 + # API API_URL=http://localhost:8080 # Auth diff --git a/keep-ui/.gitignore b/keep-ui/.gitignore index f070db838..82729d76a 100644 --- a/keep-ui/.gitignore +++ b/keep-ui/.gitignore @@ -38,3 +38,5 @@ jspm_packages .next .env.local + +app/topology/mock-topology-data.tsx diff --git a/keep-ui/app/alerts/alert-associate-incident-modal.tsx b/keep-ui/app/alerts/alert-associate-incident-modal.tsx index 0f391d5a0..df6cf5399 100644 --- a/keep-ui/app/alerts/alert-associate-incident-modal.tsx +++ b/keep-ui/app/alerts/alert-associate-incident-modal.tsx @@ -1,13 +1,14 @@ -import React, {FormEvent, useState} from "react"; -import { useSession } from "next-auth/react"; -import { AlertDto } from "./models"; import Modal from "@/components/ui/Modal"; -import { useIncidents } from "../../utils/hooks/useIncidents"; +import { Button, Divider, Select, SelectItem, Title } from "@tremor/react"; +import CreateOrUpdateIncident from "app/incidents/create-or-update-incident"; +import { useSession } from "next-auth/react"; +import { useRouter } from "next/navigation"; +import { FormEvent, useCallback, useEffect, useState } from "react"; +import { toast } from "react-toastify"; +import { getApiURL } from "../../utils/apiUrl"; +import { useIncidents, usePollIncidents } from "../../utils/hooks/useIncidents"; import Loading from "../loading"; -import {Button, Divider, Select, SelectItem, Title} from "@tremor/react"; -import {useRouter} from "next/navigation"; -import {getApiURL} from "../../utils/apiUrl"; -import {toast} from "react-toastify"; +import { AlertDto } from "./models"; interface AlertAssociateIncidentModalProps { isOpen: boolean; @@ -22,19 +23,22 @@ const AlertAssociateIncidentModal = ({ handleClose, alerts, }: AlertAssociateIncidentModalProps) => { + const [createIncident, setCreateIncident] = useState(false) const { data: incidents, isLoading, mutate } = useIncidents(true, 100); + usePollIncidents(mutate) + const [selectedIncident, setSelectedIncident] = useState(null); // get the token const { data: session } = useSession(); const router = useRouter(); // if this modal should not be open, do nothing if (!alerts) return null; - const handleAssociateAlerts = async (e: FormEvent) => { - e.preventDefault(); + + const associateAlertsHandler = async (incidentId: string) => { const apiUrl = getApiURL(); const response = await fetch( - `${apiUrl}/incidents/${selectedIncident}/alerts`, + `${apiUrl}/incidents/${incidentId}/alerts`, { method: "POST", headers: { @@ -54,6 +58,28 @@ const AlertAssociateIncidentModal = ({ } } + const handleAssociateAlerts = (e: FormEvent) => { + e.preventDefault(); + associateAlertsHandler(selectedIncident) + } + + const showCreateIncidentForm = useCallback(() => setCreateIncident(true), []) + + const hideCreateIncidentForm = useCallback(() => setCreateIncident(false), []) + + const onIncidentCreated = useCallback((incidentId: string) => { + hideCreateIncidentForm() + handleClose() + associateAlertsHandler(incidentId) + }, []) + + // reset modal state after closing + useEffect(() => { + if (!isOpen) { + hideCreateIncidentForm() + setSelectedIncident(null) + } + }, [isOpen]) return (
{isLoading ? ( - - ) : incidents && incidents.items.length > 0 ? ( + + ) : createIncident ? ( + + ): incidents && incidents.items.length > 0 ? (
- Filter 1 - - */} -
-
- setProvidersSearchString(e.target.value)} - /> - - Alert - Messaging - Ticketing - Data - + +
+
+
+ + +
-
-
{children}
-
- + + ); } diff --git a/keep-ui/app/providers/page.client.tsx b/keep-ui/app/providers/page.client.tsx index b67ca6d09..5db731a03 100644 --- a/keep-ui/app/providers/page.client.tsx +++ b/keep-ui/app/providers/page.client.tsx @@ -1,6 +1,5 @@ "use client"; import { - Providers, defaultProvider, Provider, ProvidersResponse, @@ -13,7 +12,7 @@ import ProvidersTiles from "./providers-tiles"; import React, { useState, Suspense, useContext, useEffect } from "react"; import useSWR from "swr"; import Loading from "../loading"; -import { LayoutContext } from "./context"; +import { useFilterContext } from "./filter-context"; import { toast } from "react-toastify"; import { useRouter } from "next/navigation"; @@ -140,7 +139,7 @@ export default function ProvidersPage({ isSlowLoading, isLocalhost, } = useFetchProviders(); - const { searchProviderString, selectedTags } = useContext(LayoutContext); + const { providersSearchString, providersSelectedTags } = useFilterContext(); const router = useRouter(); useEffect(() => { if (searchParams?.oauth === "failure") { @@ -154,14 +153,14 @@ export default function ProvidersPage({ }); } }, [searchParams]); - + if (error) { + throw new KeepApiError(error.message, `${getApiURL()}/providers`, `Failed to query ${getApiURL()}/providers, is Keep API up?`); + } if (status === "loading") return ; if (status === "unauthenticated") router.push("/signin"); if (!providers || !installedProviders || providers.length <= 0) return ; - if (error) { - throw new KeepApiError(error.message, `${getApiURL()}/providers`); - } + const addProvider = (provider: Provider) => { setInstalledProviders((prevProviders) => { @@ -183,15 +182,15 @@ export default function ProvidersPage({ const searchProviders = (provider: Provider) => { return ( - !searchProviderString || - provider.type?.toLowerCase().includes(searchProviderString.toLowerCase()) + !providersSearchString || + provider.type?.toLowerCase().includes(providersSearchString.toLowerCase()) ); }; const searchTags = (provider: Provider) => { return ( - selectedTags.length === 0 || - provider.tags.some((tag) => selectedTags.includes(tag)) + providersSelectedTags.length === 0 || + provider.tags.some((tag) => providersSelectedTags.includes(tag)) ); }; diff --git a/keep-ui/app/providers/provider-tile.tsx b/keep-ui/app/providers/provider-tile.tsx index e25c6a636..d2cc54101 100644 --- a/keep-ui/app/providers/provider-tile.tsx +++ b/keep-ui/app/providers/provider-tile.tsx @@ -14,6 +14,7 @@ import { CircleStackIcon, QueueListIcon, TicketIcon, + MapIcon, } from "@heroicons/react/20/solid"; import "./provider-tile.css"; import moment from "moment"; @@ -199,6 +200,8 @@ export default function ProviderTile({ provider, onClick }: Props) { ? TicketIcon : tag === "queue" ? QueueListIcon + : tag === "topology" + ? MapIcon : ChatBubbleBottomCenterIcon; return ( { + const { useAllAlerts } = useAlerts(); + const { data: alerts, mutate } = useAllAlerts("feed"); + const { data: pollAlerts } = useAlertPolling(); + const router = useRouter(); + + useEffect(() => { + if (pollAlerts) { + mutate(); + } + }, [pollAlerts, mutate]); + + const relevantAlerts = alerts?.filter((alert) => alert.service === data.service); + + const handleClick = () => { + router.push( + `/alerts/feed?cel=service%3D%3D${encodeURIComponent(`"${data.service}"`)}` + ); + }; + + const alertCount = relevantAlerts?.length || 0; + const badgeColor = alertCount < THRESHOLD ? "bg-orange-500" : "bg-red-500"; + + return ( +
+ {data.service} + {alertCount > 0 && ( + + {alertCount} + + )} + + +
+ ); +}; + +export default CustomNode; diff --git a/keep-ui/app/topology/layout.tsx b/keep-ui/app/topology/layout.tsx new file mode 100644 index 000000000..ab479c40f --- /dev/null +++ b/keep-ui/app/topology/layout.tsx @@ -0,0 +1,5 @@ +export default function Layout({ children }: { children: any }) { + return ( +
{children}
+ ); +} diff --git a/keep-ui/app/topology/models.tsx b/keep-ui/app/topology/models.tsx new file mode 100644 index 000000000..ebd15ebed --- /dev/null +++ b/keep-ui/app/topology/models.tsx @@ -0,0 +1,20 @@ +export interface TopologyServiceDependency { + serviceId: string; + serviceName: string; + protocol?: string; +} + +export interface TopologyService { + id: string; + source_provider_id?: string; + repository?: string; + tags?: string[]; + service: string; + display_name: string; + description?: string; + team?: string; + application?: string; + email?: string; + slack?: string; + dependencies: TopologyServiceDependency[]; +} diff --git a/keep-ui/app/topology/page.tsx b/keep-ui/app/topology/page.tsx new file mode 100644 index 000000000..497f57a3f --- /dev/null +++ b/keep-ui/app/topology/page.tsx @@ -0,0 +1,16 @@ +import { Title } from "@tremor/react"; +import TopologyPage from "./topology"; + +export default function Page() { + return ( + <> + Service Topology + + + ); +} + +export const metadata = { + title: "Keep - Service Topology", + description: "See service topology and information about your services", +}; diff --git a/keep-ui/app/topology/styles.tsx b/keep-ui/app/topology/styles.tsx new file mode 100644 index 000000000..36090ebcc --- /dev/null +++ b/keep-ui/app/topology/styles.tsx @@ -0,0 +1,28 @@ +import { MarkerType } from "@xyflow/react"; + +export const nodeWidth = 220; +export const nodeHeight = 80; + +// Edge No Hover +export const edgeLabelBgStyleNoHover = { + strokeWidth: 1, + strokeDasharray: "5,5", + stroke: "#b1b1b7", // default graph stroke line color +}; +export const edgeLabelBgBorderRadiusNoHover = 10; +export const edgeLabelBgPaddingNoHover: [number, number] = [10, 5]; +export const edgeMarkerEndNoHover = { + type: MarkerType.ArrowClosed, +}; + +// Edge Hover +export const edgeLabelBgStyleHover = { + ...edgeLabelBgStyleNoHover, + stroke: "none", + fill: "orange", + color: "white", +}; +export const edgeMarkerEndHover = { + ...edgeMarkerEndNoHover, + color: "orange", +}; diff --git a/keep-ui/app/topology/topology.css b/keep-ui/app/topology/topology.css new file mode 100644 index 000000000..dcea1c6a6 --- /dev/null +++ b/keep-ui/app/topology/topology.css @@ -0,0 +1,10 @@ +.react-flow__handle-left, +.react-flow__handle-right, +.react-flow__handle-top, +.react-flow__handle-bottom { + opacity: 0; +} + +.react-flow__edge.selectable { + cursor: default; +} diff --git a/keep-ui/app/topology/topology.tsx b/keep-ui/app/topology/topology.tsx new file mode 100644 index 000000000..ba7c26e57 --- /dev/null +++ b/keep-ui/app/topology/topology.tsx @@ -0,0 +1,236 @@ +"use client"; +import React, { useCallback, useEffect, useState } from "react"; +import { + Background, + BackgroundVariant, + Controls, + Edge, + Node, + ReactFlow, + ReactFlowInstance, + ReactFlowProvider, +} from "@xyflow/react"; +import dagre, { graphlib } from "@dagrejs/dagre"; +import "@xyflow/react/dist/style.css"; +import CustomNode from "./custom-node"; +import { Card, TextInput } from "@tremor/react"; +import { + edgeLabelBgPaddingNoHover, + edgeLabelBgStyleNoHover, + edgeLabelBgBorderRadiusNoHover, + edgeMarkerEndNoHover, + edgeLabelBgStyleHover, + edgeMarkerEndHover, + nodeHeight, + nodeWidth, +} from "./styles"; +import "./topology.css"; +import { useTopology } from "utils/hooks/useTopology"; +import Loading from "app/loading"; +import { EmptyStateCard } from "@/components/ui/EmptyStateCard"; +import { useRouter } from "next/navigation"; + +interface Props { + providerId?: string; + service?: string; + environment?: string; + showSearch?: boolean; +} + +// Function to create a Dagre layout +const dagreGraph = new graphlib.Graph(); +dagreGraph.setDefaultEdgeLabel(() => ({})); + +const getLayoutedElements = (nodes: any[], edges: any[]) => { + dagreGraph.setGraph({ rankdir: "LR", nodesep: 50, ranksep: 200 }); + + nodes.forEach((node) => { + dagreGraph.setNode(node.id, { width: nodeWidth, height: nodeHeight }); + }); + + edges.forEach((edge) => { + dagreGraph.setEdge(edge.source, edge.target); + }); + + dagre.layout(dagreGraph); + + nodes.forEach((node) => { + const nodeWithPosition = dagreGraph.node(node.id); + node.targetPosition = "left"; + node.sourcePosition = "right"; + + node.position = { + x: nodeWithPosition.x - nodeWidth / 2, + y: nodeWithPosition.y - nodeHeight / 2, + }; + + return node; + }); + + return { nodes, edges }; +}; + +const TopologyPage = ({ + providerId, + service, + environment, + showSearch = true, +}: Props) => { + const router = useRouter(); + // State for nodes and edges + const [nodes, setNodes] = useState([]); + const [edges, setEdges] = useState([]); + const [serviceInput, setServiceInput] = useState(""); + const [reactFlowInstance, setReactFlowInstance] = + useState>(); + + const { topologyData, error, isLoading } = useTopology( + providerId, + service, + environment + ); + + const onEdgeHover = (eventType: "enter" | "leave", edge: Edge) => { + const newEdges = [...edges]; + const currentEdge = newEdges.find((e) => e.id === edge.id); + if (currentEdge) { + currentEdge.style = eventType === "enter" ? { stroke: "orange" } : {}; + currentEdge.labelBgStyle = + eventType === "enter" ? edgeLabelBgStyleHover : edgeLabelBgStyleNoHover; + currentEdge.markerEnd = + eventType === "enter" ? edgeMarkerEndHover : edgeMarkerEndNoHover; + currentEdge.labelStyle = eventType === "enter" ? { fill: "white" } : {}; + setEdges(newEdges); + } + }; + + const zoomToNode = useCallback( + (nodeId: string) => { + const node = reactFlowInstance?.getNode(nodeId); + if (node && reactFlowInstance) { + reactFlowInstance.setCenter(node.position.x, node.position.y); + } + }, + [reactFlowInstance] + ); + + useEffect(() => { + if (serviceInput) { + zoomToNode(serviceInput); + } + }, [serviceInput, zoomToNode]); + + useEffect(() => { + if (!topologyData) return; + + // Create nodes from service definitions + const newNodes = topologyData.map((service) => ({ + id: service.service.toString(), + type: "customNode", + data: service, + position: { x: 0, y: 0 }, // Dagre will handle the actual positioning + })); + + // Create edges from service dependencies + const edgeMap = new Map(); + + topologyData.forEach((service) => { + service.dependencies.forEach((dependency) => { + const dependencyService = topologyData.find( + (s) => s.service === dependency.serviceName + ); + const edgeId = `${service.service}_${dependency.protocol}_${ + dependencyService + ? dependencyService.service + : dependency.serviceId.toString() + }`; + if (!edgeMap.has(edgeId)) { + edgeMap.set(edgeId, { + id: edgeId, + source: service.service.toString(), + target: dependency.serviceName.toString(), + label: dependency.protocol === "unknown" ? "" : dependency.protocol, + animated: false, + labelBgPadding: edgeLabelBgPaddingNoHover, + labelBgStyle: edgeLabelBgStyleNoHover, + labelBgBorderRadius: edgeLabelBgBorderRadiusNoHover, + markerEnd: edgeMarkerEndNoHover, + }); + } + }); + }); + + const newEdges = Array.from(edgeMap.values()); + const layoutedElements = getLayoutedElements(newNodes, newEdges); + setNodes(layoutedElements.nodes); + setEdges(layoutedElements.edges); + }, [topologyData]); + + if (isLoading) return ; + if (error) + return ( +
+ { + window.open("https://slack.keephq.dev/", "_blank"); + }} + /> +
+ ); + + return ( + + {showSearch && ( +
+ +
+ )} + + onEdgeHover("enter", edge)} + onEdgeMouseLeave={(_event, edge) => onEdgeHover("leave", edge)} + nodeTypes={{ customNode: CustomNode }} + onInit={(instance) => { + setReactFlowInstance(instance); + }} + > + + + + + {!topologyData || + (topologyData?.length === 0 && ( + <> +
+
+
+ router.push("/providers?labels=topology")} + /> +
+
+ + ))} + + ); +}; + +export default TopologyPage; diff --git a/keep-ui/app/workflows/builder/builder-card.tsx b/keep-ui/app/workflows/builder/builder-card.tsx index e27e60541..b1bf5310c 100644 --- a/keep-ui/app/workflows/builder/builder-card.tsx +++ b/keep-ui/app/workflows/builder/builder-card.tsx @@ -50,7 +50,8 @@ export function BuilderCard({ if (error) { throw new KeepApiError( "The builder has failed to load providers", - `${apiUrl}/providers` + `${apiUrl}/providers`, + `Failed to query ${apiUrl}/providers, is Keep API up?` ); } diff --git a/keep-ui/components/navbar/CustomPresetAlertLinks.tsx b/keep-ui/components/navbar/CustomPresetAlertLinks.tsx index ff104f7d4..b08d66061 100644 --- a/keep-ui/components/navbar/CustomPresetAlertLinks.tsx +++ b/keep-ui/components/navbar/CustomPresetAlertLinks.tsx @@ -112,7 +112,14 @@ export const CustomPresetAlertLinks = ({ const [presetsOrder, setPresetsOrder] = useState([]); // Check for noisy presets and control sound playback - const anyNoisyNow = presets.some(preset => preset.should_do_noise_now); + const anyNoisyNow = presets.some((preset) => preset.should_do_noise_now); + + const checkValidPreset = (preset: Preset) => { + if (!preset.is_private) { + return true; + } + return preset && preset.created_by == session?.user?.email; + }; useEffect(() => { const filteredLS = presetsOrderFromLS.filter( @@ -120,11 +127,11 @@ export const CustomPresetAlertLinks = ({ ); // Combine live presets and local storage order - const combinedOrder = presets.reduce((acc, preset) => { - if (!acc.find(p => p.id === preset.id)) { + const combinedOrder = presets.reduce((acc, preset: Preset) => { + if (!acc.find((p) => p.id === preset.id)) { acc.push(preset); } - return acc; + return acc.filter((preset) => checkValidPreset(preset)); }, [...filteredLS]); // Only update state if there's an actual change to prevent infinite loops diff --git a/keep-ui/components/navbar/DashboardLinks.tsx b/keep-ui/components/navbar/DashboardLinks.tsx index 4e1591abf..7f8fb865a 100644 --- a/keep-ui/components/navbar/DashboardLinks.tsx +++ b/keep-ui/components/navbar/DashboardLinks.tsx @@ -136,14 +136,15 @@ export const DashboardLinks = ({ session }: DashboardProps) => { )): Dashboards will appear here when saved. } +
+ />
); diff --git a/keep-ui/components/navbar/IncidentLinks.tsx b/keep-ui/components/navbar/IncidentLinks.tsx index d0ee861f6..15a310d76 100644 --- a/keep-ui/components/navbar/IncidentLinks.tsx +++ b/keep-ui/components/navbar/IncidentLinks.tsx @@ -8,7 +8,7 @@ import { Session } from "next-auth"; import { Disclosure } from "@headlessui/react"; import { IoChevronUp } from "react-icons/io5"; import classNames from "classnames"; -import { useIncidents } from "utils/hooks/useIncidents"; +import { useIncidents, usePollIncidents } from "utils/hooks/useIncidents"; import { MdNearbyError } from "react-icons/md"; type IncidentsLinksProps = { session: Session | null }; @@ -16,7 +16,8 @@ const SHOW_N_INCIDENTS = 3; export const IncidentsLinks = ({ session }: IncidentsLinksProps) => { const isNOCRole = session?.userRole === "noc"; - const { data: incidents } = useIncidents(); + const { data: incidents, mutate } = useIncidents(); + usePollIncidents(mutate) const currentPath = usePathname(); if (isNOCRole) { diff --git a/keep-ui/components/navbar/Navbar.tsx b/keep-ui/components/navbar/Navbar.tsx index 074d7d956..c7bddd903 100644 --- a/keep-ui/components/navbar/Navbar.tsx +++ b/keep-ui/components/navbar/Navbar.tsx @@ -19,8 +19,8 @@ export default async function NavbarInner() {
+ -
diff --git a/keep-ui/components/navbar/NoiseReductionLinks.tsx b/keep-ui/components/navbar/NoiseReductionLinks.tsx index d24f01eca..4ccf7a056 100644 --- a/keep-ui/components/navbar/NoiseReductionLinks.tsx +++ b/keep-ui/components/navbar/NoiseReductionLinks.tsx @@ -8,6 +8,7 @@ import { Disclosure } from "@headlessui/react"; import { IoChevronUp } from "react-icons/io5"; import classNames from "classnames"; import { AILink } from "./AILink"; +import { TbTopologyRing } from "react-icons/tb"; type NoiseReductionLinksProps = { session: Session | null }; @@ -39,34 +40,31 @@ export const NoiseReductionLinks = ({ session }: NoiseReductionLinksProps) => {
  • - - Alert Groups - + Alert Groups
  • - - Workflows - + Workflows + +
  • +
  • + + Service Topology
  • - - Mapping - + Mapping
  • - - Extraction - + Extraction
  • - +
  • diff --git a/keep-ui/components/ui/EmptyStateCard.tsx b/keep-ui/components/ui/EmptyStateCard.tsx new file mode 100644 index 000000000..8563e44a4 --- /dev/null +++ b/keep-ui/components/ui/EmptyStateCard.tsx @@ -0,0 +1,38 @@ +import { Button, Card } from "@tremor/react"; +import { CircleStackIcon } from "@heroicons/react/24/outline"; + +export function EmptyStateCard({ + title, + description, + buttonText, + onClick, + className, +}: { + title: string; + description: string; + buttonText: string; + onClick: () => void; + className?: string; +}) { + return ( + +
    + +

    + {title} +

    +

    + {description} +

    + +
    +
    + ); +} diff --git a/keep-ui/package-lock.json b/keep-ui/package-lock.json index 6c46e4df4..cb5fb4e97 100644 --- a/keep-ui/package-lock.json +++ b/keep-ui/package-lock.json @@ -10,6 +10,7 @@ "license": "ISC", "dependencies": { "@boiseitguru/cookie-cutter": "^0.2.3", + "@dagrejs/dagre": "^1.1.3", "@dnd-kit/core": "^6.1.0", "@dnd-kit/sortable": "^8.0.0", "@dnd-kit/utilities": "^3.2.2", @@ -28,7 +29,7 @@ "@tanstack/react-table": "^8.11.0", "@tremor/react": "^3.15.1", "@types/react-select": "^5.0.1", - "@xyflow/react": "^12.0.1", + "@xyflow/react": "^12.0.2", "add": "^2.0.6", "ajv": "^6.12.6", "ansi-regex": "^5.0.1", @@ -2217,6 +2218,22 @@ "resolved": "https://registry.npmjs.org/commander/-/commander-2.20.3.tgz", "integrity": "sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==" }, + "node_modules/@dagrejs/dagre": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@dagrejs/dagre/-/dagre-1.1.3.tgz", + "integrity": "sha512-umT7fBPECI4zgxxXW07H3vJN7W1WZcnBjk613eOEAKcwoFrYNyMZO+1SHmoC8zPZWR18DquK2wRUp9VHUE+94g==", + "dependencies": { + "@dagrejs/graphlib": "2.2.2" + } + }, + "node_modules/@dagrejs/graphlib": { + "version": "2.2.2", + "resolved": "https://registry.npmjs.org/@dagrejs/graphlib/-/graphlib-2.2.2.tgz", + "integrity": "sha512-CbyGpCDKsiTg/wuk79S7Muoj8mghDGAESWGxcSyhHX5jD35vYMBZochYVFzlHxynpE9unpu6O+4ZuhrLxASsOg==", + "engines": { + "node": ">17.0.0" + } + }, "node_modules/@dnd-kit/accessibility": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/@dnd-kit/accessibility/-/accessibility-3.1.0.tgz", @@ -4309,11 +4326,11 @@ "integrity": "sha512-zuVdFrMJiuCDQUMCzQaD6KL28MjnqqN8XnAqiEq9PNm/hCPTSGfrXCOfwj1ow4LFb/tNymJPwsNbVePc1xFqrQ==" }, "node_modules/@xyflow/react": { - "version": "12.0.1", - "resolved": "https://registry.npmjs.org/@xyflow/react/-/react-12.0.1.tgz", - "integrity": "sha512-iGh/nO7key0sVH0c8TW2qvLNU0akJ20Mi3LPUF2pymhRqerrBk0EJhPLXRThbYWy4pNWUnkhpBLB0/gr884qnw==", + "version": "12.0.3", + "resolved": "https://registry.npmjs.org/@xyflow/react/-/react-12.0.3.tgz", + "integrity": "sha512-PJB9ARsyDesjS9fY3b62mm36nHx9aRA8tvUc5y0ubrMkSCvQRECkOamVDyx+u65UgUkZCgcO/KFdXPdbTWwaJQ==", "dependencies": { - "@xyflow/system": "0.0.35", + "@xyflow/system": "0.0.37", "classcat": "^5.0.3", "zustand": "^4.4.0" }, @@ -4323,9 +4340,9 @@ } }, "node_modules/@xyflow/system": { - "version": "0.0.35", - "resolved": "https://registry.npmjs.org/@xyflow/system/-/system-0.0.35.tgz", - "integrity": "sha512-QaUkahvmMs2gY2ykxUfjs5CbkXzU5fQNtmoQQ6HmHoAr8n2D7UyLO/UEXlke2jxuCDuiwpXhrzn4DmffVJd2qA==", + "version": "0.0.37", + "resolved": "https://registry.npmjs.org/@xyflow/system/-/system-0.0.37.tgz", + "integrity": "sha512-hSIhezhxgftPUpC+xiQVIorcRILZUOWlLjpYPTyGWRu8s4RJvM4GqvrsFmD5OnMKXLgpU7/PqqUibDVO67oWQQ==", "dependencies": { "@types/d3-drag": "^3.0.7", "@types/d3-selection": "^3.0.10", diff --git a/keep-ui/package.json b/keep-ui/package.json index 4de180653..79453dc0b 100644 --- a/keep-ui/package.json +++ b/keep-ui/package.json @@ -11,6 +11,7 @@ }, "dependencies": { "@boiseitguru/cookie-cutter": "^0.2.3", + "@dagrejs/dagre": "^1.1.3", "@dnd-kit/core": "^6.1.0", "@dnd-kit/sortable": "^8.0.0", "@dnd-kit/utilities": "^3.2.2", @@ -29,7 +30,7 @@ "@tanstack/react-table": "^8.11.0", "@tremor/react": "^3.15.1", "@types/react-select": "^5.0.1", - "@xyflow/react": "^12.0.1", + "@xyflow/react": "^12.0.2", "add": "^2.0.6", "ajv": "^6.12.6", "ansi-regex": "^5.0.1", diff --git a/keep-ui/utils/hooks/useIncidents.ts b/keep-ui/utils/hooks/useIncidents.ts index 722bd95df..849bf799f 100644 --- a/keep-ui/utils/hooks/useIncidents.ts +++ b/keep-ui/utils/hooks/useIncidents.ts @@ -1,5 +1,5 @@ import { AlertDto } from "app/alerts/models"; -import { IncidentDto, PaginatedIncidentsDto } from "app/incidents/model"; +import {IncidentDto, PaginatedIncidentAlertsDto, PaginatedIncidentsDto} from "app/incidents/model"; import { useSession } from "next-auth/react"; import useSWR, { SWRConfiguration } from "swr"; import { getApiURL } from "utils/apiUrl"; @@ -33,14 +33,16 @@ export const useIncidents = ( export const useIncidentAlerts = ( incidentId: string, + limit: number = 20, + offset: number = 0, options: SWRConfiguration = { revalidateOnFocus: false, } ) => { const apiUrl = getApiURL(); const { data: session } = useSession(); - return useSWR( - () => (session ? `${apiUrl}/incidents/${incidentId}/alerts` : null), + return useSWR( + () => (session ? `${apiUrl}/incidents/${incidentId}/alerts?limit=${limit}&offset=${offset}` : null), (url) => fetcher(url, session?.accessToken), options ); diff --git a/keep-ui/utils/hooks/usePresets.ts b/keep-ui/utils/hooks/usePresets.ts index a76989d89..3bfe9d4f3 100644 --- a/keep-ui/utils/hooks/usePresets.ts +++ b/keep-ui/utils/hooks/usePresets.ts @@ -43,7 +43,9 @@ export const usePresets = () => { updatedPresets.set(newPresetId, { ...currentPreset, alerts_count: currentPreset.alerts_count + newPreset.alerts_count, - }); + created_by: newPreset.created_by, + is_private: newPreset.is_private + }); } else { // If the preset is not in the current presets, add it updatedPresets.set(newPresetId, { diff --git a/keep-ui/utils/hooks/useTopology.ts b/keep-ui/utils/hooks/useTopology.ts new file mode 100644 index 000000000..74d46c47b --- /dev/null +++ b/keep-ui/utils/hooks/useTopology.ts @@ -0,0 +1,36 @@ +import { TopologyService } from "app/topology/models"; +import { useSession } from "next-auth/react"; +import useSWR from "swr"; +import { getApiURL } from "utils/apiUrl"; +import { fetcher } from "utils/fetcher"; + +const isNullOrUndefined = (value: any) => value === null || value === undefined; + +export const useTopology = ( + providerId?: string, + service?: string, + environment?: string +) => { + const { data: session } = useSession(); + const apiUrl = getApiURL(); + + const url = !session + ? null + : !isNullOrUndefined(providerId) && + !isNullOrUndefined(service) && + !isNullOrUndefined(environment) + ? `${apiUrl}/topology?provider_id=${providerId}&service_id=${service}&environment=${environment}` + : `${apiUrl}/topology`; + + const { data, error, mutate } = useSWR( + url, + (url: string) => fetcher(url, session!.accessToken) + ); + + return { + topologyData: data, + error, + isLoading: !data && !error, + mutate, + }; +}; diff --git a/keep/alembic.ini b/keep/alembic.ini index 74284710e..1714b9bd6 100644 --- a/keep/alembic.ini +++ b/keep/alembic.ini @@ -1,6 +1,6 @@ [alembic] # Re-defined in the keep/api/core/db_on_start.py to make it stable while keep is installed as a package -script_location = api/models/db/migrations +script_location = keep/api/models/db/migrations file_template = %%(year)d-%%(month).2d-%%(day).2d-%%(hour).2d-%%(minute).2d_%%(rev)s prepend_sys_path = . output_encoding = utf-8 diff --git a/keep/api/api.py b/keep/api/api.py index 58bf2faad..f36fccb6e 100644 --- a/keep/api/api.py +++ b/keep/api/api.py @@ -4,8 +4,8 @@ from importlib import metadata import jwt -import uvicorn import requests +import uvicorn from dotenv import find_dotenv, load_dotenv from fastapi import FastAPI, HTTPException, Request, Response from fastapi.middleware.gzip import GZipMiddleware @@ -25,6 +25,7 @@ from keep.api.logging import CONFIG as logging_config from keep.api.routes import ( actions, + ai, alerts, dashboard, extraction, @@ -38,11 +39,10 @@ rules, settings, status, + topology, users, whoami, workflows, - incidents, - ai ) from keep.event_subscriber.event_subscriber import EventSubscriber from keep.posthog.posthog import get_posthog_client @@ -68,10 +68,12 @@ # Monkey patch requests to disable redirects original_request = requests.Session.request + def no_redirect_request(self, method, url, **kwargs): - kwargs['allow_redirects'] = False + kwargs["allow_redirects"] = False return original_request(self, method, url, **kwargs) + requests.Session.request = no_redirect_request @@ -187,6 +189,7 @@ def get_app( app.include_router(preset.router, prefix="/preset", tags=["preset"]) app.include_router(groups.router, prefix="/groups", tags=["groups"]) app.include_router(users.router, prefix="/users", tags=["users"]) + app.include_router(topology.router, prefix="/topology", tags=["topology"]) app.include_router( mapping.router, prefix="/mapping", tags=["enrichment", "mapping"] ) diff --git a/keep/api/arq_worker.py b/keep/api/arq_worker.py index 0fb69e45d..726c9d5bf 100644 --- a/keep/api/arq_worker.py +++ b/keep/api/arq_worker.py @@ -16,7 +16,10 @@ ARQ_BACKGROUND_FUNCTIONS: Optional[CommaSeparatedStrings] = config( "ARQ_BACKGROUND_FUNCTIONS", cast=CommaSeparatedStrings, - default=["keep.api.tasks.process_event_task.async_process_event"], + default=[ + "keep.api.tasks.process_event_task.async_process_event", + "keep.api.tasks.process_topology_task.async_process_topology", + ], ) FUNCTIONS: list = ( [ diff --git a/keep/api/consts.py b/keep/api/consts.py index 7f0f535e7..f7972c615 100644 --- a/keep/api/consts.py +++ b/keep/api/consts.py @@ -3,6 +3,9 @@ from keep.api.models.db.preset import PresetDto, StaticPresetsId RUNNING_IN_CLOUD_RUN = os.environ.get("K_SERVICE") is not None +PROVIDER_PULL_INTERVAL_DAYS = int( + os.environ.get("KEEP_PULL_INTERVAL", 7) +) # maximum once a week STATIC_PRESETS = { "feed": PresetDto( id=StaticPresetsId.FEED_PRESET_ID.value, diff --git a/keep/api/core/db.py b/keep/api/core/db.py index a719eb829..e824f46a1 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -19,7 +19,7 @@ import validators from dotenv import find_dotenv, load_dotenv from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor -from sqlalchemy import and_, desc, func, null, update +from sqlalchemy import and_, desc, null, update from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.orm import joinedload, selectinload, subqueryload from sqlalchemy.orm.attributes import flag_modified @@ -27,7 +27,7 @@ from sqlalchemy.sql import expression from sqlmodel import Session, col, or_, select -from keep.api.core.db_utils import create_db_engine +from keep.api.core.db_utils import create_db_engine, get_json_extract_field # This import is required to create the tables from keep.api.models.alert import AlertStatus, IncidentDtoIn @@ -40,6 +40,7 @@ from keep.api.models.db.provider import * # pylint: disable=unused-wildcard-import from keep.api.models.db.rule import * # pylint: disable=unused-wildcard-import from keep.api.models.db.tenant import * # pylint: disable=unused-wildcard-import +from keep.api.models.db.topology import * # pylint: disable=unused-wildcard-import from keep.api.models.db.workflow import * # pylint: disable=unused-wildcard-import from keep.api.models.db.statistics import * # pylint: disable=unused-wildcard-import @@ -87,6 +88,7 @@ def create_workflow_execution( execution_number: int = 1, event_id: str = None, fingerprint: str = None, + execution_id: str = None, ) -> WorkflowExecution: with Session(engine) as session: try: @@ -94,7 +96,7 @@ def create_workflow_execution( triggered_by = triggered_by[:255] workflow_execution = WorkflowExecution( - id=str(uuid4()), + id=execution_id or str(uuid4()), workflow_id=workflow_id, tenant_id=tenant_id, started=datetime.now(tz=timezone.utc), @@ -496,6 +498,31 @@ def get_raw_workflow(tenant_id: str, workflow_id: str) -> str: return workflow.workflow_raw +def update_provider_last_pull_time(tenant_id: str, provider_id: str): + extra = {"tenant_id": tenant_id, "provider_id": provider_id} + logger.info("Updating provider last pull time", extra=extra) + with Session(engine) as session: + provider = session.exec( + select(Provider).where( + Provider.tenant_id == tenant_id, Provider.id == provider_id + ) + ).first() + + if not provider: + logger.warning( + "Could not update provider last pull time since provider does not exist", + extra=extra, + ) + + try: + provider.last_pull_time = datetime.now(tz=timezone.utc) + session.commit() + except Exception: + logger.exception("Failed to update provider last pull time", extra=extra) + raise + logger.info("Successfully updated provider last pull time", extra=extra) + + def get_installed_providers(tenant_id: str) -> List[Provider]: with Session(engine) as session: providers = session.exec( @@ -770,13 +797,22 @@ def count_alerts( ) -def get_enrichment(tenant_id, fingerprint): +def get_enrichment(tenant_id, fingerprint, refresh=False): with Session(engine) as session: alert_enrichment = session.exec( select(AlertEnrichment) .where(AlertEnrichment.tenant_id == tenant_id) .where(AlertEnrichment.alert_fingerprint == fingerprint) ).first() + + if refresh: + try: + session.refresh(alert_enrichment) + except Exception: + logger.exception( + "Failed to refresh enrichment", + extra={"tenant_id": tenant_id, "fingerprint": fingerprint}, + ) return alert_enrichment @@ -1842,28 +1878,10 @@ def update_preset_options(tenant_id: str, preset_id: str, options: dict) -> Pres return preset -def get_incident_by_id(incident_id: UUID) -> Incident: - with Session(engine) as session: - incident = session.exec( - select(Incident) - .options(selectinload(Incident.alerts)) - .where(Incident.id == incident_id) - ).first() - return incident - - def assign_alert_to_incident( alert_id: UUID, incident_id: UUID, tenant_id: str -) -> AlertToIncident: - with Session(engine) as session: - assignment = AlertToIncident( - alert_id=alert_id, incident_id=incident_id, tenant_id=tenant_id - ) - session.add(assignment) - session.commit() - session.refresh(assignment) - - return assignment +): + return add_alerts_to_incident_by_incident_id(tenant_id, incident_id, [alert_id]) def is_alert_assigned_to_incident(alert_id: UUID, incident_id: UUID, tenant_id: str) -> bool: with Session(engine) as session: @@ -1997,6 +2015,7 @@ def get_last_incidents( .filter(Incident.tenant_id == tenant_id) .filter(Incident.is_confirmed == is_confirmed) .options(joinedload(Incident.alerts)) + .order_by(desc(Incident.creation_time)) ) if timeframe: @@ -2028,7 +2047,7 @@ def get_last_incidents( return incidents, total_count -def get_incident_by_id(tenant_id: str, incident_id: str) -> Optional[Incident]: +def get_incident_by_id(tenant_id: str, incident_id: str | UUID) -> Optional[Incident]: with Session(engine) as session: query = session.query( Incident, @@ -2139,7 +2158,7 @@ def get_incidents_count( ) -def get_incident_alerts_by_incident_id(tenant_id: str, incident_id: str) -> List[Alert]: +def get_incident_alerts_by_incident_id(tenant_id: str, incident_id: str, limit: int, offset: int) -> (List[Alert], int): with Session(engine) as session: query = ( session.query( @@ -2153,11 +2172,66 @@ def get_incident_alerts_by_incident_id(tenant_id: str, incident_id: str) -> List ) ) - return query.all() + total_count = query.count() + + return query.limit(limit).offset(offset).all(), total_count + + +def get_alerts_data_for_incident( + alert_ids: list[str | UUID], + session: Optional[Session] = None +) -> dict: + + """ + Function to prepare aggregated data for incidents from the given list of alert_ids + Logic is wrapped to the inner function for better usability with an optional database session + + Args: + alert_ids (list[str | UUID]): list of alert ids for aggregation + session (Optional[Session]): The database session or None + + Returns: dict {sources: list[str], services: list[str], count: int} + """ + + def inner(db_session: Session): + + fields = ( + get_json_extract_field(session, Alert.event, 'service'), + Alert.provider_type + ) + + alerts_data = db_session.exec( + select( + *fields + ).where( + col(Alert.id).in_(alert_ids), + ) + ).all() + + sources = [] + services = [] + + for (service, source) in alerts_data: + if source: + sources.append(source) + if service: + services.append(service) + + return { + "sources": set(sources), + "services": set(services), + "count": len(alerts_data) + } + + # Ensure that we have a session to execute the query. If not - make new one + if not session: + with Session(engine) as session: + return inner(session) + return inner(session) def add_alerts_to_incident_by_incident_id( - tenant_id: str, incident_id: str, alert_ids: List[UUID] + tenant_id: str, incident_id: str | UUID, alert_ids: List[UUID] ): with Session(engine) as session: incident = session.exec( @@ -2178,21 +2252,34 @@ def add_alerts_to_incident_by_incident_id( ) ).all() + new_alert_ids = [alert_id for alert_id in alert_ids + if alert_id not in existed_alert_ids] + + alerts_data_for_incident = get_alerts_data_for_incident(new_alert_ids, session) + + incident.sources = list( + set(incident.sources) | set(alerts_data_for_incident["sources"]) + ) + incident.affected_services = list( + set(incident.affected_services) | set(alerts_data_for_incident["services"]) + ) + incident.alerts_count += alerts_data_for_incident["count"] + alert_to_incident_entries = [ AlertToIncident( alert_id=alert_id, incident_id=incident.id, tenant_id=tenant_id ) - for alert_id in alert_ids - if alert_id not in existed_alert_ids + for alert_id in new_alert_ids ] session.bulk_save_objects(alert_to_incident_entries) + session.add(incident) session.commit() return True def remove_alerts_to_incident_by_incident_id( - tenant_id: str, incident_id: str, alert_ids: List[UUID] + tenant_id: str, incident_id: str | UUID, alert_ids: List[UUID] ) -> Optional[int]: with Session(engine) as session: incident = session.exec( @@ -2205,6 +2292,7 @@ def remove_alerts_to_incident_by_incident_id( if not incident: return None + # Removing alerts-to-incident relation for provided alerts_ids deleted = ( session.query(AlertToIncident) .where( @@ -2214,8 +2302,51 @@ def remove_alerts_to_incident_by_incident_id( ) .delete() ) + session.commit() + # Getting aggregated data for incidents for alerts which just was removed + alerts_data_for_incident = get_alerts_data_for_incident(alert_ids, session) + + service_field = get_json_extract_field(session, Alert.event, 'service') + + # checking if services of removed alerts are still presented in alerts + # which still assigned with the incident + services_existed = session.exec( + session.query(func.distinct(service_field)) + .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .filter( + AlertToIncident.incident_id == incident_id, + service_field.in_(alerts_data_for_incident["services"]) + ) + ).scalars() + + # checking if sources (providers) of removed alerts are still presented in alerts + # which still assigned with the incident + sources_existed = session.exec( + session.query(col(Alert.provider_type).distinct()) + .join(AlertToIncident, Alert.id == AlertToIncident.alert_id) + .filter( + AlertToIncident.incident_id == incident_id, + col(Alert.provider_type).in_(alerts_data_for_incident["sources"]) + ) + ).scalars() + + # Making lists of services and sources to remove from the incident + services_to_remove = [service for service in alerts_data_for_incident["services"] + if service not in services_existed] + sources_to_remove = [source for source in alerts_data_for_incident["sources"] + if source not in sources_existed] + + # filtering removed entities from affected services and sources in the incident + incident.affected_services = [service for service in incident.affected_services + if service not in services_to_remove] + incident.sources = [source for source in incident.sources + if source not in sources_to_remove] + + incident.alerts_count -= alerts_data_for_incident["count"] + session.add(incident) session.commit() + return deleted @@ -2388,4 +2519,46 @@ def update_incident_summary(incident_id: UUID, summary: str) -> Incident: session.commit() session.refresh(incident) - return incident \ No newline at end of file + return incident + +# Fetch all topology data +def get_all_topology_data( + tenant_id: str, + provider_id: Optional[str] = None, + service: Optional[str] = None, + environment: Optional[str] = None, +) -> List[TopologyServiceDtoOut]: + with Session(engine) as session: + query = select(TopologyService).where(TopologyService.tenant_id == tenant_id) + + # @tb: let's filter by service only for now and take care of it when we handle multilpe + # services and environments and cmdbs + # the idea is that we show the service topology regardless of the underlying provider/env + # if provider_id is not None and service is not None and environment is not None: + if service is not None: + query = query.where( + TopologyService.service == service, + # TopologyService.source_provider_id == provider_id, + # TopologyService.environment == environment, + ) + + service_instance = session.exec(query).first() + if not service_instance: + return [] + + services = session.exec( + select(TopologyServiceDependency) + .where( + TopologyServiceDependency.depends_on_service_id + == service_instance.id + ) + .options(joinedload(TopologyServiceDependency.service)) + ).all() + services = [service_instance, *[service.service for service in services]] + else: + # Fetch services for the tenant + services = session.exec(query).all() + + service_dtos = [TopologyServiceDtoOut.from_orm(service) for service in services] + + return service_dtos diff --git a/keep/api/core/db_utils.py b/keep/api/core/db_utils.py index c30207083..a7a5ebb91 100644 --- a/keep/api/core/db_utils.py +++ b/keep/api/core/db_utils.py @@ -11,6 +11,7 @@ import pymysql from dotenv import find_dotenv, load_dotenv from google.cloud.sql.connector import Connector +from sqlalchemy import func from sqlmodel import create_engine # This import is required to create the tables @@ -150,3 +151,13 @@ def create_db_engine(): echo=DB_ECHO, ) return engine + + +def get_json_extract_field(session, base_field, key): + + if session.bind.dialect.name == "postgresql": + return func.json_extract_path_text(base_field, key) + elif session.bind.dialect.name == "mysql": + return func.json_unquote(func.json_extract(base_field, '$.{}'.format(key))) + else: + return func.json_extract(base_field, '$.{}'.format(key)) diff --git a/keep/api/core/elastic.py b/keep/api/core/elastic.py index adf9e6f03..e6ff4a89c 100644 --- a/keep/api/core/elastic.py +++ b/keep/api/core/elastic.py @@ -198,9 +198,11 @@ def index_alert(self, alert: AlertDto): # change severity to number so we can sort by it alert.severity = AlertSeverity(alert.severity.lower()).order # query + alert_dict = alert.dict() + alert_dict["dismissed"] = bool(alert_dict["dismissed"]) self._client.index( index=self.alerts_index, - body=alert.dict(), + body=alert_dict, id=alert.fingerprint, # we want to update the alert if it already exists so that elastic will have the latest version refresh="true", ) diff --git a/keep/api/models/alert.py b/keep/api/models/alert.py index ac8bf662c..1b474ceaf 100644 --- a/keep/api/models/alert.py +++ b/keep/api/models/alert.py @@ -334,8 +334,8 @@ class UnEnrichAlertRequestBody(BaseModel): class IncidentDtoIn(BaseModel): name: str - user_summary: str assignee: str | None + user_summary: str | None class Config: extra = Extra.allow @@ -365,6 +365,8 @@ class IncidentDto(IncidentDtoIn): is_predicted: bool generated_summary: str | None + generated_summary: str | None + def __str__(self) -> str: # Convert the model instance to a dictionary model_dict = self.dict() @@ -382,15 +384,6 @@ class Config: @classmethod def from_db_incident(cls, db_incident): - alerts_dto = [AlertDto(**alert.event) for alert in db_incident.alerts] - - unique_sources_list = list( - set([source for alert_dto in alerts_dto for source in alert_dto.source]) - ) - unique_service_list = list( - set([alert.service for alert in alerts_dto if alert.service is not None]) - ) - return cls( id=db_incident.id, name=db_incident.name, @@ -400,9 +393,9 @@ def from_db_incident(cls, db_incident): creation_time=db_incident.creation_time, start_time=db_incident.start_time, end_time=db_incident.end_time, - number_of_alerts=len(db_incident.alerts), - alert_sources=unique_sources_list, + number_of_alerts=db_incident.alerts_count, + alert_sources=db_incident.sources, severity=IncidentSeverity.CRITICAL, assignee=db_incident.assignee, - services=unique_service_list, + services=db_incident.affected_services, ) diff --git a/keep/api/models/db/alert.py b/keep/api/models/db/alert.py index 2c0072ff6..08e037959 100644 --- a/keep/api/models/db/alert.py +++ b/keep/api/models/db/alert.py @@ -120,6 +120,10 @@ class Incident(SQLModel, table=True): is_predicted: bool = Field(default=False) is_confirmed: bool = Field(default=False) + alerts_count: int = Field(default=0) + affected_services: list = Field(sa_column=Column(JSON), default_factory=list) + sources: list = Field(sa_column=Column(JSON), default_factory=list) + def __init__(self, **kwargs): super().__init__(**kwargs) if "alerts" not in kwargs: diff --git a/keep/api/models/db/migrations/env.py b/keep/api/models/db/migrations/env.py index 8e38b59b2..61c58f9ab 100644 --- a/keep/api/models/db/migrations/env.py +++ b/keep/api/models/db/migrations/env.py @@ -16,6 +16,7 @@ from keep.api.models.db.provider import * from keep.api.models.db.rule import * from keep.api.models.db.tenant import * +from keep.api.models.db.topology import * from keep.api.models.db.user import * from keep.api.models.db.workflow import * from keep.api.models.db.statistics import * diff --git a/keep/api/models/db/migrations/versions/2024-07-25-17-13_67f1efb93c99.py b/keep/api/models/db/migrations/versions/2024-07-25-17-13_67f1efb93c99.py new file mode 100644 index 000000000..8ef5c61f8 --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-07-25-17-13_67f1efb93c99.py @@ -0,0 +1,58 @@ +"""Add fields for prepopulated data from alerts + +Revision ID: 67f1efb93c99 +Revises: dcbd2873dcfd +Create Date: 2024-07-25 17:13:04.428633 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.orm import Session, joinedload + +from keep.api.models.alert import AlertDto +from keep.api.models.db.alert import Incident + +# revision identifiers, used by Alembic. +revision = "67f1efb93c99" +down_revision = "dcbd2873dcfd" +branch_labels = None +depends_on = None + + +def populate_db(session): + + incidents = session.query(Incident).options(joinedload(Incident.alerts)).all() + + for incident in incidents: + alerts_dto = [AlertDto(**alert.event) for alert in incident.alerts] + + incident.sources = list( + set([source for alert_dto in alerts_dto for source in alert_dto.source]) + ) + incident.affected_services = list( + set([alert.service for alert in alerts_dto if alert.service is not None]) + ) + incident.alerts_count = len(incident.alerts) + session.add(incident) + session.commit() + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("incident", sa.Column("affected_services", sa.JSON(), nullable=True)) + op.add_column("incident", sa.Column("sources", sa.JSON(), nullable=True)) + op.add_column("incident", sa.Column("alerts_count", sa.Integer(), nullable=False, server_default="0")) + + session = Session(op.get_bind()) + populate_db(session) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("incident", "alerts_count") + op.drop_column("incident", "sources") + op.drop_column("incident", "affected_services") + # ### end Alembic commands ### diff --git a/keep/api/models/db/migrations/versions/2024-07-29-18-10_92f4f93f2140.py b/keep/api/models/db/migrations/versions/2024-07-29-18-10_92f4f93f2140.py new file mode 100644 index 000000000..9287bdc35 --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-07-29-18-10_92f4f93f2140.py @@ -0,0 +1,78 @@ +"""Topology Migrations + +Revision ID: 92f4f93f2140 +Revises: dcbd2873dcfd +Create Date: 2024-07-29 18:10:37.723465 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "92f4f93f2140" +down_revision = "dcbd2873dcfd" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "topologyservice", + sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("tags", sa.JSON(), nullable=True), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("(CURRENT_TIMESTAMP)"), + nullable=True, + ), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "source_provider_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("repository", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("service", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("environment", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("display_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("team", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("application", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("email", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("slack", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "topologyservicedependency", + sa.Column("service_id", sa.Integer(), nullable=True), + sa.Column("depends_on_service_id", sa.Integer(), nullable=True), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("(CURRENT_TIMESTAMP)"), + nullable=True, + ), + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("protocol", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.ForeignKeyConstraint( + ["depends_on_service_id"], ["topologyservice.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint( + ["service_id"], ["topologyservice.id"], ondelete="CASCADE" + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("topologyservicedependency") + op.drop_table("topologyservice") + # ### end Alembic commands ### diff --git a/keep/api/models/db/migrations/versions/2024-08-05-13-09_4147d9e706c0.py b/keep/api/models/db/migrations/versions/2024-08-05-13-09_4147d9e706c0.py new file mode 100644 index 000000000..e2bc19350 --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-08-05-13-09_4147d9e706c0.py @@ -0,0 +1,28 @@ +"""Provider last pull time + +Revision ID: 4147d9e706c0 +Revises: 92f4f93f2140 +Create Date: 2024-08-05 13:09:18.851721 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "4147d9e706c0" +down_revision = "92f4f93f2140" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("provider", sa.Column("last_pull_time", sa.DateTime(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("provider", "last_pull_time") + # ### end Alembic commands ### diff --git a/keep/api/models/db/migrations/versions/2024-08-08-13-55_42098785763c.py b/keep/api/models/db/migrations/versions/2024-08-08-13-55_42098785763c.py new file mode 100644 index 000000000..583ae21f2 --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-08-08-13-55_42098785763c.py @@ -0,0 +1,21 @@ +"""Merging 2 heads + +Revision ID: 42098785763c +Revises: 67f1efb93c99, 4147d9e706c0 +Create Date: 2024-08-08 13:55:55.191243 + +""" + +# revision identifiers, used by Alembic. +revision = "42098785763c" +down_revision = ("67f1efb93c99", "4147d9e706c0") +branch_labels = None +depends_on = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/keep/api/models/db/provider.py b/keep/api/models/db/provider.py index 3d3b53c5d..78ef3c68e 100644 --- a/keep/api/models/db/provider.py +++ b/keep/api/models/db/provider.py @@ -20,6 +20,7 @@ class Provider(SQLModel, table=True): sa_column=Column(JSON) ) # scope name is key and value is either True if validated or string with error message, e.g: {"read": True, "write": "error message"} consumer: bool = False + last_pull_time: Optional[datetime] class Config: orm_mode = True diff --git a/keep/api/models/db/topology.py b/keep/api/models/db/topology.py new file mode 100644 index 000000000..7b802edbd --- /dev/null +++ b/keep/api/models/db/topology.py @@ -0,0 +1,128 @@ +from datetime import datetime +from typing import List, Optional + +from pydantic import BaseModel +from sqlalchemy import DateTime, ForeignKey +from sqlmodel import JSON, Column, Field, Relationship, SQLModel, func + + +class TopologyService(SQLModel, table=True): + id: Optional[int] = Field(primary_key=True, default=None) + tenant_id: str = Field(sa_column=Column(ForeignKey("tenant.id"))) + source_provider_id: str = "unknown" + repository: Optional[str] + tags: Optional[List[str]] = Field(sa_column=Column(JSON)) + service: str + environment: str = Field(default="unknown") + display_name: str + description: Optional[str] + team: Optional[str] + application: Optional[str] + email: Optional[str] + slack: Optional[str] + updated_at: Optional[datetime] = Field( + sa_column=Column( + DateTime(timezone=True), + name="updated_at", + onupdate=func.now(), + server_default=func.now(), + ) + ) + + dependencies: List["TopologyServiceDependency"] = Relationship( + back_populates="service", + sa_relationship_kwargs={ + "foreign_keys": "[TopologyServiceDependency.service_id]" + }, + ) + + class Config: + orm_mode = True + unique_together = ["tenant_id", "service", "environment", "source_provider_id"] + + +class TopologyServiceDependency(SQLModel, table=True): + id: Optional[int] = Field(primary_key=True, default=None) + service_id: int = Field( + sa_column=Column(ForeignKey("topologyservice.id", ondelete="CASCADE")) + ) + depends_on_service_id: int = Field( + sa_column=Column(ForeignKey("topologyservice.id", ondelete="CASCADE")) + ) # service_id calls deponds_on_service_id (A->B) + protocol: Optional[str] = "unknown" + updated_at: Optional[datetime] = Field( + sa_column=Column( + DateTime(timezone=True), + name="updated_at", + onupdate=func.now(), + server_default=func.now(), + ) + ) + + service: TopologyService = Relationship( + back_populates="dependencies", + sa_relationship_kwargs={ + "foreign_keys": "[TopologyServiceDependency.service_id]" + }, + ) + dependent_service: TopologyService = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[TopologyServiceDependency.depends_on_service_id]" + } + ) + + +class TopologyServiceDtoBase(BaseModel, extra="ignore"): + source_provider_id: Optional[str] + repository: Optional[str] = None + tags: Optional[List[str]] = None + service: str + display_name: str + environment: str = "unknown" + description: Optional[str] = None + team: Optional[str] = None + application: Optional[str] = None + email: Optional[str] = None + slack: Optional[str] = None + + +class TopologyServiceInDto(TopologyServiceDtoBase): + dependencies: dict[str, str] = {} # dict of service it depends on : protocol + + +class TopologyServiceDependencyDto(BaseModel, extra="ignore"): + serviceId: int + serviceName: str + protocol: Optional[str] = "unknown" + + +class TopologyServiceDtoOut(TopologyServiceDtoBase): + id: int + dependencies: List[TopologyServiceDependencyDto] + updated_at: Optional[datetime] + + @classmethod + def from_orm(cls, service: "TopologyService") -> "TopologyServiceDtoOut": + return cls( + id=service.id, + source_provider_id=service.source_provider_id, + repository=service.repository, + tags=service.tags, + service=service.service, + display_name=service.display_name, + environment=service.environment, + description=service.description, + team=service.team, + application=service.application, + email=service.email, + slack=service.slack, + dependencies=[ + TopologyServiceDependencyDto( + serviceId=dep.depends_on_service_id, + protocol=dep.protocol, + serviceName=dep.dependent_service.service, + ) + for dep in service.dependencies + ], + updated_at=service.updated_at, + ) diff --git a/keep/api/models/provider.py b/keep/api/models/provider.py index a871d873f..78df4eb62 100644 --- a/keep/api/models/provider.py +++ b/keep/api/models/provider.py @@ -37,7 +37,10 @@ class Provider(BaseModel): methods: list[ProviderMethod] = [] installed_by: str | None = None installation_time: datetime | None = None + last_pull_time: datetime | None = None docs: str | None = None - tags: list[Literal["alert", "ticketing", "messaging", "data", "queue"]] = [] + tags: list[ + Literal["alert", "ticketing", "messaging", "data", "queue", "topology"] + ] = [] alertsDistribution: dict[str, int] | None = None alertExample: dict | None = None diff --git a/keep/api/routes/incidents.py b/keep/api/routes/incidents.py index 07db9073b..3120cb085 100644 --- a/keep/api/routes/incidents.py +++ b/keep/api/routes/incidents.py @@ -29,8 +29,8 @@ ) from keep.api.models.alert import AlertDto, IncidentDto, IncidentDtoIn from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts -from keep.api.utils.pagination import IncidentsPaginatedResultsDto from keep.api.utils.import_ee import mine_incidents_and_create_objects +from keep.api.utils.pagination import IncidentsPaginatedResultsDto, AlertPaginatedResultsDto router = APIRouter() logger = logging.getLogger(__name__) @@ -230,8 +230,10 @@ def delete_incident( ) def get_incident_alerts( incident_id: str, + limit: int = 25, + offset: int = 0, authenticated_entity: AuthenticatedEntity = Depends(AuthVerifier(["read:alert"])), -) -> List[AlertDto]: +) -> AlertPaginatedResultsDto: tenant_id = authenticated_entity.tenant_id logger.info( "Fetching incident", @@ -251,7 +253,12 @@ def get_incident_alerts( "tenant_id": tenant_id, }, ) - db_alerts = get_incident_alerts_by_incident_id(tenant_id, incident_id) + db_alerts, total_count = get_incident_alerts_by_incident_id( + tenant_id=tenant_id, + incident_id=incident_id, + limit=limit, + offset=offset, + ) enriched_alerts_dto = convert_db_alerts_to_dto_alerts(db_alerts) logger.info( @@ -261,7 +268,7 @@ def get_incident_alerts( }, ) - return enriched_alerts_dto + return AlertPaginatedResultsDto(limit=limit, offset=offset, count=total_count, items=enriched_alerts_dto) @router.post( diff --git a/keep/api/routes/preset.py b/keep/api/routes/preset.py index d12d3932f..e5fddf438 100644 --- a/keep/api/routes/preset.py +++ b/keep/api/routes/preset.py @@ -1,5 +1,7 @@ import logging +import os import uuid +from datetime import datetime from fastapi import ( APIRouter, @@ -12,15 +14,21 @@ from pydantic import BaseModel from sqlmodel import Session, select -from keep.api.consts import STATIC_PRESETS +from keep.api.consts import PROVIDER_PULL_INTERVAL_DAYS, STATIC_PRESETS from keep.api.core.db import get_preset_by_name as get_preset_by_name_db from keep.api.core.db import get_presets as get_presets_db -from keep.api.core.db import get_session, update_preset_options +from keep.api.core.db import ( + get_session, + update_preset_options, + update_provider_last_pull_time, +) from keep.api.core.dependencies import AuthenticatedEntity, AuthVerifier from keep.api.models.alert import AlertDto from keep.api.models.db.preset import Preset, PresetDto, PresetOption from keep.api.tasks.process_event_task import process_event +from keep.api.tasks.process_topology_task import process_topology from keep.contextmanager.contextmanager import ContextManager +from keep.providers.base.base_provider import BaseTopologyProvider from keep.providers.providers_factory import ProvidersFactory from keep.searchengine.searchengine import SearchEngine @@ -30,7 +38,7 @@ # SHAHAR: this function runs as background tasks as a seperate thread # DO NOT ADD async HERE as it will run in the main thread and block the whole server -def pull_alerts_from_providers( +def pull_data_from_providers( tenant_id: str, trace_id: str, ) -> list[AlertDto]: @@ -39,29 +47,72 @@ def pull_alerts_from_providers( "Get or create logics". """ + if os.environ.get("KEEP_PULL_DATA_ENABLED", "true") != "true": + logger.debug("Pull data from providers is disabled") + return + context_manager = ContextManager( tenant_id=tenant_id, workflow_id=None, ) for provider in ProvidersFactory.get_installed_providers(tenant_id=tenant_id): + extra = { + "provider_type": provider.type, + "provider_id": provider.id, + "tenant_id": tenant_id, + } + + if provider.last_pull_time is not None: + now = datetime.now() + days_passed = (now - provider.last_pull_time).days + if days_passed <= PROVIDER_PULL_INTERVAL_DAYS: + logger.info( + "Skipping provider data pulling since not enough time has passed", + extra={ + **extra, + "days_passed": days_passed, + "provider_last_pull_time": str(provider.last_pull_time), + }, + ) + continue + provider_class = ProvidersFactory.get_provider( context_manager=context_manager, provider_id=provider.id, provider_type=provider.type, provider_config=provider.details, ) + logger.info( f"Pulling alerts from provider {provider.type} ({provider.id})", - extra={ - "provider_type": provider.type, - "provider_id": provider.id, - "tenant_id": tenant_id, - }, + extra=extra, ) sorted_provider_alerts_by_fingerprint = ( provider_class.get_alerts_by_fingerprint(tenant_id=tenant_id) ) + + try: + if isinstance(provider_class, BaseTopologyProvider): + logger.info("Getting topology data", extra=extra) + topology_data = provider_class.pull_topology() + logger.info("Got topology data, processing", extra=extra) + process_topology(tenant_id, topology_data, provider.id) + logger.info("Processed topology data", extra=extra) + except NotImplementedError: + logger.warning( + f"Provider {provider.type} ({provider.id}) does not support topology data", + extra=extra, + ) + except Exception: + logger.error( + f"Unknown error pulling topology from provider {provider.type} ({provider.id})", + extra=extra, + ) + + # Even if we failed at processing some event, lets save the last pull time to not iterate this process over and over again. + update_provider_last_pull_time(tenant_id=tenant_id, provider_id=provider.id) + for fingerprint, alert in sorted_provider_alerts_by_fingerprint.items(): process_event( {}, @@ -215,7 +266,7 @@ async def get_preset_alerts( # In the worst case, gathered alerts will be pulled in the next request. bg_tasks.add_task( - pull_alerts_from_providers, + pull_data_from_providers, authenticated_entity.tenant_id, request.state.trace_id, ) diff --git a/keep/api/routes/providers.py b/keep/api/routes/providers.py index 30abdb5f2..8618d878c 100644 --- a/keep/api/routes/providers.py +++ b/keep/api/routes/providers.py @@ -13,7 +13,7 @@ from starlette.datastructures import UploadFile from keep.api.core.config import config -from keep.api.core.db import get_provider_distribution, get_session, count_alerts +from keep.api.core.db import count_alerts, get_provider_distribution, get_session from keep.api.core.dependencies import AuthenticatedEntity, AuthVerifier from keep.api.models.db.provider import Provider from keep.api.models.provider import ProviderAlertsCountResponseDTO @@ -673,9 +673,6 @@ async def install_provider_oauth2( ): tenant_id = authenticated_entity.tenant_id installed_by = authenticated_entity.email - # Extract parameters from the provider_info dictionary - provider_name = f"{provider_type}-oauth2" - provider_unique_id = uuid.uuid4().hex logger.info( "Installing provider", @@ -688,6 +685,10 @@ async def install_provider_oauth2( try: provider_class = ProvidersFactory.get_provider_class(provider_type) provider_info = provider_class.oauth2_logic(**provider_info) + provider_name = provider_info.pop( + "provider_name", f"{provider_unique_id}-oauth2" + ) + provider_name = provider_name.lower().replace(" ", "").replace("_", "-") provider_config = { "authentication": provider_info, "name": provider_name, diff --git a/keep/api/routes/topology.py b/keep/api/routes/topology.py new file mode 100644 index 000000000..a8c0e56bd --- /dev/null +++ b/keep/api/routes/topology.py @@ -0,0 +1,51 @@ +import logging +from typing import List, Optional + +from fastapi import APIRouter, Depends, HTTPException + +from keep.api.core.db import ( # Assuming this function exists to fetch topology data + get_all_topology_data, +) +from keep.api.core.dependencies import AuthenticatedEntity, AuthVerifier +from keep.api.models.db.topology import TopologyServiceDtoOut + +logger = logging.getLogger(__name__) +router = APIRouter() + + +# GET all topology data +@router.get( + "", description="Get all topology data", response_model=List[TopologyServiceDtoOut] +) +def get_topology_data( + provider_id: Optional[str] = None, + service_id: Optional[str] = None, + environment: Optional[str] = None, + authenticated_entity: AuthenticatedEntity = Depends( + AuthVerifier(["read:topology"]) + ), +) -> List[TopologyServiceDtoOut]: + tenant_id = authenticated_entity.tenant_id + logger.info("Getting topology data", extra={tenant_id: tenant_id}) + + # @tb: althought we expect all, we just take service_id for now. + # Checkout the `get_all_topology_data` function in db.py for more details + # if ( + # provider_id is not None or service_id is not None or environment is not None + # ) and not (provider_id and service_id and environment): + # raise HTTPException( + # status_code=400, + # detail="If any of provider_id, service_id, or environment are provided, all must be provided.", + # ) + + try: + topology_data = get_all_topology_data( + tenant_id, provider_id, service_id, environment + ) + return topology_data + except Exception: + logger.exception("Failed to get topology data") + raise HTTPException( + status_code=400, + detail="Unknown exception when getting topology data, please contact us", + ) diff --git a/keep/api/tasks/process_topology_task.py b/keep/api/tasks/process_topology_task.py new file mode 100644 index 000000000..37788cbb0 --- /dev/null +++ b/keep/api/tasks/process_topology_task.py @@ -0,0 +1,93 @@ +import copy +import logging + +from keep.api.core.db import get_session_sync +from keep.api.models.db.topology import ( + TopologyService, + TopologyServiceDependency, + TopologyServiceInDto, +) + +logger = logging.getLogger(__name__) + +TIMES_TO_RETRY_JOB = 5 # the number of times to retry the job in case of failure + + +def process_topology( + tenant_id: str, topology_data: list[TopologyServiceInDto], provider_id: str +): + extra = {"provider_id": provider_id, "tenant_id": tenant_id} + if not topology_data: + logger.info( + "No topology data to process", + extra=extra, + ) + return + + logger.info("Processing topology data", extra=extra) + session = get_session_sync() + + try: + logger.info( + "Deleting existing topology data", + extra=extra, + ) + session.query(TopologyService).filter( + TopologyService.source_provider_id == provider_id, + TopologyService.tenant_id == tenant_id, + ).delete() + session.commit() + logger.info( + "Deleted existing topology data", + extra=extra, + ) + except Exception: + logger.exception( + "Failed to delete existing topology data", + extra=extra, + ) + raise + + logger.info( + "Creating new topology data", + extra={"provider_id": provider_id, "tenant_id": tenant_id}, + ) + service_to_keep_service_id_map = {} + # First create the services so we have ids + for service in topology_data: + service_copy = copy.deepcopy(service.dict()) + service_copy.pop("dependencies") + db_service = TopologyService(**service_copy, tenant_id=tenant_id) + session.add(db_service) + session.flush() + service_to_keep_service_id_map[service.service] = db_service.id + + # Then create the dependencies + for service in topology_data: + for dependency in service.dependencies: + session.add( + TopologyServiceDependency( + service_id=service_to_keep_service_id_map[service.service], + depends_on_service_id=service_to_keep_service_id_map[dependency], + protocol=service.dependencies[dependency], + ) + ) + + session.commit() + + try: + session.close() + except Exception as e: + logger.warning( + "Failed to close session", + extra={**extra, "error": str(e)}, + ) + + logger.info( + "Created new topology data", + extra=extra, + ) + + +async def async_process_topology(*args, **kwargs): + return process_topology(*args, **kwargs) diff --git a/keep/api/utils/pagination.py b/keep/api/utils/pagination.py index cac63e905..ba33c5256 100644 --- a/keep/api/utils/pagination.py +++ b/keep/api/utils/pagination.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -from keep.api.models.alert import IncidentDto +from keep.api.models.alert import IncidentDto, AlertDto class PaginatedResultsDto(BaseModel): @@ -14,3 +14,7 @@ class PaginatedResultsDto(BaseModel): class IncidentsPaginatedResultsDto(PaginatedResultsDto): items: list[IncidentDto] + + +class AlertPaginatedResultsDto(PaginatedResultsDto): + items: list[AlertDto] diff --git a/keep/providers/base/base_provider.py b/keep/providers/base/base_provider.py index 5d0de6661..c004fc475 100644 --- a/keep/providers/base/base_provider.py +++ b/keep/providers/base/base_provider.py @@ -22,6 +22,7 @@ from keep.api.core.db import get_enrichments from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.api.models.db.alert import AlertActionType +from keep.api.models.db.topology import TopologyServiceInDto from keep.api.utils.enrichment_helpers import parse_and_enrich_deleted_and_assignees from keep.contextmanager.contextmanager import ContextManager from keep.providers.models.provider_config import ProviderConfig, ProviderScope @@ -135,6 +136,7 @@ def _enrich_alert(self, enrichments, results): """ self.logger.debug("Extracting the fingerprint from the alert") + event = None if "fingerprint" in results: fingerprint = results["fingerprint"] elif self.context_manager.foreach_context.get("value", {}): @@ -144,11 +146,13 @@ def _enrich_alert(self, enrichments, results): if isinstance(foreach_context, tuple): # This is when we are in a foreach context that is zipped foreach_context: dict = foreach_context[0] + event = foreach_context fingerprint = foreach_context.get("fingerprint") # else, if we are in an event context, use the event fingerprint elif self.context_manager.event_context: # TODO: map all casses event_context is dict and update them to the DTO # and remove this if statement + event = self.context_manager.event_context if isinstance(self.context_manager.event_context, dict): fingerprint = self.context_manager.event_context.get("fingerprint") # Alert DTO @@ -169,15 +173,20 @@ def _enrich_alert(self, enrichments, results): # enrich only the requested fields for enrichment in enrichments: try: - if enrichment["value"].startswith("results."): + value = enrichment["value"] + if value.startswith("results."): val = enrichment["value"].replace("results.", "") parts = val.split(".") r = copy.copy(results) for part in parts: r = r[part] - _enrichments[enrichment["key"]] = r - else: - _enrichments[enrichment["key"]] = enrichment["value"] + value = r + _enrichments[enrichment["key"]] = value + if event is not None: + if isinstance(event, dict): + event[enrichment["key"]] = value + else: + setattr(event, enrichment["key"], value) except Exception: self.logger.error( f"Failed to enrich alert - enrichment: {enrichment}", @@ -583,3 +592,8 @@ def simulate_alert(cls) -> dict: simulated_alert = alert_data["payload"].copy() return simulated_alert + + +class BaseTopologyProvider(BaseProvider): + def pull_topology(self) -> list[TopologyServiceInDto]: + raise NotImplementedError("get_topology() method not implemented") diff --git a/keep/providers/datadog_provider/datadog_provider.py b/keep/providers/datadog_provider/datadog_provider.py index b97478555..f0232fdf6 100644 --- a/keep/providers/datadog_provider/datadog_provider.py +++ b/keep/providers/datadog_provider/datadog_provider.py @@ -1,9 +1,11 @@ """ Datadog Provider is a class that allows to ingest/digest data from Datadog. """ + import dataclasses import datetime import json +import logging import os import time from typing import Optional @@ -26,10 +28,12 @@ from datadog_api_client.v1.model.monitor_options import MonitorOptions from datadog_api_client.v1.model.monitor_thresholds import MonitorThresholds from datadog_api_client.v1.model.monitor_type import MonitorType +from datadog_api_client.v2.api.service_definition_api import ServiceDefinitionApi from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus +from keep.api.models.db.topology import TopologyServiceInDto from keep.contextmanager.contextmanager import ContextManager -from keep.providers.base.base_provider import BaseProvider +from keep.providers.base.base_provider import BaseTopologyProvider from keep.providers.base.provider_exceptions import GetAlertException from keep.providers.datadog_provider.datadog_alert_format_description import ( DatadogAlertFormatDescription, @@ -38,6 +42,8 @@ from keep.providers.models.provider_method import ProviderMethod from keep.providers.providers_factory import ProvidersFactory +logger = logging.getLogger(__name__) + @pydantic.dataclasses.dataclass class DatadogProviderAuthConfig: @@ -74,6 +80,15 @@ class DatadogProviderAuthConfig: }, default="https://api.datadoghq.com", ) + environment: str = dataclasses.field( + metadata={ + "required": False, + "description": "Topology environment name", + "sensitive": False, + "hint": "Defaults to *", + }, + default="*", + ) oauth_token: dict = dataclasses.field( metadata={ "description": "For OAuth flow", @@ -85,7 +100,7 @@ class DatadogProviderAuthConfig: ) -class DatadogProvider(BaseProvider): +class DatadogProvider(BaseTopologyProvider): """Pull/push alerts from Datadog.""" OAUTH2_URL = os.environ.get("DATADOG_OAUTH2_URL") @@ -133,6 +148,18 @@ class DatadogProvider(BaseProvider): mandatory=False, alias="Logs Read Data", ), + ProviderScope( + name="apm_read", + description="Read APM data for Topology creation.", + mandatory=False, + alias="Read APM Data", + ), + ProviderScope( + name="apm_service_catalog_read", + description="Read APM service catalog for Topology creation.", + mandatory=False, + alias="Read APM service catalog Data", + ), ] PROVIDER_METHODS = [ ProviderMethod( @@ -201,12 +228,12 @@ def __init__( super().__init__(context_manager, provider_id, config) self.configuration = Configuration(request_timeout=5) if self.authentication_config.api_key and self.authentication_config.app_key: - self.configuration.api_key[ - "apiKeyAuth" - ] = self.authentication_config.api_key - self.configuration.api_key[ - "appKeyAuth" - ] = self.authentication_config.app_key + self.configuration.api_key["apiKeyAuth"] = ( + self.authentication_config.api_key + ) + self.configuration.api_key["appKeyAuth"] = ( + self.authentication_config.app_key + ) domain = self.authentication_config.domain or "https://api.datadoghq.com" self.configuration.host = domain elif self.authentication_config.oauth_token: @@ -492,6 +519,16 @@ def validate_scopes(self): api.list_events( start=int(start.timestamp()), end=int(end.timestamp()) ) + elif scope.name == "apm_read": + api_instance = ServiceDefinitionApi(api_client) + api_instance.list_service_definitions(schema_version="v1") + elif scope.name == "apm_service_catalog_read": + endpoint = self.__get_service_deps_endpoint(api_client) + epoch_time_one_year_ago = self.__get_epoch_one_year_ago() + endpoint.call_with_http_info( + env=self.authentication_config.environment, + start=str(epoch_time_one_year_ago), + ) except ApiException as e: # API failed and it means we're probably lacking some permissions # perhaps we should check if status code is 403 and otherwise mark as valid? @@ -627,9 +664,11 @@ def _get_alerts(self) -> list[AlertDto]: tags=tags, environment=tags.get("environment", "undefined"), service=tags.get("service"), - created_by=monitor.creator.email - if monitor and monitor.creator - else None, + created_by=( + monitor.creator.email + if monitor and monitor.creator + else None + ), ) alert.fingerprint = self.get_alert_fingerprint( alert, self.fingerprint_fields @@ -754,7 +793,15 @@ def _format_alert( ) -> AlertDto: tags_list = event.get("tags", "").split(",") tags_list.remove("monitor") - tags = {k: v for k, v in map(lambda tag: tag.split(":"), tags_list)} + + try: + tags = {k: v for k, v in map(lambda tag: tag.split(":"), tags_list)} + except Exception as e: + logger.error( + "Failed to parse tags", extra={"error": str(e), "tags": tags_list} + ) + tags = {} + event_time = datetime.datetime.fromtimestamp( int(event.get("last_updated")) / 1000, tz=datetime.timezone.utc ) @@ -766,6 +813,7 @@ def _format_alert( severity = DatadogProvider.SEVERITIES_MAP.get( event.get("severity"), AlertSeverity.INFO ) + service = tags.get("service") url = event.pop("url", None) @@ -785,6 +833,7 @@ def _format_alert( message=event.get("body"), groups=groups, severity=severity, + service=service, url=url, tags=tags, monitor_id=event.get("monitor_id"), @@ -820,6 +869,47 @@ def get_logs(self, limit: int = 5) -> list: def get_alert_schema(): return DatadogAlertFormatDescription.schema() + @staticmethod + def __get_epoch_one_year_ago() -> int: + # Get the current time + current_time = datetime.datetime.now() + + # Calculate the time one year ago + one_year_ago = current_time - datetime.timedelta(days=365) + + # Convert the time one year ago to epoch time + return int(time.mktime(one_year_ago.timetuple())) + + @staticmethod + def __get_service_deps_endpoint(api_client) -> Endpoint: + return Endpoint( + settings={ + "auth": ["apiKeyAuth", "appKeyAuth", "AuthZ"], + "endpoint_path": "/api/v1/service_dependencies", + "response_type": (dict,), + "http_method": "GET", + "operation_id": "get_service_dependencies", + "version": "v1", + }, + params_map={ + "start": { + "openapi_types": (str,), + "attribute": "start", + "location": "query", + }, + "env": { + "openapi_types": (str,), + "attribute": "env", + "location": "query", + }, + }, + headers_map={ + "accept": ["application/json"], + "content_type": ["application/json"], + }, + api_client=api_client, + ) + @classmethod def simulate_alert(cls) -> dict: # Choose a random alert type @@ -854,6 +944,55 @@ def simulate_alert(cls) -> dict: ).hexdigest() return simulated_alert + def pull_topology(self) -> list[TopologyServiceInDto]: + services = {} + with ApiClient(self.configuration) as api_client: + api_instance = ServiceDefinitionApi(api_client) + service_definitions = api_instance.list_service_definitions( + schema_version="v1" + ) + epoch_time_one_year_ago = self.__get_epoch_one_year_ago() + endpoint = self.__get_service_deps_endpoint(api_client) + service_dependencies = endpoint.call_with_http_info( + env=self.authentication_config.environment, + start=str(epoch_time_one_year_ago), + ) + + # Parse data + environment = self.authentication_config.environment + if environment == "*": + environment = "unknown" + for service_definition in service_definitions.data: + name = service_definition.attributes.schema.info.dd_service + services[name] = TopologyServiceInDto( + source_provider_id=self.provider_id, + repository=service_definition.attributes.schema.integrations.github, + tags=service_definition.attributes.schema.tags, + service=name, + display_name=service_definition.attributes.schema.info.display_name, + environment=environment, + description=service_definition.attributes.schema.info.description, + team=service_definition.attributes.schema.org.team, + application=service_definition.attributes.schema.org.application, + email=service_definition.attributes.schema.contact.email, + slack=service_definition.attributes.schema.contact.slack, + ) + for service_dep in service_dependencies: + service = services.get(service_dep) + if not service: + service = TopologyServiceInDto( + source_provider_id=self.provider_id, + service=service_dep, + display_name=service_dep, + environment=environment, + ) + dependencies = service_dependencies[service_dep].get("calls", []) + service.dependencies = { + dependency: "unknown" for dependency in dependencies + } + services[service_dep] = service + return list(services.values()) + if __name__ == "__main__": # Output debug messages @@ -873,11 +1012,11 @@ def simulate_alert(cls) -> dict: provider_config = { "authentication": {"api_key": api_key, "app_key": app_key}, } - provider = ProvidersFactory.get_provider( + provider: BaseTopologyProvider = ProvidersFactory.get_provider( context_manager=context_manager, provider_id="datadog-keephq", provider_type="datadog", provider_config=provider_config, ) - result = provider._get_alerts() + result = provider.pull_topology() print(result) diff --git a/keep/providers/datadog_provider/topology_mock.py b/keep/providers/datadog_provider/topology_mock.py new file mode 100644 index 000000000..8ec5e443c --- /dev/null +++ b/keep/providers/datadog_provider/topology_mock.py @@ -0,0 +1,40 @@ +import json + +from keep.api.models.db.topology import TopologyServiceInDto +from keep.api.tasks.process_topology_task import process_topology + +if __name__ == "__main__": + services = {} + environment = "production" + with open("/tmp/service_definitions.json", "r") as file: + service_definitions = json.load(file) + with open("/tmp/service_dependencies.json", "r") as file: + service_dependencies = json.load(file) + for service_definition in service_definitions["data"]: + name = service_definition["attributes"]["schema"].get("dd-service") + services[name] = TopologyServiceInDto( + source_provider_id="datadog", + repository=service_definition["attributes"]["schema"]["integrations"].get( + "github" + ), + tags=service_definition["attributes"]["schema"].get("tags"), + service=name, + display_name=name, + environment=environment, + ) + for service_dep in service_dependencies: + service = services.get(service_dep) + if not service: + service = TopologyServiceInDto( + source_provider_id="datadog", + service=service_dep, + display_name=service_dep, + environment=environment, + ) + dependencies = service_dependencies[service_dep].get("calls", []) + service.dependencies = {dependency: "unknown" for dependency in dependencies} + services[service_dep] = service + topology_data = list(services.values()) + print(topology_data) + + process_topology("keep", topology_data, "datadog") diff --git a/keep/providers/openobserve_provider/openobserve_provider.py b/keep/providers/openobserve_provider/openobserve_provider.py index 7eefa6a81..ac264b88d 100644 --- a/keep/providers/openobserve_provider/openobserve_provider.py +++ b/keep/providers/openobserve_provider/openobserve_provider.py @@ -419,6 +419,35 @@ def _format_alert( "Calculating fingerprint fields", extra={"fingerprint_fields": fingerprint_fields}, ) + + # sort the fields to ensure the fingerprint is consistent + # for e.g. host1, host2 is the same as host2, host1 + for field in fingerprint_fields: + try: + field_attr = getattr(alert_dto, field) + if "," not in field_attr: + continue + # sort it lexographically + logger.info( + "Sorting field attributes", + extra={"field": field, "field_attr": field_attr}, + ) + sorted_field_attr = sorted(field_attr.replace(" ", "").split(",")) + sorted_field_attr = ", ".join(sorted_field_attr) + logger.info( + "Sorted field attributes", + extra={"field": field, "sorted_field_attr": sorted_field_attr}, + ) + # set the attr + setattr(alert_dto, field, sorted_field_attr) + except AttributeError: + pass + except Exception as e: + logger.error( + "Error while sorting field attributes", + extra={"field": field, "error": e}, + ) + alert_dto.fingerprint = OpenobserveProvider.get_alert_fingerprint( alert_dto, fingerprint_fields=fingerprint_fields ) diff --git a/keep/providers/prometheus_provider/alerts_mock.py b/keep/providers/prometheus_provider/alerts_mock.py index c3c06af0f..1287f1a68 100644 --- a/keep/providers/prometheus_provider/alerts_mock.py +++ b/keep/providers/prometheus_provider/alerts_mock.py @@ -11,6 +11,7 @@ }, "parameters": { "labels.host": ["host1", "host2", "host3"], + "labels.service": ["calendar-producer-java-otel-api-dd", "kafka"], "labels.instance": ["instance1", "instance2", "instance3"], }, }, @@ -23,6 +24,7 @@ }, "parameters": { "labels.queue": ["queue1", "queue2", "queue3"], + "labels.service": ["calendar-producer-java-otel-api-dd", "kafka"], "labels.mq_manager": ["mq_manager1", "mq_manager2", "mq_manager3"], }, }, @@ -35,6 +37,7 @@ }, "parameters": { "labels.host": ["host1", "host2", "host3"], + "labels.service": ["calendar-producer-java-otel-api-dd", "kafka"], "labels.instance": ["instance1", "instance2", "instance3"], }, }, @@ -47,6 +50,7 @@ }, "parameters": { "labels.host": ["host1", "host2", "host3"], + "labels.service": ["calendar-producer-java-otel-api-dd", "kafka"], "labels.instance": ["instance1", "instance2", "instance3"], }, }, diff --git a/keep/providers/prometheus_provider/prometheus_provider.py b/keep/providers/prometheus_provider/prometheus_provider.py index 21b3201b3..7f4ecf0c8 100644 --- a/keep/providers/prometheus_provider/prometheus_provider.py +++ b/keep/providers/prometheus_provider/prometheus_provider.py @@ -160,7 +160,8 @@ def get_status(event: dict) -> AlertStatus: @staticmethod def _format_alert( - event: dict | list[AlertDto], provider_instance: Optional["PrometheusProvider"] = None + event: dict | list[AlertDto], + provider_instance: Optional["PrometheusProvider"] = None, ) -> list[AlertDto]: # TODO: need to support more than 1 alert per event alert_dtos = [] @@ -179,6 +180,7 @@ def _format_alert( annotations = { k.lower(): v for k, v in alert.pop("annotations", {}).items() } + service = labels.get("service", annotations.get("service", None)) # map severity and status to keep's format status = alert.pop("state", None) or alert.pop("status", None) status = PrometheusProvider.STATUS_MAP.get(status, AlertStatus.FIRING) @@ -190,6 +192,7 @@ def _format_alert( name=alert_id, description=description, status=status, + service=service, lastReceived=datetime.datetime.now( tz=datetime.timezone.utc ).isoformat(), diff --git a/keep/providers/providers_factory.py b/keep/providers/providers_factory.py index 29a60170e..bb80d4ca9 100644 --- a/keep/providers/providers_factory.py +++ b/keep/providers/providers_factory.py @@ -21,7 +21,7 @@ ) from keep.api.models.provider import Provider from keep.contextmanager.contextmanager import ContextManager -from keep.providers.base.base_provider import BaseProvider +from keep.providers.base.base_provider import BaseProvider, BaseTopologyProvider from keep.providers.models.provider_config import ProviderConfig from keep.providers.models.provider_method import ProviderMethodDTO, ProviderMethodParam from keep.secretmanager.secretmanagerfactory import SecretManagerFactory @@ -224,7 +224,7 @@ def get_all_providers() -> list[Provider]: logger = logging.getLogger(__name__) # use the cache if exists if ProvidersFactory._loaded_providers_cache: - logger.info("Using cached providers") + logger.debug("Using cached providers") return ProvidersFactory._loaded_providers_cache logger.info("Loading providers") @@ -312,9 +312,12 @@ def get_all_providers() -> list[Provider]: ) oauth2_url = provider_class.__dict__.get("OAUTH2_URL") docs = provider_class.__doc__ + can_fetch_topology = issubclass(provider_class, BaseTopologyProvider) provider_tags = [] provider_tags.extend(provider_class.PROVIDER_TAGS) + if can_fetch_topology: + provider_tags.append("topology") if can_query and "data" not in provider_tags: provider_tags.append("data") if ( @@ -397,6 +400,7 @@ def get_installed_providers( provider_copy.id = p.id provider_copy.installed_by = p.installed_by provider_copy.installation_time = p.installation_time + provider_copy.last_pull_time = p.last_pull_time try: provider_auth = {"name": p.name} if include_details: diff --git a/keep/providers/slack_provider/slack_provider.py b/keep/providers/slack_provider/slack_provider.py index 1684a7519..87d2ae608 100644 --- a/keep/providers/slack_provider/slack_provider.py +++ b/keep/providers/slack_provider/slack_provider.py @@ -68,7 +68,7 @@ def dispose(self): pass @staticmethod - def oauth2_logic(**payload): + def oauth2_logic(**payload) -> dict: """ Logic for handling oauth2 callback. @@ -95,7 +95,13 @@ def oauth2_logic(**payload): raise Exception( response_json.get("error"), ) - return {"access_token": response_json.get("access_token")} + new_provider_info = {"access_token": response_json.get("access_token")} + + team_name = response_json.get("team", {}).get("name") + if team_name: + new_provider_info["provider_name"] = team_name + + return new_provider_info def _notify(self, message="", blocks=[], channel="", **kwargs: dict): """ diff --git a/keep/workflowmanager/workflowmanager.py b/keep/workflowmanager/workflowmanager.py index e64f5671f..82502aa72 100644 --- a/keep/workflowmanager/workflowmanager.py +++ b/keep/workflowmanager/workflowmanager.py @@ -48,7 +48,7 @@ def stop(self): def _apply_filter(self, filter_val, value): # if it's a regex, apply it - if filter_val.startswith('r"'): + if isinstance(filter_val, str) and filter_val.startswith('r"'): try: # remove the r" and the last " pattern = re.compile(filter_val[2:-1]) @@ -60,6 +60,9 @@ def _apply_filter(self, filter_val, value): ) return False else: + # For cases like `dismissed` + if isinstance(filter_val, bool) and isinstance(value, str): + return value == str(filter_val) return value == filter_val def insert_events(self, tenant_id, events: typing.List[AlertDto]): @@ -104,9 +107,15 @@ def insert_events(self, tenant_id, events: typing.List[AlertDto]): filter_key = filter.get("key") filter_val = filter.get("value") event_val = self._get_event_value(event, filter_key) - if not event_val: + if event_val is None: self.logger.warning( - "Failed to run filter, skipping the event. Probably misconfigured workflow." + "Failed to run filter, skipping the event. Probably misconfigured workflow.", + extra={ + "tenant_id": tenant_id, + "filter_key": filter_key, + "filter_val": filter_val, + "workflow_id": workflow_model.id, + }, ) should_run = False continue @@ -119,11 +128,7 @@ def insert_events(self, tenant_id, events: typing.List[AlertDto]): break should_run = False # elif the filter is string/int/float, compare them: - elif type(event_val) in [ - int, - str, - float, - ]: + elif type(event_val) in [int, str, float, bool]: if not self._apply_filter(filter_val, event_val): self.logger.debug( "Filter didn't match, skipping", diff --git a/keep/workflowmanager/workflowscheduler.py b/keep/workflowmanager/workflowscheduler.py index f5cdc4f54..1a98c3a92 100644 --- a/keep/workflowmanager/workflowscheduler.py +++ b/keep/workflowmanager/workflowscheduler.py @@ -323,6 +323,9 @@ def _handle_event_workflows(self): # In manual, we create the workflow execution id sync so it could be tracked by the caller (UI) # In event (e.g. alarm), we will create it here if not workflow_execution_id: + # creating the execution id here to be able to trace it in logs even in case of IntegrityError + # eventually, workflow_execution_id == execution_id + execution_id = str(uuid.uuid4()) try: # if the workflow can run in parallel, we just to create a some random execution number if workflow.workflow_strategy == WorkflowStrategy.PARALLEL.value: @@ -339,6 +342,7 @@ def _handle_event_workflows(self): execution_number=workflow_execution_number, fingerprint=event.fingerprint, event_id=event.event_id, + execution_id=execution_id, ) # If there is already running workflow from the same event except IntegrityError: @@ -348,7 +352,11 @@ def _handle_event_workflows(self): == WorkflowStrategy.NONPARALLEL_WITH_RETRY.value ): self.logger.info( - "Collision with workflow execution! will retry next time" + "Collision with workflow execution! will retry next time", + extra={ + "workflow_id": workflow_id, + "tenant_id": tenant_id, + }, ) with self.lock: self.workflows_to_run.append( @@ -367,7 +375,11 @@ def _handle_event_workflows(self): workflow.workflow_strategy == WorkflowStrategy.NONPARALLEL.value ): self.logger.error( - "Collision with workflow execution! will not retry" + "Collision with workflow execution! will not retry", + extra={ + "workflow_id": workflow_id, + "tenant_id": tenant_id, + }, ) self._finish_workflow_execution( tenant_id=tenant_id, @@ -394,16 +406,40 @@ def _handle_event_workflows(self): # and will trigger a workflow that will update the ticket with "resolved" if workflow_to_run.get("retry", False): try: - self.logger.info("Updating enrichment") - new_enrichment = get_enrichment(tenant_id, event.fingerprint) + self.logger.info( + "Updating enrichments for workflow after retry", + extra={ + "workflow_id": workflow_id, + "workflow_execution_id": workflow_execution_id, + "tenant_id": tenant_id, + }, + ) + new_enrichment = get_enrichment( + tenant_id, event.fingerprint, refresh=True + ) # merge the new enrichment with the original event if new_enrichment: new_event = event.dict() new_event.update(new_enrichment.enrichments) event = AlertDto(**new_event) - self.logger.info("Enrichment updated") + self.logger.info( + "Enrichments updated for workflow after retry", + extra={ + "workflow_id": workflow_id, + "workflow_execution_id": workflow_execution_id, + "tenant_id": tenant_id, + "new_enrichment": new_enrichment, + }, + ) except Exception as e: - self.logger.error(f"Failed to get enrichment: {e}") + self.logger.error( + f"Failed to get enrichment: {e}", + extra={ + "workflow_id": workflow_id, + "workflow_execution_id": workflow_execution_id, + "tenant_id": tenant_id, + }, + ) self._finish_workflow_execution( tenant_id=tenant_id, workflow_id=workflow_id, diff --git a/tests/conftest.py b/tests/conftest.py index 29f6c0b89..53a0bcf34 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -368,40 +368,51 @@ def setup_alerts(elastic_client, db_session, request): @pytest.fixture -def setup_stress_alerts(elastic_client, db_session, request): +def setup_stress_alerts_no_elastic(db_session): + + def _setup_stress_alerts_no_elastic(num_alerts): + alert_details = [ + { + "source": [ + "source_{}".format(i % 10) + ], # Cycle through 10 different sources + "service": "service_{}".format(i % 10), # Random of 10 different services + "severity": random.choice( + ["info", "warning", "critical"] + ), # Alternate between 'critical' and 'warning' + "fingerprint": f"test-{i}", + } + for i in range(num_alerts) + ] + alerts = [] + for i, detail in enumerate(alert_details): + random_timestamp = datetime.utcnow() - timedelta(days=random.uniform(0, 7)) + alerts.append( + Alert( + timestamp=random_timestamp, + tenant_id=SINGLE_TENANT_UUID, + provider_type=detail['source'][0], + provider_id="test_{}".format( + i % 5 + ), # Cycle through 5 different provider_ids + event=_create_valid_event(detail, lastReceived=random_timestamp), + fingerprint="fingerprint_{}".format(i), + ) + ) + db_session.add_all(alerts) + db_session.commit() + + return alerts + return _setup_stress_alerts_no_elastic + + +@pytest.fixture +def setup_stress_alerts(elastic_client, db_session, request, setup_stress_alerts_no_elastic): num_alerts = request.param.get( "num_alerts", 1000 ) # Default to 1000 alerts if not specified - alert_details = [ - { - "source": [ - "source_{}".format(i % 10) - ], # Cycle through 10 different sources - "severity": random.choice( - ["info", "warning", "critical"] - ), # Alternate between 'critical' and 'warning' - "fingerprint": f"test-{i}", - } - for i in range(num_alerts) - ] - alerts = [] - for i, detail in enumerate(alert_details): - random_timestamp = datetime.utcnow() - timedelta(days=random.uniform(0, 7)) - alerts.append( - Alert( - timestamp=random_timestamp, - tenant_id=SINGLE_TENANT_UUID, - provider_type="test", - provider_id="test_{}".format( - i % 5 - ), # Cycle through 5 different provider_ids - event=_create_valid_event(detail, lastReceived=random_timestamp), - fingerprint="fingerprint_{}".format(i), - ) - ) - db_session.add_all(alerts) - db_session.commit() + alerts = setup_stress_alerts_no_elastic(num_alerts) # add all to elasticsearch alerts_dto = convert_db_alerts_to_dto_alerts(alerts) elastic_client.index_alerts(alerts_dto) @@ -409,17 +420,19 @@ def setup_stress_alerts(elastic_client, db_session, request): @pytest.fixture def create_alert(db_session): - def _create_alert(fingerprint, status, timestamp): + def _create_alert(fingerprint, status, timestamp, details=None): + details = details or {} alert = Alert( tenant_id=SINGLE_TENANT_UUID, - provider_type="test", + provider_type=details["source"][0] if details and "source" in details else "test", provider_id="test", - event={"fingerprint": fingerprint, "status": status.value}, + event={"fingerprint": fingerprint, "status": status.value, **details}, fingerprint=fingerprint, alert_hash="test_hash", timestamp=timestamp.replace(tzinfo=pytz.utc), ) db_session.add(alert) db_session.commit() + return alert return _create_alert diff --git a/tests/test_incidents.py b/tests/test_incidents.py new file mode 100644 index 000000000..7701f5d2a --- /dev/null +++ b/tests/test_incidents.py @@ -0,0 +1,112 @@ +from datetime import datetime + +import pytz +from sqlalchemy import func + +from keep.api.core.db import ( + add_alerts_to_incident_by_incident_id, + create_incident_from_dict, + get_alerts_data_for_incident, + get_incident_by_id, + remove_alerts_to_incident_by_incident_id, +) +from keep.api.core.db_utils import get_json_extract_field +from keep.api.models.alert import AlertStatus +from keep.api.models.db.alert import Alert + + +def test_get_alerts_data_for_incident(db_session, setup_stress_alerts_no_elastic): + alerts = setup_stress_alerts_no_elastic(100) + assert 100 == db_session.query(func.count(Alert.id)).scalar() + + data = get_alerts_data_for_incident([a.id for a in alerts]) + assert data["sources"] == set(["source_{}".format(i) for i in range(10)]) + assert data["services"] == set(["service_{}".format(i) for i in range(10)]) + assert data["count"] == 100 + + +def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elastic): + alerts = setup_stress_alerts_no_elastic(100) + incident = create_incident_from_dict("keep", {"name": "test", "description": "test"}) + + assert len(incident.alerts) == 0 + + add_alerts_to_incident_by_incident_id( + "keep", + incident.id, + [a.id for a in alerts] + ) + + incident = get_incident_by_id("keep", incident.id) + + assert sorted(incident.affected_services) == sorted(["service_{}".format(i) for i in range(10)]) + assert sorted(incident.sources) == sorted(["source_{}".format(i) for i in range(10)]) + + service_field = get_json_extract_field(db_session, Alert.event, 'service') + + service_0 = ( + db_session.query(Alert.id) + .filter( + service_field == "service_0" + ) + .all() + ) + + remove_alerts_to_incident_by_incident_id( + "keep", + incident.id, + [service_0[0].id, ] + ) + + incident = get_incident_by_id("keep", incident.id) + + assert len(incident.alerts) == 99 + assert "service_0" in incident.affected_services + assert len(incident.affected_services) == 10 + assert sorted(incident.affected_services) == sorted(["service_{}".format(i) for i in range(10)]) + + remove_alerts_to_incident_by_incident_id( + "keep", + incident.id, + [a.id for a in service_0] + ) + + incident = get_incident_by_id("keep", incident.id) + + assert len(incident.alerts) == 90 + assert "service_0" not in incident.affected_services + assert len(incident.affected_services) == 9 + assert sorted(incident.affected_services) == sorted(["service_{}".format(i) for i in range(1, 10)]) + + source_1 = ( + db_session.query(Alert.id) + .filter( + Alert.provider_type == "source_1" + ) + .all() + ) + + remove_alerts_to_incident_by_incident_id( + "keep", + incident.id, + [source_1[0].id, ] + ) + + incident = get_incident_by_id("keep", incident.id) + + assert len(incident.alerts) == 89 + assert "source_1" in incident.sources + # source_0 was removed together with service_0 + assert len(incident.sources) == 9 + assert sorted(incident.sources) == sorted(["source_{}".format(i) for i in range(1, 10)]) + + remove_alerts_to_incident_by_incident_id( + "keep", + incident.id, + [a.id for a in source_1] + ) + + incident = get_incident_by_id("keep", incident.id) + + assert len(incident.sources) == 8 + assert sorted(incident.sources) == sorted(["source_{}".format(i) for i in range(2, 10)])