diff --git a/.gitignore b/.gitignore index ecc7b8450..57a86b596 100644 --- a/.gitignore +++ b/.gitignore @@ -205,4 +205,7 @@ playwright_dump_*.html playwright_dump_*.png ee/experimental/ai_temp/* -!ee/experimental/ai_temp/.gitkeep +,e!ee/experimental/ai_temp/.gitkeep + +oauth2.cfg +scripts/keep_slack_bot.py diff --git a/docs/mint.json b/docs/mint.json index d06960ab6..191697a2b 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -45,6 +45,7 @@ ] }, "overview/maintenance-windows", + "overview/deduplication", "overview/examples" ] }, diff --git a/docs/overview/deduplication.mdx b/docs/overview/deduplication.mdx new file mode 100644 index 000000000..00d755ec5 --- /dev/null +++ b/docs/overview/deduplication.mdx @@ -0,0 +1,107 @@ +--- +title: "Alert Deduplication" +--- + +## Overview + +Alert deduplication is a crucial feature in Keep that helps reduce noise and streamline incident management by grouping similar alerts together. This process ensures that your team isn't overwhelmed by a flood of notifications for what is essentially the same issue, allowing for more efficient and focused incident response. + +## Glossary + +- **Deduplication Rule**: A set of criteria used to determine if alerts should be grouped together. +- **Partial Deduplication**: Correlates instances of alerts into single alerts, considering the case of the same alert with different statuses (e.g., firing and resolved). This is the default mode where specified fields are used to identify and group related alerts. +- **Fingerprint Fields**: Specific alert attributes used to identify similar alerts. +- **Full Deduplication**: A mode where alerts are considered identical if all fields match exactly (except those explicitly ignored). This helps avoid system overload by discarding duplicate alerts. +- **Ignore Fields**: In full deduplication mode, these are fields that are not considered when comparing alerts. + +## Deduplication Types + +### Partial Deduplication +Partial deduplication allows you to specify certain fields (fingerprint fields) that are used to identify similar alerts. Alerts with matching values in these specified fields are considered duplicates and are grouped together. This method is flexible and allows for fine-tuned control over how alerts are deduplicated. + +Every provider integrated with Keep comes with pre-built partial deduplication rule tailored to that provider's specific alert format and common use cases. +The default fingerprint fields defined using `FINGERPRINT_FIELDS` attributes in the provider code (e.g. [datadog provider](https://github.com/keephq/keep/blob/main/keep/providers/datadog_provider/datadog_provider.py#L188) or [gcp monitoring provder](https://github.com/keephq/keep/blob/main/keep/providers/gcpmonitoring_provider/gcpmonitoring_provider.py#L52)). + +### Full Deduplication +When full deduplication is enabled, Keep will also discard exact same events (excluding ignore fields). This mode considers all fields of an alert when determining duplicates, except for explicitly ignored fields. + +By default, exact similar events excluding lastReceived time are fully deduplicated and discarded. This helps prevent system overload from repeated identical alerts. + +## Real Examples of Alerts and Results + +### Example 1: Partial Deduplication + +**Rule** - Deduplicate based on 'service' and 'error_message' fields. + +```json +# alert 1 +{ + "service": "payment", + "error_message": "Database connection failed", + "severity": "high", + "lastReceived": "2023-05-01T10:00:00Z" +} +# alert 2 +{ + "service": "payment", + "error_message": "Database connection failed", + "severity": "critical", + "lastReceived": "2023-05-01T10:05:00Z" +} +# alert 3 +{ + "service": "auth", + "error_message": "Invalid token", + "severity": "medium", + "lastReceived": "2023-05-01T10:10:00Z" +} +``` + +**Result**: +- Alerts 1 and 2 are deduplicated into a single alert, fields are updated. +- Alert 3 remains separate as it has a different service and error message. + +### Example 2: Full Deduplication + +**Rule**: Full deduplication with 'timestamp' as an ignore field + +**Incoming Alerts**: + +```json + +# alert 1 +{ + service: "api", + error: "Rate limit exceeded", + user_id: "12345", + lastReceived: "2023-05-02T14:00:00Z" +} +# alert 2 (discarded as its identical) +{ + service: "api", + error: "Rate limit exceeded", + user_id: "12345", + lastReceived: "2023-05-02T14:01:00Z" +} +# alert 3 +{ + service: "api", + error: "Rate limit exceeded", + user_id: "67890", + lastReceived: "2023-05-02T14:02:00Z" +} +``` + +**Result**: +- Alerts 1 and 2 are deduplicated as they are identical except for the ignored timestamp field. +- Alert 3 remains separate due to the different user_id. + +## How It Works + +Keep's deduplication process follows these steps: + +1. **Alert Ingestion**: Every alert received by Keep is first ingested into the system. + +2. **Enrichment**: After ingestion, each alert undergoes an enrichment process. This step adds additional context or information to the alert, enhancing its value and usefulness. + +3. **Deduplication**: Following enrichment, Keep's alert deduplicator comes into play. It applies the defined deduplication rules to the enriched alerts. diff --git a/keep-ui/app/deduplication/DeduplicationPlaceholder.tsx b/keep-ui/app/deduplication/DeduplicationPlaceholder.tsx new file mode 100644 index 000000000..e9fb49f1a --- /dev/null +++ b/keep-ui/app/deduplication/DeduplicationPlaceholder.tsx @@ -0,0 +1,30 @@ +import { Fragment, useState } from "react"; +import { Button, Card, Subtitle, Title } from "@tremor/react"; +// import { CorrelationSidebar } from "./CorrelationSidebar"; +import { DeduplicationSankey } from "./DeduplicationSankey"; + +export const DeduplicationPlaceholder = () => { + const [isSidebarOpen, setIsSidebarOpen] = useState(false); + + const onCorrelationClick = () => { + setIsSidebarOpen(true); + }; + + return ( + + +
+ No Deduplications Yet + + Reduce noise by creatiing deduplications. + + + Start sending alerts or connect providers to create deduplication + rules. + +
+ +
+
+ ); +}; diff --git a/keep-ui/app/deduplication/DeduplicationSankey.tsx b/keep-ui/app/deduplication/DeduplicationSankey.tsx new file mode 100644 index 000000000..a1c2ba23b --- /dev/null +++ b/keep-ui/app/deduplication/DeduplicationSankey.tsx @@ -0,0 +1,88 @@ +import {SVGProps} from "react"; + +export const DeduplicationSankey = (props: SVGProps) => ( + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +); diff --git a/keep-ui/app/deduplication/DeduplicationSidebar.tsx b/keep-ui/app/deduplication/DeduplicationSidebar.tsx new file mode 100644 index 000000000..ddc2e6a8f --- /dev/null +++ b/keep-ui/app/deduplication/DeduplicationSidebar.tsx @@ -0,0 +1,554 @@ +import { Fragment, useEffect, useState, useMemo } from "react"; +import { Dialog, Transition } from "@headlessui/react"; +import { useForm, Controller, SubmitHandler } from "react-hook-form"; +import { + Text, + Button, + TextInput, + Callout, + Badge, + Switch, + Icon, + Title, + Card, +} from "@tremor/react"; +import { IoMdClose } from "react-icons/io"; +import { DeduplicationRule } from "app/deduplication/models"; +import { useProviders } from "utils/hooks/useProviders"; +import { useDeduplicationFields } from "utils/hooks/useDeduplicationRules"; +import { GroupBase } from "react-select"; +import Select from "@/components/ui/Select"; +import MultiSelect from "@/components/ui/MultiSelect"; +import { + ExclamationTriangleIcon, + InformationCircleIcon, +} from "@heroicons/react/24/outline"; +import { getApiURL } from "utils/apiUrl"; +import { useSession } from "next-auth/react"; +import { KeyedMutator } from "swr"; + +interface ProviderOption { + value: string; + label: string; + logoUrl: string; +} + +interface DeduplicationSidebarProps { + isOpen: boolean; + toggle: VoidFunction; + selectedDeduplicationRule: DeduplicationRule | null; + onSubmit: (data: Partial) => Promise; + mutateDeduplicationRules: KeyedMutator; +} + +const DeduplicationSidebar: React.FC = ({ + isOpen, + toggle, + selectedDeduplicationRule, + onSubmit, + mutateDeduplicationRules, +}) => { + const { + control, + handleSubmit, + setValue, + reset, + setError, + watch, + formState: { errors }, + clearErrors, + } = useForm>({ + defaultValues: selectedDeduplicationRule || { + name: "", + description: "", + provider_type: "", + provider_id: "", + fingerprint_fields: [], + full_deduplication: false, + ignore_fields: [], + }, + }); + + const [isSubmitting, setIsSubmitting] = useState(false); + const { + data: providers = { installed_providers: [], linked_providers: [] }, + } = useProviders(); + const { data: deduplicationFields = {} } = useDeduplicationFields(); + const { data: session } = useSession(); + + const alertProviders = useMemo( + () => + [ + { id: null, type: "keep", details: { name: "Keep" }, tags: ["alert"] }, + ...providers.installed_providers, + ...providers.linked_providers, + ].filter((provider) => provider.tags?.includes("alert")), + [providers] + ); + const fullDeduplication = watch("full_deduplication"); + const selectedProviderType = watch("provider_type"); + const selectedProviderId = watch("provider_id"); + const fingerprintFields = watch("fingerprint_fields"); + const ignoreFields = watch("ignore_fields"); + + const availableFields = useMemo(() => { + const defaultFields = [ + "source", + "service", + "description", + "fingerprint", + "name", + "lastReceived", + ]; + if (selectedProviderType) { + const key = `${selectedProviderType}_${selectedProviderId || "null"}`; + const providerFields = deduplicationFields[key] || []; + return [ + ...new Set([ + ...defaultFields, + ...providerFields, + ...(fingerprintFields ?? []), + ...(ignoreFields ?? []), + ]), + ]; + } + return [...new Set([...defaultFields, ...(fingerprintFields ?? [])])]; + }, [ + selectedProviderType, + selectedProviderId, + deduplicationFields, + fingerprintFields, + ignoreFields, + ]); + + useEffect(() => { + if (isOpen && selectedDeduplicationRule) { + reset(selectedDeduplicationRule); + } else if (isOpen) { + reset({ + name: "", + description: "", + provider_type: "", + provider_id: "", + fingerprint_fields: [], + full_deduplication: false, + ignore_fields: [], + }); + } + }, [isOpen, selectedDeduplicationRule, reset]); + + const handleToggle = () => { + if (isOpen) { + clearErrors(); + } + toggle(); + }; + + const onFormSubmit: SubmitHandler> = async ( + data + ) => { + setIsSubmitting(true); + clearErrors(); + try { + const apiUrl = getApiURL(); + let url = `${apiUrl}/deduplications`; + + if (selectedDeduplicationRule && selectedDeduplicationRule.id) { + url += `/${selectedDeduplicationRule.id}`; + } + + const method = + !selectedDeduplicationRule || !selectedDeduplicationRule.id + ? "POST" + : "PUT"; + + const response = await fetch(url, { + method: method, + headers: { + Authorization: `Bearer ${session?.accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(data), + }); + + if (response.ok) { + console.log("Deduplication rule saved:", data); + reset(); + handleToggle(); + await mutateDeduplicationRules(); + } else { + const errorData = await response.json(); + setError("root.serverError", { + type: "manual", + message: errorData.message || "Failed to save deduplication rule", + }); + } + } catch (error) { + setError("root.serverError", { + type: "manual", + message: "An unexpected error occurred", + }); + } finally { + setIsSubmitting(false); + } + }; + + return ( + + + + + + ); +}; + +export default DeduplicationSidebar; diff --git a/keep-ui/app/deduplication/DeduplicationTable.tsx b/keep-ui/app/deduplication/DeduplicationTable.tsx new file mode 100644 index 000000000..5bef79f40 --- /dev/null +++ b/keep-ui/app/deduplication/DeduplicationTable.tsx @@ -0,0 +1,338 @@ +import React, { useEffect, useMemo, useState } from "react"; +import { + Button, + Card, + Subtitle, + Table, + TableBody, + TableCell, + TableHead, + TableHeaderCell, + TableRow, + Title, + Badge, + SparkAreaChart, +} from "@tremor/react"; +import { useRouter, useSearchParams } from "next/navigation"; +import { + createColumnHelper, + flexRender, + getCoreRowModel, + useReactTable, +} from "@tanstack/react-table"; +import { DeduplicationRule } from "app/deduplication/models"; +import DeduplicationSidebar from "app/deduplication/DeduplicationSidebar"; +import { TrashIcon, PauseIcon, PlusIcon } from "@heroicons/react/24/outline"; +import Image from "next/image"; +import { getApiURL } from "utils/apiUrl"; +import { useSession } from "next-auth/react"; + +const columnHelper = createColumnHelper(); + +import { KeyedMutator } from "swr"; + +type DeduplicationTableProps = { + deduplicationRules: DeduplicationRule[]; + mutateDeduplicationRules: KeyedMutator; +}; + +export const DeduplicationTable: React.FC = ({ + deduplicationRules, + mutateDeduplicationRules, +}) => { + const router = useRouter(); + const { data: session } = useSession(); + const searchParams = useSearchParams(); + + let selectedId = searchParams ? searchParams.get("id") : null; + const [isSidebarOpen, setIsSidebarOpen] = useState(false); + const [selectedDeduplicationRule, setSelectedDeduplicationRule] = + useState(null); + + const onDeduplicationClick = (rule: DeduplicationRule) => { + setSelectedDeduplicationRule(rule); + setIsSidebarOpen(true); + router.push(`/deduplication?id=${rule.id}`); + }; + + const onCloseDeduplication = () => { + setIsSidebarOpen(false); + setSelectedDeduplicationRule(null); + router.push("/deduplication"); + }; + + const handleDeleteRule = async ( + rule: DeduplicationRule, + event: React.MouseEvent + ) => { + event.stopPropagation(); + if (rule.default) return; // Don't delete default rules + + if ( + window.confirm("Are you sure you want to delete this deduplication rule?") + ) { + try { + const url = `${getApiURL()}/deduplications/${rule.id}`; + const response = await fetch(url, { + method: "DELETE", + headers: { + Authorization: `Bearer ${session?.accessToken}`, + }, + }); + + if (response.ok) { + await mutateDeduplicationRules(); + } else { + console.error("Failed to delete deduplication rule"); + } + } catch (error) { + console.error("Error deleting deduplication rule:", error); + } + } + }; + + useEffect(() => { + if (selectedId && !isSidebarOpen) { + const rule = deduplicationRules.find((r) => r.id === selectedId); + if (rule) { + setSelectedDeduplicationRule(rule); + setIsSidebarOpen(true); + } + } + }, [selectedId, deduplicationRules]); + + useEffect(() => { + if (!isSidebarOpen && selectedId) { + router.push("/deduplication"); + } + }, [isSidebarOpen, selectedId, router]); + + const DEDUPLICATION_TABLE_COLS = useMemo( + () => [ + columnHelper.accessor("provider_type", { + header: "", + cell: (info) => ( +
+ {info.getValue()} +
+ ), + }), + columnHelper.accessor("description", { + header: "Name", + cell: (info) => ( +
+ + {info.getValue()} + + {info.row.original.default ? ( + + Default + + ) : ( + + Custom + + )} + {info.row.original.full_deduplication && ( + + Full Deduplication + + )} +
+ ), + }), + columnHelper.accessor("ingested", { + header: "Ingested", + cell: (info) => ( + + {info.getValue() || 0} + + ), + }), + columnHelper.accessor("dedup_ratio", { + header: "Dedup Ratio", + cell: (info) => { + const value = info.getValue() || 0; + const formattedValue = Number(value).toFixed(1); + return ( + + {formattedValue}% + + ); + }, + }), + columnHelper.accessor("distribution", { + header: "Distribution", + cell: (info) => { + const rawData = info.getValue(); + const maxNumber = Math.max(...rawData.map((item) => item.number)); + const allZero = rawData.every((item) => item.number === 0); + const data = rawData.map((item) => ({ + ...item, + number: maxNumber > 0 ? item.number / maxNumber + 1 : 0.5, + })); + const colors = ["orange"]; + const showGradient = true; + return ( + + ); + }, + }), + columnHelper.accessor("fingerprint_fields", { + header: "Fields", + cell: (info) => { + const fields = info.getValue(); + const ignoreFields = info.row.original.ignore_fields; + const displayFields = + fields && fields.length > 0 ? fields : ignoreFields; + + if (!displayFields || displayFields.length === 0) { + return ( +
+ + N/A + +
+ ); + } + + return ( +
+ {displayFields.map((field: string, index: number) => ( + + {index > 0 && } + + {field} + + + ))} +
+ ); + }, + }), + columnHelper.display({ + id: "actions", + cell: (info) => ( +
+ {/*
+ ), + }), + ], + [handleDeleteRule] + ); + + const table = useReactTable({ + data: deduplicationRules, + columns: DEDUPLICATION_TABLE_COLS, + getCoreRowModel: getCoreRowModel(), + }); + + const handleSubmitDeduplicationRule = async ( + data: Partial + ) => { + // Implement the logic to submit the deduplication rule + // This is a placeholder function, replace with actual implementation + console.log("Submitting deduplication rule:", data); + // Add API call or state update logic here + }; + + return ( +
+
+
+ + Deduplication Rules{" "} + <span className="text-gray-400">({deduplicationRules.length})</span> + + + Set up rules to deduplicate similar alerts + +
+ +
+ + + + {table.getHeaderGroups().map((headerGroup) => ( + + {headerGroup.headers.map((header) => ( + + {flexRender( + header.column.columnDef.header, + header.getContext() + )} + + ))} + + ))} + + + {table.getRowModel().rows.map((row) => ( + onDeduplicationClick(row.original)} + > + {row.getVisibleCells().map((cell) => ( + + {flexRender(cell.column.columnDef.cell, cell.getContext())} + + ))} + + ))} + +
+
+ +
+ ); +}; diff --git a/keep-ui/app/deduplication/client.tsx b/keep-ui/app/deduplication/client.tsx new file mode 100644 index 000000000..98187ef28 --- /dev/null +++ b/keep-ui/app/deduplication/client.tsx @@ -0,0 +1,20 @@ +"use client"; + +import { useDeduplicationRules } from "utils/hooks/useDeduplicationRules"; +import { DeduplicationPlaceholder } from "./DeduplicationPlaceholder"; +import { DeduplicationTable } from "./DeduplicationTable"; +import Loading from "app/loading"; + +export const Client = () => { + const { data: deduplicationRules = [], isLoading, mutate: mutateDeduplicationRules } = useDeduplicationRules(); + + if (isLoading) { + return ; + } + + if (deduplicationRules.length === 0) { + return ; + } + + return ; +}; diff --git a/keep-ui/app/deduplication/models.tsx b/keep-ui/app/deduplication/models.tsx new file mode 100644 index 000000000..8e5e0d3ec --- /dev/null +++ b/keep-ui/app/deduplication/models.tsx @@ -0,0 +1,20 @@ +export interface DeduplicationRule { + id: string; + name: string; + description: string; + default: boolean; + distribution: { hour: number; number: number }[]; + provider_type: string; + provider_id: string; + last_updated: string; + last_updated_by: string; + created_at: string; + created_by: string; + enabled: boolean; + fingerprint_fields: string[]; + ingested: number; + dedup_ratio: number; + // full_deduplication is true if the deduplication rule is a full deduplication rule + full_deduplication: boolean; + ignore_fields: string[]; +} diff --git a/keep-ui/app/deduplication/page.tsx b/keep-ui/app/deduplication/page.tsx new file mode 100644 index 000000000..01a99a74d --- /dev/null +++ b/keep-ui/app/deduplication/page.tsx @@ -0,0 +1,10 @@ +import { Client } from "./client"; + +export default function Page() { + return ; +} + +export const metadata = { + title: "Keep - Deduplication", + description: "Create and manage Keep Deduplication.", +}; diff --git a/keep-ui/components/navbar/NoiseReductionLinks.tsx b/keep-ui/components/navbar/NoiseReductionLinks.tsx index b7799d42b..ec851c342 100644 --- a/keep-ui/components/navbar/NoiseReductionLinks.tsx +++ b/keep-ui/components/navbar/NoiseReductionLinks.tsx @@ -10,6 +10,7 @@ import classNames from "classnames"; import { AILink } from "./AILink"; import { TbTopologyRing } from "react-icons/tb"; import { FaVolumeMute } from "react-icons/fa"; +import { IoMdGitMerge } from "react-icons/io"; import { useTopology } from "utils/hooks/useTopology"; type NoiseReductionLinksProps = { session: Session | null }; @@ -41,6 +42,11 @@ export const NoiseReductionLinks = ({ session }: NoiseReductionLinksProps) => { +
  • + + Deduplication + +
  • Correlations diff --git a/keep-ui/components/ui/MultiSelect.tsx b/keep-ui/components/ui/MultiSelect.tsx new file mode 100644 index 000000000..eb22a710d --- /dev/null +++ b/keep-ui/components/ui/MultiSelect.tsx @@ -0,0 +1,112 @@ +import React from "react"; +import Select from "react-select"; +import { + components, + Props as SelectProps, + GroupBase, + StylesConfig, +} from "react-select"; +import { Badge } from "@tremor/react"; + +type OptionType = { value: string; label: string }; + +const customStyles: StylesConfig = { + control: (provided: any, state: any) => ({ + ...provided, + borderColor: state.isFocused ? "orange" : "#ccc", + "&:hover": { + borderColor: "orange", + }, + boxShadow: state.isFocused ? "0 0 0 1px orange" : null, + backgroundColor: "transparent", + }), + option: (provided, state) => ({ + ...provided, + backgroundColor: state.isSelected + ? "orange" + : state.isFocused + ? "rgba(255, 165, 0, 0.1)" + : "transparent", + color: state.isSelected ? "white" : "black", + "&:hover": { + backgroundColor: "rgba(255, 165, 0, 0.3)", + }, + }), + multiValue: (provided) => ({ + ...provided, + backgroundColor: "default", + }), + multiValueLabel: (provided) => ({ + ...provided, + color: "black", + }), + multiValueRemove: (provided) => ({ + ...provided, + color: "orange", + "&:hover": { + backgroundColor: "orange", + color: "white", + }, + }), + menuPortal: (base) => ({ + ...base, + zIndex: 9999, + }), + menu: (provided) => ({ + ...provided, + zIndex: 9999, + }), +}; + +type CustomSelectProps = SelectProps< + OptionType, + true, + GroupBase +> & { + components?: { + Option?: typeof components.Option; + MultiValue?: typeof components.MultiValue; + }; +}; + +const customComponents: CustomSelectProps["components"] = { + Option: ({ children, ...props }) => ( + + + {children} + + + ), + MultiValue: ({ children, ...props }) => ( + + + {children} + + + ), +}; + +type MultiSelectProps = SelectProps>; + +const MultiSelect: React.FC = ({ + value, + onChange, + options, + placeholder, + ...rest +}) => ( + > value={value} onChange={onChange} options={options} placeholder={placeholder} styles={customStyles} components={customComponents} - menuPortalTarget={document.body} // Render the menu in a portal + menuPortalTarget={document.body} menuPosition="fixed" - getOptionLabel={getOptionLabel} // Support custom getOptionLabel - getOptionValue={getOptionValue} // Support custom getOptionValue + getOptionLabel={getOptionLabel} + getOptionValue={getOptionValue} + {...rest} /> ); diff --git a/keep-ui/tailwind.config.js b/keep-ui/tailwind.config.js index 1646a014b..595f66d20 100644 --- a/keep-ui/tailwind.config.js +++ b/keep-ui/tailwind.config.js @@ -23,7 +23,7 @@ module.exports = { muted: "rgb(255 237 213)", // orange-200 subtle: "rgb(251 146 60)", // orange-400 DEFAULT: "rgb(249 115 22)", // orange-500 - emphasis: "#1d4ed8", // blue-700 + emphasis: "#374151", // gray-700 inverted: "#ffffff", // white }, background: { diff --git a/keep-ui/utils/hooks/useDeduplicationRules.ts b/keep-ui/utils/hooks/useDeduplicationRules.ts new file mode 100644 index 000000000..1d6d43c54 --- /dev/null +++ b/keep-ui/utils/hooks/useDeduplicationRules.ts @@ -0,0 +1,28 @@ +import { DeduplicationRule } from "app/deduplication/models"; +import { useSession } from "next-auth/react"; +import { SWRConfiguration } from "swr"; +import useSWRImmutable from "swr/immutable"; +import { getApiURL } from "utils/apiUrl"; +import { fetcher } from "utils/fetcher"; + +export const useDeduplicationRules = (options: SWRConfiguration = {}) => { + const apiUrl = getApiURL(); + const { data: session } = useSession(); + + return useSWRImmutable( + () => (session ? `${apiUrl}/deduplications` : null), + (url) => fetcher(url, session?.accessToken), + options + ); +}; + +export const useDeduplicationFields = (options: SWRConfiguration = {}) => { + const apiUrl = getApiURL(); + const { data: session } = useSession(); + + return useSWRImmutable>( + () => (session ? `${apiUrl}/deduplications/fields` : null), + (url) => fetcher(url, session?.accessToken), + options + ); +}; diff --git a/keep/api/alert_deduplicator/alert_deduplicator.py b/keep/api/alert_deduplicator/alert_deduplicator.py index a16cf20af..08fb5c243 100644 --- a/keep/api/alert_deduplicator/alert_deduplicator.py +++ b/keep/api/alert_deduplicator/alert_deduplicator.py @@ -2,97 +2,157 @@ import hashlib import json import logging +import uuid -import celpy +from fastapi import HTTPException -from keep.api.core.db import get_all_filters, get_last_alert_hash_by_fingerprint -from keep.api.models.alert import AlertDto +from keep.api.core.config import config +from keep.api.core.db import ( + create_deduplication_event, + create_deduplication_rule, + delete_deduplication_rule, + get_alerts_fields, + get_all_deduplication_rules, + get_all_deduplication_stats, + get_custom_deduplication_rules, + get_last_alert_hash_by_fingerprint, + update_deduplication_rule, +) +from keep.api.models.alert import ( + AlertDto, + DeduplicationRuleDto, + DeduplicationRuleRequestDto, +) +from keep.providers.providers_factory import ProvidersFactory +from keep.searchengine.searchengine import SearchEngine + +DEFAULT_RULE_UUID = "00000000-0000-0000-0000-000000000000" -# decide whether this should be a singleton so that we can keep the filters in memory class AlertDeduplicator: - # this fields will be removed from the alert before hashing - # TODO: make this configurable - DEFAULT_FIELDS = ["lastReceived"] def __init__(self, tenant_id): - self.filters = get_all_filters(tenant_id) self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id + self.provider_distribution_enabled = config( + "PROVIDER_DISTRIBUTION_ENABLED", cast=bool, default=True + ) + self.search_engine = SearchEngine(self.tenant_id) - def is_deduplicated(self, alert: AlertDto) -> bool: - # Apply all deduplication filters - for filt in self.filters: - alert = self._apply_deduplication_filter(filt, alert) + def _apply_deduplication_rule( + self, alert: AlertDto, rule: DeduplicationRuleDto + ) -> bool: + """ + Apply a deduplication rule to an alert. - # Remove default fields - for field in AlertDeduplicator.DEFAULT_FIELDS: - alert = self._remove_field(field, alert) + Gets an alert and a deduplication rule and apply the rule to the alert by: + - removing the fields that should be ignored + - calculating the hash + - checking if the hash is already in the database + - setting the isFullDuplicate or isPartialDuplicate flag + """ + # we don't want to remove fields from the original alert + alert_copy = copy.deepcopy(alert) + # remove the fields that should be ignored + for field in rule.ignore_fields: + alert_copy = self._remove_field(field, alert_copy) - # Calculate the hash + # calculate the hash alert_hash = hashlib.sha256( - json.dumps(alert.dict(), default=str).encode() + json.dumps(alert_copy.dict(), default=str).encode() ).hexdigest() - + alert.alert_hash = alert_hash # Check if the hash is already in the database last_alert_hash_by_fingerprint = get_last_alert_hash_by_fingerprint( self.tenant_id, alert.fingerprint ) - alert_deduplicate = ( - True - if last_alert_hash_by_fingerprint + # the hash is the same as the last alert hash by fingerprint - full deduplication + if ( + last_alert_hash_by_fingerprint and last_alert_hash_by_fingerprint == alert_hash - else False - ) - if alert_deduplicate: - self.logger.info(f"Alert {alert.id} is deduplicated {alert.source}") - - return alert_hash, alert_deduplicate - - def _run_matcher(self, matcher, alert: AlertDto) -> bool: - # run the CEL matcher - env = celpy.Environment() - ast = env.compile(matcher) - prgm = env.program(ast) - activation = celpy.json_to_cel( - json.loads(json.dumps(alert.dict(), default=str)) + ): + self.logger.info( + "Alert is deduplicated", + extra={ + "alert_id": alert.id, + "rule_id": rule.id, + "tenant_id": self.tenant_id, + }, + ) + alert.isFullDuplicate = True + # it means that there is another alert with the same fingerprint but different hash + # so its a deduplication + elif last_alert_hash_by_fingerprint: + self.logger.info( + "Alert is partially deduplicated", + extra={ + "alert_id": alert.id, + "tenant_id": self.tenant_id, + }, + ) + alert.isPartialDuplicate = True + + return alert + + def apply_deduplication(self, alert: AlertDto) -> bool: + # IMPOTRANT NOTE TO SOMEONE WORKING ON THIS CODE: + # apply_deduplication runs AFTER _format_alert, so you can assume that alert fields are in the expected format. + # you can also safe to assume that alert.fingerprint is set by the provider itself + + # get only relevant rules + rules = self.get_deduplication_rules( + self.tenant_id, alert.providerId, alert.providerType ) - try: - r = prgm.evaluate(activation) - except celpy.evaluation.CELEvalError as e: - # this is ok, it means that the subrule is not relevant for this event - if "no such member" in str(e): - return False - # unknown - raise - return True if r else False - - def _apply_deduplication_filter(self, filt, alert: AlertDto) -> AlertDto: - # check if the matcher applies - filter_apply = self._run_matcher(filt.matcher_cel, alert) - if not filter_apply: - self.logger.debug(f"Filter {filt.id} did not match") - return alert - - # remove the fields - for field in filt.fields: - alert = self._remove_field(field, alert) + + for rule in rules: + self.logger.debug( + "Applying deduplication rule to alert", + extra={ + "rule_id": rule.id, + "alert_id": alert.id, + }, + ) + alert = self._apply_deduplication_rule(alert, rule) + self.logger.debug( + "Alert after deduplication rule applied", + extra={ + "rule_id": rule.id, + "alert_id": alert.id, + "is_full_duplicate": alert.isFullDuplicate, + "is_partial_duplicate": alert.isPartialDuplicate, + }, + ) + if alert.isFullDuplicate or alert.isPartialDuplicate: + # create deduplication event + create_deduplication_event( + tenant_id=self.tenant_id, + deduplication_rule_id=rule.id, + deduplication_type="full" if alert.isFullDuplicate else "partial", + provider_id=alert.providerId, + provider_type=alert.providerType, + ) + # we don't need to check the other rules + break + else: + # create none deduplication event, for statistics + create_deduplication_event( + tenant_id=self.tenant_id, + deduplication_rule_id=rule.id, + deduplication_type="none", + provider_id=alert.providerId, + provider_type=alert.providerType, + ) return alert def _remove_field(self, field, alert: AlertDto) -> AlertDto: - # remove the field from the alert alert = copy.deepcopy(alert) field_parts = field.split(".") - # if its not a nested field if len(field_parts) == 1: try: delattr(alert, field) except AttributeError: - self.logger.warning("Failed to delete attribute {field} from alert") - pass - # if its a nested field, copy the dictionaty and remove the field - # this is for cases such as labels/tags + self.logger.warning(f"Failed to delete attribute {field} from alert") else: alert_attr = field_parts[0] d = copy.deepcopy(getattr(alert, alert_attr)) @@ -101,3 +161,325 @@ def _remove_field(self, field, alert: AlertDto) -> AlertDto: del d[field_parts[-1]] setattr(alert, field_parts[0], d) return alert + + def get_deduplication_rules( + self, tenant_id, provider_id, provider_type + ) -> DeduplicationRuleDto: + # try to get the rule from the database + rules = get_custom_deduplication_rules(tenant_id, provider_id, provider_type) + + if not rules: + self.logger.debug( + "No custom deduplication rules found, using deafult full deduplication rule", + extra={ + "provider_id": provider_id, + "provider_type": provider_type, + "tenant_id": tenant_id, + }, + ) + rule = self._get_default_full_deduplication_rule(provider_id, provider_type) + return [rule] + + # else, return the custom rules + self.logger.debug( + "Using custom deduplication rules", + extra={ + "provider_id": provider_id, + "provider_type": provider_type, + "tenant_id": tenant_id, + }, + ) + # + # check that at least one of them is full deduplication rule + full_deduplication_rules = [rule for rule in rules if rule.full_deduplication] + # if full deduplication rule found, return the rules + if full_deduplication_rules: + return rules + + # if not, assign them the default full deduplication rule ignore fields + self.logger.info( + "No full deduplication rule found, assigning default full deduplication rule ignore fields" + ) + default_full_dedup_rule = self._get_default_full_deduplication_rule( + provider_id=provider_id, provider_type=provider_type + ) + for rule in rules: + if not rule.full_deduplication: + self.logger.debug( + "Assigning default full deduplication rule ignore fields", + ) + rule.ignore_fields = default_full_dedup_rule.ignore_fields + return rules + + def _generate_uuid(self, provider_id, provider_type): + # this is a way to generate a unique uuid for the default deduplication rule per (provider_id, provider_type) + namespace_uuid = uuid.uuid5(uuid.NAMESPACE_DNS, "keephq.dev") + generated_uuid = str( + uuid.uuid5(namespace_uuid, f"{provider_id}_{provider_type}") + ) + return generated_uuid + + def _get_default_full_deduplication_rule( + self, provider_id, provider_type + ) -> DeduplicationRuleDto: + # this is a way to generate a unique uuid for the default deduplication rule per (provider_id, provider_type) + generated_uuid = self._generate_uuid(provider_id, provider_type) + + # just return a default deduplication rule with lastReceived field + if not provider_type: + provider_type = "keep" + + return DeduplicationRuleDto( + id=generated_uuid, + name=f"{provider_type} default deduplication rule", + description=f"{provider_type} default deduplication rule", + default=True, + distribution=[{"hour": i, "number": 0} for i in range(24)], + fingerprint_fields=[], # ["fingerprint"], # this is fallback + provider_type=provider_type or "keep", + provider_id=provider_id, + full_deduplication=True, + ignore_fields=["lastReceived"], + priority=0, + last_updated=None, + last_updated_by=None, + created_at=None, + created_by=None, + ingested=0, + dedup_ratio=0.0, + enabled=True, + ) + + def get_deduplications(self) -> list[DeduplicationRuleDto]: + # get all providers + installed_providers = ProvidersFactory.get_installed_providers(self.tenant_id) + installed_providers = [ + provider for provider in installed_providers if "alert" in provider.tags + ] + # get all linked providers + linked_providers = ProvidersFactory.get_linked_providers(self.tenant_id) + providers = [*installed_providers, *linked_providers] + + # get default deduplication rules + default_deduplications = ProvidersFactory.get_default_deduplication_rules() + default_deduplications_dict = { + dd.provider_type: dd for dd in default_deduplications + } + for dd in default_deduplications: + provider_id, provider_type = dd.provider_id, dd.provider_type + dd.id = self._generate_uuid(provider_id, provider_type) + # get custom deduplication rules + custom_deduplications = get_all_deduplication_rules(self.tenant_id) + # cast to dto + custom_deduplications_dto = [ + DeduplicationRuleDto( + id=str(rule.id), + name=rule.name, + description=rule.description, + default=False, + distribution=[{"hour": i, "number": 0} for i in range(24)], + fingerprint_fields=rule.fingerprint_fields, + provider_type=rule.provider_type, + provider_id=rule.provider_id, + full_deduplication=rule.full_deduplication, + ignore_fields=rule.ignore_fields, + priority=rule.priority, + last_updated=str(rule.last_updated), + last_updated_by=rule.last_updated_by, + created_at=str(rule.created_at), + created_by=rule.created_by, + ingested=0, + dedup_ratio=0.0, + enabled=rule.enabled, + ) + for rule in custom_deduplications + ] + + custom_deduplications_dict = {} + for rule in custom_deduplications_dto: + key = f"{rule.provider_type}_{rule.provider_id}" + if key not in custom_deduplications_dict: + custom_deduplications_dict[key] = [] + custom_deduplications_dict[key].append(rule) + + # get the "catch all" full deduplication rule + catch_all_full_deduplication = self._get_default_full_deduplication_rule( + provider_id=None, provider_type=None + ) + + # calculate the deduplciations + # if a provider has custom deduplication rule, use it + # else, use the default deduplication rule of the provider + final_deduplications = [catch_all_full_deduplication] + for provider in providers: + # if the provider doesn't have a deduplication rule, use the default one + key = f"{provider.type}_{provider.id}" + if key not in custom_deduplications_dict: + # no default deduplication rule found [if provider doesn't have FINGERPRINT_FIELDS] + if provider.type not in default_deduplications_dict: + self.logger.warning( + f"Provider {provider.type} does not have a default deduplication" + ) + continue + + # create a copy of the default deduplication rule + default_deduplication = copy.deepcopy( + default_deduplications_dict[provider.type] + ) + default_deduplication.id = self._generate_uuid( + provider.id, provider.type + ) + # copy the provider id to the description + if provider.id: + default_deduplication.description = ( + f"{default_deduplication.description} - {provider.id}" + ) + default_deduplication.provider_id = provider.id + # set the provider type + final_deduplications.append(default_deduplication) + # else, just use the custom deduplication rule + else: + final_deduplications += custom_deduplications_dict[key] + + # now calculate some statistics + # alerts_by_provider_stats = get_all_alerts_by_providers(self.tenant_id) + deduplication_stats = get_all_deduplication_stats(self.tenant_id) + + result = [] + for dedup in final_deduplications: + key = dedup.id + full_dedup = deduplication_stats.get(key, {"full_dedup_count": 0}).get( + "full_dedup_count", 0 + ) + partial_dedup = deduplication_stats.get( + key, {"partial_dedup_count": 0} + ).get("partial_dedup_count", 0) + none_dedup = deduplication_stats.get(key, {"none_dedup_count": 0}).get( + "none_dedup_count", 0 + ) + + dedup.ingested = full_dedup + partial_dedup + none_dedup + # total dedup count is the sum of full and partial dedup count + dedup_count = full_dedup + partial_dedup + + if dedup.ingested == 0: + dedup.dedup_ratio = 0.0 + # this shouldn't happen, only in backward compatibility or some bug that dedup events are not created + elif key not in deduplication_stats: + self.logger.warning(f"Provider {key} does not have deduplication stats") + dedup.dedup_ratio = 0.0 + elif dedup_count == 0: + dedup.dedup_ratio = 0.0 + else: + dedup.dedup_ratio = (dedup_count / dedup.ingested) * 100 + dedup.distribution = deduplication_stats[key].get( + "alerts_last_24_hours" + ) + result.append(dedup) + + if self.provider_distribution_enabled: + for dedup in result: + for pd, stats in deduplication_stats.items(): + if pd == f"{dedup.provider_id}_{dedup.provider_type}": + distribution = stats.get("alert_last_24_hours") + dedup.distribution = distribution + break + + # sort providers to have enabled first + result = sorted(result, key=lambda x: x.default, reverse=True) + + # if the default is empty, remove it + if len(result) == 1 and result[0].ingested == 0: + # empty states, no alerts + return [] + + return result + + def get_deduplication_fields(self) -> list[str]: + fields = get_alerts_fields(self.tenant_id) + + fields_per_provider = {} + for field in fields: + provider_type = field.provider_type if field.provider_type else "null" + provider_id = field.provider_id if field.provider_id else "null" + key = f"{provider_type}_{provider_id}" + if key not in fields_per_provider: + fields_per_provider[key] = [] + fields_per_provider[key].append(field.field_name) + + return fields_per_provider + + def create_deduplication_rule( + self, rule: DeduplicationRuleRequestDto, created_by: str + ) -> DeduplicationRuleDto: + # check that provider installed (cannot create deduplication rule for uninstalled provider) + provider = None + installed_providers = ProvidersFactory.get_installed_providers(self.tenant_id) + linked_providers = ProvidersFactory.get_linked_providers(self.tenant_id) + provider_key = f"{rule.provider_type}_{rule.provider_id}" + for p in installed_providers + linked_providers: + if provider_key == f"{p.type}_{p.id}": + provider = p + break + + if not provider: + message = f"Provider {rule.provider_type} not found" + if rule.provider_id: + message += f" with id {rule.provider_id}" + raise HTTPException( + status_code=404, + detail=message, + ) + + # Use the db function to create a new deduplication rule + new_rule = create_deduplication_rule( + tenant_id=self.tenant_id, + name=rule.name, + description=rule.description, + provider_id=rule.provider_id, + provider_type=rule.provider_type, + created_by=created_by, + enabled=True, + fingerprint_fields=rule.fingerprint_fields, + full_deduplication=rule.full_deduplication, + ignore_fields=rule.ignore_fields or [], + priority=0, + ) + + return new_rule + + def update_deduplication_rule( + self, rule_id: str, rule: DeduplicationRuleRequestDto, updated_by: str + ) -> DeduplicationRuleDto: + # check if this is a default rule + default_rule_id = self._generate_uuid(rule.provider_id, rule.provider_type) + # if its a default, we need to override and create a new rule + if rule_id == default_rule_id: + self.logger.info("Default rule update, creating a new rule") + rule_dto = self.create_deduplication_rule(rule, updated_by) + self.logger.info("Default rule updated") + return rule_dto + + # else, use the db function to update an existing deduplication rule + updated_rule = update_deduplication_rule( + rule_id=rule_id, + tenant_id=self.tenant_id, + name=rule.name, + description=rule.description, + provider_id=rule.provider_id, + provider_type=rule.provider_type, + last_updated_by=updated_by, + enabled=True, + fingerprint_fields=rule.fingerprint_fields, + full_deduplication=rule.full_deduplication, + ignore_fields=rule.ignore_fields or [], + priority=0, + ) + + return updated_rule + + def delete_deduplication_rule(self, rule_id: str) -> bool: + # Use the db function to delete a deduplication rule + success = delete_deduplication_rule(rule_id=rule_id, tenant_id=self.tenant_id) + + return success diff --git a/keep/api/api.py b/keep/api/api.py index 86c3445ea..5bcc0ddcc 100644 --- a/keep/api/api.py +++ b/keep/api/api.py @@ -37,6 +37,7 @@ ai, alerts, dashboard, + deduplications, extraction, healthcheck, incidents, @@ -235,7 +236,9 @@ def get_app( app.include_router(tags.router, prefix="/tags", tags=["tags"]) app.include_router(maintenance.router, prefix="/maintenance", tags=["maintenance"]) app.include_router(topology.router, prefix="/topology", tags=["topology"]) - + app.include_router( + deduplications.router, prefix="/deduplications", tags=["deduplications"] + ) # if its single tenant with authentication, add signin endpoint logger.info(f"Starting Keep with authentication type: {AUTH_TYPE}") # If we run Keep with SINGLE_TENANT auth type, we want to add the signin endpoint @@ -298,12 +301,21 @@ async def on_shutdown(): if SCHEDULER: logger.info("Stopping the scheduler") wf_manager = WorkflowManager.get_instance() - await wf_manager.stop() + # stop the scheduler + try: + await wf_manager.stop() + # in pytest, there could be race condition + except TypeError: + pass logger.info("Scheduler stopped successfully") if CONSUMER: logger.info("Stopping the consumer") event_subscriber = EventSubscriber.get_instance() - await event_subscriber.stop() + try: + await event_subscriber.stop() + # in pytest, there could be race condition + except TypeError: + pass logger.info("Consumer stopped successfully") # ARQ workers stops themselves? see "shutdown on SIGTERM" in logs logger.info("Keep shutdown complete") diff --git a/keep/api/core/db.py b/keep/api/core/db.py index 1381657a1..c523de1b0 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -20,10 +20,13 @@ from dotenv import find_dotenv, load_dotenv from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor from sqlalchemy import and_, desc, null, update +from sqlalchemy.dialects.mysql import insert as mysql_insert +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.orm import joinedload, selectinload, subqueryload from sqlalchemy.sql import expression -from sqlmodel import Session, col, or_, select +from sqlmodel import Session, col, or_, select, text from keep.api.core.db_utils import create_db_engine, get_json_extract_field @@ -957,7 +960,6 @@ def count_alerts( def get_enrichment(tenant_id, fingerprint, refresh=False): with Session(engine) as session: return get_enrichment_with_session(session, tenant_id, fingerprint, refresh) - return alert_enrichment def get_enrichments( @@ -985,7 +987,7 @@ def get_enrichment_with_session(session, tenant_id, fingerprint, refresh=False): .where(AlertEnrichment.tenant_id == tenant_id) .where(AlertEnrichment.alert_fingerprint == fingerprint) ).first() - if refresh: + if refresh and alert_enrichment: try: session.refresh(alert_enrichment) except Exception: @@ -1681,19 +1683,275 @@ def get_rule_distribution(tenant_id, minute=False): return rule_distribution -def get_all_filters(tenant_id): +def get_all_deduplication_rules(tenant_id): with Session(engine) as session: - filters = session.exec( - select(AlertDeduplicationFilter).where( - AlertDeduplicationFilter.tenant_id == tenant_id + rules = session.exec( + select(AlertDeduplicationRule).where( + AlertDeduplicationRule.tenant_id == tenant_id ) ).all() - return filters + return rules -def get_last_alert_hash_by_fingerprint(tenant_id, fingerprint): - from sqlalchemy.dialects import mssql +def get_custom_deduplication_rule(tenant_id, provider_id, provider_type): + with Session(engine) as session: + rule = session.exec( + select(AlertDeduplicationRule) + .where(AlertDeduplicationRule.tenant_id == tenant_id) + .where(AlertDeduplicationRule.provider_id == provider_id) + .where(AlertDeduplicationRule.provider_type == provider_type) + ).first() + return rule + + +def create_deduplication_rule( + tenant_id: str, + name: str, + description: str, + provider_id: str | None, + provider_type: str, + created_by: str, + enabled: bool = True, + fingerprint_fields: list[str] = [], + full_deduplication: bool = False, + ignore_fields: list[str] = [], + priority: int = 0, +): + with Session(engine) as session: + new_rule = AlertDeduplicationRule( + tenant_id=tenant_id, + name=name, + description=description, + provider_id=provider_id, + provider_type=provider_type, + last_updated_by=created_by, # on creation, last_updated_by is the same as created_by + created_by=created_by, + enabled=enabled, + fingerprint_fields=fingerprint_fields, + full_deduplication=full_deduplication, + ignore_fields=ignore_fields, + priority=priority, + ) + session.add(new_rule) + session.commit() + session.refresh(new_rule) + return new_rule + +def update_deduplication_rule( + rule_id: str, + tenant_id: str, + name: str, + description: str, + provider_id: str | None, + provider_type: str, + last_updated_by: str, + enabled: bool = True, + fingerprint_fields: list[str] = [], + full_deduplication: bool = False, + ignore_fields: list[str] = [], + priority: int = 0, +): + with Session(engine) as session: + rule = session.exec( + select(AlertDeduplicationRule) + .where(AlertDeduplicationRule.id == rule_id) + .where(AlertDeduplicationRule.tenant_id == tenant_id) + ).first() + if not rule: + raise ValueError(f"No deduplication rule found with id {rule_id}") + + rule.name = name + rule.description = description + rule.provider_id = provider_id + rule.provider_type = provider_type + rule.last_updated_by = last_updated_by + rule.enabled = enabled + rule.fingerprint_fields = fingerprint_fields + rule.full_deduplication = full_deduplication + rule.ignore_fields = ignore_fields + rule.priority = priority + + session.add(rule) + session.commit() + session.refresh(rule) + return rule + + +def delete_deduplication_rule(rule_id: str, tenant_id: str) -> bool: + with Session(engine) as session: + rule = session.exec( + select(AlertDeduplicationRule) + .where(AlertDeduplicationRule.id == rule_id) + .where(AlertDeduplicationRule.tenant_id == tenant_id) + ).first() + if not rule: + return False + + session.delete(rule) + session.commit() + return True + + +def get_custom_deduplication_rules(tenant_id, provider_id, provider_type): + with Session(engine) as session: + rules = session.exec( + select(AlertDeduplicationRule) + .where(AlertDeduplicationRule.tenant_id == tenant_id) + .where(AlertDeduplicationRule.provider_id == provider_id) + .where(AlertDeduplicationRule.provider_type == provider_type) + ).all() + return rules + + +def create_deduplication_event( + tenant_id, deduplication_rule_id, deduplication_type, provider_id, provider_type +): + with Session(engine) as session: + deduplication_event = AlertDeduplicationEvent( + tenant_id=tenant_id, + deduplication_rule_id=deduplication_rule_id, + deduplication_type=deduplication_type, + provider_id=provider_id, + provider_type=provider_type, + timestamp=datetime.utcnow(), + date_hour=datetime.utcnow().replace(minute=0, second=0, microsecond=0), + ) + session.add(deduplication_event) + session.commit() + + +def get_all_alerts_by_providers(tenant_id): + with Session(engine) as session: + # Query to get the count of alerts per provider_id and provider_type + query = ( + select( + Alert.provider_id, + Alert.provider_type, + func.count(Alert.id).label("num_alerts"), + ) + .where(Alert.tenant_id == tenant_id) + .group_by(Alert.provider_id, Alert.provider_type) + ) + + results = session.exec(query).all() + + # Create a dictionary with the number of alerts for each provider + stats = {} + for result in results: + provider_id = result.provider_id + provider_type = result.provider_type + num_alerts = result.num_alerts + + key = f"{provider_type}_{provider_id}" + stats[key] = { + "num_alerts": num_alerts, + } + + return stats + + +def get_all_deduplication_stats(tenant_id): + with Session(engine) as session: + # Query to get all-time deduplication stats + all_time_query = ( + select( + AlertDeduplicationEvent.deduplication_rule_id, + AlertDeduplicationEvent.provider_id, + AlertDeduplicationEvent.provider_type, + AlertDeduplicationEvent.deduplication_type, + func.count(AlertDeduplicationEvent.id).label("dedup_count"), + ) + .where(AlertDeduplicationEvent.tenant_id == tenant_id) + .group_by( + AlertDeduplicationEvent.deduplication_rule_id, + AlertDeduplicationEvent.provider_id, + AlertDeduplicationEvent.provider_type, + AlertDeduplicationEvent.deduplication_type, + ) + ) + + all_time_results = session.exec(all_time_query).all() + + # Query to get alerts distribution in the last 24 hours + twenty_four_hours_ago = datetime.utcnow() - timedelta(hours=24) + alerts_last_24_hours_query = ( + select( + AlertDeduplicationEvent.deduplication_rule_id, + AlertDeduplicationEvent.provider_id, + AlertDeduplicationEvent.provider_type, + AlertDeduplicationEvent.date_hour, + func.count(AlertDeduplicationEvent.id).label("hourly_count"), + ) + .where(AlertDeduplicationEvent.tenant_id == tenant_id) + .where(AlertDeduplicationEvent.date_hour >= twenty_four_hours_ago) + .group_by( + AlertDeduplicationEvent.deduplication_rule_id, + AlertDeduplicationEvent.provider_id, + AlertDeduplicationEvent.provider_type, + AlertDeduplicationEvent.date_hour, + ) + ) + + alerts_last_24_hours_results = session.exec(alerts_last_24_hours_query).all() + + # Create a dictionary with deduplication stats for each rule + stats = {} + current_hour = datetime.utcnow().replace(minute=0, second=0, microsecond=0) + for result in all_time_results: + provider_id = result.provider_id + provider_type = result.provider_type + dedup_count = result.dedup_count + dedup_type = result.deduplication_type + + # alerts without provider_id and provider_type are considered as "keep" + if not provider_type: + provider_type = "keep" + + key = str(result.deduplication_rule_id) + if key not in stats: + # initialize the stats for the deduplication rule + stats[key] = { + "full_dedup_count": 0, + "partial_dedup_count": 0, + "none_dedup_count": 0, + "alerts_last_24_hours": [ + {"hour": (current_hour - timedelta(hours=i)).hour, "number": 0} + for i in range(0, 24) + ], + "provider_id": provider_id, + "provider_type": provider_type, + } + + if dedup_type == "full": + stats[key]["full_dedup_count"] += dedup_count + elif dedup_type == "partial": + stats[key]["partial_dedup_count"] += dedup_count + elif dedup_type == "none": + stats[key]["none_dedup_count"] += dedup_count + + # Add alerts distribution from the last 24 hours + for result in alerts_last_24_hours_results: + provider_id = result.provider_id + provider_type = result.provider_type + date_hour = result.date_hour + hourly_count = result.hourly_count + key = str(result.deduplication_rule_id) + + if not provider_type: + provider_type = "keep" + + if key in stats: + hours_ago = int((current_hour - date_hour).total_seconds() / 3600) + if 0 <= hours_ago < 24: + stats[key]["alerts_last_24_hours"][23 - hours_ago][ + "number" + ] = hourly_count + + return stats + + +def get_last_alert_hash_by_fingerprint(tenant_id, fingerprint): # get the last alert for a given fingerprint # to check deduplication with Session(engine) as session: @@ -1705,12 +1963,6 @@ def get_last_alert_hash_by_fingerprint(tenant_id, fingerprint): .limit(1) # Add LIMIT 1 for MSSQL ) - # Compile the query and log it - compiled_query = query.compile( - dialect=mssql.dialect(), compile_kwargs={"literal_binds": True} - ) - logger.debug(f"Compiled query: {compiled_query}") - alert_hash = session.exec(query).first() return alert_hash @@ -2911,6 +3163,109 @@ def get_provider_by_name(tenant_id: str, provider_name: str) -> Provider: return provider +def get_provider_by_type_and_id( + tenant_id: str, provider_type: str, provider_id: Optional[str] +) -> Provider: + with Session(engine) as session: + query = select(Provider).where( + Provider.tenant_id == tenant_id, + Provider.type == provider_type, + Provider.id == provider_id, + ) + provider = session.exec(query).first() + return provider + + +def bulk_upsert_alert_fields( + tenant_id: str, fields: List[str], provider_id: str, provider_type: str +): + with Session(engine) as session: + try: + # Prepare the data for bulk insert + data = [ + { + "tenant_id": tenant_id, + "field_name": field, + "provider_id": provider_id, + "provider_type": provider_type, + } + for field in fields + ] + + if engine.dialect.name == "postgresql": + stmt = pg_insert(AlertField).values(data) + stmt = stmt.on_conflict_do_update( + index_elements=[ + "tenant_id", + "field_name", + ], # Unique constraint columns + set_={ + "provider_id": stmt.excluded.provider_id, + "provider_type": stmt.excluded.provider_type, + }, + ) + elif engine.dialect.name == "mysql": + stmt = mysql_insert(AlertField).values(data) + stmt = stmt.on_duplicate_key_update( + provider_id=stmt.inserted.provider_id, + provider_type=stmt.inserted.provider_type, + ) + elif engine.dialect.name == "sqlite": + stmt = sqlite_insert(AlertField).values(data) + stmt = stmt.on_conflict_do_update( + index_elements=[ + "tenant_id", + "field_name", + ], # Unique constraint columns + set_={ + "provider_id": stmt.excluded.provider_id, + "provider_type": stmt.excluded.provider_type, + }, + ) + elif engine.dialect.name == "mssql": + # SQL Server requires a raw query with a MERGE statement + values = ", ".join( + f"('{tenant_id}', '{field}', '{provider_id}', '{provider_type}')" + for field in fields + ) + + merge_query = text( + f""" + MERGE INTO AlertField AS target + USING (VALUES {values}) AS source (tenant_id, field_name, provider_id, provider_type) + ON target.tenant_id = source.tenant_id AND target.field_name = source.field_name + WHEN MATCHED THEN + UPDATE SET provider_id = source.provider_id, provider_type = source.provider_type + WHEN NOT MATCHED THEN + INSERT (tenant_id, field_name, provider_id, provider_type) + VALUES (source.tenant_id, source.field_name, source.provider_id, source.provider_type) + """ + ) + + session.execute(merge_query) + else: + raise NotImplementedError( + f"Upsert not supported for {engine.dialect.name}" + ) + + # Execute the statement + if engine.dialect.name != "mssql": # Already executed for SQL Server + session.execute(stmt) + session.commit() + + except IntegrityError: + # Handle any potential race conditions + session.rollback() + + +def get_alerts_fields(tenant_id: str) -> List[AlertField]: + with Session(engine) as session: + fields = session.exec( + select(AlertField).where(AlertField.tenant_id == tenant_id) + ).all() + return fields + + def change_incident_status_by_id( tenant_id: str, incident_id: UUID | str, status: IncidentStatus ) -> bool: diff --git a/keep/api/logging.py b/keep/api/logging.py index bd3a75e59..4fbaac2d7 100644 --- a/keep/api/logging.py +++ b/keep/api/logging.py @@ -93,6 +93,29 @@ def dump(self): LOG_FORMAT = os.environ.get("LOG_FORMAT", LOG_FORMAT_OPEN_TELEMETRY) +class DevTerminalFormatter(logging.Formatter): + def format(self, record): + message = super().format(record) + extra_info = "" + + # Use inspect to go up the stack until we find the _log function + frame = inspect.currentframe() + while frame: + if frame.f_code.co_name == "_log": + # Extract extra from the _log function's local variables + extra = frame.f_locals.get("extra", {}) + if extra: + extra_info = " ".join( + [f"[{k}: {v}]" for k, v in extra.items() if k != "raw_event"] + ) + else: + extra_info = "" + break + frame = frame.f_back + + return f"{message} {extra_info}" + + CONFIG = { "version": 1, "disable_existing_loggers": False, @@ -100,18 +123,26 @@ def dump(self): "json": { "format": "%(asctime)s %(message)s %(levelname)s %(name)s %(filename)s %(otelTraceID)s %(otelSpanID)s %(otelServiceName)s %(threadName)s %(process)s %(module)s", "class": "pythonjsonlogger.jsonlogger.JsonFormatter", - } + }, + "dev_terminal": { + "()": DevTerminalFormatter, + "format": "%(asctime)s - %(levelname)s - %(message)s", + }, }, "handlers": { "default": { "level": "DEBUG", - "formatter": "json" if LOG_FORMAT == LOG_FORMAT_OPEN_TELEMETRY else None, + "formatter": ( + "json" if LOG_FORMAT == LOG_FORMAT_OPEN_TELEMETRY else "dev_terminal" + ), "class": "logging.StreamHandler", "stream": "ext://sys.stdout", }, "context": { "level": "DEBUG", - "formatter": "json" if LOG_FORMAT == LOG_FORMAT_OPEN_TELEMETRY else None, + "formatter": ( + "json" if LOG_FORMAT == LOG_FORMAT_OPEN_TELEMETRY else "dev_terminal" + ), "class": "keep.api.logging.WorkflowDBHandler", }, }, diff --git a/keep/api/models/alert.py b/keep/api/models/alert.py index 5d635f6eb..6426a8238 100644 --- a/keep/api/models/alert.py +++ b/keep/api/models/alert.py @@ -4,7 +4,7 @@ import logging import uuid from enum import Enum -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from uuid import UUID import pytz @@ -124,7 +124,8 @@ class AlertDto(BaseModel): lastReceived: str firingStartTime: str | None = None environment: str = "undefined" - isDuplicate: bool | None = None + isFullDuplicate: bool | None = False + isPartialDuplicate: bool | None = False duplicateReason: str | None = None service: str | None = None source: list[str] | None = [] @@ -305,7 +306,6 @@ class Config: "status": "firing", "lastReceived": "2021-01-01T00:00:00.000Z", "environment": "production", - "isDuplicate": False, "duplicateReason": None, "service": "backend", "source": ["keep"], @@ -472,6 +472,36 @@ def from_db_incident(cls, db_incident): return dto +class DeduplicationRuleDto(BaseModel): + id: str | None # UUID + name: str + description: str + default: bool + distribution: list[dict] # list of {hour: int, count: int} + provider_id: str | None # None for default rules + provider_type: str + last_updated: str | None + last_updated_by: str | None + created_at: str | None + created_by: str | None + ingested: int + dedup_ratio: float + enabled: bool + fingerprint_fields: list[str] + full_deduplication: bool + ignore_fields: list[str] + + +class DeduplicationRuleRequestDto(BaseModel): + name: str + description: Optional[str] = None + provider_type: str + provider_id: Optional[str] = None + fingerprint_fields: list[str] + full_deduplication: bool = False + ignore_fields: Optional[list[str]] = None + + class IncidentStatusChangeDto(BaseModel): status: IncidentStatus comment: str | None diff --git a/keep/api/models/db/alert.py b/keep/api/models/db/alert.py index 4e26bd989..769777adf 100644 --- a/keep/api/models/db/alert.py +++ b/keep/api/models/db/alert.py @@ -4,7 +4,7 @@ from typing import List from uuid import UUID, uuid4 -from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKey, UniqueConstraint from sqlalchemy.dialects.mssql import DATETIME2 as MSSQL_DATETIME2 from sqlalchemy.dialects.mysql import DATETIME as MySQL_DATETIME from sqlalchemy.engine.url import make_url @@ -170,13 +170,86 @@ class Config: arbitrary_types_allowed = True -class AlertDeduplicationFilter(SQLModel, table=True): +class AlertDeduplicationRule(SQLModel, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True) tenant_id: str = Field(foreign_key="tenant.id") - # the list of fields to pop from the alert before hashing - fields: list = Field(sa_column=Column(JSON), default=[]) - # a CEL expression to match the alert - matcher_cel: str + name: str = Field(index=True) + description: str + provider_id: str | None = Field(default=None) # None for default rules + provider_type: str + last_updated: datetime = Field(default_factory=datetime.utcnow) + last_updated_by: str + created_at: datetime = Field(default_factory=datetime.utcnow) + created_by: str + enabled: bool = Field(default=True) + fingerprint_fields: list[str] = Field(sa_column=Column(JSON), default=[]) + full_deduplication: bool = Field(default=False) + ignore_fields: list[str] = Field(sa_column=Column(JSON), default=[]) + priority: int = Field(default=0) # for future use + + class Config: + arbitrary_types_allowed = True + + +class AlertDeduplicationEvent(SQLModel, table=True): + id: UUID = Field(default_factory=uuid4, primary_key=True) + tenant_id: str = Field(foreign_key="tenant.id", index=True) + timestamp: datetime = Field( + sa_column=Column(datetime_column_type, nullable=False), + default_factory=datetime.utcnow, + ) + deduplication_rule_id: UUID # TODO: currently rules can also be implicit (like default) so they won't exists on db Field(foreign_key="alertdeduplicationrule.id", index=True) + deduplication_type: str = Field() # 'full' or 'partial' + date_hour: datetime = Field( + sa_column=Column(datetime_column_type), + default_factory=lambda: datetime.utcnow().replace( + minute=0, second=0, microsecond=0 + ), + ) + # these are only soft reference since it could be linked provider + provider_id: str | None = Field() + provider_type: str | None = Field() + + __table_args__ = ( + Index( + "ix_alert_deduplication_event_provider_id", + "provider_id", + ), + Index( + "ix_alert_deduplication_event_provider_type", + "provider_type", + ), + Index( + "ix_alert_deduplication_event_provider_id_date_hour", + "provider_id", + "date_hour", + ), + Index( + "ix_alert_deduplication_event_provider_type_date_hour", + "provider_type", + "date_hour", + ), + ) + + class Config: + arbitrary_types_allowed = True + + +class AlertField(SQLModel, table=True): + id: UUID = Field(default_factory=uuid4, primary_key=True) + tenant_id: str = Field(foreign_key="tenant.id", index=True) + field_name: str = Field(index=True) + provider_id: str | None = Field(index=True) + provider_type: str | None = Field(index=True) + + __table_args__ = ( + UniqueConstraint("tenant_id", "field_name", name="uq_tenant_field"), + Index("ix_alert_field_tenant_id", "tenant_id"), + Index("ix_alert_field_tenant_id_field_name", "tenant_id", "field_name"), + Index( + "ix_alert_field_provider_id_provider_type", "provider_id", "provider_type" + ), + ) class Config: arbitrary_types_allowed = True diff --git a/keep/api/models/db/migrations/versions/2024-09-19-15-26_493f217af6b6.py b/keep/api/models/db/migrations/versions/2024-09-19-15-26_493f217af6b6.py new file mode 100644 index 000000000..39bc26942 --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-09-19-15-26_493f217af6b6.py @@ -0,0 +1,198 @@ +"""Dedup + +Revision ID: 493f217af6b6 +Revises: 5d7ae55efc6a +Create Date: 2024-09-19 15:26:21.564118 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import sqlite + +# revision identifiers, used by Alembic. +revision = "493f217af6b6" +down_revision = "5d7ae55efc6a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "alertdeduplicationevent", + sa.Column("timestamp", sa.DateTime(), nullable=False), + sa.Column("date_hour", sa.DateTime(), nullable=True), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column( + "deduplication_rule_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column( + "deduplication_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("provider_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("provider_type", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "ix_alert_deduplication_event_provider_id", + "alertdeduplicationevent", + ["provider_id"], + unique=False, + ) + op.create_index( + "ix_alert_deduplication_event_provider_id_date_hour", + "alertdeduplicationevent", + ["provider_id", "date_hour"], + unique=False, + ) + op.create_index( + "ix_alert_deduplication_event_provider_type", + "alertdeduplicationevent", + ["provider_type"], + unique=False, + ) + op.create_index( + "ix_alert_deduplication_event_provider_type_date_hour", + "alertdeduplicationevent", + ["provider_type", "date_hour"], + unique=False, + ) + op.create_index( + op.f("ix_alertdeduplicationevent_tenant_id"), + "alertdeduplicationevent", + ["tenant_id"], + unique=False, + ) + op.create_table( + "alertdeduplicationrule", + sa.Column("fingerprint_fields", sa.JSON(), nullable=True), + sa.Column("ignore_fields", sa.JSON(), nullable=True), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("provider_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("provider_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("last_updated", sa.DateTime(), nullable=False), + sa.Column( + "last_updated_by", sqlmodel.sql.sqltypes.AutoString(), nullable=False + ), + sa.Column("created_at", sa.DateTime(), nullable=False), + sa.Column("created_by", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("enabled", sa.Boolean(), nullable=False), + sa.Column("full_deduplication", sa.Boolean(), nullable=False), + sa.Column("priority", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_alertdeduplicationrule_name"), + "alertdeduplicationrule", + ["name"], + unique=False, + ) + op.create_table( + "alertfield", + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("field_name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("provider_id", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("provider_type", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("tenant_id", "field_name", name="uq_tenant_field"), + ) + op.create_index( + "ix_alert_field_provider_id_provider_type", + "alertfield", + ["provider_id", "provider_type"], + unique=False, + ) + op.create_index( + "ix_alert_field_tenant_id", "alertfield", ["tenant_id"], unique=False + ) + op.create_index( + "ix_alert_field_tenant_id_field_name", + "alertfield", + ["tenant_id", "field_name"], + unique=False, + ) + op.create_index( + op.f("ix_alertfield_field_name"), "alertfield", ["field_name"], unique=False + ) + op.create_index( + op.f("ix_alertfield_provider_id"), "alertfield", ["provider_id"], unique=False + ) + op.create_index( + op.f("ix_alertfield_provider_type"), + "alertfield", + ["provider_type"], + unique=False, + ) + op.create_index( + op.f("ix_alertfield_tenant_id"), "alertfield", ["tenant_id"], unique=False + ) + op.drop_table("alertdeduplicationfilter") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "alertdeduplicationfilter", + sa.Column("fields", sqlite.JSON(), nullable=True), + sa.Column("id", sa.CHAR(length=32), nullable=False), + sa.Column("tenant_id", sa.VARCHAR(), nullable=False), + sa.Column("matcher_cel", sa.VARCHAR(), nullable=False), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.drop_index(op.f("ix_alertfield_tenant_id"), table_name="alertfield") + op.drop_index(op.f("ix_alertfield_provider_type"), table_name="alertfield") + op.drop_index(op.f("ix_alertfield_provider_id"), table_name="alertfield") + op.drop_index(op.f("ix_alertfield_field_name"), table_name="alertfield") + op.drop_index("ix_alert_field_tenant_id_field_name", table_name="alertfield") + op.drop_index("ix_alert_field_tenant_id", table_name="alertfield") + op.drop_index("ix_alert_field_provider_id_provider_type", table_name="alertfield") + op.drop_table("alertfield") + op.drop_index( + op.f("ix_alertdeduplicationrule_name"), table_name="alertdeduplicationrule" + ) + op.drop_table("alertdeduplicationrule") + op.drop_index( + op.f("ix_alertdeduplicationevent_tenant_id"), + table_name="alertdeduplicationevent", + ) + op.drop_index( + "ix_alert_deduplication_event_provider_type_date_hour", + table_name="alertdeduplicationevent", + ) + op.drop_index( + "ix_alert_deduplication_event_provider_type", + table_name="alertdeduplicationevent", + ) + op.drop_index( + "ix_alert_deduplication_event_provider_id_date_hour", + table_name="alertdeduplicationevent", + ) + op.drop_index( + "ix_alert_deduplication_event_provider_id", table_name="alertdeduplicationevent" + ) + op.drop_table("alertdeduplicationevent") + # ### end Alembic commands ### diff --git a/keep/api/models/provider.py b/keep/api/models/provider.py index 76307de20..b800a2890 100644 --- a/keep/api/models/provider.py +++ b/keep/api/models/provider.py @@ -44,4 +44,5 @@ class Provider(BaseModel): ] = [] alertsDistribution: dict[str, int] | None = None alertExample: dict | None = None + default_fingerprint_fields: list[str] | None = None provisioned: bool = False diff --git a/keep/api/routes/deduplications.py b/keep/api/routes/deduplications.py new file mode 100644 index 000000000..0e8b45c54 --- /dev/null +++ b/keep/api/routes/deduplications.py @@ -0,0 +1,142 @@ +import logging +import uuid + +from fastapi import APIRouter, Depends, HTTPException + +from keep.api.alert_deduplicator.alert_deduplicator import AlertDeduplicator +from keep.api.models.alert import DeduplicationRuleRequestDto as DeduplicationRule +from keep.identitymanager.authenticatedentity import AuthenticatedEntity +from keep.identitymanager.identitymanagerfactory import IdentityManagerFactory + +router = APIRouter() + +logger = logging.getLogger(__name__) + + +@router.get( + "", + description="Get Deduplications", +) +def get_deduplications( + authenticated_entity: AuthenticatedEntity = Depends( + IdentityManagerFactory.get_auth_verifier(["read:deduplications"]) + ), +): + tenant_id = authenticated_entity.tenant_id + logger.info("Getting deduplications") + + alert_deduplicator = AlertDeduplicator(tenant_id) + deduplications = alert_deduplicator.get_deduplications() + + logger.info(deduplications) + return deduplications + + +@router.get( + "/fields", + description="Get Optional Fields For Deduplications", +) +def get_deduplication_fields( + authenticated_entity: AuthenticatedEntity = Depends( + IdentityManagerFactory.get_auth_verifier(["read:deduplications"]) + ), +) -> dict[str, list[str]]: + tenant_id = authenticated_entity.tenant_id + logger.info("Getting deduplication fields") + + alert_deduplicator = AlertDeduplicator(tenant_id) + fields = alert_deduplicator.get_deduplication_fields() + + logger.info("Got deduplication fields") + return fields + + +@router.post( + "", + description="Create Deduplication Rule", +) +def create_deduplication_rule( + rule: DeduplicationRule, + authenticated_entity: AuthenticatedEntity = Depends( + IdentityManagerFactory.get_auth_verifier(["write:deduplications"]) + ), +): + tenant_id = authenticated_entity.tenant_id + logger.info( + "Creating deduplication rule", + extra={"tenant_id": tenant_id, "rule": rule.dict()}, + ) + alert_deduplicator = AlertDeduplicator(tenant_id) + try: + # This is a custom rule + created_rule = alert_deduplicator.create_deduplication_rule( + rule=rule, created_by=authenticated_entity.email + ) + logger.info("Created deduplication rule") + return created_rule + except HTTPException as e: + raise e + except Exception as e: + logger.exception("Error creating deduplication rule") + raise HTTPException(status_code=400, detail=str(e)) + + +@router.put( + "/{rule_id}", + description="Update Deduplication Rule", +) +def update_deduplication_rule( + rule_id: str, + rule: DeduplicationRule, + authenticated_entity: AuthenticatedEntity = Depends( + IdentityManagerFactory.get_auth_verifier(["write:deduplications"]) + ), +): + tenant_id = authenticated_entity.tenant_id + logger.info("Updating deduplication rule", extra={"rule_id": rule_id}) + alert_deduplicator = AlertDeduplicator(tenant_id) + try: + updated_rule = alert_deduplicator.update_deduplication_rule( + rule_id, rule, authenticated_entity.email + ) + logger.info("Updated deduplication rule") + return updated_rule + except Exception as e: + logger.exception("Error updating deduplication rule") + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete( + "/{rule_id}", + description="Delete Deduplication Rule", +) +def delete_deduplication_rule( + rule_id: str, + authenticated_entity: AuthenticatedEntity = Depends( + IdentityManagerFactory.get_auth_verifier(["write:deduplications"]) + ), +): + tenant_id = authenticated_entity.tenant_id + logger.info("Deleting deduplication rule", extra={"rule_id": rule_id}) + alert_deduplicator = AlertDeduplicator(tenant_id) + + # verify rule id is uuid + try: + uuid.UUID(rule_id) + except ValueError: + raise HTTPException(status_code=400, detail="Invalid rule id") + + try: + success = alert_deduplicator.delete_deduplication_rule(rule_id) + if success: + logger.info("Deleted deduplication rule") + return {"message": "Deduplication rule deleted successfully"} + else: + raise HTTPException(status_code=404, detail="Deduplication rule not found") + except HTTPException as e: + logger.exception("Error deleting deduplication rule") + # keep the same status code + raise e + except Exception as e: + logger.exception("Error deleting deduplication rule") + raise HTTPException(status_code=400, detail=str(e)) diff --git a/keep/api/tasks/process_event_task.py b/keep/api/tasks/process_event_task.py index 114481198..c9cbcf383 100644 --- a/keep/api/tasks/process_event_task.py +++ b/keep/api/tasks/process_event_task.py @@ -17,6 +17,7 @@ from keep.api.bl.enrichments_bl import EnrichmentsBl from keep.api.bl.maintenance_windows_bl import MaintenanceWindowsBl from keep.api.core.db import ( + bulk_upsert_alert_fields, get_alerts_by_fingerprint, get_all_presets, get_enrichment_with_session, @@ -277,16 +278,16 @@ def __handle_formatted_events( alert_deduplicator = AlertDeduplicator(tenant_id) for event in formatted_events: - event_hash, event_deduplicated = alert_deduplicator.is_deduplicated(event) - event.alert_hash = event_hash - event.isDuplicate = event_deduplicated + # apply deduplication + # apply_deduplication set alert_hash and isDuplicate on event + event = alert_deduplicator.apply_deduplication(event) # filter out the deduplicated events deduplicated_events = list( - filter(lambda event: event.isDuplicate, formatted_events) + filter(lambda event: event.isFullDuplicate, formatted_events) ) formatted_events = list( - filter(lambda event: not event.isDuplicate, formatted_events) + filter(lambda event: not event.isFullDuplicate, formatted_events) ) # save to db @@ -301,6 +302,39 @@ def __handle_formatted_events( timestamp_forced, ) + # let's save all fields to the DB so that we can use them in the future such in deduplication fields suggestions + # todo: also use it on correlation rules suggestions + for enriched_formatted_event in enriched_formatted_events: + logger.debug( + "Bulk upserting alert fields", + extra={ + "alert_event_id": enriched_formatted_event.event_id, + "alert_fingerprint": enriched_formatted_event.fingerprint, + }, + ) + fields = [] + for key, value in enriched_formatted_event.dict().items(): + if isinstance(value, dict): + for nested_key in value.keys(): + fields.append(f"{key}_{nested_key}") + else: + fields.append(key) + + bulk_upsert_alert_fields( + tenant_id=tenant_id, + fields=fields, + provider_id=enriched_formatted_event.providerId, + provider_type=enriched_formatted_event.providerType, + ) + + logger.debug( + "Bulk upserted alert fields", + extra={ + "alert_event_id": enriched_formatted_event.event_id, + "alert_fingerprint": enriched_formatted_event.fingerprint, + }, + ) + # after the alert enriched and mapped, lets send it to the elasticsearch elastic_client = ElasticClient(tenant_id=tenant_id) for alert in enriched_formatted_events: @@ -482,7 +516,12 @@ def process_event( if provider_type is not None and isinstance(event, dict): provider_class = ProvidersFactory.get_provider_class(provider_type) - event = provider_class.format_alert(event, None) + event = provider_class.format_alert( + tenant_id=tenant_id, + event=event, + provider_id=provider_id, + provider_type=provider_type, + ) # SHAHAR: for aws cloudwatch, we get a subscription notification message that we should skip # todo: move it to be generic if event is None and provider_type == "cloudwatch": diff --git a/keep/providers/appdynamics_provider/appdynamics_provider.py b/keep/providers/appdynamics_provider/appdynamics_provider.py index d33598e7d..5e44dce50 100644 --- a/keep/providers/appdynamics_provider/appdynamics_provider.py +++ b/keep/providers/appdynamics_provider/appdynamics_provider.py @@ -6,7 +6,7 @@ import json import tempfile from pathlib import Path -from typing import List, Optional +from typing import List from urllib.parse import urlencode, urljoin import pydantic @@ -322,10 +322,7 @@ def setup_webhook( self.logger.info("Webhook created") @staticmethod - def _format_alert( - event: dict, - provider_instance: Optional["AppdynamicsProvider"] = None, - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: return AlertDto( id=event["id"], name=event["name"], diff --git a/keep/providers/azuremonitoring_provider/azuremonitoring_provider.py b/keep/providers/azuremonitoring_provider/azuremonitoring_provider.py index 352d2bf36..30c413902 100644 --- a/keep/providers/azuremonitoring_provider/azuremonitoring_provider.py +++ b/keep/providers/azuremonitoring_provider/azuremonitoring_provider.py @@ -3,7 +3,6 @@ """ import datetime -from typing import Optional from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.contextmanager.contextmanager import ContextManager @@ -62,9 +61,7 @@ def validate_config(self): pass @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["AzuremonitoringProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: essentials = event.get("data", {}).get("essentials", {}) alert_context = event.get("data", {}).get("alertContext", {}) diff --git a/keep/providers/base/base_provider.py b/keep/providers/base/base_provider.py index 840869479..ba968837b 100644 --- a/keep/providers/base/base_provider.py +++ b/keep/providers/base/base_provider.py @@ -19,7 +19,7 @@ import requests from keep.api.bl.enrichments_bl import EnrichmentsBl -from keep.api.core.db import get_enrichments +from keep.api.core.db import get_custom_deduplication_rule, get_enrichments from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.api.models.db.alert import AlertActionType from keep.api.models.db.topology import TopologyServiceInDto @@ -36,9 +36,9 @@ class BaseProvider(metaclass=abc.ABCMeta): PROVIDER_SCOPES: list[ProviderScope] = [] PROVIDER_METHODS: list[ProviderMethod] = [] FINGERPRINT_FIELDS: list[str] = [] - PROVIDER_TAGS: list[Literal["alert", "ticketing", "messaging", "data", "queue", "topology"]] = ( - [] - ) + PROVIDER_TAGS: list[ + Literal["alert", "ticketing", "messaging", "data", "queue", "topology"] + ] = [] def __init__( self, @@ -306,12 +306,49 @@ def _format_alert( def format_alert( cls, event: dict, - provider_instance: Optional["BaseProvider"], + tenant_id: str, + provider_type: str, + provider_id: str, ) -> AlertDto | list[AlertDto]: logger = logging.getLogger(__name__) logger.debug("Formatting alert") - formatted_alert = cls._format_alert(event, provider_instance) + formatted_alert = cls._format_alert(event) logger.debug("Alert formatted") + # after the provider calculated the default fingerprint + # check if there is a custom deduplication rule and apply + custom_deduplication_rule = get_custom_deduplication_rule( + tenant_id=tenant_id, + provider_id=provider_id, + provider_type=provider_type, + ) + + if not isinstance(formatted_alert, list): + formatted_alert.providerId = provider_id + formatted_alert.providerType = provider_type + formatted_alert = [formatted_alert] + + else: + for alert in formatted_alert: + alert.providerId = provider_id + alert.providerType = provider_type + + # if there is no custom deduplication rule, return the formatted alert + if not custom_deduplication_rule: + return formatted_alert + # if there is a custom deduplication rule, apply it + # apply the custom deduplication rule to calculate the fingerprint + for alert in formatted_alert: + logger.info( + "Applying custom deduplication rule", + extra={ + "tenant_id": tenant_id, + "provider_id": provider_id, + "alert_id": alert.id, + }, + ) + alert.fingerprint = cls.get_alert_fingerprint( + alert, custom_deduplication_rule.fingerprint_fields + ) return formatted_alert @staticmethod diff --git a/keep/providers/cloudwatch_provider/cloudwatch_provider.py b/keep/providers/cloudwatch_provider/cloudwatch_provider.py index 67b45d822..5dcd8d344 100644 --- a/keep/providers/cloudwatch_provider/cloudwatch_provider.py +++ b/keep/providers/cloudwatch_provider/cloudwatch_provider.py @@ -9,7 +9,6 @@ import logging import os import time -from typing import Optional from urllib.parse import urlparse import boto3 @@ -503,9 +502,7 @@ def parse_event_raw_body(raw_body: bytes | dict) -> dict: return json.loads(raw_body) @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["CloudwatchProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: logger = logging.getLogger(__name__) # if its confirmation event, we need to confirm the subscription if event.get("Type") == "SubscriptionConfirmation": diff --git a/keep/providers/coralogix_provider/coralogix_provider.py b/keep/providers/coralogix_provider/coralogix_provider.py index cd0cf636a..43e2741b3 100644 --- a/keep/providers/coralogix_provider/coralogix_provider.py +++ b/keep/providers/coralogix_provider/coralogix_provider.py @@ -2,13 +2,12 @@ Coralogix is a modern observability platform delivers comprehensive visibility into all your logs, metrics, traces and security events with end-to-end monitoring. """ -from typing import Optional - -from keep.api.models.alert import AlertDto, AlertStatus, AlertSeverity +from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig + class CoralogixProvider(BaseProvider): """Get alerts from Coralogix into Keep.""" @@ -49,7 +48,7 @@ class CoralogixProvider(BaseProvider): FINGERPRINT_FIELDS = ["alertUniqueIdentifier"] def __init__( - self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): super().__init__(context_manager, provider_id, config) @@ -59,42 +58,71 @@ def validate_config(self): """ # no config pass - + def get_value_by_key(fields: dict, key: str): for item in fields: if item["key"] == key: return item["value"] return None - + @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["CoralogixProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: alert = AlertDto( - id=CoralogixProvider.get_value_by_key(event["fields"], "alertUniqueIdentifier") if "fields" in event else None, + id=( + CoralogixProvider.get_value_by_key( + event["fields"], "alertUniqueIdentifier" + ) + if "fields" in event + else None + ), alert_id=event["alert_id"] if "alert_id" in event else None, name=event["name"] if "name" in event else None, description=event["description"] if "description" in event else None, - status=CoralogixProvider.STATUS_MAP.get( - event["alert_action"]), + status=CoralogixProvider.STATUS_MAP.get(event["alert_action"]), severity=CoralogixProvider.SEVERITIES_MAP.get( - CoralogixProvider.get_value_by_key(event["fields"], "severityLowercase")), - lastReceived=CoralogixProvider.get_value_by_key(event["fields"], "timestampISO") if "fields" in event else None, - alertUniqueIdentifier=CoralogixProvider.get_value_by_key(event["fields"], "alertUniqueIdentifier") if "fields" in event else None, + CoralogixProvider.get_value_by_key(event["fields"], "severityLowercase") + ), + lastReceived=( + CoralogixProvider.get_value_by_key(event["fields"], "timestampISO") + if "fields" in event + else None + ), + alertUniqueIdentifier=( + CoralogixProvider.get_value_by_key( + event["fields"], "alertUniqueIdentifier" + ) + if "fields" in event + else None + ), uuid=event["uuid"] if "uuid" in event else None, threshold=event["threshold"] if "threshold" in event else None, timewindow=event["timewindow"] if "timewindow" in event else None, - group_by_labels=event["group_by_labels"] if "group_by_labels" in event else None, + group_by_labels=( + event["group_by_labels"] if "group_by_labels" in event else None + ), alert_url=event["alert_url"] if "alert_url" in event else None, log_url=event["log_url"] if "log_url" in event else None, - team=CoralogixProvider.get_value_by_key(event["fields"], "team") if "fields" in event else None, - priority=CoralogixProvider.get_value_by_key(event["fields"], "priority") if "fields" in event else None, - computer=CoralogixProvider.get_value_by_key(event["fields"], "computer") if "fields" in event else None, + team=( + CoralogixProvider.get_value_by_key(event["fields"], "team") + if "fields" in event + else None + ), + priority=( + CoralogixProvider.get_value_by_key(event["fields"], "priority") + if "fields" in event + else None + ), + computer=( + CoralogixProvider.get_value_by_key(event["fields"], "computer") + if "fields" in event + else None + ), fields=event["fields"] if "fields" in event else None, source=["coralogix"], ) return alert - + + if __name__ == "__main__": pass diff --git a/keep/providers/datadog_provider/datadog_provider.py b/keep/providers/datadog_provider/datadog_provider.py index 9b622b0fe..107a3e8ca 100644 --- a/keep/providers/datadog_provider/datadog_provider.py +++ b/keep/providers/datadog_provider/datadog_provider.py @@ -8,7 +8,6 @@ import logging import os import time -from typing import Optional import pydantic import requests @@ -789,9 +788,7 @@ def setup_webhook( self.logger.info("Monitors updated") @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["DatadogProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: tags_list = event.get("tags", "").split(",") tags_list.remove("monitor") diff --git a/keep/providers/dynatrace_provider/dynatrace_provider.py b/keep/providers/dynatrace_provider/dynatrace_provider.py index 29b75a0bd..874aec504 100644 --- a/keep/providers/dynatrace_provider/dynatrace_provider.py +++ b/keep/providers/dynatrace_provider/dynatrace_provider.py @@ -8,7 +8,6 @@ import json import logging import os -from typing import Optional import pydantic import requests @@ -135,9 +134,9 @@ def validate_scopes(self): self.logger.info( "Failed to validate dynatrace scopes - wrong environment id" ) - scopes[ - "problems.read" - ] = "Failed to validate scope, wrong environment id (Keep got 404)" + scopes["problems.read"] = ( + "Failed to validate scope, wrong environment id (Keep got 404)" + ) scopes["settings.read"] = scopes["problems.read"] scopes["settings.write"] = scopes["problems.read"] return scopes @@ -146,9 +145,9 @@ def validate_scopes(self): self.logger.info( "Failed to validate dynatrace scopes - invalid API token" ) - scopes[ - "problems.read" - ] = "Invalid API token - authentication failed (401)" + scopes["problems.read"] = ( + "Invalid API token - authentication failed (401)" + ) scopes["settings.read"] = scopes["problems.read"] scopes["settings.write"] = scopes["problems.read"] return scopes @@ -156,9 +155,9 @@ def validate_scopes(self): self.logger.info( "Failed to validate dynatrace scopes - no problems.read scopes" ) - scopes[ - "problems.read" - ] = "Token is missing required scope - problems.read (403)" + scopes["problems.read"] = ( + "Token is missing required scope - problems.read (403)" + ) else: self.logger.info("Validated dynatrace scopes - problems.read") scopes["problems.read"] = True @@ -174,9 +173,9 @@ def validate_scopes(self): f"Failed to validate dynatrace scopes - settings.read: {e}" ) scopes["settings.read"] = str(e) - scopes[ - "settings.write" - ] = "Cannot validate the settings.write scope without the settings.read scope, you need to first add the settings.read scope" + scopes["settings.write"] = ( + "Cannot validate the settings.write scope without the settings.read scope, you need to first add the settings.read scope" + ) # we are done return scopes # if we have settings.read, we can try settings.write @@ -197,22 +196,20 @@ def validate_scopes(self): ) # understand if its localhost: if "The environment does not allow for site-local URLs" in str(e): - scopes[ - "settings.write" - ] = "Cannot use localhost as a webhook URL, please use a public URL when installing dynatrace webhook (you can use Keep with ngrok or similar)" + scopes["settings.write"] = ( + "Cannot use localhost as a webhook URL, please use a public URL when installing dynatrace webhook (you can use Keep with ngrok or similar)" + ) else: - scopes[ - "settings.write" - ] = f"Failed to validate the settings.write scope: {e}" + scopes["settings.write"] = ( + f"Failed to validate the settings.write scope: {e}" + ) return scopes self.logger.info(f"Validated dynatrace scopes: {scopes}") return scopes @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["DynatraceProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: # alert that comes from webhook if event.get("ProblemID"): tags = event.get("Tags", []) diff --git a/keep/providers/gcpmonitoring_provider/gcpmonitoring_provider.py b/keep/providers/gcpmonitoring_provider/gcpmonitoring_provider.py index 38d678fe8..2a4409ee7 100644 --- a/keep/providers/gcpmonitoring_provider/gcpmonitoring_provider.py +++ b/keep/providers/gcpmonitoring_provider/gcpmonitoring_provider.py @@ -3,7 +3,6 @@ """ import datetime -from typing import Optional from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.contextmanager.contextmanager import ContextManager @@ -64,9 +63,7 @@ def validate_config(self): pass @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["GcpmonitoringProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: incident = event.get("incident", {}) description = incident.pop("summary", "") status = GcpmonitoringProvider.STATUS_MAP.get( diff --git a/keep/providers/grafana_provider/grafana_provider.py b/keep/providers/grafana_provider/grafana_provider.py index 5ae4f8f31..69f2c3e71 100644 --- a/keep/providers/grafana_provider/grafana_provider.py +++ b/keep/providers/grafana_provider/grafana_provider.py @@ -4,7 +4,6 @@ import dataclasses import datetime -from typing import Optional import pydantic import requests @@ -51,6 +50,8 @@ class GrafanaProvider(BaseProvider): """Pull/Push alerts from Grafana.""" KEEP_GRAFANA_WEBHOOK_INTEGRATION_NAME = "keep-grafana-webhook-integration" + FINGERPRINT_FIELDS = ["fingerprint"] + PROVIDER_SCOPES = [ ProviderScope( name="alert.rules:read", @@ -194,9 +195,7 @@ def get_alert_schema(): return GrafanaAlertFormatDescription.schema() @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["GrafanaProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: alerts = event.get("alerts", []) formatted_alerts = [] for alert in alerts: diff --git a/keep/providers/incidentmanager_provider/incidentmanager_provider.py b/keep/providers/incidentmanager_provider/incidentmanager_provider.py index 4d026f55f..ac4a69187 100644 --- a/keep/providers/incidentmanager_provider/incidentmanager_provider.py +++ b/keep/providers/incidentmanager_provider/incidentmanager_provider.py @@ -5,7 +5,6 @@ import dataclasses import logging import os -from typing import Optional from urllib.parse import urlparse from uuid import uuid4 @@ -413,9 +412,7 @@ def setup_webhook( self.logger.info("Webhook setup completed!") @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["IncidentmanagerProvider"] - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: logger = logging.getLogger(__name__) # if its confirmation event, we need to confirm the subscription if event.get("Type") == "SubscriptionConfirmation": diff --git a/keep/providers/keep_provider/keep_provider.py b/keep/providers/keep_provider/keep_provider.py index ec7835232..d8bafa788 100644 --- a/keep/providers/keep_provider/keep_provider.py +++ b/keep/providers/keep_provider/keep_provider.py @@ -3,7 +3,6 @@ """ import logging -from typing import Optional from keep.api.core.db import get_alerts_with_filters from keep.api.models.alert import AlertDto @@ -86,9 +85,7 @@ def validate_config(self): pass @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["KeepProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: return AlertDto( **event, ) diff --git a/keep/providers/kibana_provider/kibana_provider.py b/keep/providers/kibana_provider/kibana_provider.py index 5cd5f3aa3..92d8f9adb 100644 --- a/keep/providers/kibana_provider/kibana_provider.py +++ b/keep/providers/kibana_provider/kibana_provider.py @@ -6,7 +6,7 @@ import datetime import json import uuid -from typing import Literal, Optional +from typing import Literal from urllib.parse import urlparse import pydantic @@ -471,9 +471,7 @@ def format_alert_from_watcher(event: dict) -> AlertDto | list[AlertDto]: ) @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["KibanaProvider"] = None - ) -> AlertDto | list[AlertDto]: + def _format_alert(event: dict) -> AlertDto | list[AlertDto]: """ Formats an alert from Kibana to a standard format. diff --git a/keep/providers/netdata_provider/netdata_provider.py b/keep/providers/netdata_provider/netdata_provider.py index 3e22e2204..61939d006 100644 --- a/keep/providers/netdata_provider/netdata_provider.py +++ b/keep/providers/netdata_provider/netdata_provider.py @@ -2,13 +2,12 @@ Netdata is a cloud-based monitoring tool that provides real-time monitoring of servers, applications, and devices. """ -from typing import Optional - -from keep.api.models.alert import AlertDto, AlertStatus, AlertSeverity +from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig + class NetdataProvider(BaseProvider): """Get alerts from Netdata into Keep.""" @@ -60,19 +59,28 @@ def validate_config(self): pass @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["NetdataProvider"] = None - ) -> AlertDto: - + def _format_alert(event: dict) -> AlertDto: alert = AlertDto( id=event["id"] if "id" in event else None, name=event["name"] if "name" in event else None, host=event["host"], message=event["message"], - severity=NetdataProvider.SEVERITIES_MAP.get(event["severity"], AlertSeverity.INFO), - status=NetdataProvider.STATUS_MAP.get(event["status"]["text"], AlertStatus.INFO) if "status" in event else AlertStatus.INFO, + severity=NetdataProvider.SEVERITIES_MAP.get( + event["severity"], AlertSeverity.INFO + ), + status=( + NetdataProvider.STATUS_MAP.get( + event["status"]["text"], AlertStatus.INFO + ) + if "status" in event + else AlertStatus.INFO + ), alert=event["alert"] if "alert" in event else None, - url=event["alert_url"] or event["url"] if "alert_url" in event or "url" in event else None, + url=( + event["alert_url"] or event["url"] + if "alert_url" in event or "url" in event + else None + ), chart=event["chart"] if "chart" in event else None, alert_class=event["class"] if "class" in event else None, context=event["context"] if "context" in event else None, @@ -80,12 +88,17 @@ def _format_alert( duration=event["duration"] if "duration" in event else None, info=event["info"] if "info" in event else None, space=event["space"] if "space" in event else None, - total_critical=event["total_critical"] if "total_critical" in event else None, - total_warnings=event["total_warnings"] if "total_warnings" in event else None, + total_critical=( + event["total_critical"] if "total_critical" in event else None + ), + total_warnings=( + event["total_warnings"] if "total_warnings" in event else None + ), value=event["value"] if "value" in event else None, ) return alert - + + if __name__ == "__main__": pass diff --git a/keep/providers/newrelic_provider/newrelic_provider.py b/keep/providers/newrelic_provider/newrelic_provider.py index ee13b30db..5ac143cd5 100644 --- a/keep/providers/newrelic_provider/newrelic_provider.py +++ b/keep/providers/newrelic_provider/newrelic_provider.py @@ -6,7 +6,6 @@ import json import logging from datetime import datetime -from typing import Optional import pydantic import requests @@ -430,9 +429,7 @@ def get_alerts(self) -> list[AlertDto]: return formatted_alerts @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["NewrelicProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: """We are already registering template same as generic AlertDTO""" logger = logging.getLogger(__name__) logger.info("Got event from New Relic") diff --git a/keep/providers/openobserve_provider/openobserve_provider.py b/keep/providers/openobserve_provider/openobserve_provider.py index c7d1ef996..5e5f0c357 100644 --- a/keep/providers/openobserve_provider/openobserve_provider.py +++ b/keep/providers/openobserve_provider/openobserve_provider.py @@ -7,7 +7,7 @@ import logging import uuid from pathlib import Path -from typing import List, Optional +from typing import List from urllib.parse import urlencode, urljoin import pydantic @@ -369,10 +369,7 @@ def setup_webhook( self.logger.info("Webhook created") @staticmethod - def _format_alert( - event: dict, - provider_instance: Optional["OpenobserveProvider"] = None, - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: logger = logging.getLogger(__name__) name = event.pop("alert_name", "") # openoboserve does not provide severity diff --git a/keep/providers/pagerduty_provider/pagerduty_provider.py b/keep/providers/pagerduty_provider/pagerduty_provider.py index 6db3cee08..664a771b9 100644 --- a/keep/providers/pagerduty_provider/pagerduty_provider.py +++ b/keep/providers/pagerduty_provider/pagerduty_provider.py @@ -99,7 +99,7 @@ def validate_config(self): "PagerdutyProvider requires either routing_key or api_key", provider_id=self.provider_id, ) - + def validate_scopes(self): """ Validate that the provider has the required scopes. @@ -306,9 +306,7 @@ def _get_alerts(self) -> list[AlertDto]: return incidents @staticmethod - def _format_alert( - event: dict, provider_instance: typing.Optional["PagerdutyProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: actual_event = event.get("event", {}) data = actual_event.get("data", {}) url = data.pop("self", data.pop("html_url")) @@ -348,7 +346,6 @@ def _format_alert( "impacted_services": service, } - return AlertDto( **data, url=url, diff --git a/keep/providers/parseable_provider/parseable_provider.py b/keep/providers/parseable_provider/parseable_provider.py index cc8d5dcc9..be3064918 100644 --- a/keep/providers/parseable_provider/parseable_provider.py +++ b/keep/providers/parseable_provider/parseable_provider.py @@ -6,7 +6,6 @@ import datetime import logging import os -from typing import Optional from uuid import uuid4 import pydantic @@ -119,9 +118,7 @@ def validate_config(self): ) @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["ParseableProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: environment = "unknown" id = event.pop("id", str(uuid4())) name = event.pop("alert", "") diff --git a/keep/providers/pingdom_provider/pingdom_provider.py b/keep/providers/pingdom_provider/pingdom_provider.py index 24d5e5896..0c025c8e6 100644 --- a/keep/providers/pingdom_provider/pingdom_provider.py +++ b/keep/providers/pingdom_provider/pingdom_provider.py @@ -1,6 +1,5 @@ import dataclasses import datetime -from typing import Optional import pydantic import requests @@ -144,9 +143,7 @@ def _get_alerts(self) -> list[AlertDto]: return alerts_dtos @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["PingdomProvider"] - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: # https://pingdom.com/resources/webhooks/#Examples-of-webhook-JSON-output-for-uptime-checks # map severity and status to keep's format diff --git a/keep/providers/prometheus_provider/prometheus_provider.py b/keep/providers/prometheus_provider/prometheus_provider.py index 7f4ecf0c8..47568bb14 100644 --- a/keep/providers/prometheus_provider/prometheus_provider.py +++ b/keep/providers/prometheus_provider/prometheus_provider.py @@ -5,7 +5,6 @@ import dataclasses import datetime import os -from typing import Optional import pydantic import requests @@ -82,6 +81,7 @@ class PrometheusProvider(BaseProvider): name="connectivity", description="Connectivity Test", mandatory=True ) ] + FINGERPRINT_FIELDS = ["fingerprint"] def __init__( self, context_manager: ContextManager, provider_id: str, config: ProviderConfig @@ -153,16 +153,8 @@ def _get_alerts(self) -> list[AlertDto]: alert_dtos = self._format_alert(alerts_data) return alert_dtos - def get_status(event: dict) -> AlertStatus: - return PrometheusProvider.STATUS_MAP.get( - event.get("status", event.get("state", "firing")) - ) - @staticmethod - def _format_alert( - event: dict | list[AlertDto], - provider_instance: Optional["PrometheusProvider"] = None, - ) -> list[AlertDto]: + def _format_alert(event: dict | list[AlertDto]) -> list[AlertDto]: # TODO: need to support more than 1 alert per event alert_dtos = [] if isinstance(event, list): diff --git a/keep/providers/providers_factory.py b/keep/providers/providers_factory.py index aadddd514..0b3385625 100644 --- a/keep/providers/providers_factory.py +++ b/keep/providers/providers_factory.py @@ -19,6 +19,7 @@ get_installed_providers, get_linked_providers, ) +from keep.api.models.alert import DeduplicationRuleDto from keep.api.models.provider import Provider from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider, BaseTopologyProvider @@ -341,6 +342,13 @@ def get_all_providers() -> list[Provider]: # not all providers have this method (yet ^^) except Exception: alert_example = None + + # Add default fingerprint fields if available + if hasattr(provider_class, "FINGERPRINT_FIELDS"): + default_fingerprint_fields = provider_class.FINGERPRINT_FIELDS + else: + default_fingerprint_fields = [] + providers.append( Provider( type=provider_type, @@ -359,6 +367,7 @@ def get_all_providers() -> list[Provider]: methods=provider_methods, tags=provider_tags, alertExample=alert_example, + default_fingerprint_fields=default_fingerprint_fields, ) ) except ModuleNotFoundError: @@ -512,3 +521,39 @@ def get_linked_providers(tenant_id: str) -> list[Provider]: _linked_providers.append(provider) return _linked_providers + + @staticmethod + def get_default_deduplication_rules() -> list[DeduplicationRuleDto]: + """ + Get the default deduplications for all providers with FINGERPRINT_FIELDS. + + Returns: + list: The default deduplications for each provider. + """ + default_deduplications = [] + all_providers = ProvidersFactory.get_all_providers() + + for provider in all_providers: + if provider.default_fingerprint_fields: + deduplication_dto = DeduplicationRuleDto( + name=f"{provider.type}_default", + description=f"{provider.display_name} default deduplication rule", + default=True, + distribution=[{"hour": i, "number": 0} for i in range(24)], + provider_type=provider.type, + last_updated="", + last_updated_by="", + created_at="", + created_by="", + ingested=0, + dedup_ratio=0.0, + enabled=True, + fingerprint_fields=provider.default_fingerprint_fields, + # default provider deduplication rules are not full deduplication + full_deduplication=False, + # not relevant for default deduplication rules + ignore_fields=[], + ) + default_deduplications.append(deduplication_dto) + + return default_deduplications diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py index debb97662..dad1d8d37 100644 --- a/keep/providers/providers_service.py +++ b/keep/providers/providers_service.py @@ -285,7 +285,7 @@ def provision_providers_from_env(tenant_id: str): provisioned=True, validate_scopes=False, ) + logger.info(f"Provider {provider_name} provisioned") except Exception: logger.exception(f"Failed to provision provider {provider_name}") continue - logger.info(f"Provider {provider_name} provisioned") diff --git a/keep/providers/rollbar_provider/rollbar_provider.py b/keep/providers/rollbar_provider/rollbar_provider.py index 04e8457e3..db72ed27a 100644 --- a/keep/providers/rollbar_provider/rollbar_provider.py +++ b/keep/providers/rollbar_provider/rollbar_provider.py @@ -3,24 +3,25 @@ """ import dataclasses -import pydantic - import datetime -import requests - -from typing import List, Optional +from typing import List from urllib.parse import urljoin +import pydantic +import requests + from keep.api.models.alert import AlertDto, AlertSeverity from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope + @pydantic.dataclasses.dataclass class RollbarProviderAuthConfig: """ RollbarProviderAuthConfig is a class that allows to authenticate in Rollbar. """ + rollbarAccessToken: str = dataclasses.field( metadata={ "required": True, @@ -30,6 +31,7 @@ class RollbarProviderAuthConfig: default=None, ) + class RollbarProvider(BaseProvider): PROVIDER_DISPLAY_NAME = "Rollbar" PROVIDER_TAGS = ["alert"] @@ -46,11 +48,11 @@ class RollbarProvider(BaseProvider): "error": AlertSeverity.HIGH, "info": AlertSeverity.INFO, "critical": AlertSeverity.CRITICAL, - "debug": AlertSeverity.LOW + "debug": AlertSeverity.LOW, } def __init__( - self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): super().__init__(context_manager, provider_id, config) @@ -78,7 +80,7 @@ def __get_headers(self): return { "X-Rollbar-Access-Token": self.authentication_config.rollbarAccessToken, "accept": "application/json; charset=utf-8", - "content-type": "application/json" + "content-type": "application/json", } def validate_scopes(self) -> dict[str, bool | str]: @@ -86,50 +88,60 @@ def validate_scopes(self) -> dict[str, bool | str]: Validate the scopes of the provider. """ try: - response = requests.get(self.__get_url("items"), headers=self.__get_headers()) + response = requests.get( + self.__get_url("items"), headers=self.__get_headers() + ) if response.status_code == 200: - scopes = { - "authenticated": True - } + scopes = {"authenticated": True} else: - self.logger.error("Unable to read projects from Rollbar, statusCode: %s", response.status_code) + self.logger.error( + "Unable to read projects from Rollbar, statusCode: %s", + response.status_code, + ) scopes = { "authenticated": f"Unable to read projects from Rollbar, statusCode: {response.status_code}" } except Exception as e: self.logger.error("Error validating scopes for Rollbar: %s", e) - scopes = { - "authenticated": f"Error validating scopes for Rollbar: {e}" - } + scopes = {"authenticated": f"Error validating scopes for Rollbar: {e}"} return scopes def __get_occurences(self) -> List[AlertDto]: try: - response = requests.get(self.__get_url("instances"), headers=self.__get_headers()) + response = requests.get( + self.__get_url("instances"), headers=self.__get_headers() + ) if not response.ok: - self.logger.error("Failed to get occurrences from Rollbar: %s", response.json()) - raise Exception("Could not get occurrences from Rollbar") - - return [AlertDto( - id=alert["id"], - name=alert["project_id"], - environment=alert["data"]["environment"], - event_id=alert["data"]["uuid"], - language=alert["data"]["language"], - message=alert["data"]["body"]["message"]["body"], - host=alert["data"]["server"]["host"], - pid=alert["data"]["server"]["pid"], - severity=RollbarProvider.SEVERITIES_MAP[alert["data"]["level"]], - lastReceived=datetime.datetime.fromtimestamp(alert["timestamp"]).isoformat(), - ) for alert in response.json()["result"]["instances"]] - + self.logger.error( + "Failed to get occurrences from Rollbar: %s", response.json() + ) + raise Exception("Could not get occurrences from Rollbar") + + return [ + AlertDto( + id=alert["id"], + name=alert["project_id"], + environment=alert["data"]["environment"], + event_id=alert["data"]["uuid"], + language=alert["data"]["language"], + message=alert["data"]["body"]["message"]["body"], + host=alert["data"]["server"]["host"], + pid=alert["data"]["server"]["pid"], + severity=RollbarProvider.SEVERITIES_MAP[alert["data"]["level"]], + lastReceived=datetime.datetime.fromtimestamp( + alert["timestamp"] + ).isoformat(), + ) + for alert in response.json()["result"]["instances"] + ] + except Exception as e: self.logger.error("Error getting occurrences from Rollbar: %s", e) raise Exception(f"Error getting occurrences from Rollbar: {e}") - + def _get_alerts(self) -> List[AlertDto]: alerts = [] try: @@ -140,31 +152,32 @@ def _get_alerts(self) -> List[AlertDto]: self.logger.error("Error getting occurrences from Rollbar: %s", e) return alerts - + @staticmethod - def _format_alert( - event: dict, - provider_instance: Optional["RollbarProvider"] = None, - ) -> AlertDto: - item_data = event['data']['item'] - occurrence_data = event['data']['occurrence'] + def _format_alert(event: dict) -> AlertDto: + item_data = event["data"]["item"] + occurrence_data = event["data"]["occurrence"] return AlertDto( - id=str(item_data['id']), - name=event['event_name'], + id=str(item_data["id"]), + name=event["event_name"], severity=RollbarProvider.SEVERITIES_MAP[occurrence_data["level"]], - lastReceived=datetime.datetime.fromtimestamp(item_data['last_occurrence_timestamp']).isoformat(), - environment=item_data['environment'], - service='Rollbar', - source=[occurrence_data['framework']], - url=event['data']['url'], - message=occurrence_data['body']['message']['body'], - description=item_data['title'], - event_id=str(occurrence_data['uuid']), - labels={'level': item_data['level']}, - fingerprint=item_data['hash'], + lastReceived=datetime.datetime.fromtimestamp( + item_data["last_occurrence_timestamp"] + ).isoformat(), + environment=item_data["environment"], + service="Rollbar", + source=[occurrence_data["framework"]], + url=event["data"]["url"], + message=occurrence_data["body"]["message"]["body"], + description=item_data["title"], + event_id=str(occurrence_data["uuid"]), + labels={"level": item_data["level"]}, + fingerprint=item_data["hash"], ) - def setup_webhook(self, tenant_id: str, keep_api_url: str, api_key: str, setup_alerts: bool = True): + def setup_webhook( + self, tenant_id: str, keep_api_url: str, api_key: str, setup_alerts: bool = True + ): self.logger.info("Setting up webhook for Rollbar") self.logger.info("Enabling Webhook in Rollbar") try: @@ -174,7 +187,7 @@ def setup_webhook(self, tenant_id: str, keep_api_url: str, api_key: str, setup_a json={ "enabled": True, "url": f"{keep_api_url}?api_key={api_key}", - } + }, ) if response.ok: @@ -185,19 +198,22 @@ def setup_webhook(self, tenant_id: str, keep_api_url: str, api_key: str, setup_a { "trigger": "occurrence", } - } + }, ) if response.ok: self.logger.info("Created occurrence rule in Rollbar") else: - self.logger.error("Failed to enable webhook in Rollbar: %s", response.json()) + self.logger.error( + "Failed to enable webhook in Rollbar: %s", response.json() + ) raise Exception("Failed to enable webhook in Rollbar") - + self.logger.info("Webhook enabled in Rollbar") except Exception as e: self.logger.error("Error setting up webhook for Rollbar: %s", e) raise Exception(f"Error setting up webhook for Rollbar: {e}") - + + if __name__ == "__main__": import logging @@ -213,7 +229,7 @@ def setup_webhook(self, tenant_id: str, keep_api_url: str, api_key: str, setup_a if rollbar_host is None: raise Exception("ROLLBAR_HOST is not set") - + config = ProviderConfig( description="Rollbar Provider", authentication={ @@ -227,4 +243,4 @@ def setup_webhook(self, tenant_id: str, keep_api_url: str, api_key: str, setup_a config=config, ) - provider._get_alerts() \ No newline at end of file + provider._get_alerts() diff --git a/keep/providers/sentry_provider/sentry_provider.py b/keep/providers/sentry_provider/sentry_provider.py index ee9117136..099ad1019 100644 --- a/keep/providers/sentry_provider/sentry_provider.py +++ b/keep/providers/sentry_provider/sentry_provider.py @@ -5,7 +5,6 @@ import dataclasses import datetime import logging -from typing import Optional import pydantic import requests @@ -202,9 +201,7 @@ def validate_scopes(self) -> dict[str, bool | str]: return validated_scopes @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["SentryProvider"] = None - ) -> AlertDto | list[AlertDto]: + def _format_alert(event: dict) -> AlertDto | list[AlertDto]: logger = logging.getLogger(__name__) logger.info( "Formatting Sentry alert", diff --git a/keep/providers/signalfx_provider/signalfx_provider.py b/keep/providers/signalfx_provider/signalfx_provider.py index ecfb6234f..68ed28ccb 100644 --- a/keep/providers/signalfx_provider/signalfx_provider.py +++ b/keep/providers/signalfx_provider/signalfx_provider.py @@ -1,7 +1,6 @@ import base64 import dataclasses import datetime -from typing import Optional from urllib.parse import quote, urlparse import pydantic @@ -211,9 +210,7 @@ def _format_alert_get_alert(self, incident: dict) -> AlertDto: return alert_dto @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["SignalfxProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: # Transform a SignalFx event into an AlertDto object # see: https://docs.splunk.com/observability/en/admin/notif-services/webhook.html#observability-cloud-webhook-request-body-fields severity = SignalfxProvider.SEVERITIES_MAP.get( diff --git a/keep/providers/site24x7_provider/site24x7_provider.py b/keep/providers/site24x7_provider/site24x7_provider.py index c46867e9e..dba231137 100644 --- a/keep/providers/site24x7_provider/site24x7_provider.py +++ b/keep/providers/site24x7_provider/site24x7_provider.py @@ -3,7 +3,7 @@ """ import dataclasses -from typing import List, Optional +from typing import List from urllib.parse import urlencode, urljoin import pydantic @@ -25,6 +25,7 @@ class Site24X7ProviderAuthConfig: """ Site24x7 authentication configuration. """ + zohoRefreshToken: str = dataclasses.field( metadata={ "required": True, @@ -82,11 +83,11 @@ class Site24X7Provider(BaseProvider): "DOWN": AlertSeverity.WARNING, "TROUBLE": AlertSeverity.HIGH, "UP": AlertSeverity.INFO, - "CRITICAL": AlertSeverity.CRITICAL + "CRITICAL": AlertSeverity.CRITICAL, } def __init__( - self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): super().__init__(context_manager, provider_id, config) @@ -133,15 +134,17 @@ def __get_headers(self): Getting the access token from Zoho API using the permanent refresh token. """ data = { - 'client_id': self.authentication_config.zohoClientId, - 'client_secret': self.authentication_config.zohoClientSecret, - 'refresh_token': self.authentication_config.zohoRefreshToken, - 'grant_type': 'refresh_token', + "client_id": self.authentication_config.zohoClientId, + "client_secret": self.authentication_config.zohoClientSecret, + "refresh_token": self.authentication_config.zohoRefreshToken, + "grant_type": "refresh_token", } - response = requests.post(f'https://accounts.zoho{self.authentication_config.zohoAccountTLD}/oauth/v2/token', - data=data).json() + response = requests.post( + f"https://accounts.zoho{self.authentication_config.zohoAccountTLD}/oauth/v2/token", + data=data, + ).json() return { - 'Authorization': f'Bearer {response["access_token"]}', + "Authorization": f'Bearer {response["access_token"]}', } def validate_scopes(self) -> dict[str, bool | str]: @@ -150,23 +153,32 @@ def validate_scopes(self) -> dict[str, bool | str]: authentication_scope = "Validate TLD first" if self.authentication_config.zohoAccountTLD in valid_tlds: valid_tld_scope = True - response = requests.get(f'{self.__get_url(paths=["monitors"])}', headers=self.__get_headers()) + response = requests.get( + f'{self.__get_url(paths=["monitors"])}', headers=self.__get_headers() + ) if response.status_code == 401: authentication_scope = response.json() - self.logger.error("Failed to authenticate user", extra=authentication_scope) + self.logger.error( + "Failed to authenticate user", extra=authentication_scope + ) elif response.status_code == 200: authentication_scope = True self.logger.info("Authenticated user successfully") else: - authentication_scope = f"Error while authenticating user, {response.status_code}" - self.logger.error("Error while authenticating user", extra={"status_code": response.status_code}) + authentication_scope = ( + f"Error while authenticating user, {response.status_code}" + ) + self.logger.error( + "Error while authenticating user", + extra={"status_code": response.status_code}, + ) return { - 'authenticated': authentication_scope, - 'valid_tld': valid_tld_scope, + "authenticated": authentication_scope, + "valid_tld": valid_tld_scope, } def setup_webhook( - self, tenant_id: str, keep_api_url: str, api_key: str, setup_alerts: bool = True + self, tenant_id: str, keep_api_url: str, api_key: str, setup_alerts: bool = True ): webhook_data = { "method": "P", @@ -174,12 +186,7 @@ def setup_webhook( "is_poller_webhook": False, "type": 8, "alert_tags_id": [], - "custom_headers": [ - { - "name": "X-API-KEY", - "value": api_key - } - ], + "custom_headers": [{"name": "X-API-KEY", "value": api_key}], "url": keep_api_url, "timeout": 30, "selection_type": 0, @@ -190,25 +197,25 @@ def setup_webhook( "send_incident_parameters": True, "service_status": 0, "name": "KeepWebhook", - "manage_tickets": False + "manage_tickets": False, } - response = requests.post(self.__get_url(paths=["integration/webhooks"]), json=webhook_data, - headers=self.__get_headers()) + response = requests.post( + self.__get_url(paths=["integration/webhooks"]), + json=webhook_data, + headers=self.__get_headers(), + ) if not response.ok: response_json = response.json() self.logger.error("Error while creating webhook", extra=response_json) - raise Exception(response_json['message']) + raise Exception(response_json["message"]) else: self.logger.info("Webhook created successfully") @staticmethod - def _format_alert( - event: dict, - provider_instance: Optional["Site24X7Provider"] = None, - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: return AlertDto( url=event.get("MONITORURL", ""), - lastReceived=event.get('INCIDENT_TIME', ""), + lastReceived=event.get("INCIDENT_TIME", ""), description=event.get("INCIDENT_REASON", ""), name=event.get("MONITORNAME", ""), id=event.get("MONITOR_ID", ""), @@ -216,11 +223,13 @@ def _format_alert( ) def _get_alerts(self) -> list[AlertDto]: - response = requests.get(self.__get_url(paths=['alert_logs']), headers=self.__get_headers()) + response = requests.get( + self.__get_url(paths=["alert_logs"]), headers=self.__get_headers() + ) if response.status_code == 200: alerts = [] response = response.json() - for alert in response['data']: + for alert in response["data"]: alerts.append( AlertDto( name=alert["display_name"], diff --git a/keep/providers/splunk_provider/splunk_provider.py b/keep/providers/splunk_provider/splunk_provider.py index 3334d5d07..126530519 100644 --- a/keep/providers/splunk_provider/splunk_provider.py +++ b/keep/providers/splunk_provider/splunk_provider.py @@ -1,7 +1,6 @@ import dataclasses import datetime import json -from typing import Optional import pydantic from splunklib.client import connect @@ -150,9 +149,7 @@ def setup_webhook( saved_search.update(**creation_updation_kwargs).refresh() @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["SplunkProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: result: dict = event.get("result", event.get("_result", {})) try: diff --git a/keep/providers/statuscake_provider/statuscake_provider.py b/keep/providers/statuscake_provider/statuscake_provider.py index 263244f84..37fa41018 100644 --- a/keep/providers/statuscake_provider/statuscake_provider.py +++ b/keep/providers/statuscake_provider/statuscake_provider.py @@ -6,264 +6,297 @@ import pydantic import requests -from typing import Optional from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope + @pydantic.dataclasses.dataclass class StatuscakeProviderAuthConfig: - """ - StatuscakeProviderAuthConfig is a class that holds the authentication information for the StatuscakeProvider. - """ - - api_key: str = dataclasses.field( - metadata={ - "required": True, - "description": "Statuscake API Key", - "sensitive": True, - }, - default=None, - ) - -class StatuscakeProvider(BaseProvider): - PROVIDER_DISPLAY_NAME = "Statuscake" - PROVIDER_TAGS = ["alert"] + """ + StatuscakeProviderAuthConfig is a class that holds the authentication information for the StatuscakeProvider. + """ - PROVIDER_SCOPES = [ - ProviderScope( - name="alerts", - description="Read alerts from Statuscake", + api_key: str = dataclasses.field( + metadata={ + "required": True, + "description": "Statuscake API Key", + "sensitive": True, + }, + default=None, ) - ] - SEVERITIES_MAP = { - "high": AlertSeverity.HIGH, - } - STATUS_MAP = { - "up": AlertStatus.RESOLVED, - "down": AlertStatus.FIRING, - } +class StatuscakeProvider(BaseProvider): + PROVIDER_DISPLAY_NAME = "Statuscake" + PROVIDER_TAGS = ["alert"] + + PROVIDER_SCOPES = [ + ProviderScope( + name="alerts", + description="Read alerts from Statuscake", + ) + ] + + SEVERITIES_MAP = { + "high": AlertSeverity.HIGH, + } + + STATUS_MAP = { + "up": AlertStatus.RESOLVED, + "down": AlertStatus.FIRING, + } + + def __init__( + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + ): + super().__init__(context_manager, provider_id, config) + + def dispose(self): + pass + + def validate_scopes(self): + """ + Validate that the user has the required scopes to use the provider + """ + try: + response = requests.get( + "https://api.statuscake.com/v1/uptime/", + headers=self.__get_auth_headers(), + ) + + if response.status_code == 200: + scopes = {"alerts": True} + + else: + self.logger.error( + "Unable to read alerts from Statuscake, statusCode: %s", + response.status_code, + ) + scopes = { + "alerts": f"Unable to read alerts from Statuscake, statusCode: {response.status_code}" + } + + except Exception as e: + self.logger.error("Error validating scopes for Statuscake: %s", e) + scopes = {"alerts": f"Error validating scopes for Statuscake: {e}"} + + return scopes + + def validate_config(self): + self.authentication_config = StatuscakeProviderAuthConfig( + **self.config.authentication + ) + if self.authentication_config.api_key is None: + raise ValueError("Statuscake API Key is required") + + def __get_auth_headers(self): + if self.authentication_config.api_key is not None: + return {"Authorization": f"Bearer {self.authentication_config.api_key}"} + + def __get_heartbeat_alerts(self) -> list[AlertDto]: + try: + response = requests.get( + "https://api.statuscake.com/v1/uptime/", + headers=self.__get_auth_headers(), + ) + + if not response.ok: + self.logger.error( + "Failed to get heartbeat from Statuscake: %s", response.json() + ) + raise Exception("Could not get heartbeat from Statuscake") + + return [ + AlertDto( + id=alert["id"], + name=alert["name"], + status=alert["status"], + url=alert["website_url"], + uptime=alert["uptime"], + source="statuscake", + ) + for alert in response.json()["data"] + ] + + except Exception as e: + self.logger.error("Error getting heartbeat from Statuscake: %s", e) + raise Exception(f"Error getting heartbeat from Statuscake: {e}") + + def __get_pagespeed_alerts(self) -> list[AlertDto]: + try: + response = requests.get( + "https://api.statuscake.com/v1/pagespeed/", + headers=self.__get_auth_headers(), + ) + + if not response.ok: + self.logger.error( + "Failed to get pagespeed from Statuscake: %s", response.json() + ) + raise Exception("Could not get pagespeed from Statuscake") + + return [ + AlertDto( + name=alert["name"], + url=alert["website_url"], + location=alert["location"], + alert_smaller=alert["alert_smaller"], + alert_bigger=alert["alert_bigger"], + alert_slower=alert["alert_slower"], + status=alert["status"], + source="statuscake", + ) + for alert in response.json()["data"] + ] + + except Exception as e: + self.logger.error("Error getting pagespeed from Statuscake: %s", e) + raise Exception(f"Error getting pagespeed from Statuscake: {e}") + + def __get_ssl_alerts(self) -> list[AlertDto]: + try: + response = requests.get( + "https://api.statuscake.com/v1/ssl/", headers=self.__get_auth_headers() + ) + + if not response.ok: + self.logger.error( + "Failed to get ssl from Statuscake: %s", response.json() + ) + raise Exception("Could not get ssl from Statuscake") + + return [ + AlertDto( + id=alert["id"], + url=alert["website_url"], + issuer_common_name=alert["issuer_common_name"], + cipher=alert["cipher"], + cipher_score=alert["cipher_score"], + certificate_score=alert["certificate_score"], + certificate_status=alert["certificate_status"], + valid_from=alert["valid_from"], + valid_until=alert["valid_until"], + source="statuscake", + ) + for alert in response.json()["data"] + ] + + except Exception as e: + self.logger.error("Error getting ssl from Statuscake: %s", e) + raise Exception(f"Error getting ssl from Statuscake: {e}") + + def __get_uptime_alerts(self) -> list[AlertDto]: + try: + response = requests.get( + "https://api.statuscake.com/v1/uptime/", + headers=self.__get_auth_headers(), + ) + + if not response.ok: + self.logger.error( + "Failed to get uptime from Statuscake: %s", response.json() + ) + raise Exception("Could not get uptime from Statuscake") + + return [ + AlertDto( + id=alert["id"], + name=alert["name"], + status=alert["status"], + url=alert["website_url"], + uptime=alert["uptime"], + source="statuscake", + ) + for alert in response.json()["data"] + ] + + except Exception as e: + self.logger.error("Error getting uptime from Statuscake: %s", e) + raise Exception(f"Error getting uptime from Statuscake: {e}") + + def _get_alerts(self) -> list[AlertDto]: + alerts = [] + try: + self.logger.info("Collecting alerts (heartbeats) from Statuscake") + heartbeat_alerts = self.__get_heartbeat_alerts() + alerts.extend(heartbeat_alerts) + except Exception as e: + self.logger.error("Error getting heartbeat from Statuscake: %s", e) + + try: + self.logger.info("Collecting alerts (pagespeed) from Statuscake") + pagespeed_alerts = self.__get_pagespeed_alerts() + alerts.extend(pagespeed_alerts) + except Exception as e: + self.logger.error("Error getting pagespeed from Statuscake: %s", e) + + try: + self.logger.info("Collecting alerts (ssl) from Statuscake") + ssl_alerts = self.__get_ssl_alerts() + alerts.extend(ssl_alerts) + except Exception as e: + self.logger.error("Error getting ssl from Statuscake: %s", e) + + try: + self.logger.info("Collecting alerts (uptime) from Statuscake") + uptime_alerts = self.__get_uptime_alerts() + alerts.extend(uptime_alerts) + except Exception as e: + self.logger.error("Error getting uptime from Statuscake: %s", e) + + return alerts + + @staticmethod + def _format_alert(event: dict) -> AlertDto: + + status = StatuscakeProvider.STATUS_MAP.get( + event.get("status"), AlertStatus.FIRING + ) + + # Statuscake does not provide severity information + severity = AlertSeverity.HIGH + + alert = AlertDto( + id=event.get("id"), + name=event.get("name"), + status=status if status is not None else AlertStatus.FIRING, + severity=severity, + url=event["website_url"] if "website_url" in event else None, + source="statuscake", + ) + + return alert - def __init__( - self, context_manager: ContextManager, provider_id: str, config: ProviderConfig - ): - super().__init__(context_manager, provider_id, config) - def dispose(self): +if __name__ == "__main__": pass + import logging - def validate_scopes(self): - """ - Validate that the user has the required scopes to use the provider - """ - try: - response = requests.get('https://api.statuscake.com/v1/uptime/', headers=self.__get_auth_headers()) - - if response.status_code == 200: - scopes = { - "alerts": True - } - - else: - self.logger.error("Unable to read alerts from Statuscake, statusCode: %s", response.status_code) - scopes = { - "alerts": f"Unable to read alerts from Statuscake, statusCode: {response.status_code}" - } - - except Exception as e: - self.logger.error("Error validating scopes for Statuscake: %s", e) - scopes = { - "alerts": f"Error validating scopes for Statuscake: {e}" - } - - return scopes - - def validate_config(self): - self.authentication_config = StatuscakeProviderAuthConfig(**self.config.authentication) - if self.authentication_config.api_key is None: - raise ValueError("Statuscake API Key is required") - - def __get_auth_headers(self): - if self.authentication_config.api_key is not None: - return { - "Authorization": f"Bearer {self.authentication_config.api_key}" - } - - def __get_heartbeat_alerts(self) -> list[AlertDto]: - try: - response = requests.get('https://api.statuscake.com/v1/uptime/', headers=self.__get_auth_headers()) - - if not response.ok: - self.logger.error("Failed to get heartbeat from Statuscake: %s", response.json()) - raise Exception("Could not get heartbeat from Statuscake") - - return [AlertDto( - id=alert["id"], - name=alert["name"], - status=alert["status"], - url=alert["website_url"], - uptime=alert["uptime"], - source="statuscake" - ) for alert in response.json()["data"]] - - except Exception as e: - self.logger.error("Error getting heartbeat from Statuscake: %s", e) - raise Exception(f"Error getting heartbeat from Statuscake: {e}") - - def __get_pagespeed_alerts(self) -> list[AlertDto]: - try: - response = requests.get('https://api.statuscake.com/v1/pagespeed/', headers=self.__get_auth_headers()) - - if not response.ok: - self.logger.error("Failed to get pagespeed from Statuscake: %s", response.json()) - raise Exception("Could not get pagespeed from Statuscake") - - return [AlertDto( - name=alert["name"], - url=alert["website_url"], - location=alert["location"], - alert_smaller=alert["alert_smaller"], - alert_bigger=alert["alert_bigger"], - alert_slower=alert["alert_slower"], - status=alert["status"], - source="statuscake" - ) for alert in response.json()["data"]] - - except Exception as e: - self.logger.error("Error getting pagespeed from Statuscake: %s", e) - raise Exception(f"Error getting pagespeed from Statuscake: {e}") - - def __get_ssl_alerts(self) -> list[AlertDto]: - try: - response = requests.get('https://api.statuscake.com/v1/ssl/', headers=self.__get_auth_headers()) - - if not response.ok: - self.logger.error("Failed to get ssl from Statuscake: %s", response.json()) - raise Exception("Could not get ssl from Statuscake") - - return [AlertDto( - id=alert["id"], - url=alert["website_url"], - issuer_common_name=alert["issuer_common_name"], - cipher=alert["cipher"], - cipher_score=alert["cipher_score"], - certificate_score=alert["certificate_score"], - certificate_status=alert["certificate_status"], - valid_from=alert["valid_from"], - valid_until=alert["valid_until"], - source="statuscake" - ) for alert in response.json()["data"]] - - except Exception as e: - self.logger.error("Error getting ssl from Statuscake: %s", e) - raise Exception(f"Error getting ssl from Statuscake: {e}") - - def __get_uptime_alerts(self) -> list[AlertDto]: - try: - response = requests.get('https://api.statuscake.com/v1/uptime/', headers=self.__get_auth_headers()) - - if not response.ok: - self.logger.error("Failed to get uptime from Statuscake: %s", response.json()) - raise Exception("Could not get uptime from Statuscake") - - return [AlertDto( - id=alert["id"], - name=alert["name"], - status=alert["status"], - url=alert["website_url"], - uptime=alert["uptime"], - source="statuscake" - ) for alert in response.json()["data"]] - - except Exception as e: - self.logger.error("Error getting uptime from Statuscake: %s", e) - raise Exception(f"Error getting uptime from Statuscake: {e}") - - def _get_alerts(self) -> list[AlertDto]: - alerts = [] - try: - self.logger.info("Collecting alerts (heartbeats) from Statuscake") - heartbeat_alerts = self.__get_heartbeat_alerts() - alerts.extend(heartbeat_alerts) - except Exception as e: - self.logger.error("Error getting heartbeat from Statuscake: %s", e) - - try: - self.logger.info("Collecting alerts (pagespeed) from Statuscake") - pagespeed_alerts = self.__get_pagespeed_alerts() - alerts.extend(pagespeed_alerts) - except Exception as e: - self.logger.error("Error getting pagespeed from Statuscake: %s", e) - - try: - self.logger.info("Collecting alerts (ssl) from Statuscake") - ssl_alerts = self.__get_ssl_alerts() - alerts.extend(ssl_alerts) - except Exception as e: - self.logger.error("Error getting ssl from Statuscake: %s", e) - - try: - self.logger.info("Collecting alerts (uptime) from Statuscake") - uptime_alerts = self.__get_uptime_alerts() - alerts.extend(uptime_alerts) - except Exception as e: - self.logger.error("Error getting uptime from Statuscake: %s", e) - - return alerts - - @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["StatuscakeProvider"] = None - ) -> AlertDto: - - status = StatuscakeProvider.STATUS_MAP.get(event.get("status"),AlertStatus.FIRING) - - # Statuscake does not provide severity information - severity = AlertSeverity.HIGH - - alert = AlertDto( - id=event.get("id"), - name=event.get("name"), - status=status if status is not None else AlertStatus.FIRING, - severity=severity, - url=event["website_url"] if "website_url" in event else None, - source="statuscake" + logging.basicConfig(level=logging.DEBUG, handlers=[logging.StreamHandler()]) + context_manager = ContextManager( + tenant_id="singletenant", + workflow_id="test", ) - return alert + import os -if __name__ == "__main__": - pass - import logging - - logging.basicConfig(level=logging.DEBUG, handlers=[logging.StreamHandler()]) - context_manager = ContextManager( - tenant_id="singletenant", - workflow_id="test", - ) - - import os - - statuscake_api_key = os.environ.get("STATUSCAKE_API_KEY") - - if statuscake_api_key is None: - raise Exception("STATUSCAKE_API_KEY is required") - - config = ProviderConfig( - description="Statuscake Provider", - authentication={ - "api_key": statuscake_api_key - }, - ) - - provider = StatuscakeProvider( - context_manager, - provider_id="statuscake", - config=config, - ) - - provider._get_alerts() \ No newline at end of file + statuscake_api_key = os.environ.get("STATUSCAKE_API_KEY") + + if statuscake_api_key is None: + raise Exception("STATUSCAKE_API_KEY is required") + + config = ProviderConfig( + description="Statuscake Provider", + authentication={"api_key": statuscake_api_key}, + ) + + provider = StatuscakeProvider( + context_manager, + provider_id="statuscake", + config=config, + ) + + provider._get_alerts() diff --git a/keep/providers/uptimekuma_provider/uptimekuma_provider.py b/keep/providers/uptimekuma_provider/uptimekuma_provider.py index 2762773cb..4aba2108e 100644 --- a/keep/providers/uptimekuma_provider/uptimekuma_provider.py +++ b/keep/providers/uptimekuma_provider/uptimekuma_provider.py @@ -5,188 +5,193 @@ import dataclasses import pydantic -from typing import Optional +from uptime_kuma_api import UptimeKumaApi from keep.api.models.alert import AlertDto, AlertStatus -from keep.exceptions.provider_exception import ProviderException from keep.contextmanager.contextmanager import ContextManager +from keep.exceptions.provider_exception import ProviderException from keep.providers.base.base_provider import BaseProvider from keep.providers.models.provider_config import ProviderConfig, ProviderScope -from uptime_kuma_api import UptimeKumaApi @pydantic.dataclasses.dataclass class UptimekumaProviderAuthConfig: - """ - UptimekumaProviderAuthConfig is a class that holds the authentication information for the UptimekumaProvider. - """ - - host_url: str = dataclasses.field( - metadata={ - "required": True, - "description": "UptimeKuma Host URL", - "sensitive": False, - }, - default=None, - ) - - username: str = dataclasses.field( - metadata={ - "required": True, - "description": "UptimeKuma Username", - "sensitive": False, - }, - default=None, - ) - - password: str = dataclasses.field( - metadata={ - "required": True, - "description": "UptimeKuma Password", - "sensitive": True, - }, - default=None, - ) + """ + UptimekumaProviderAuthConfig is a class that holds the authentication information for the UptimekumaProvider. + """ -class UptimekumaProvider(BaseProvider): - PROVIDER_DISPLAY_NAME = "UptimeKuma" - PROVIDER_TAGS = ["alert"] + host_url: str = dataclasses.field( + metadata={ + "required": True, + "description": "UptimeKuma Host URL", + "sensitive": False, + }, + default=None, + ) - PROVIDER_SCOPES = [ - ProviderScope( - name="alerts", - description="Read alerts from UptimeKuma", + username: str = dataclasses.field( + metadata={ + "required": True, + "description": "UptimeKuma Username", + "sensitive": False, + }, + default=None, ) - ] - STATUS_MAP = { - "up": AlertStatus.RESOLVED, - "down": AlertStatus.FIRING, - } + password: str = dataclasses.field( + metadata={ + "required": True, + "description": "UptimeKuma Password", + "sensitive": True, + }, + default=None, + ) - def __init__( - self, context_manager: ContextManager, provider_id: str, config: ProviderConfig - ): - super().__init__(context_manager, provider_id, config) - def dispose(self): - pass +class UptimekumaProvider(BaseProvider): + PROVIDER_DISPLAY_NAME = "UptimeKuma" + PROVIDER_TAGS = ["alert"] - def validate_scopes(self): - """ - Validate that the scopes provided in the config are valid - """ - api = UptimeKumaApi(self.authentication_config.host_url) - response = api.login(self.authentication_config.username, self.authentication_config.password) - api.disconnect() - if "token" in response: - return {"alerts": True} - return {"alerts": False} - - def validate_config(self): - self.authentication_config = UptimekumaProviderAuthConfig(**self.config.authentication) - if self.authentication_config.host_url is None: - raise ProviderException("UptimeKuma Host URL is required") - if self.authentication_config.username is None: - raise ProviderException("UptimeKuma Username is required") - if self.authentication_config.password is None: - raise ProviderException("UptimeKuma Password is required") - - def _get_heartbeats(self): - try: - api = UptimeKumaApi(self.authentication_config.host_url) - api.login(self.authentication_config.username, self.authentication_config.password) - response = api.get_heartbeats() - api.disconnect() - - length = len(response) - - if length == 0: - return [] - - for key in response: - heartbeat = response[key][-1] - name = api.get_monitor(heartbeat["monitor_id"])['name'] - - return AlertDto( - id=heartbeat["id"], - name=name, - monitor_id=heartbeat["monitor_id"], - description=heartbeat["msg"], - status=heartbeat["status"].name.lower(), - lastReceived=heartbeat["time"], - ping=heartbeat["ping"], - source=["uptimekuma"] + PROVIDER_SCOPES = [ + ProviderScope( + name="alerts", + description="Read alerts from UptimeKuma", + ) + ] + + STATUS_MAP = { + "up": AlertStatus.RESOLVED, + "down": AlertStatus.FIRING, + } + + def __init__( + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + ): + super().__init__(context_manager, provider_id, config) + + def dispose(self): + pass + + def validate_scopes(self): + """ + Validate that the scopes provided in the config are valid + """ + api = UptimeKumaApi(self.authentication_config.host_url) + response = api.login( + self.authentication_config.username, self.authentication_config.password + ) + api.disconnect() + if "token" in response: + return {"alerts": True} + return {"alerts": False} + + def validate_config(self): + self.authentication_config = UptimekumaProviderAuthConfig( + **self.config.authentication + ) + if self.authentication_config.host_url is None: + raise ProviderException("UptimeKuma Host URL is required") + if self.authentication_config.username is None: + raise ProviderException("UptimeKuma Username is required") + if self.authentication_config.password is None: + raise ProviderException("UptimeKuma Password is required") + + def _get_heartbeats(self): + try: + api = UptimeKumaApi(self.authentication_config.host_url) + api.login( + self.authentication_config.username, self.authentication_config.password + ) + response = api.get_heartbeats() + api.disconnect() + + length = len(response) + + if length == 0: + return [] + + for key in response: + heartbeat = response[key][-1] + name = api.get_monitor(heartbeat["monitor_id"])["name"] + + return AlertDto( + id=heartbeat["id"], + name=name, + monitor_id=heartbeat["monitor_id"], + description=heartbeat["msg"], + status=heartbeat["status"].name.lower(), + lastReceived=heartbeat["time"], + ping=heartbeat["ping"], + source=["uptimekuma"], + ) + + except Exception as e: + self.logger.error("Error getting heartbeats from UptimeKuma: %s", e) + raise Exception(f"Error getting heartbeats from UptimeKuma: {e}") + + def _get_alerts(self) -> list[AlertDto]: + try: + self.logger.info("Collecting alerts (heartbeats) from UptimeKuma") + alerts = self._get_heartbeats() + return alerts + except Exception as e: + self.logger.error("Error getting alerts from UptimeKuma: %s", e) + raise Exception(f"Error getting alerts from UptimeKuma: {e}") + + @staticmethod + def _format_alert(event: dict) -> AlertDto: + + alert = AlertDto( + id=event["monitor"]["id"], + name=event["monitor"]["name"], + monitor_url=event["monitor"]["url"], + status=event["heartbeat"]["status"], + description=event["msg"], + lastReceived=event["heartbeat"]["localDateTime"], + msg=event["heartbeat"]["msg"], + source=["uptimekuma"], ) - - except Exception as e: - self.logger.error("Error getting heartbeats from UptimeKuma: %s", e) - raise Exception(f"Error getting heartbeats from UptimeKuma: {e}") - - def _get_alerts(self) -> list[AlertDto]: - try: - self.logger.info("Collecting alerts (heartbeats) from UptimeKuma") - alerts = self._get_heartbeats() - return alerts - except Exception as e: - self.logger.error("Error getting alerts from UptimeKuma: %s", e) - raise Exception(f"Error getting alerts from UptimeKuma: {e}") - - @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["UptimekumaProvider"] = None - ) -> AlertDto: - - alert = AlertDto( - id=event['monitor']['id'], - name=event['monitor']['name'], - monitor_url=event['monitor']['url'], - status=event['heartbeat']['status'], - description=event['msg'], - lastReceived=event['heartbeat']['localDateTime'], - msg=event['heartbeat']['msg'], - source=["uptimekuma"] - ) - return alert - + return alert + + if __name__ == "__main__": - import logging - - logging.basicConfig(level=logging.DEBUG, handlers=[logging.StreamHandler()]) - context_manager = ContextManager( - tenant_id="singletenant", - workflow_id="test", - ) - - import os - - uptimekuma_host = os.environ.get("UPTIMEKUMA_HOST") - uptimekuma_username = os.environ.get("UPTIMEKUMA_USERNAME") - uptimekuma_password = os.environ.get("UPTIMEKUMA_PASSWORD") - - if uptimekuma_host is None: - raise Exception("UPTIMEKUMA_HOST is required") - if uptimekuma_username is None: - raise Exception("UPTIMEKUMA_USERNAME is required") - if uptimekuma_password is None: - raise Exception("UPTIMEKUMA_PASSWORD is required") - - config = ProviderConfig( - description="UptimeKuma Provider", - authentication={ - "host_url": uptimekuma_host, - "username": uptimekuma_username, - "password": uptimekuma_password, - }, - ) - - provider = UptimekumaProvider( - context_manager=context_manager, - provider_id="uptimekuma", - config=config, - ) - - alerts = provider.get_alerts() - print(alerts) - provider.dispose() + import logging + + logging.basicConfig(level=logging.DEBUG, handlers=[logging.StreamHandler()]) + context_manager = ContextManager( + tenant_id="singletenant", + workflow_id="test", + ) + + import os + + uptimekuma_host = os.environ.get("UPTIMEKUMA_HOST") + uptimekuma_username = os.environ.get("UPTIMEKUMA_USERNAME") + uptimekuma_password = os.environ.get("UPTIMEKUMA_PASSWORD") + + if uptimekuma_host is None: + raise Exception("UPTIMEKUMA_HOST is required") + if uptimekuma_username is None: + raise Exception("UPTIMEKUMA_USERNAME is required") + if uptimekuma_password is None: + raise Exception("UPTIMEKUMA_PASSWORD is required") + + config = ProviderConfig( + description="UptimeKuma Provider", + authentication={ + "host_url": uptimekuma_host, + "username": uptimekuma_username, + "password": uptimekuma_password, + }, + ) + + provider = UptimekumaProvider( + context_manager=context_manager, + provider_id="uptimekuma", + config=config, + ) + + alerts = provider.get_alerts() + print(alerts) + provider.dispose() diff --git a/keep/providers/victoriametrics_provider/victoriametrics_provider.py b/keep/providers/victoriametrics_provider/victoriametrics_provider.py index feb037fbf..bee7cd90e 100644 --- a/keep/providers/victoriametrics_provider/victoriametrics_provider.py +++ b/keep/providers/victoriametrics_provider/victoriametrics_provider.py @@ -4,7 +4,6 @@ import dataclasses import datetime -from typing import Optional import pydantic import requests @@ -80,7 +79,7 @@ class VictoriametricsProvider(BaseProvider): "high": AlertSeverity.HIGH, "warning": AlertSeverity.WARNING, "low": AlertSeverity.LOW, - "test": AlertSeverity.INFO + "test": AlertSeverity.INFO, } STATUS_MAP = { @@ -92,19 +91,26 @@ class VictoriametricsProvider(BaseProvider): } def validate_scopes(self) -> dict[str, bool | str]: - response = requests.get(f"{self.vmalert_host}:{self.authentication_config.VMAlertPort}") + response = requests.get( + f"{self.vmalert_host}:{self.authentication_config.VMAlertPort}" + ) if response.status_code == 200: connected_to_client = True self.logger.info("Connected to client successfully") else: - connected_to_client = f"Error while connecting to client, {response.status_code}" - self.logger.error("Error while connecting to client", extra={"status_code": response.status_code}) + connected_to_client = ( + f"Error while connecting to client, {response.status_code}" + ) + self.logger.error( + "Error while connecting to client", + extra={"status_code": response.status_code}, + ) return { - 'connected': connected_to_client, + "connected": connected_to_client, } def __init__( - self, context_manager: ContextManager, provider_id: str, config: ProviderConfig + self, context_manager: ContextManager, provider_id: str, config: ProviderConfig ): self._host = None super().__init__(context_manager, provider_id, config) @@ -131,7 +137,7 @@ def vmalert_host(self): # if the user explicitly supplied a host with http/https, use it if self.authentication_config.VMAlertHost.startswith( - "http://" + "http://" ) or self.authentication_config.VMAlertHost.startswith("https://"): self._host = self.authentication_config.VMAlertHost return self.authentication_config.VMAlertHost.rstrip("/") @@ -154,18 +160,16 @@ def vmalert_host(self): return self.authentication_config.VMAlertHost.rstrip("/") @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["BaseProvider"] = None - ) -> AlertDto | list[AlertDto]: + def _format_alert(event: dict) -> AlertDto | list[AlertDto]: alerts = [] for alert in event["alerts"]: alerts.append( AlertDto( name=alert["labels"]["alertname"], - fingerprint=alert['fingerprint'], - id=alert['fingerprint'], - description=alert["annotations"]['description'], - message=alert["annotations"]['summary'], + fingerprint=alert["fingerprint"], + id=alert["fingerprint"], + description=alert["annotations"]["description"], + message=alert["annotations"]["summary"], status=VictoriametricsProvider.STATUS_MAP[alert["status"]], startedAt=alert["startsAt"], url=alert["generatorURL"], @@ -179,19 +183,23 @@ def _format_alert( return alerts def _get_alerts(self) -> list[AlertDto]: - response = requests.get(f"{self.vmalert_host}:{self.authentication_config.VMAlertPort}/api/v1/alerts") + response = requests.get( + f"{self.vmalert_host}:{self.authentication_config.VMAlertPort}/api/v1/alerts" + ) if response.status_code == 200: alerts = [] response = response.json() - for alert in response['data']['alerts']: + for alert in response["data"]["alerts"]: alerts.append( AlertDto( name=alert["name"], - id=alert['id'], - description=alert["annotations"]['description'], - message=alert["annotations"]['summary'], + id=alert["id"], + description=alert["annotations"]["description"], + message=alert["annotations"]["summary"], status=VictoriametricsProvider.STATUS_MAP[alert["state"]], - severity=VictoriametricsProvider.STATUS_MAP[alert["labels"]["severity"]], + severity=VictoriametricsProvider.STATUS_MAP[ + alert["labels"]["severity"] + ], startedAt=alert["activeAt"], url=alert["source"], source=["victoriametrics"], @@ -204,7 +212,7 @@ def _get_alerts(self) -> list[AlertDto]: self.logger.error("Failed to get alerts", extra=response.json()) raise Exception("Could not get alerts") - def _query(self, query="", start="", end="", step="", queryType="", **kwargs:dict): + def _query(self, query="", start="", end="", step="", queryType="", **kwargs: dict): if queryType == "query": response = requests.get( f"{self.vmalert_host}:{self.authentication_config.VMAlertPort}/api/v1/query", @@ -213,9 +221,11 @@ def _query(self, query="", start="", end="", step="", queryType="", **kwargs:dic if response.status_code == 200: return response.json() else: - self.logger.error("Failed to perform instant query", extra=response.json()) + self.logger.error( + "Failed to perform instant query", extra=response.json() + ) raise Exception("Could not perform instant query") - + elif queryType == "query_range": response = requests.get( f"{self.vmalert_host}:{self.authentication_config.VMAlertPort}/api/v1/query_range", @@ -224,9 +234,11 @@ def _query(self, query="", start="", end="", step="", queryType="", **kwargs:dic if response.status_code == 200: return response.json() else: - self.logger.error("Failed to perform range query", extra=response.json()) + self.logger.error( + "Failed to perform range query", extra=response.json() + ) raise Exception("Could not range query") - + else: self.logger.error("Invalid query type") raise Exception("Invalid query type") diff --git a/keep/providers/zabbix_provider/zabbix_provider.py b/keep/providers/zabbix_provider/zabbix_provider.py index 10caa2d4a..57faf12ab 100644 --- a/keep/providers/zabbix_provider/zabbix_provider.py +++ b/keep/providers/zabbix_provider/zabbix_provider.py @@ -8,7 +8,7 @@ import logging import os import random -from typing import Literal, Optional +from typing import Literal import pydantic import requests @@ -575,9 +575,7 @@ def setup_webhook( self.logger.info("Finished installing webhook") @staticmethod - def _format_alert( - event: dict, provider_instance: Optional["ZabbixProvider"] = None - ) -> AlertDto: + def _format_alert(event: dict) -> AlertDto: environment = "unknown" tags_raw = event.pop("tags", "[]") try: diff --git a/keep/rulesengine/rulesengine.py b/keep/rulesengine/rulesengine.py index 9cdacc1b8..095077645 100644 --- a/keep/rulesengine/rulesengine.py +++ b/keep/rulesengine/rulesengine.py @@ -4,7 +4,7 @@ import celpy from keep.api.consts import STATIC_PRESETS -from keep.api.core.db import get_incident_for_grouping_rule, assign_alert_to_incident +from keep.api.core.db import assign_alert_to_incident, get_incident_for_grouping_rule from keep.api.core.db import get_rules as get_rules_db from keep.api.models.alert import AlertDto, AlertSeverity, IncidentDto from keep.api.utils.cel_utils import preprocess_cel_expression @@ -37,15 +37,17 @@ def run_rules(self, events: list[AlertDto]) -> list[IncidentDto]: self.logger.info( f"Rule {rule.name} on event {event.id} is relevant" ) - + rule_fingerprint = self._calc_rule_fingerprint(event, rule) - incident = get_incident_for_grouping_rule(self.tenant_id, rule, rule.timeframe, rule_fingerprint) + incident = get_incident_for_grouping_rule( + self.tenant_id, rule, rule.timeframe, rule_fingerprint + ) incident = assign_alert_to_incident( alert_id=event.event_id, incident_id=incident.id, - tenant_id=self.tenant_id + tenant_id=self.tenant_id, ) incidents_dto[incident.id] = IncidentDto.from_db_incident(incident) diff --git a/keep/workflowmanager/workflowscheduler.py b/keep/workflowmanager/workflowscheduler.py index c7df0b534..93611d3d4 100644 --- a/keep/workflowmanager/workflowscheduler.py +++ b/keep/workflowmanager/workflowscheduler.py @@ -42,6 +42,8 @@ def __init__(self, workflow_manager): async def start(self): self.logger.info("Starting workflows scheduler") + # Shahar: fix for a bug in unit tests + self._stop = False thread = threading.Thread(target=self._start) thread.start() self.threads.append(thread) @@ -222,7 +224,7 @@ def handle_manual_event_workflow( execution_number=unique_execution_number, fingerprint=alert.fingerprint, event_id=alert.event_id, - event_type="alert" + event_type="alert", ) self.logger.info(f"Workflow execution id: {workflow_execution_id}") # This is kinda WTF exception since create_workflow_execution shouldn't fail for manual @@ -339,7 +341,6 @@ def _handle_event_workflows(self): event_type = "alert" fingerprint = event.fingerprint - # In manual, we create the workflow execution id sync so it could be tracked by the caller (UI) # In event (e.g. alarm), we will create it here if not workflow_execution_id: diff --git a/scripts/simulate_alerts.py b/scripts/simulate_alerts.py index 5bdc6b86b..b3b4872c5 100644 --- a/scripts/simulate_alerts.py +++ b/scripts/simulate_alerts.py @@ -18,8 +18,9 @@ def main(): + GENERATE_DEDUPLICATIONS = True keep_api_key = os.environ.get("KEEP_API_KEY") - keep_api_url = os.environ.get("KEEP_API_URL") + keep_api_url = os.environ.get("KEEP_API_URL") or "http://localhost:8080" if keep_api_key is None or keep_api_url is None: raise Exception("KEEP_API_KEY and KEEP_API_URL must be set") @@ -35,25 +36,31 @@ def main(): provider = provider_classes[provider_type] alert = provider.simulate_alert() - logger.info("Sending alert: {}".format(alert)) - try: - env = random.choice(["production", "staging", "development"]) - response = requests.post( - send_alert_url + f"?provider_id={provider_type}-{env}", - headers={"x-api-key": keep_api_key}, - json=alert, - ) - except Exception as e: - logger.error("Failed to send alert: {}".format(e)) - time.sleep(0.2) - continue - - if response.status_code != 202: - logger.error("Failed to send alert: {}".format(response.text)) - else: - logger.info("Alert sent successfully") - - time.sleep(0.2) # Wait for 10 seconds before sending the next alert + # Determine number of times to send the same alert + num_iterations = 1 + if GENERATE_DEDUPLICATIONS: + num_iterations = random.randint(1, 3) + + for _ in range(num_iterations): + logger.info("Sending alert: {}".format(alert)) + try: + env = random.choice(["production", "staging", "development"]) + response = requests.post( + send_alert_url + f"?provider_id={provider_type}-{env}", + headers={"x-api-key": keep_api_key}, + json=alert, + ) + except Exception as e: + logger.error("Failed to send alert: {}".format(e)) + time.sleep(0.2) + continue + + if response.status_code != 202: + logger.error("Failed to send alert: {}".format(response.text)) + else: + logger.info("Alert sent successfully") + + time.sleep(0.2) # Wait for 0.2 seconds before sending the next alert if __name__ == "__main__": diff --git a/tests/conftest.py b/tests/conftest.py index 0a5088d84..220c86518 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -385,7 +385,7 @@ def elastic_client(request): pass -@pytest.fixture +@pytest.fixture(scope="session") def keycloak_client(request): os.environ["KEYCLOAK_URL"] = "http://localhost:8787/auth/" os.environ["KEYCLOAK_REALM"] = "keeptest" diff --git a/tests/fixtures/client.py b/tests/fixtures/client.py index c49b6ece5..e71a3a2c1 100644 --- a/tests/fixtures/client.py +++ b/tests/fixtures/client.py @@ -47,7 +47,10 @@ def test_app(monkeypatch, request): for event_handler in app.router.on_startup: asyncio.run(event_handler()) - return app + yield app + + for event_handler in app.router.on_shutdown: + asyncio.run(event_handler()) # Fixture for TestClient using the test_app fixture diff --git a/tests/test_alert_deduplicator.py b/tests/test_alert_deduplicator.py deleted file mode 100644 index a83b988f9..000000000 --- a/tests/test_alert_deduplicator.py +++ /dev/null @@ -1,142 +0,0 @@ -from keep.api.alert_deduplicator.alert_deduplicator import AlertDeduplicator -from keep.api.core.dependencies import SINGLE_TENANT_UUID -from keep.api.models.alert import AlertDto, AlertSeverity, AlertStatus -from keep.api.models.db.alert import Alert, AlertDeduplicationFilter - -# Mocked filter data -filters = [ - {"id": 1, "matcher_cel": 'source == "sensorA"', "fields": ["field_to_remove_1"]}, - {"id": 2, "matcher_cel": 'source == "sensorB"', "fields": ["field_to_remove_2"]}, -] - -# Mocked alerts for testing -alerts = [] - - -def test_deduplication_sanity(db_session): - deduplicator = AlertDeduplicator(SINGLE_TENANT_UUID) - alert = AlertDto( - id="grafana-1", - source=["grafana"], - name="grafana-test-alert", - status=AlertStatus.FIRING, - severity=AlertSeverity.CRITICAL, - lastReceived="2021-08-01T00:00:00Z", - ) - alert_hash, deduplicated = deduplicator.is_deduplicated(alert) - # shouldn't be deduplicated - assert not deduplicated - # now add it to the db - db_session.add( - Alert( - tenant_id=SINGLE_TENANT_UUID, - provider_type="test", - provider_id="test", - event=alert.dict(), - fingerprint=alert.fingerprint, - alert_hash=alert_hash, - ) - ) - db_session.commit() - # Now let's re run it - should be deduplicated - _, deduplicated = deduplicator.is_deduplicated(alert) - assert deduplicated - - -def test_deduplication_with_matcher(db_session): - # add the matcher: - matcher = AlertDeduplicationFilter( - tenant_id=SINGLE_TENANT_UUID, - matcher_cel='source[0] == "grafana"', - fields=["labels.some-non-relevant-field-2"], - ) - db_session.add(matcher) - db_session.commit() - # now let's run the deduplicator - deduplicator = AlertDeduplicator(SINGLE_TENANT_UUID) - alert = AlertDto( - id="grafana-1", - source=["grafana"], - name="grafana-test-alert", - status=AlertStatus.FIRING, - severity=AlertSeverity.CRITICAL, - lastReceived="2021-08-01T00:00:00Z", - labels={ - "some-non-relevant-field-1": "1234", - "some-non-relevant-field-2": "4321", - }, - ) - alert_hash, deduplicated = deduplicator.is_deduplicated(alert) - # sanity - shouldn't be deduplicated - assert not deduplicated - # now add it to the db - db_session.add( - Alert( - tenant_id=SINGLE_TENANT_UUID, - provider_type="test", - provider_id="test", - event=alert.dict(), - fingerprint=alert.fingerprint, - alert_hash=alert_hash, - ) - ) - db_session.commit() - # Now let's re run it - should not be deduplicated (since some-non-relevant-field-1 is not in fields) - alert.labels["some-non-relevant-field-1"] = "1111" - _, deduplicated = deduplicator.is_deduplicated(alert) - # Shouldn't be deduplicated since some-non-relevant-field-1 changed - # and it is not the field we are removing in filter - assert not deduplicated - # Now let's re run it - should be deduplicated - alert.labels["some-non-relevant-field-1"] = "1234" - alert.labels["some-non-relevant-field-2"] = "1111" - alert_hash, deduplicated = deduplicator.is_deduplicated(alert) - # Should be deduplicated since some-non-relevant-field-2 changed - # and it is the field we are removing in filter - assert deduplicated - - -def test_deduplication_with_unrelated_filter(db_session): - # add the matcher: - matcher = AlertDeduplicationFilter( - tenant_id=SINGLE_TENANT_UUID, - matcher_cel='source[0] == "grafana"', - fields=["labels.some-non-relevant-field"], - ) - db_session.add(matcher) - db_session.commit() - # now let's run the deduplicator - deduplicator = AlertDeduplicator(SINGLE_TENANT_UUID) - alert = AlertDto( - id="grafana-1", - source=["not-grafana"], - name="grafana-test-alert", - status=AlertStatus.FIRING, - severity=AlertSeverity.CRITICAL, - lastReceived="2021-08-01T00:00:00Z", - labels={ - "some-non-relevant-field": "1234", - }, - ) - alert_hash, deduplicated = deduplicator.is_deduplicated(alert) - # sanity - shouldn't be deduplicated anyway - assert not deduplicated - # now add it to the db - db_session.add( - Alert( - tenant_id=SINGLE_TENANT_UUID, - provider_type="test", - provider_id="test", - event=alert.dict(), - fingerprint=alert.fingerprint, - alert_hash=alert_hash, - ) - ) - db_session.commit() - # Let's change the non relevant field and re run it - should not be deduplicated - # since the filter does not match - alert.labels["some-non-relevant-field"] = "1111" - _, deduplicated = deduplicator.is_deduplicated(alert) - # Shouldn't be deduplicated since some-non-relevant-field-1 changed - # and it is not the field we are removing in filter - assert not deduplicated diff --git a/tests/test_auth.py b/tests/test_auth.py index 86ad7a385..f0c506137 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -203,7 +203,11 @@ def test_api_key_impersonation_without_admin(db_session, client, test_app): @pytest.mark.parametrize( "test_app", [ - {"AUTH_TYPE": "SINGLE_TENANT", "KEEP_IMPERSONATION_ENABLED": "true"}, + { + "AUTH_TYPE": "SINGLE_TENANT", + "KEEP_IMPERSONATION_ENABLED": "true", + "KEEP_IMPERSONATION_AUTO_PROVISION": "false", + }, ], indirect=True, ) diff --git a/tests/test_deduplications.py b/tests/test_deduplications.py new file mode 100644 index 000000000..54bce5c40 --- /dev/null +++ b/tests/test_deduplications.py @@ -0,0 +1,739 @@ +import random +import uuid + +import pytest + +from keep.providers.providers_factory import ProvidersFactory +from tests.fixtures.client import client, setup_api_key, test_app # noqa + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_default_deduplication_rule(db_session, client, test_app): + # insert an alert with some provider_id and make sure that the default deduplication rule is working + provider_classes = { + provider: ProvidersFactory.get_provider_class(provider) + for provider in ["datadog", "prometheus"] + } + for provider_type, provider in provider_classes.items(): + alert = provider.simulate_alert() + client.post( + f"/alerts/event/{provider_type}?", + json=alert, + headers={"x-api-key": "some-api-key"}, + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + assert len(deduplication_rules) == 3 # default + datadog + prometheus + + for dedup_rule in deduplication_rules: + # check that the default deduplication rule is working + if dedup_rule.get("provider_type") == "keep": + assert dedup_rule.get("ingested") == 0 + assert dedup_rule.get("default") + # check how many times the alert was deduplicated in the last 24 hours + assert dedup_rule.get("distribution") == [ + {"hour": i, "number": 0} for i in range(24) + ] + # check that the datadog/prometheus deduplication rule is working + else: + assert dedup_rule.get("ingested") == 1 + # the deduplication ratio is zero since the alert was not deduplicated + assert dedup_rule.get("dedup_ratio") == 0 + assert dedup_rule.get("default") + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_deduplication_sanity(db_session, client, test_app): + # insert the same alert twice and make sure that the default deduplication rule is working + # insert an alert with some provider_id and make sure that the default deduplication rule is working + provider = ProvidersFactory.get_provider_class("datadog") + alert = provider.simulate_alert() + for i in range(2): + client.post( + "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + assert len(deduplication_rules) == 2 # default + datadog + + for dedup_rule in deduplication_rules: + # check that the default deduplication rule is working + if dedup_rule.get("provider_type") == "keep": + assert dedup_rule.get("ingested") == 0 + assert dedup_rule.get("default") + # check how many times the alert was deduplicated in the last 24 hours + assert dedup_rule.get("distribution") == [ + {"hour": i, "number": 0} for i in range(24) + ] + # check that the datadog/prometheus deduplication rule is working + else: + assert dedup_rule.get("ingested") == 2 + # the deduplication ratio is zero since the alert was not deduplicated + assert dedup_rule.get("dedup_ratio") == 50.0 + assert dedup_rule.get("default") + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_deduplication_sanity_2(db_session, client, test_app): + # insert two different alerts, twice each, and make sure that the default deduplication rule is working + provider = ProvidersFactory.get_provider_class("datadog") + alert1 = provider.simulate_alert() + alert2 = alert1 + # datadog deduplicated by monitor_id + while alert2.get("monitor_id") == alert1.get("monitor_id"): + alert2 = provider.simulate_alert() + + for alert in [alert1, alert2]: + for _ in range(2): + client.post( + "/alerts/event/datadog", + json=alert, + headers={"x-api-key": "some-api-key"}, + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + assert len(deduplication_rules) == 2 # default + datadog + + for dedup_rule in deduplication_rules: + if dedup_rule.get("provider_type") == "datadog": + assert dedup_rule.get("ingested") == 4 + assert dedup_rule.get("dedup_ratio") == 50.0 + assert dedup_rule.get("default") + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_deduplication_sanity_3(db_session, client, test_app): + # insert many alerts and make sure that the default deduplication rule is working + provider = ProvidersFactory.get_provider_class("datadog") + alerts = [provider.simulate_alert() for _ in range(10)] + + monitor_ids = set() + for alert in alerts: + # lets make it not deduplicated by randomizing the monitor_id + while alert["monitor_id"] in monitor_ids: + alert["monitor_id"] = random.randint(0, 10**10) + monitor_ids.add(alert["monitor_id"]) + client.post( + "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + assert len(deduplication_rules) == 2 # default + datadog + + for dedup_rule in deduplication_rules: + if dedup_rule.get("provider_type") == "datadog": + assert dedup_rule.get("ingested") == 10 + assert dedup_rule.get("dedup_ratio") == 0 + assert dedup_rule.get("default") + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_custom_deduplication_rule(db_session, client, test_app): + provider = ProvidersFactory.get_provider_class("datadog") + alert1 = provider.simulate_alert() + client.post( + "/alerts/event/datadog", json=alert1, headers={"x-api-key": "some-api-key"} + ) + + # create a custom deduplication rule and insert alerts that should be deduplicated by this + custom_rule = { + "name": "Custom Rule", + "description": "Custom Rule Description", + "provider_type": "datadog", + "fingerprint_fields": ["title", "message"], + "full_deduplication": False, + "ignore_fields": None, + } + + resp = client.post( + "/deduplications", json=custom_rule, headers={"x-api-key": "some-api-key"} + ) + assert resp.status_code == 200 + + provider = ProvidersFactory.get_provider_class("datadog") + alert = provider.simulate_alert() + + for _ in range(2): + client.post( + "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + custom_rule_found = False + for dedup_rule in deduplication_rules: + if dedup_rule.get("name") == "Custom Rule": + custom_rule_found = True + assert dedup_rule.get("ingested") == 2 + assert dedup_rule.get("dedup_ratio") == 50.0 + assert not dedup_rule.get("default") + + assert custom_rule_found + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_custom_deduplication_rule_behaviour(db_session, client, test_app): + # create a custom deduplication rule and insert alerts that should be deduplicated by this + provider = ProvidersFactory.get_provider_class("datadog") + alert1 = provider.simulate_alert() + client.post( + "/alerts/event/datadog", json=alert1, headers={"x-api-key": "some-api-key"} + ) + custom_rule = { + "name": "Custom Rule", + "description": "Custom Rule Description", + "provider_type": "datadog", + "fingerprint_fields": ["title", "message"], + "full_deduplication": False, + "ignore_fields": None, + } + + resp = client.post( + "/deduplications", json=custom_rule, headers={"x-api-key": "some-api-key"} + ) + assert resp.status_code == 200 + + provider = ProvidersFactory.get_provider_class("datadog") + alert = provider.simulate_alert() + + for _ in range(2): + # the default rule should deduplicate the alert by monitor_id so let's randomize it - + # if the custom rule is working, the alert should be deduplicated by title and message + alert["monitor_id"] = random.randint(0, 10**10) + client.post( + "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + custom_rule_found = False + for dedup_rule in deduplication_rules: + if dedup_rule.get("name") == "Custom Rule": + custom_rule_found = True + assert dedup_rule.get("ingested") == 2 + assert dedup_rule.get("dedup_ratio") == 50.0 + assert not dedup_rule.get("default") + + assert custom_rule_found + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"keepDatadog":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', + }, + ], + indirect=True, +) +def test_custom_deduplication_rule_2(db_session, client, test_app): + # create a custom full deduplication rule and insert alerts that should not be deduplicated by this + providers = client.get("/providers", headers={"x-api-key": "some-api-key"}).json() + datadog_provider_id = next( + provider["id"] + for provider in providers.get("installed_providers") + if provider["type"] == "datadog" + ) + + custom_rule = { + "name": "Custom Rule", + "description": "Custom Rule Description", + "provider_type": "datadog", + "provider_id": datadog_provider_id, + "fingerprint_fields": [ + "name", + "message", + ], # title in datadog mapped to name in keep + "full_deduplication": False, + "ignore_fields": ["field_that_never_exists"], + } + + response = client.post( + "/deduplications", json=custom_rule, headers={"x-api-key": "some-api-key"} + ) + assert response.status_code == 200 + + provider = ProvidersFactory.get_provider_class("datadog") + alert1 = provider.simulate_alert() + + client.post( + f"/alerts/event/datadog?provider_id={datadog_provider_id}", + json=alert1, + headers={"x-api-key": "some-api-key"}, + ) + alert1["title"] = "Different title" + client.post( + f"/alerts/event/datadog?provider_id={datadog_provider_id}", + json=alert1, + headers={"x-api-key": "some-api-key"}, + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + custom_rule_found = False + for dedup_rule in deduplication_rules: + if dedup_rule.get("name") == "Custom Rule": + custom_rule_found = True + assert dedup_rule.get("ingested") == 2 + assert dedup_rule.get("dedup_ratio") == 0 + assert not dedup_rule.get("default") + + assert custom_rule_found + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"keepDatadog":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', + }, + ], + indirect=True, +) +def test_update_deduplication_rule(db_session, client, test_app): + # create a custom deduplication rule and update it + response = client.get("/providers", headers={"x-api-key": "some-api-key"}) + assert response.status_code == 200 + datadog_provider_id = next( + provider["id"] + for provider in response.json().get("installed_providers") + if provider["type"] == "datadog" + ) + + custom_rule = { + "name": "Custom Rule", + "description": "Custom Rule Description", + "provider_type": "datadog", + "provider_id": datadog_provider_id, + "fingerprint_fields": ["title", "message"], + "full_deduplication": False, + "ignore_fields": None, + } + + response = client.post( + "/deduplications", json=custom_rule, headers={"x-api-key": "some-api-key"} + ) + assert response.status_code == 200 + + rule_id = response.json().get("id") + updated_rule = { + "name": "Updated Custom Rule", + "description": "Updated Custom Rule", + "provider_type": "datadog", + "provider_id": datadog_provider_id, + "fingerprint_fields": ["title"], + "full_deduplication": False, + "ignore_fields": None, + } + + response = client.put( + f"/deduplications/{rule_id}", + json=updated_rule, + headers={"x-api-key": "some-api-key"}, + ) + assert response.status_code == 200 + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + updated_rule_found = False + for dedup_rule in deduplication_rules: + if dedup_rule.get("id") == rule_id: + updated_rule_found = True + assert dedup_rule.get("description") == "Updated Custom Rule" + assert dedup_rule.get("fingerprint_fields") == ["title"] + + assert updated_rule_found + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_update_deduplication_rule_non_exist_provider(db_session, client, test_app): + # create a custom deduplication rule and update it + custom_rule = { + "name": "Custom Rule", + "description": "Custom Rule Description", + "provider_type": "datadog", + "fingerprint_fields": ["title", "message"], + "full_deduplication": False, + "ignore_fields": None, + } + response = client.post( + "/deduplications", json=custom_rule, headers={"x-api-key": "some-api-key"} + ) + assert response.status_code == 404 + assert response.json() == {"detail": "Provider datadog not found"} + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_update_deduplication_rule_linked_provider(db_session, client, test_app): + provider = ProvidersFactory.get_provider_class("datadog") + alert1 = provider.simulate_alert() + response = client.post( + "/alerts/event/datadog", json=alert1, headers={"x-api-key": "some-api-key"} + ) + custom_rule = { + "name": "Custom Rule", + "description": "Custom Rule Description", + "provider_type": "datadog", + "fingerprint_fields": ["title", "message"], + "full_deduplication": False, + "ignore_fields": None, + } + response = client.post( + "/deduplications", json=custom_rule, headers={"x-api-key": "some-api-key"} + ) + # once a linked provider is created, a customization should be allowed + assert response.status_code == 200 + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"keepDatadog":{"type":"datadog","authentication":{"api_key":"1234","app_key": "1234"}}}', + }, + ], + indirect=True, +) +def test_delete_deduplication_rule_sanity(db_session, client, test_app): + response = client.get("/providers", headers={"x-api-key": "some-api-key"}) + assert response.status_code == 200 + datadog_provider_id = next( + provider["id"] + for provider in response.json().get("installed_providers") + if provider["type"] == "datadog" + ) + # create a custom deduplication rule and delete it + custom_rule = { + "name": "Custom Rule", + "description": "Custom Rule Description", + "provider_type": "datadog", + "provider_id": datadog_provider_id, + "fingerprint_fields": ["title", "message"], + "full_deduplication": False, + "ignore_fields": None, + } + + response = client.post( + "/deduplications", json=custom_rule, headers={"x-api-key": "some-api-key"} + ) + assert response.status_code == 200 + + rule_id = response.json().get("id") + client.delete(f"/deduplications/{rule_id}", headers={"x-api-key": "some-api-key"}) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + assert all(rule.get("id") != rule_id for rule in deduplication_rules) + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_delete_deduplication_rule_invalid(db_session, client, test_app): + # try to delete a deduplication rule that does not exist + response = client.delete( + "/deduplications/non-existent-id", headers={"x-api-key": "some-api-key"} + ) + + assert response.status_code == 400 + assert response.json() == {"detail": "Invalid rule id"} + + # now use UUID + some_uuid = str(uuid.uuid4()) + response = client.delete( + f"/deduplications/{some_uuid}", headers={"x-api-key": "some-api-key"} + ) + assert response.status_code == 404 + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_delete_deduplication_rule_default(db_session, client, test_app): + # shoot an alert to create a default deduplication rule + provider = ProvidersFactory.get_provider_class("datadog") + alert = provider.simulate_alert() + client.post( + "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} + ) + + # try to delete a default deduplication rule + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + default_rule_id = next( + rule["id"] for rule in deduplication_rules if rule["default"] + ) + + response = client.delete( + f"/deduplications/{default_rule_id}", headers={"x-api-key": "some-api-key"} + ) + + assert response.status_code == 404 + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_full_deduplication(db_session, client, test_app): + # create a custom deduplication rule with full deduplication and insert alerts that should be deduplicated by this + provider = ProvidersFactory.get_provider_class("datadog") + alert = provider.simulate_alert() + # send the alert so a linked provider is created + response = client.post( + "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} + ) + custom_rule = { + "name": "Full Deduplication Rule", + "description": "Full Deduplication Rule", + "provider_type": "datadog", + "fingerprint_fields": ["title", "message", "source"], + "full_deduplication": True, + "ignore_fields": list(alert.keys()), # ignore all fields + } + + response = client.post( + "/deduplications", json=custom_rule, headers={"x-api-key": "some-api-key"} + ) + assert response.status_code == 200 + + for _ in range(3): + client.post( + "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + full_dedup_rule_found = False + for dedup_rule in deduplication_rules: + if dedup_rule.get("description") == "Full Deduplication Rule": + full_dedup_rule_found = True + assert dedup_rule.get("ingested") == 3 + assert 66.667 - dedup_rule.get("dedup_ratio") < 0.1 # 0.66666666....7 + + assert full_dedup_rule_found + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_partial_deduplication(db_session, client, test_app): + # insert a datadog alert with the same incident_id, group and title and make sure that the datadog default deduplication rule is working + provider = ProvidersFactory.get_provider_class("datadog") + base_alert = provider.simulate_alert() + + alerts = [ + base_alert, + {**base_alert, "message": "Different message"}, + {**base_alert, "source": "Different source"}, + ] + + for alert in alerts: + client.post( + "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + datadog_rule_found = False + for dedup_rule in deduplication_rules: + if dedup_rule.get("provider_type") == "datadog" and dedup_rule.get("default"): + datadog_rule_found = True + assert dedup_rule.get("ingested") == 3 + assert ( + dedup_rule.get("dedup_ratio") > 0 + and dedup_rule.get("dedup_ratio") < 100 + ) + + assert datadog_rule_found + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_ingesting_alert_without_fingerprint_fields(db_session, client, test_app): + # insert a datadog alert without the required fingerprint fields and make sure that it is not deduplicated + provider = ProvidersFactory.get_provider_class("datadog") + alert = provider.simulate_alert() + alert.pop("incident_id", None) + alert.pop("group", None) + alert["title"] = str(random.randint(0, 10**10)) + + client.post( + "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + datadog_rule_found = False + for dedup_rule in deduplication_rules: + if dedup_rule.get("provider_type") == "datadog" and dedup_rule.get("default"): + datadog_rule_found = True + assert dedup_rule.get("ingested") == 1 + assert dedup_rule.get("dedup_ratio") == 0 + + assert datadog_rule_found + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + }, + ], + indirect=True, +) +def test_deduplication_fields(db_session, client, test_app): + # insert a datadog alert with the same incident_id and make sure that the datadog default deduplication rule is working + provider = ProvidersFactory.get_provider_class("datadog") + base_alert = provider.simulate_alert() + + alerts = [ + base_alert, + {**base_alert, "group": "Different group"}, + {**base_alert, "title": "Different title"}, + ] + + for alert in alerts: + client.post( + "/alerts/event/datadog", json=alert, headers={"x-api-key": "some-api-key"} + ) + + deduplication_rules = client.get( + "/deduplications", headers={"x-api-key": "some-api-key"} + ).json() + + datadog_rule_found = False + for dedup_rule in deduplication_rules: + if dedup_rule.get("provider_type") == "datadog" and dedup_rule.get("default"): + datadog_rule_found = True + assert dedup_rule.get("ingested") == 3 + assert 66.667 - dedup_rule.get("dedup_ratio") < 0.1 # 0.66666666....7 + + assert datadog_rule_found diff --git a/tests/test_incidents.py b/tests/test_incidents.py index 6156525c6..b2ca94fdf 100644 --- a/tests/test_incidents.py +++ b/tests/test_incidents.py @@ -6,19 +6,26 @@ from sqlalchemy.orm.exc import DetachedInstanceError from keep.api.core.db import ( + IncidentSorting, add_alerts_to_incident_by_incident_id, create_incident_from_dict, get_alerts_data_for_incident, get_incident_by_id, + get_last_incidents, remove_alerts_to_incident_by_incident_id, - get_last_incidents, IncidentSorting, get_last_alerts, ) from keep.api.core.db_utils import get_json_extract_field from keep.api.core.dependencies import SINGLE_TENANT_UUID -from keep.api.models.alert import IncidentSeverity, AlertSeverity, AlertStatus, IncidentStatus +from keep.api.models.alert import ( + AlertSeverity, + AlertStatus, + IncidentSeverity, + IncidentStatus, +) from keep.api.models.db.alert import Alert from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts -from tests.fixtures.client import client, test_app +from tests.fixtures.client import client, test_app # noqa + def test_get_alerts_data_for_incident(db_session, setup_stress_alerts_no_elastic): alerts = setup_stress_alerts_no_elastic(100) @@ -32,35 +39,35 @@ def test_get_alerts_data_for_incident(db_session, setup_stress_alerts_no_elastic def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elastic): alerts = setup_stress_alerts_no_elastic(100) - incident = create_incident_from_dict(SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"}) + incident = create_incident_from_dict( + SINGLE_TENANT_UUID, {"user_generated_name": "test", "user_summary": "test"} + ) assert len(incident.alerts) == 0 add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, - incident.id, - [a.id for a in alerts] + SINGLE_TENANT_UUID, incident.id, [a.id for a in alerts] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) - assert sorted(incident.affected_services) == sorted(["service_{}".format(i) for i in range(10)]) - assert sorted(incident.sources) == sorted(["source_{}".format(i) for i in range(10)]) + assert sorted(incident.affected_services) == sorted( + ["service_{}".format(i) for i in range(10)] + ) + assert sorted(incident.sources) == sorted( + ["source_{}".format(i) for i in range(10)] + ) - service_field = get_json_extract_field(db_session, Alert.event, 'service') + service_field = get_json_extract_field(db_session, Alert.event, "service") - service_0 = ( - db_session.query(Alert.id) - .filter( - service_field == "service_0" - ) - .all() - ) + service_0 = db_session.query(Alert.id).filter(service_field == "service_0").all() remove_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, - [service_0[0].id, ] + [ + service_0[0].id, + ], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) @@ -68,12 +75,12 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti assert len(incident.alerts) == 99 assert "service_0" in incident.affected_services assert len(incident.affected_services) == 10 - assert sorted(incident.affected_services) == sorted(["service_{}".format(i) for i in range(10)]) + assert sorted(incident.affected_services) == sorted( + ["service_{}".format(i) for i in range(10)] + ) remove_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, - incident.id, - [a.id for a in service_0] + SINGLE_TENANT_UUID, incident.id, [a.id for a in service_0] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) @@ -81,20 +88,20 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti assert len(incident.alerts) == 90 assert "service_0" not in incident.affected_services assert len(incident.affected_services) == 9 - assert sorted(incident.affected_services) == sorted(["service_{}".format(i) for i in range(1, 10)]) + assert sorted(incident.affected_services) == sorted( + ["service_{}".format(i) for i in range(1, 10)] + ) source_1 = ( - db_session.query(Alert.id) - .filter( - Alert.provider_type == "source_1" - ) - .all() + db_session.query(Alert.id).filter(Alert.provider_type == "source_1").all() ) remove_alerts_to_incident_by_incident_id( SINGLE_TENANT_UUID, incident.id, - [source_1[0].id, ] + [ + source_1[0].id, + ], ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) @@ -103,18 +110,20 @@ def test_add_remove_alert_to_incidents(db_session, setup_stress_alerts_no_elasti assert "source_1" in incident.sources # source_0 was removed together with service_0 assert len(incident.sources) == 9 - assert sorted(incident.sources) == sorted(["source_{}".format(i) for i in range(1, 10)]) + assert sorted(incident.sources) == sorted( + ["source_{}".format(i) for i in range(1, 10)] + ) remove_alerts_to_incident_by_incident_id( - "keep", - incident.id, - [a.id for a in source_1] + "keep", incident.id, [a.id for a in source_1] ) incident = get_incident_by_id(SINGLE_TENANT_UUID, incident.id) assert len(incident.sources) == 8 - assert sorted(incident.sources) == sorted(["source_{}".format(i) for i in range(2, 10)]) + assert sorted(incident.sources) == sorted( + ["source_{}".format(i) for i in range(2, 10)] + ) def test_get_last_incidents(db_session, create_alert): @@ -123,38 +132,50 @@ def test_get_last_incidents(db_session, create_alert): for i in range(50): severity = next(severity_cycle) - incident = create_incident_from_dict(SINGLE_TENANT_UUID, { - "user_generated_name": f"test-{i}", - "user_summary": f"test-{i}", - "is_confirmed": True, - "severity": severity - }) - create_alert(f"alert-test-{i}", AlertStatus.FIRING, datetime.utcnow(), {"severity": AlertSeverity.from_number(severity)}) + incident = create_incident_from_dict( + SINGLE_TENANT_UUID, + { + "user_generated_name": f"test-{i}", + "user_summary": f"test-{i}", + "is_confirmed": True, + "severity": severity, + }, + ) + create_alert( + f"alert-test-{i}", + AlertStatus.FIRING, + datetime.utcnow(), + {"severity": AlertSeverity.from_number(severity).value}, + ) alert = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() add_alerts_to_incident_by_incident_id( - SINGLE_TENANT_UUID, - incident.id, - [alert.id] + SINGLE_TENANT_UUID, incident.id, [alert.id] ) incidents_default, incidents_default_count = get_last_incidents(SINGLE_TENANT_UUID) assert len(incidents_default) == 0 assert incidents_default_count == 0 - incidents_confirmed, incidents_confirmed_count = get_last_incidents(SINGLE_TENANT_UUID, is_confirmed=True) + incidents_confirmed, incidents_confirmed_count = get_last_incidents( + SINGLE_TENANT_UUID, is_confirmed=True + ) assert len(incidents_confirmed) == 25 assert incidents_confirmed_count == 50 for i in range(25): assert incidents_confirmed[i].user_generated_name == f"test-{i}" - incidents_limit_5, incidents_count_limit_5 = get_last_incidents(SINGLE_TENANT_UUID, is_confirmed=True, limit=5) + incidents_limit_5, incidents_count_limit_5 = get_last_incidents( + SINGLE_TENANT_UUID, is_confirmed=True, limit=5 + ) assert len(incidents_limit_5) == 5 assert incidents_count_limit_5 == 50 for i in range(5): assert incidents_limit_5[i].user_generated_name == f"test-{i}" - incidents_limit_5_page_2, incidents_count_limit_5_page_2 = get_last_incidents(SINGLE_TENANT_UUID, is_confirmed=True, limit=5, offset=5) + incidents_limit_5_page_2, incidents_count_limit_5_page_2 = get_last_incidents( + SINGLE_TENANT_UUID, is_confirmed=True, limit=5, offset=5 + ) assert len(incidents_limit_5_page_2) == 5 assert incidents_count_limit_5_page_2 == 50 @@ -164,45 +185,69 @@ def test_get_last_incidents(db_session, create_alert): # If alerts not preloaded, we will have detached session issue during attempt to get them # Background on this error at: https://sqlalche.me/e/14/bhk3 with pytest.raises(DetachedInstanceError): - alerts = incidents_confirmed[0].alerts + alerts = incidents_confirmed[0].alerts # noqa - incidents_with_alerts, _ = get_last_incidents(SINGLE_TENANT_UUID, is_confirmed=True, with_alerts=True) + incidents_with_alerts, _ = get_last_incidents( + SINGLE_TENANT_UUID, is_confirmed=True, with_alerts=True + ) for i in range(25): assert len(incidents_with_alerts[i].alerts) == 1 # Test sorting - incidents_sorted_by_severity, _ = get_last_incidents(SINGLE_TENANT_UUID, is_confirmed=True, sorting=IncidentSorting.severity, limit=5) - assert all([i.severity == IncidentSeverity.LOW.order for i in incidents_sorted_by_severity]) + incidents_sorted_by_severity, _ = get_last_incidents( + SINGLE_TENANT_UUID, is_confirmed=True, sorting=IncidentSorting.severity, limit=5 + ) + assert all( + [i.severity == IncidentSeverity.LOW.order for i in incidents_sorted_by_severity] + ) - incidents_sorted_by_severity_desc, _ = get_last_incidents(SINGLE_TENANT_UUID, is_confirmed=True, sorting=IncidentSorting.severity_desc, limit=5) - assert all([i.severity == IncidentSeverity.CRITICAL.order for i in incidents_sorted_by_severity_desc]) + incidents_sorted_by_severity_desc, _ = get_last_incidents( + SINGLE_TENANT_UUID, + is_confirmed=True, + sorting=IncidentSorting.severity_desc, + limit=5, + ) + assert all( + [ + i.severity == IncidentSeverity.CRITICAL.order + for i in incidents_sorted_by_severity_desc + ] + ) -@pytest.mark.parametrize( - "test_app", ["NO_AUTH"], indirect=True -) -def test_incident_status_change(db_session, client, test_app, setup_stress_alerts_no_elastic): - alerts = setup_stress_alerts_no_elastic(100) - incident = create_incident_from_dict("keep", {"name": "test", "description": "test"}) +@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) +def test_incident_status_change( + db_session, client, test_app, setup_stress_alerts_no_elastic +): - add_alerts_to_incident_by_incident_id( - "keep", - incident.id, - [a.id for a in alerts] + alerts = setup_stress_alerts_no_elastic(100) + incident = create_incident_from_dict( + "keep", {"name": "test", "description": "test"} ) + add_alerts_to_incident_by_incident_id("keep", incident.id, [a.id for a in alerts]) + incident = get_incident_by_id("keep", incident.id, with_alerts=True) alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) - assert len([alert for alert in alerts_dtos if alert.status == AlertStatus.RESOLVED.value]) == 0 + assert ( + len( + [ + alert + for alert in alerts_dtos + if alert.status == AlertStatus.RESOLVED.value + ] + ) + == 0 + ) response_ack = client.post( "/incidents/{}/status".format(incident.id), headers={"x-api-key": "some-key"}, json={ "status": IncidentStatus.ACKNOWLEDGED.value, - } + }, ) assert response_ack.status_code == 200 @@ -214,14 +259,23 @@ def test_incident_status_change(db_session, client, test_app, setup_stress_alert assert incident.status == IncidentStatus.ACKNOWLEDGED.value alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) - assert len([alert for alert in alerts_dtos if alert.status == AlertStatus.RESOLVED.value]) == 0 + assert ( + len( + [ + alert + for alert in alerts_dtos + if alert.status == AlertStatus.RESOLVED.value + ] + ) + == 0 + ) response_resolved = client.post( "/incidents/{}/status".format(incident.id), headers={"x-api-key": "some-key"}, json={ "status": IncidentStatus.RESOLVED.value, - } + }, ) assert response_resolved.status_code == 200 @@ -234,4 +288,13 @@ def test_incident_status_change(db_session, client, test_app, setup_stress_alert assert incident.status == IncidentStatus.RESOLVED.value # All alerts are resolved as well alerts_dtos = convert_db_alerts_to_dto_alerts(incident.alerts) - assert len([alert for alert in alerts_dtos if alert.status == AlertStatus.RESOLVED.value]) == 100 + assert ( + len( + [ + alert + for alert in alerts_dtos + if alert.status == AlertStatus.RESOLVED.value + ] + ) + == 100 + ) diff --git a/tests/test_parser.py b/tests/test_parser.py index 54c0a4a78..6e607fe37 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,8 +1,8 @@ # here we are going to create all needed tests for the parser.py parse function -import uuid import builtins import json import time +import uuid from pathlib import Path import pytest @@ -11,13 +11,13 @@ from fastapi import HTTPException from keep.api.core.dependencies import SINGLE_TENANT_UUID +from keep.api.models.db.action import Action from keep.contextmanager.contextmanager import ContextManager from keep.parser.parser import Parser, ParserUtils from keep.providers.mock_provider.mock_provider import MockProvider from keep.providers.models.provider_config import ProviderConfig from keep.step.step import Step from keep.workflowmanager.workflowstore import WorkflowStore -from keep.api.models.db.action import Action def test_parse_with_nonexistent_file(db_session): @@ -86,14 +86,6 @@ def test_parse_all_alerts(db_session): # You can add more specific assertions based on the content of mock_files and how they are parsed into alerts. -# This test depends on the previous one because of global providers configuration -@pytest.mark.xfail -def test_parse_with_alert_source_with_no_providers_file(): - parser = Parser() - with pytest.raises(TypeError): - parser.parse(str(workflow_path)) - - def parse_env_setup(context_manager): parser = Parser() parser._parse_providers_from_env(context_manager=context_manager) @@ -301,7 +293,9 @@ def test_parse_alert_steps(self): ## Test Case for reusable actions path_to_test_reusable_resources = Path(__file__).parent / "workflows" reusable_workflow_path = str(path_to_test_resources / "reusable_alert_for_testing.yml") -reusable_workflow_with_action_path = str(path_to_test_resources / "reusable_alert_with_actions_for_testing.yml") +reusable_workflow_with_action_path = str( + path_to_test_resources / "reusable_alert_with_actions_for_testing.yml" +) reusable_providers_path = str(path_to_test_resources / "providers_for_testing.yaml") reusable_actions_path = str(path_to_test_resources / "reusable_actions_for_testing.yml") @@ -397,7 +391,7 @@ def test_load_actions_config(self, db_session): class TestParserUtils: - + def test_deep_merge_dict(self): """Dictionary: if the merge combines recursively and prioritize values of source""" source = {"1": {"s11": "s11", "s12": "s12"}, "2": {"s21": "s21"}} @@ -405,16 +399,18 @@ def test_deep_merge_dict(self): expected_results = { "1": {"s11": "s11", "s12": "s12", "d11": "d11", "d12": "d12"}, "2": {"s21": "s21"}, - "3": {"d31": "d31"} + "3": {"d31": "d31"}, } results = ParserUtils.deep_merge(source, dest) assert expected_results == results def test_deep_merge_list(self): """List: if the merge combines recursively and prioritize values of source""" - source = {"data": [{"s1": "s1"}, {"s2": "s2"}]} - dest = {"data": [{"d1": "d1"}, {"d2": "d2"}, {"d3": "d3"}]} - expected_results = {"data": [{"s1": "s1", "d1": "d1"}, {"s2": "s2", "d2": "d2"}, {"d3": "d3"}]} + source = {"data": [{"s1": "s1"}, {"s2": "s2"}]} + dest = {"data": [{"d1": "d1"}, {"d2": "d2"}, {"d3": "d3"}]} + expected_results = { + "data": [{"s1": "s1", "d1": "d1"}, {"s2": "s2", "d2": "d2"}, {"d3": "d3"}] + } results = ParserUtils.deep_merge(source, dest) assert expected_results == results diff --git a/tests/test_workflow_execution.py b/tests/test_workflow_execution.py index af04791d4..2faabce1b 100644 --- a/tests/test_workflow_execution.py +++ b/tests/test_workflow_execution.py @@ -4,13 +4,13 @@ import pytest import pytz -from asyncio import sleep from keep.api.core.db import get_last_workflow_execution_by_workflow_id from keep.api.core.dependencies import SINGLE_TENANT_UUID -from keep.api.models.alert import AlertDto, AlertStatus, IncidentDtoIn, IncidentDto +from keep.api.models.alert import AlertDto, AlertStatus, IncidentDto from keep.api.models.db.workflow import Workflow from keep.workflowmanager.workflowmanager import WorkflowManager +from tests.fixtures.client import client, test_app # noqa # This workflow definition is used to test the execution of workflows based on alert firing times. # It defines two actions: @@ -78,52 +78,76 @@ def setup_workflow(db_session): @pytest.mark.parametrize( - "test_case, alert_statuses, expected_tier, db_session", + "test_app, test_case, alert_statuses, expected_tier, db_session", [ - ("No action", [[0, "firing"]], None, None), - ("Tier 1", [[20, "firing"]], 1, None), - ("Tier 2", [[35, "firing"]], 2, None), - ("Resolved before tier 1", [[10, "firing"], [11, "resolved"]], None, None), - ("Resolved after tier 1", [[20, "firing"], [25, "resolved"]], 1, None), - ("Resolved after tier 2", [[35, "firing"], [40, "resolved"]], 2, None), + ({"AUTH_TYPE": "NOAUTH"}, "No action", [[0, "firing"]], None, None), + ({"AUTH_TYPE": "NOAUTH"}, "Tier 1", [[20, "firing"]], 1, None), + ({"AUTH_TYPE": "NOAUTH"}, "Tier 2", [[35, "firing"]], 2, None), ( + {"AUTH_TYPE": "NOAUTH"}, + "Resolved before tier 1", + [[10, "firing"], [11, "resolved"]], + None, + None, + ), + ( + {"AUTH_TYPE": "NOAUTH"}, + "Resolved after tier 1", + [[20, "firing"], [25, "resolved"]], + 1, + None, + ), + ( + {"AUTH_TYPE": "NOAUTH"}, + "Resolved after tier 2", + [[35, "firing"], [40, "resolved"]], + 2, + None, + ), + ( + {"AUTH_TYPE": "NOAUTH"}, "Multiple firings, last one tier 2", [[10, "firing"], [20, "firing"], [35, "firing"]], 2, None, ), - ("No action", [[0, "firing"]], None, {"db": "mysql"}), - ("Tier 1", [[20, "firing"]], 1, {"db": "mysql"}), - ("Tier 2", [[35, "firing"]], 2, {"db": "mysql"}), + ({"AUTH_TYPE": "NOAUTH"}, "No action", [[0, "firing"]], None, {"db": "mysql"}), + ({"AUTH_TYPE": "NOAUTH"}, "Tier 1", [[20, "firing"]], 1, {"db": "mysql"}), + ({"AUTH_TYPE": "NOAUTH"}, "Tier 2", [[35, "firing"]], 2, {"db": "mysql"}), ( + {"AUTH_TYPE": "NOAUTH"}, "Resolved before tier 1", [[10, "firing"], [11, "resolved"]], None, {"db": "mysql"}, ), ( + {"AUTH_TYPE": "NOAUTH"}, "Resolved after tier 1", [[20, "firing"], [25, "resolved"]], 1, {"db": "mysql"}, ), ( + {"AUTH_TYPE": "NOAUTH"}, "Resolved after tier 2", [[35, "firing"], [40, "resolved"]], 2, {"db": "mysql"}, ), ( + {"AUTH_TYPE": "NOAUTH"}, "Multiple firings, last one tier 2", [[10, "firing"], [20, "firing"], [35, "firing"]], 2, {"db": "mysql"}, ), ], - indirect=["db_session"], + indirect=["test_app", "db_session"], ) def test_workflow_execution( db_session, + test_app, create_alert, setup_workflow, workflow_manager, @@ -229,36 +253,59 @@ def test_workflow_execution( @pytest.mark.parametrize( - "workflow_id, test_case, alert_statuses, expected_action", + "test_app, workflow_id, test_case, alert_statuses, expected_action", [ - ("alert-first-firing", "First firing", [[0, "firing"]], True), - ("alert-second-firing", "Second firing within 24h", [[0, "firing"], [1, "firing"]], False), ( + {"AUTH_TYPE": "NOAUTH"}, + "alert-first-firing", + "First firing", + [[0, "firing"]], + True, + ), + ( + {"AUTH_TYPE": "NOAUTH"}, + "alert-second-firing", + "Second firing within 24h", + [[0, "firing"], [1, "firing"]], + False, + ), + ( + {"AUTH_TYPE": "NOAUTH"}, "firing-resolved-firing-24", "First firing, resolved, and fired again after 24h", [[0, "firing"], [1, "resolved"], [25, "firing"]], True, ), ( + {"AUTH_TYPE": "NOAUTH"}, "multiple-firings-24", "Multiple firings within 24h", [[0, "firing"], [1, "firing"], [2, "firing"], [3, "firing"]], False, ), ( + {"AUTH_TYPE": "NOAUTH"}, "resolved-fired-24", "Resolved and fired again within 24h", [[0, "firing"], [1, "resolved"], [2, "firing"]], False, ), ( + {"AUTH_TYPE": "NOAUTH"}, "first-firing-multiple-resolutions", "First firing after multiple resolutions", [[0, "resolved"], [1, "resolved"], [2, "firing"]], True, ), - ("firing-exactly-24", "Firing exactly at 24h boundary", [[0, "firing"], [24, "firing"]], True), ( + {"AUTH_TYPE": "NOAUTH"}, + "firing-exactly-24", + "Firing exactly at 24h boundary", + [[0, "firing"], [24, "firing"]], + True, + ), + ( + {"AUTH_TYPE": "NOAUTH"}, "complex-scenario", "Complex scenario with multiple status changes", [ @@ -271,9 +318,11 @@ def test_workflow_execution( False, ), ], + indirect=["test_app"], ) def test_workflow_execution_2( db_session, + test_app, create_alert, workflow_manager, workflow_id, @@ -342,7 +391,8 @@ def test_workflow_execution_2( status = None while workflow_execution is None and count < 30 and status != "success": workflow_execution = get_last_workflow_execution_by_workflow_id( - SINGLE_TENANT_UUID, workflow_id, + SINGLE_TENANT_UUID, + workflow_id, ) if workflow_execution is not None: status = workflow_execution.status @@ -390,22 +440,30 @@ def test_workflow_execution_2( @pytest.mark.parametrize( - "test_case, alert_statuses, expected_tier, db_session", + "test_app, test_case, alert_statuses, expected_tier, db_session", [ - ("Tier 0", [[0, "firing"]], 0, None), - ("Tier 1", [[10, "firing"], [0, "firing"]], 1, None), - ("Resolved", [[15, "firing"], [5, "firing"], [0, "resolved"]], None, None), + ({"AUTH_TYPE": "NOAUTH"}, "Tier 0", [[0, "firing"]], 0, None), + ({"AUTH_TYPE": "NOAUTH"}, "Tier 1", [[10, "firing"], [0, "firing"]], 1, None), ( + {"AUTH_TYPE": "NOAUTH"}, + "Resolved", + [[15, "firing"], [5, "firing"], [0, "resolved"]], + None, + None, + ), + ( + {"AUTH_TYPE": "NOAUTH"}, "Tier 0 again", [[20, "firing"], [10, "firing"], [5, "resolved"], [0, "firing"]], 0, None, ), ], - indirect=["db_session"], + indirect=["test_app", "db_session"], ) def test_workflow_execution3( db_session, + test_app, create_alert, workflow_manager, test_case, @@ -478,7 +536,6 @@ def test_workflow_execution3( assert "Tier 1" in workflow_execution.results["send-slack-message-tier-1"][0] - workflow_definition_for_enabled_disabled = """workflow: id: %s description: Handle alerts based on startedAt timestamp @@ -507,8 +564,16 @@ def test_workflow_execution3( """ +@pytest.mark.parametrize( + "test_app", + [ + ({"AUTH_TYPE": "NOAUTH"}), + ], + indirect=["test_app"], +) def test_workflow_execution_with_disabled_workflow( db_session, + test_app, create_alert, workflow_manager, ): @@ -521,7 +586,7 @@ def test_workflow_execution_with_disabled_workflow( created_by="test@keephq.dev", interval=0, is_disabled=False, - workflow_raw=workflow_definition_for_enabled_disabled % enabled_id + workflow_raw=workflow_definition_for_enabled_disabled % enabled_id, ) disabled_id = "disabled-workflow" @@ -533,7 +598,7 @@ def test_workflow_execution_with_disabled_workflow( created_by="test@keephq.dev", interval=0, is_disabled=True, - workflow_raw=workflow_definition_for_enabled_disabled % disabled_id + workflow_raw=workflow_definition_for_enabled_disabled % disabled_id, ) db_session.add(enabled_workflow) @@ -561,7 +626,9 @@ def test_workflow_execution_with_disabled_workflow( disabled_workflow_execution = None count = 0 - while (enabled_workflow_execution is None and disabled_workflow_execution is None) and count < 30: + while ( + enabled_workflow_execution is None and disabled_workflow_execution is None + ) and count < 30: enabled_workflow_execution = get_last_workflow_execution_by_workflow_id( SINGLE_TENANT_UUID, enabled_id ) @@ -578,10 +645,9 @@ def test_workflow_execution_with_disabled_workflow( assert disabled_workflow_execution is None - workflow_definition_4 = """workflow: id: incident-triggers-test-created-updated -description: test incident triggers +description: test incident triggers triggers: - type: incident events: @@ -621,8 +687,16 @@ def test_workflow_execution_with_disabled_workflow( """ +@pytest.mark.parametrize( + "test_app", + [ + ({"AUTH_TYPE": "NOAUTH"}), + ], + indirect=["test_app"], +) def test_workflow_incident_triggers( db_session, + test_app, workflow_manager, ): workflow_created = Workflow( @@ -669,18 +743,26 @@ def wait_workflow_execution(workflow_id): workflow_manager.insert_incident(SINGLE_TENANT_UUID, incident, "created") assert len(workflow_manager.scheduler.workflows_to_run) == 1 - workflow_execution_created = wait_workflow_execution("incident-triggers-test-created-updated") + workflow_execution_created = wait_workflow_execution( + "incident-triggers-test-created-updated" + ) assert workflow_execution_created is not None assert workflow_execution_created.status == "success" - assert workflow_execution_created.results['mock-action'] == ['"incident: incident"\n'] + assert workflow_execution_created.results["mock-action"] == [ + '"incident: incident"\n' + ] assert len(workflow_manager.scheduler.workflows_to_run) == 0 workflow_manager.insert_incident(SINGLE_TENANT_UUID, incident, "updated") assert len(workflow_manager.scheduler.workflows_to_run) == 1 - workflow_execution_updated = wait_workflow_execution("incident-triggers-test-created-updated") + workflow_execution_updated = wait_workflow_execution( + "incident-triggers-test-created-updated" + ) assert workflow_execution_updated is not None assert workflow_execution_updated.status == "success" - assert workflow_execution_updated.results['mock-action'] == ['"incident: incident"\n'] + assert workflow_execution_updated.results["mock-action"] == [ + '"incident: incident"\n' + ] # incident-triggers-test-created-updated should not be triggered workflow_manager.insert_incident(SINGLE_TENANT_UUID, incident, "deleted") @@ -702,9 +784,13 @@ def wait_workflow_execution(workflow_id): assert len(workflow_manager.scheduler.workflows_to_run) == 1 # incident-triggers-test-deleted should be triggered now - workflow_execution_deleted = wait_workflow_execution("incident-triggers-test-deleted") + workflow_execution_deleted = wait_workflow_execution( + "incident-triggers-test-deleted" + ) assert len(workflow_manager.scheduler.workflows_to_run) == 0 assert workflow_execution_deleted is not None assert workflow_execution_deleted.status == "success" - assert workflow_execution_deleted.results['mock-action'] == ['"deleted incident: incident"\n'] \ No newline at end of file + assert workflow_execution_deleted.results["mock-action"] == [ + '"deleted incident: incident"\n' + ]