diff --git a/.github/workflows/test-pr-e2e.yml b/.github/workflows/test-pr-e2e.yml index 9d390af14..6638ac00d 100644 --- a/.github/workflows/test-pr-e2e.yml +++ b/.github/workflows/test-pr-e2e.yml @@ -22,6 +22,8 @@ env: POSTGRES_USER: keepuser POSTGRES_PASSWORD: keeppassword POSTGRES_DB: keepdb + # To test if imports are working properly + EE_ENABLED: true jobs: tests: diff --git a/.gitignore b/.gitignore index a4c1a8531..592e55b10 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,9 @@ __pycache__/ # C extensions *.so +# .csv files +*.csv + # Distribution / packaging .Python build/ diff --git a/docker/Dockerfile.api b/docker/Dockerfile.api index 488877461..291210d25 100644 --- a/docker/Dockerfile.api +++ b/docker/Dockerfile.api @@ -19,6 +19,7 @@ RUN python -m venv /venv COPY pyproject.toml poetry.lock ./ RUN poetry export -f requirements.txt --output requirements.txt --without-hashes && /venv/bin/python -m pip install --upgrade -r requirements.txt COPY keep keep +COPY ee keep/ee COPY examples examples COPY README.md README.md RUN poetry build && /venv/bin/pip install --use-deprecated=legacy-resolver dist/*.whl @@ -26,6 +27,7 @@ RUN poetry build && /venv/bin/pip install --use-deprecated=legacy-resolver dist/ FROM base as final ENV PATH="/venv/bin:${PATH}" ENV VIRTUAL_ENV="/venv" +ENV EE_PATH="ee" COPY --from=builder /venv /venv COPY --from=builder /app/examples /examples # as per Openshift guidelines, https://docs.openshift.com/container-platform/4.11/openshift_images/create-images.html#use-uid_create-images diff --git a/ee/experimental/graph_utils.py b/ee/experimental/graph_utils.py new file mode 100644 index 000000000..d8e058104 --- /dev/null +++ b/ee/experimental/graph_utils.py @@ -0,0 +1,99 @@ +import numpy as np +import networkx as nx + +from typing import List, Tuple + +from keep.api.core.db import get_pmi_values + + +def detect_knee_1d_auto_increasing(y: List[float]) -> Tuple[int, float]: + """ + This function detects the knee point in an increasing 1D curve. Knee point is the point where a curve + starts to flatten out (https://en.wikipedia.org/wiki/Knee_of_a_curve). + + Parameters: + y (List[float]): a list of float values + + Returns: + tuple: knee_index, knee_y + """ + + def detect_knee_1d(y: List[float], curve: str, direction: str = 'increasing') -> Tuple[int, float, List[float]]: + x = np.arange(len(y)) + + x_norm = (x - np.min(x)) / (np.max(x) - np.min(x)) + y_norm = (y - np.min(y)) / (np.max(y) - np.min(y)) + + diff_curve = y_norm - x_norm + + if curve == 'concave': + knee_index = np.argmax(diff_curve) + else: + knee_index = np.argmin(diff_curve) + + knee_y = y[knee_index] + + return knee_index, knee_y, diff_curve + + knee_index_concave, knee_y_concave, diff_curve_concave = detect_knee_1d(y, 'concave') + knee_index_convex, knee_y_convex, diff_curve_convex = detect_knee_1d(y, 'convex') + max_diff_concave = np.max(np.abs(diff_curve_concave)) + max_diff_convex = np.max(np.abs(diff_curve_convex)) + + if max_diff_concave > max_diff_convex: + return knee_index_concave, knee_y_concave + else: + return knee_index_convex, knee_y_convex + + +def create_graph(tenant_id: str, fingerprints: List[str], pmi_threshold: float = 0., knee_threshold: float = 0.8) -> nx.Graph: + """ + This function creates a graph from a list of fingerprints. The graph is created based on the PMI values between + the fingerprints. The edges are created between the fingerprints that have a PMI value greater than the threshold. + The nodes are removed if the knee point of the PMI values of the edges connected to the node is less than the threshold. + + Parameters: + tenant_id (str): tenant id + fingerprints (List[str]): a list of fingerprints + pmi_threshold (float): PMI threshold + knee_threshold (float): knee threshold + + Returns: + nx.Graph: a graph + """ + + graph = nx.Graph() + + if len(fingerprints) == 1: + graph.add_node(fingerprints[0]) + return graph + + # Load all PMI values at once + pmi_values = get_pmi_values(tenant_id, fingerprints) + + for idx_i, fingerprint_i in enumerate(fingerprints): + if not isinstance(pmi_values[(fingerprint_i, fingerprint_i)], float): + continue + + for idx_j in range(idx_i + 1, len(fingerprints)): + fingerprint_j = fingerprints[idx_j] + weight = pmi_values[(fingerprint_i, fingerprint_j)] + if not isinstance(weight, float): + continue + + if weight > pmi_threshold: + graph.add_edge(fingerprint_i, fingerprint_j, weight=weight) + + nodes_to_delete = [] + + for node in graph.nodes: + weights = sorted([edge['weight'] for edge in graph[node].values()]) + + knee_index, knee_statistic = detect_knee_1d_auto_increasing(weights) + + if knee_statistic < knee_threshold: + nodes_to_delete.append(node) + + graph.remove_nodes_from(nodes_to_delete) + + return graph \ No newline at end of file diff --git a/ee/experimental/incident_utils.py b/ee/experimental/incident_utils.py index 975e64ff2..66b734755 100644 --- a/ee/experimental/incident_utils.py +++ b/ee/experimental/incident_utils.py @@ -1,10 +1,197 @@ +import os +import logging + import numpy as np import pandas as pd import networkx as nx -from typing import List - -from keep.api.models.db.alert import Alert +from typing import List, Dict +from openai import OpenAI + +from datetime import datetime, timedelta + +from fastapi import Depends + +from ee.experimental.node_utils import NodeCandidateQueue, NodeCandidate +from ee.experimental.graph_utils import create_graph +from ee.experimental.statistical_utils import get_alert_pmi_matrix + +from pusher import Pusher + +from keep.api.models.db.alert import Alert, Incident +from keep.api.core.db import ( + assign_alert_to_incident, + is_alert_assigned_to_incident, + add_alerts_to_incident_by_incident_id, + get_last_alerts, + get_last_incidents, + get_incident_by_id, + write_pmi_matrix_to_db, + create_incident_from_dict, + update_incident_summary, +) + +from keep.api.core.dependencies import ( + AuthenticatedEntity, + AuthVerifier, + get_pusher_client, +) + +logger = logging.getLogger(__name__) + +ALGORITHM_VERBOSE_NAME = "Basic correlation algorithm v0.2" +USE_N_HISTORICAL_ALERTS = 10e10 +USE_N_HISTORICAL_INCIDENTS = 10e10 + + +def calculate_pmi_matrix( + ctx: dict | None, # arq context + tenant_id: str, + upper_timestamp: datetime = None, + use_n_historical_alerts: int = USE_N_HISTORICAL_ALERTS, + sliding_window: int = None, + stride: int = None, +) -> dict: + logger.info( + "Calculating PMI coefficients for alerts", + extra={ + "tenant_id": tenant_id, + }, + ) + + if not upper_timestamp: + upper_timestamp = os.environ.get('PMI_ALERT_UPPER_TIMESTAMP', datetime.now()) + + if not sliding_window: + sliding_window = os.environ.get('PMI_SLIDING_WINDOW', 4 * 60 * 60) + + if not stride: + stride = os.environ.get('PMI_STRIDE', 60 * 60) + + alerts=get_last_alerts(tenant_id, limit=use_n_historical_alerts, upper_timestamp=upper_timestamp) + pmi_matrix = get_alert_pmi_matrix(alerts, 'fingerprint', sliding_window, stride) + write_pmi_matrix_to_db(tenant_id, pmi_matrix) + + return {"status": "success"} + + +async def mine_incidents_and_create_objects( + ctx: dict | None, # arq context + tenant_id: str, + alert_lower_timestamp: datetime = None, + alert_upper_timestamp: datetime = None, + use_n_historical_alerts: int = USE_N_HISTORICAL_ALERTS, + incident_lower_timestamp: datetime = None, + incident_upper_timestamp: datetime = None, + use_n_hist_incidents: int = USE_N_HISTORICAL_INCIDENTS, + pmi_threshold: float = None, + knee_threshold: float = None, + min_incident_size: int = None, + incident_similarity_threshold: float = None, + ) -> Dict[str, List[Incident]]: + + """ + This function mines incidents from alerts and creates incidents in the database. + + Parameters: + tenant_id (str): tenant id + alert_lower_timestamp (datetime): lower timestamp for alerts + alert_upper_timestamp (datetime): upper timestamp for alerts + use_n_historical_alerts (int): number of historical alerts to use + incident_lower_timestamp (datetime): lower timestamp for incidents + incident_upper_timestamp (datetime): upper timestamp for incidents + use_n_hist_incidents (int): number of historical incidents to use + pmi_threshold (float): PMI threshold used for incident graph edges creation + knee_threshold (float): knee threshold used for incident graph nodes creation + min_incident_size (int): minimum incident size + incident_similarity_threshold (float): incident similarity threshold + + Returns: + Dict[str, List[Incident]]: a dictionary containing the created incidents + """ + + if not incident_upper_timestamp: + incident_upper_timestamp = os.environ.get('MINE_INCIDENT_UPPER_TIMESTAMP', datetime.now()) + + if not incident_lower_timestamp: + incident_validity = os.environ.get('MINE_INCIDENT_VALIDITY', timedelta(days=1)) + incident_lower_timestamp = incident_upper_timestamp - incident_validity + + if not alert_upper_timestamp: + alert_upper_timestamp = os.environ.get('MINE_ALERT_UPPER_TIMESTAMP', datetime.now()) + + if not alert_lower_timestamp: + alert_window = os.environ.get('MINE_ALERT_WINDOW', timedelta(hours=12)) + alert_lower_timestamp = alert_upper_timestamp - alert_window + + if not pmi_threshold: + pmi_threshold = os.environ.get('PMI_THRESHOLD', 0.0) + + if not knee_threshold: + knee_threshold = os.environ.get('KNEE_THRESHOLD', 0.8) + + if not min_incident_size: + min_incident_size = os.environ.get('MIN_INCIDENT_SIZE', 5) + + if not incident_similarity_threshold: + incident_similarity_threshold = os.environ.get('INCIDENT_SIMILARITY_THRESHOLD', 0.8) + + calculate_pmi_matrix(ctx, tenant_id) + + alerts = get_last_alerts(tenant_id, limit=use_n_historical_alerts, upper_timestamp=alert_upper_timestamp, lower_timestamp=alert_lower_timestamp) + incidents, _ = get_last_incidents(tenant_id, limit=use_n_hist_incidents, upper_timestamp=incident_upper_timestamp, lower_timestamp=incident_lower_timestamp) + nc_queue = NodeCandidateQueue() + + for candidate in [NodeCandidate(alert.fingerprint, alert.timestamp) for alert in alerts]: + nc_queue.push_candidate(candidate) + candidates = nc_queue.get_candidates() + + graph = create_graph(tenant_id, [candidate.fingerprint for candidate in candidates], pmi_threshold, knee_threshold) + ids = [] + + for component in nx.connected_components(graph): + if len(component) > min_incident_size: + alerts_appended = False + for incident in incidents: + incident_fingerprints = set([alert.fingerprint for alert in incident.Incident.alerts]) + intersection = incident_fingerprints.intersection(component) + + if len(intersection) / len(component) >= incident_similarity_threshold: + alerts_appended = True + + add_alerts_to_incident_by_incident_id(tenant_id, incident.Incident.id, [alert.id for alert in alerts if alert.fingerprint in component]) + + summary = generate_incident_summary(incident.Incident) + update_incident_summary(incident.Incident.id, summary) + + if not alerts_appended: + incident_start_time = min([alert.timestamp for alert in alerts if alert.fingerprint in component]) + incident_start_time = incident_start_time.replace(microsecond=0) + + incident = create_incident_from_dict(tenant_id, + {"name": f"Incident started at {incident_start_time}", + "description": "Summarization is Disabled", "is_predicted": True}) + ids.append(incident.id) + + add_alerts_to_incident_by_incident_id(tenant_id, incident.id, [alert.id for alert in alerts if alert.fingerprint in component]) + + summary = generate_incident_summary(incident) + update_incident_summary(incident.id, summary) + + pusher_client = get_pusher_client() + if pusher_client: + pusher_client.trigger( + f"private-{tenant_id}", + "ai-logs-change", + {"log": ALGORITHM_VERBOSE_NAME + " successfully executed."}, + ) + logger.info( + "Client notified on new AI log", + extra={"tenant_id": tenant_id}, + ) + + + return {"incidents": [get_incident_by_id(tenant_id, incident_id) for incident_id in ids]} def mine_incidents(alerts: List[Alert], incident_sliding_window_size: int=6*24*60*60, statistic_sliding_window_size: int=60*60, @@ -145,4 +332,56 @@ def shape_incidents(alerts: pd.DataFrame, unique_alert_identifier: str, incident 'alert_fingerprints': local_alerts[unique_alert_identifier].unique().tolist(), }) - return incidents \ No newline at end of file + return incidents + + +def generate_incident_summary(incident: Incident, use_n_alerts_for_summary: int = -1) -> str: + if "OPENAI_API_KEY" not in os.environ: + logger.error("OpenAI API key is not set. Incident summary generation is not available.") + return "Summarization is Disabled" + + try: + client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + + prompt_addition = '' + if incident.user_summary: + prompt_addition = f'When generating, you must rely on the summary provided by human: {incident.user_summary}' + + description_strings = np.unique([f'{alert.event["name"]}' for alert in incident.alerts]).tolist() + + if use_n_alerts_for_summary > 0: + incident_description = "\n".join(description_strings[:use_n_alerts_for_summary]) + else: + incident_description = "\n".join(description_strings) + + timestamps = [alert.timestamp for alert in incident.alerts] + incident_start = min(timestamps).replace(microsecond=0) + incident_end = max(timestamps).replace(microsecond=0) + + model = os.environ.get("OPENAI_MODEL", "gpt-4o-mini") + + summary = client.chat.completions.create(model=model, messages=[ + { + "role": "system", + "content": """You are a very skilled DevOps specialist who can summarize any incident based on alert descriptions. + When provided with information, summarize it in a 2-3 sentences explaining what happened and when. + ONLY SUMMARIZE WHAT YOU SEE. In the end add information about potential scenario of the incident. + + EXAMPLE: + An incident occurred between 2022-11-17 14:11:04.955070 and 2022-11-22 22:19:04.837526, involving a + total of 200 alerts. The alerts indicated critical and warning issues such as high CPU and memory + usage in pods and nodes, as well as stuck Kubernetes Daemonset rollout. Potential incident scenario: + Kubernetes Daemonset rollout stuck due to high CPU and memory usage in pods and nodes. This caused a + long tail of alerts on various topics.""" + }, + { + "role": "user", + "content": f"""Here are alerts of an incident for summarization:\n{incident_description}\n This incident started on + {incident_start}, ended on {incident_end}, included {len(description_strings)} alerts. {prompt_addition}""" + } + ]).choices[0].message.content + + return summary + except Exception as e: + logger.error(f"Error in generating incident summary: {e}") + return "Summarization is Disabled" \ No newline at end of file diff --git a/ee/experimental/node_utils.py b/ee/experimental/node_utils.py new file mode 100644 index 000000000..c61c98ace --- /dev/null +++ b/ee/experimental/node_utils.py @@ -0,0 +1,77 @@ +import heapq + +from datetime import timedelta, datetime +from typing import List, Dict, Tuple + + +class NodeCandidate: + def __init__(self, fingerpint: str, timestamps: datetime): + self.fingerprint = fingerpint + self.timestamps = set([timestamps]) + + @property + def first_timestamp(self): + if self.timestamps: + return min(self.timestamps) + return None + + @property + def last_timestamp(self): + if self.timestamps: + return max(self.timestamps) + return None + + def __lt__(self, other): + return self.last_timestamp < other.last_timestamp + + def __str__(self): + return f'NodeCandidate(fingerprint={self.fingerprint}, first_timestamp={self.first_timestamp}, last_timestamp={self.last_timestamp}, timestamps={self.timestamps})' + + +class NodeCandidateQueue: + def __init__(self, candidate_validity_window: int = None): + self.queue = [] + self.candidate_validity_window = candidate_validity_window + + def push_candidate(self, candidate: NodeCandidate): + for c in self.queue: + if c.fingerprint == candidate.fingerprint: + c.timestamps.update(candidate.timestamps) + heapq.heapify(self.queue) + return + heapq.heappush(self.queue, candidate) + + def push_candidates(self, candidates: List[NodeCandidate]): + for candidate in candidates: + self.push_candidate(candidate) + + def pop_invalid_candidates(self, current_timestamp: datetime): + # check incident-wise consistency + validity_threshold = current_timestamp - \ + timedelta(seconds=self.candidate_validity_window) + + while self.queue and self.queue[0].last_timestamp <= validity_threshold: + heapq.heappop(self.queue) + + for c in self.queue: + c.timestamps = { + ts for ts in c.timestamps if ts > validity_threshold} + heapq.heapify(self.queue) + + def get_candidates(self): + return self.queue + + def copy(self): + new_queue = NodeCandidateQueue(self.candidate_validity_window) + new_queue.queue = self.queue.copy() + return new_queue + + def __str__(self): + candidates_str = "\n".join(str(candidate) for candidate in self.queue) + return f'NodeCandidateQueue(\ncandidate_validity_window={self.candidate_validity_window}, \nqueue=[\n{candidates_str}\n])' + + def __iter__(self): + return iter(self.queue) + + def __len__(self): + return len(self.queue) diff --git a/ee/experimental/statistical_utils.py b/ee/experimental/statistical_utils.py new file mode 100644 index 000000000..0aa45477a --- /dev/null +++ b/ee/experimental/statistical_utils.py @@ -0,0 +1,127 @@ +import numpy as np +import pandas as pd + +from typing import List, Tuple + +def get_batched_alert_counts(alerts: pd.DataFrame, unique_alert_identifier: str, sliding_window_size: int, step_size: int) -> pd.DataFrame: + """ + This function calculates number of alerts per sliding window. + + Parameters: + alerts (pd.DataFrame): a DataFrame containing alerts + unique_alert_identifier (str): a unique identifier for alerts + sliding_window_size (int): sliding window size in seconds + step_size (int): step size in seconds + + Returns: + rolling_counts (pd.DataFrame): a DataFrame containing the number of alerts per sliding window + """ + + resampled_alert_counts = alerts.set_index('starts_at').resample(f'{step_size}s')[unique_alert_identifier].value_counts().unstack(fill_value=0) + rolling_counts = resampled_alert_counts.rolling(window=f'{sliding_window_size}s', min_periods=1).sum() + + return rolling_counts + + +def get_batched_alert_occurrences(alerts: pd.DataFrame, unique_alert_identifier: str, sliding_window_size: int, step_size: int) -> pd.DataFrame: + """ + This function calculates occurrences of alerts per sliding window. + + Parameters: + alerts (pd.DataFrame): a DataFrame containing alerts + unique_alert_identifier (str): a unique identifier for alerts + sliding_window_size (int): sliding window size in seconds + step_size (int): step size in seconds + + Returns: + alert_occurences (pd.DataFrame): a DataFrame containing the occurrences of alerts per sliding window + """ + + alert_counts = get_batched_alert_counts(alerts, unique_alert_identifier, sliding_window_size, step_size) + alert_occurences = pd.DataFrame(np.where(alert_counts > 0, 1, 0), index=alert_counts.index, columns=alert_counts.columns) + + return alert_occurences + +def get_jaccard_scores(P_a: np.array, P_aa: np.array) -> np.array: + """ + This function calculates the Jaccard similarity scores between recurring events. + + Parameters: + P_a (np.array): a 1D array containing the probabilities of events + P_aa (np.array): a 2D array containing the probabilities of joint events + + Returns: + jaccard_matrix (np.array): a 2D array containing the Jaccard similarity scores between events + """ + + P_a_matrix = P_a[:, None] + P_a + union_matrix = P_a_matrix - P_aa + + with np.errstate(divide='ignore', invalid='ignore'): + jaccard_matrix = np.where(union_matrix != 0, P_aa / union_matrix, 0) + + np.fill_diagonal(jaccard_matrix, 1) + + return jaccard_matrix + + +def get_alert_jaccard_matrix(alerts: pd.DataFrame, unique_alert_identifier: str, sliding_window_size: int, step_size: int) -> pd.DataFrame: + """ + This function calculates Jaccard similarity scores between alert groups (fingerprints). + + Parameters: + alerts (pd.DataFrame): a DataFrame containing alerts + unique_alert_identifier (str): a unique identifier for alerts + sliding_window_size (int): sliding window size in seconds + step_size (int): step size in seconds + + Returns: + jaccard_scores_df (pd.DataFrame): a DataFrame containing the Jaccard similarity scores between alert groups + """ + + alert_occurrences_df = get_batched_alert_occurrences(alerts, unique_alert_identifier, sliding_window_size, step_size) + alert_occurrences = alert_occurrences_df.to_numpy() + + alert_probabilities = np.mean(alert_occurrences, axis=0) + joint_alert_occurrences = np.dot(alert_occurrences.T, alert_occurrences) + pairwise_alert_probabilities = joint_alert_occurrences / alert_occurrences.shape[0] + + jaccard_scores = get_jaccard_scores(alert_probabilities, pairwise_alert_probabilities) + jaccard_scores_df = pd.DataFrame(jaccard_scores, index=alert_occurrences_df.columns, columns=alert_occurrences_df.columns) + + return jaccard_scores_df + + +def get_alert_pmi_matrix(alerts: pd.DataFrame, unique_alert_identifier: str, sliding_window_size: int, step_size: int) -> pd.DataFrame: + """ + This funciton calculates PMI scores between alert groups (fingerprints). + + Parameters: + alerts (pd.DataFrame): a DataFrame containing alerts + unique_alert_identifier (str): a unique identifier for alerts + sliding_window_size (int): sliding window size in seconds + step_size (int): step size in seconds + + Returns: + pmi_matrix_df (pd.DataFrame): a DataFrame containing the PMI scores between + """ + + alert_dict = { + 'fingerprint': [alert.fingerprint for alert in alerts], + 'starts_at': [alert.timestamp for alert in alerts], + } + + alert_df = pd.DataFrame(alert_dict) + alert_occurences_df = get_batched_alert_occurrences(alert_df, unique_alert_identifier, sliding_window_size, step_size) + alert_occurrences = alert_occurences_df.to_numpy() + alert_probabilities = np.mean(alert_occurrences, axis=0) + joint_alert_occurrences = np.dot(alert_occurrences.T, alert_occurrences) + pairwise_alert_probabilities = joint_alert_occurrences / alert_occurrences.shape[0] + + pmi_matrix = np.log(pairwise_alert_probabilities / (alert_probabilities[:, None] * alert_probabilities)) + pmi_matrix[np.isnan(pmi_matrix)] = 0 + np.fill_diagonal(pmi_matrix, 0) + + pmi_matrix_df = pd.DataFrame(pmi_matrix, index=alert_occurences_df.columns, columns=alert_occurences_df.columns) + + return pmi_matrix_df \ No newline at end of file diff --git a/keep-ui/app/ai/ai.tsx b/keep-ui/app/ai/ai.tsx index 632facaab..f0006e6d2 100644 --- a/keep-ui/app/ai/ai.tsx +++ b/keep-ui/app/ai/ai.tsx @@ -1,19 +1,26 @@ "use client"; import { Card, List, ListItem, Title, Subtitle } from "@tremor/react"; -import { useAIStats } from "utils/hooks/useAIStats"; +import { useAIStats, usePollAILogs } from "utils/hooks/useAI"; import { useSession } from "next-auth/react"; import { getApiURL } from "utils/apiUrl"; import { toast } from "react-toastify"; import { useEffect, useState, useRef, FormEvent } from "react"; +import { AILogs } from "./model"; export default function Ai() { const { data: aistats, isLoading } = useAIStats(); const { data: session } = useSession(); const [text, setText] = useState(""); + const [basicAlgorithmLog, setBasicAlgorithmLog] = useState(""); const [newText, setNewText] = useState("Mine incidents"); const [animate, setAnimate] = useState(false); const onlyOnce = useRef(false); + const mutateAILogs = (logs: AILogs) => { + setBasicAlgorithmLog(logs.log); + }; + usePollAILogs(mutateAILogs); + useEffect(() => { let index = 0; @@ -42,14 +49,14 @@ export default function Ai() { Authorization: `Bearer ${session?.accessToken}`, "Content-Type": "application/json", }, - body: JSON.stringify({ - }), + body: JSON.stringify({}), }); if (!response.ok) { toast.error( "Failed to mine incidents, please contact us if this issue persists." ); } + setAnimate(false); setNewText("Mine incidents"); }; @@ -68,7 +75,8 @@ export default function Ai() {
👋 You are almost there!
- AI Correlation is coming soon. Make sure you have enough data collected to prepare. + AI Correlation is coming soon. Make sure you have enough data + collected to prepare.
@@ -98,7 +106,9 @@ export default function Ai() { Collect alerts for more than 3 days - {aistats?.first_alert_datetime && new Date(aistats.first_alert_datetime) < new Date(Date.now() - 3 * 24 * 60 * 60 * 1000) ? ( + {aistats?.first_alert_datetime && + new Date(aistats.first_alert_datetime) < + new Date(Date.now() - 3 * 24 * 60 * 60 * 1000) ? (
) : (
@@ -107,41 +117,84 @@ export default function Ai() {
- {(aistats?.is_mining_enabled && +
-
{text}
- )} + )}
diff --git a/keep-ui/app/ai/model.ts b/keep-ui/app/ai/model.ts index a0d51d359..3d78cbb9f 100644 --- a/keep-ui/app/ai/model.ts +++ b/keep-ui/app/ai/model.ts @@ -3,4 +3,9 @@ export interface AIStats { incidents_count: number; first_alert_datetime?: Date; is_mining_enabled: boolean; + algorithm_verbose_name: string } + +export interface AILogs { + log: string; +} \ No newline at end of file diff --git a/keep-ui/app/incidents/[id]/incident-info.tsx b/keep-ui/app/incidents/[id]/incident-info.tsx index 35c1a2877..8d46dac2a 100644 --- a/keep-ui/app/incidents/[id]/incident-info.tsx +++ b/keep-ui/app/incidents/[id]/incident-info.tsx @@ -1,3 +1,4 @@ + import {Button, Title} from "@tremor/react"; import { IncidentDto } from "../model"; import CreateOrUpdateIncident from "../create-or-update-incident"; diff --git a/keep-ui/app/workflows/mockworkflows.tsx b/keep-ui/app/workflows/mockworkflows.tsx index 2cb7e572b..8a62b6c45 100644 --- a/keep-ui/app/workflows/mockworkflows.tsx +++ b/keep-ui/app/workflows/mockworkflows.tsx @@ -151,7 +151,7 @@ export default function MockWorkflowCardSection({ return (
diff --git a/keep-ui/components/navbar/AILink.tsx b/keep-ui/components/navbar/AILink.tsx index 9c05fd986..e28ee5c7c 100644 --- a/keep-ui/components/navbar/AILink.tsx +++ b/keep-ui/components/navbar/AILink.tsx @@ -5,11 +5,17 @@ import { LinkWithIcon } from "components/LinkWithIcon"; import { RiSparkling2Line } from "react-icons/ri"; import { useEffect, useState } from "react"; +import { usePollAILogs } from "utils/hooks/useAI"; export const AILink = () => { const [text, setText] = useState(""); const [newText, setNewText] = useState("AI correlation"); + const mutateAILogs = (logs: any) => { + setNewText("AI iterated 🎉") + } + + usePollAILogs(mutateAILogs); useEffect(() => { let index = 0; @@ -31,7 +37,7 @@ export const AILink = () => { return (
- + {text}
diff --git a/keep-ui/components/navbar/IncidentLinks.tsx b/keep-ui/components/navbar/IncidentLinks.tsx index 74c83cac0..15a310d76 100644 --- a/keep-ui/components/navbar/IncidentLinks.tsx +++ b/keep-ui/components/navbar/IncidentLinks.tsx @@ -62,6 +62,7 @@ export const IncidentsLinks = ({ session }: IncidentsLinksProps) => { "bg-gray-200": currentPath === `/incidents/${incident.id}`, })} > + {incident.name}
diff --git a/keep-ui/utils/hooks/useAI.ts b/keep-ui/utils/hooks/useAI.ts new file mode 100644 index 000000000..57af67d73 --- /dev/null +++ b/keep-ui/utils/hooks/useAI.ts @@ -0,0 +1,41 @@ +import { AILogs, AIStats } from "app/ai/model"; +import { useSession } from "next-auth/react"; +import useSWR, { SWRConfiguration } from "swr"; +import { getApiURL } from "utils/apiUrl"; +import { fetcher } from "utils/fetcher"; + +import { useWebsocket } from "./usePusher"; +import { useCallback, useEffect } from "react"; + + +export const useAIStats = ( + options: SWRConfiguration = { + revalidateOnFocus: false, + } +) => { + const apiUrl = getApiURL(); + const { data: session } = useSession(); + + return useSWR( + () => (session ? `${apiUrl}/ai/stats` : null), + (url) => fetcher(url, session?.accessToken), + options + ); +}; + +export const usePollAILogs = (mutateAILogs: (logs: AILogs) => void) => { + const { bind, unbind } = useWebsocket(); + const handleIncoming = useCallback( + (data: AILogs) => { + mutateAILogs(data); + }, + [mutateAILogs] + ); + + useEffect(() => { + bind("ai-logs-change", handleIncoming); + return () => { + unbind("ai-logs-change", handleIncoming); + }; + }, [bind, unbind, handleIncoming]); +}; \ No newline at end of file diff --git a/keep-ui/utils/hooks/useAIStats.ts b/keep-ui/utils/hooks/useAIStats.ts deleted file mode 100644 index 514786f42..000000000 --- a/keep-ui/utils/hooks/useAIStats.ts +++ /dev/null @@ -1,20 +0,0 @@ -import { AIStats } from "app/ai/model"; -import { useSession } from "next-auth/react"; -import useSWR, { SWRConfiguration } from "swr"; -import { getApiURL } from "utils/apiUrl"; -import { fetcher } from "utils/fetcher"; - -export const useAIStats = ( - options: SWRConfiguration = { - revalidateOnFocus: false, - } -) => { - const apiUrl = getApiURL(); - const { data: session } = useSession(); - - return useSWR( - () => (session ? `${apiUrl}/ai/stats` : null), - (url) => fetcher(url, session?.accessToken), - options - ); -}; diff --git a/keep/api/arq_worker.py b/keep/api/arq_worker.py index f62a826d4..7519b0b3e 100644 --- a/keep/api/arq_worker.py +++ b/keep/api/arq_worker.py @@ -10,6 +10,7 @@ # internals from keep.api.core.config import config +from keep.api.tasks.process_background_ai_task import process_background_ai_task from keep.api.tasks.healthcheck_task import healthcheck_task ARQ_BACKGROUND_FUNCTIONS: Optional[CommaSeparatedStrings] = config( @@ -18,6 +19,8 @@ default=[ "keep.api.tasks.process_event_task.async_process_event", "keep.api.tasks.process_topology_task.async_process_topology", + "keep.api.tasks.process_background_ai_task.process_background_ai_task", + "keep.api.tasks.process_background_ai_task.process_correlation", "keep.api.tasks.healthcheck_task.healthcheck_task", ], ) @@ -57,14 +60,15 @@ def get_worker() -> Worker: "ARQ_KEEP_RESULT", cast=int, default=3600 ) # duration to keep job results for expires = config( - "ARQ_EXPIRES", cast=int, default=86_400_000 - ) # the default length of time from when a job is expected to start after which the job expires, defaults to 1 day in ms + "ARQ_EXPIRES", cast=int, default=3600 + ) # the default length of time from when a job is expected to start after which the job expires, making it shorter to avoid clogging return create_worker( WorkerSettings, keep_result=keep_result, expires_extra_ms=expires ) def at_every_x_minutes(x: int, start: int = 0, end: int = 59): return {*list(range(start, end, x))} + class WorkerSettings: """ Settings for the ARQ worker. @@ -83,6 +87,7 @@ class WorkerSettings: ) functions: list = FUNCTIONS cron_jobs = [ + cron( healthcheck_task, minute=at_every_x_minutes(1), @@ -91,4 +96,12 @@ class WorkerSettings: max_tries=1, run_at_startup=True, ), + cron( + process_background_ai_task, + minute=at_every_x_minutes(1), + unique=True, + timeout=30, + max_tries=1, + run_at_startup=True, + ), ] diff --git a/keep/api/core/db.py b/keep/api/core/db.py index 352fa0563..15736cc42 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -9,6 +9,9 @@ import logging import random import uuid + +import pandas as pd + from datetime import datetime, timedelta, timezone from typing import Any, Dict, List, Tuple, Union from uuid import uuid4 @@ -39,6 +42,7 @@ from keep.api.models.db.tenant import * # pylint: disable=unused-wildcard-import from keep.api.models.db.topology import * # pylint: disable=unused-wildcard-import from keep.api.models.db.workflow import * # pylint: disable=unused-wildcard-import +from keep.api.models.db.statistics import * # pylint: disable=unused-wildcard-import logger = logging.getLogger(__name__) @@ -922,7 +926,7 @@ def get_alerts_with_filters( def get_last_alerts( - tenant_id, provider_id=None, limit=1000, timeframe=None + tenant_id, provider_id=None, limit=1000, timeframe=None, upper_timestamp=None, lower_timestamp=None ) -> list[Alert]: """ Get the last alert for each fingerprint along with the first time the alert was triggered. @@ -958,6 +962,22 @@ def get_last_alerts( ) .subquery() ) + + filter_conditions = [] + + if upper_timestamp is not None: + filter_conditions.append(subquery.c.max_timestamp < upper_timestamp) + + if lower_timestamp is not None: + filter_conditions.append(subquery.c.max_timestamp >= lower_timestamp) + + # Apply the filter conditions + if filter_conditions: + subquery = ( + session.query(subquery) + .filter(*filter_conditions) # Unpack and apply all conditions + .subquery() + ) # Main query joins the subquery to select alerts with their first and last occurrence. query = ( session.query( @@ -1846,6 +1866,16 @@ def update_preset_options(tenant_id: str, preset_id: str, options: dict) -> Pres def assign_alert_to_incident(alert_id: UUID, incident_id: UUID, tenant_id: str): return add_alerts_to_incident_by_incident_id(tenant_id, incident_id, [alert_id]) +def is_alert_assigned_to_incident(alert_id: UUID, incident_id: UUID, tenant_id: str) -> bool: + with Session(engine) as session: + assigned = session.exec( + select(AlertToIncident) + .where(AlertToIncident.alert_id == alert_id) + .where(AlertToIncident.incident_id == incident_id) + .where(AlertToIncident.tenant_id == tenant_id) + ).first() + return assigned is not None + def get_incidents(tenant_id) -> List[Incident]: with Session(engine) as session: @@ -1931,8 +1961,10 @@ def get_last_incidents( limit: int = 25, offset: int = 0, timeframe: int = None, + upper_timestamp: datetime = None, + lower_timestamp: datetime = None, is_confirmed: bool = False, -) -> (list[Incident], int): +) -> Tuple[list[Incident], int]: """ Get the last incidents and total amount of incidents. @@ -1947,12 +1979,25 @@ def get_last_incidents( List[Incident]: A list of Incident objects. """ with Session(engine) as session: + subquery = ( + select( + AlertToIncident.incident_id, + func.max(Alert.timestamp).label('last_updated_time') + ) + .join(Alert, Alert.id == AlertToIncident.alert_id) + .group_by(AlertToIncident.incident_id) + .subquery() + ) + query = ( session.query( Incident, + subquery.c.last_updated_time ) + .join(subquery, subquery.c.incident_id == Incident.id) .filter(Incident.tenant_id == tenant_id) .filter(Incident.is_confirmed == is_confirmed) + .options(joinedload(Incident.alerts)) .order_by(desc(Incident.creation_time)) ) @@ -1962,9 +2007,22 @@ def get_last_incidents( >= datetime.now(tz=timezone.utc) - timedelta(days=timeframe) ) + if upper_timestamp and lower_timestamp: + query = query.filter( + subquery.c.last_updated_time.between(lower_timestamp, upper_timestamp) + ) + elif upper_timestamp: + query = query.filter( + subquery.c.last_updated_time <= upper_timestamp + ) + elif lower_timestamp: + query = query.filter( + subquery.c.last_updated_time >= lower_timestamp + ) + total_count = query.count() - # Order by timestamp in descending order and limit the results + # Order by start_time in descending order and limit the results query = query.order_by(desc(Incident.start_time)).limit(limit).offset(offset) # Execute the query incidents = query.all() @@ -2029,7 +2087,7 @@ def update_incident_from_dto_by_id( ).update( { "name": updated_incident_dto.name, - "description": updated_incident_dto.description, + "user_summary": updated_incident_dto.user_summary, "assignee": updated_incident_dto.assignee, } ) @@ -2344,6 +2402,52 @@ def confirm_predicted_incident_by_id( session.refresh(incident) return incident + + +def write_pmi_matrix_to_db(tenant_id: str, pmi_matrix_df: pd.DataFrame) -> bool: + # TODO: add handlers for sequential launches + with Session(engine) as session: + for fingerprint_i in pmi_matrix_df.index: + for fingerprint_j in pmi_matrix_df.columns: + pmi = pmi_matrix_df.at[fingerprint_i, fingerprint_j] + + pmi_entry = PMIMatrix( + tenant_id=tenant_id, + fingerprint_i=fingerprint_i, + fingerprint_j=fingerprint_j, + pmi=pmi + ) + session.merge(pmi_entry) + + session.commit() + + return True + +def get_pmi_value(tenant_id: str, fingerprint_i: str, fingerprint_j: str) -> Optional[float]: + with Session(engine) as session: + pmi_entry = session.exec( + select(PMIMatrix) + .where(PMIMatrix.tenant_id == tenant_id) + .where(PMIMatrix.fingerprint_i == fingerprint_i) + .where(PMIMatrix.fingerprint_j == fingerprint_j) + ).first() + + return pmi_entry.pmi if pmi_entry else None + +def get_pmi_values(tenant_id: str, fingerprints: List[str]) -> Dict[Tuple[str, str], Optional[float]]: + pmi_values = {} + with Session(engine) as session: + for idx_i, fingerprint_i in enumerate(fingerprints): + for idx_j in range(idx_i, len(fingerprints)): + fingerprint_j = fingerprints[idx_j] + pmi_entry = session.exec( + select(PMIMatrix) + .where(PMIMatrix.tenant_id == tenant_id) + .where(PMIMatrix.fingerprint_i == fingerprint_i) + .where(PMIMatrix.fingerprint_j == fingerprint_j) + ).first() + pmi_values[(fingerprint_i, fingerprint_j)] = pmi_entry.pmi if pmi_entry else None + return pmi_values def get_alert_firing_time(tenant_id: str, fingerprint: str) -> timedelta: @@ -2408,6 +2512,21 @@ def get_alert_firing_time(tenant_id: str, fingerprint: str) -> timedelta: tzinfo=timezone.utc ) +def update_incident_summary(incident_id: UUID, summary: str) -> Incident: + with Session(engine) as session: + incident = session.exec( + select(Incident) + .where(Incident.id == incident_id) + ).first() + + if not incident: + return None + + incident.generated_summary = summary + session.commit() + session.refresh(incident) + + return incident # Fetch all topology data def get_all_topology_data( diff --git a/keep/api/core/db_on_start.py b/keep/api/core/db_on_start.py index 8a2f69c67..cff23b493 100644 --- a/keep/api/core/db_on_start.py +++ b/keep/api/core/db_on_start.py @@ -35,6 +35,7 @@ from keep.api.models.db.rule import * # pylint: disable=unused-wildcard-import from keep.api.models.db.tenant import * # pylint: disable=unused-wildcard-import from keep.api.models.db.workflow import * # pylint: disable=unused-wildcard-import +from keep.api.models.db.statistics import * # pylint: disable=unused-wildcard-import logger = logging.getLogger(__name__) diff --git a/keep/api/models/alert.py b/keep/api/models/alert.py index 69b4a5820..1d80b7844 100644 --- a/keep/api/models/alert.py +++ b/keep/api/models/alert.py @@ -344,7 +344,7 @@ class Config: { "id": "c2509cb3-6168-4347-b83b-a41da9df2d5b", "name": "Incident name", - "description": "Keep: Incident description", + "user_summary": "Keep: Incident description", } ] } @@ -387,7 +387,8 @@ def from_db_incident(cls, db_incident): return cls( id=db_incident.id, name=db_incident.name, - description=db_incident.description, + user_summary=db_incident.user_summary, + generated_summary=db_incident.generated_summary, is_predicted=db_incident.is_predicted, is_confirmed=db_incident.is_confirmed, creation_time=db_incident.creation_time, diff --git a/keep/api/models/db/alert.py b/keep/api/models/db/alert.py index e97be8965..08e037959 100644 --- a/keep/api/models/db/alert.py +++ b/keep/api/models/db/alert.py @@ -99,7 +99,9 @@ class Incident(SQLModel, table=True): tenant_id: str = Field(foreign_key="tenant.id") tenant: Tenant = Relationship() name: str - description: str + + user_summary: str | None + generated_summary: str | None assignee: str | None diff --git a/keep/api/models/db/migrations/env.py b/keep/api/models/db/migrations/env.py index 149f267fd..61c58f9ab 100644 --- a/keep/api/models/db/migrations/env.py +++ b/keep/api/models/db/migrations/env.py @@ -19,6 +19,7 @@ from keep.api.models.db.topology import * from keep.api.models.db.user import * from keep.api.models.db.workflow import * +from keep.api.models.db.statistics import * target_metadata = SQLModel.metadata diff --git a/keep/api/models/db/migrations/versions/2024-07-24-13-39_9ba0aeecd4d0.py b/keep/api/models/db/migrations/versions/2024-07-24-13-39_9ba0aeecd4d0.py new file mode 100644 index 000000000..4557c943e --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-07-24-13-39_9ba0aeecd4d0.py @@ -0,0 +1,40 @@ +"""For AI + +Revision ID: 9ba0aeecd4d0 +Revises: dcbd2873dcfd +Create Date: 2024-07-24 13:39:10.576538 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "9ba0aeecd4d0" +down_revision = "dcbd2873dcfd" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "pmimatrix", + sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("fingerprint_i", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("fingerprint_j", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("pmi", sa.Float(), nullable=False), + sa.ForeignKeyConstraint( + ["tenant_id"], + ["tenant.id"], + ), + sa.PrimaryKeyConstraint("fingerprint_i", "fingerprint_j"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("pmimatrix") + # ### end Alembic commands ### diff --git a/keep/api/models/db/migrations/versions/2024-07-25-17-13_67f1efb93c99.py b/keep/api/models/db/migrations/versions/2024-07-25-17-13_67f1efb93c99.py index baadaf919..7bae04f45 100644 --- a/keep/api/models/db/migrations/versions/2024-07-25-17-13_67f1efb93c99.py +++ b/keep/api/models/db/migrations/versions/2024-07-25-17-13_67f1efb93c99.py @@ -9,9 +9,8 @@ import sqlalchemy as sa from alembic import op from pydantic import BaseModel -from sqlalchemy.orm import Session, joinedload - -from keep.api.models.db.alert import Incident +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Session # revision identifiers, used by Alembic. revision = "67f1efb93c99" @@ -19,39 +18,76 @@ branch_labels = None depends_on = None +# Define a completely separate metadata for the migration +migration_metadata = sa.MetaData() + +# Direct table definition for AlertToIncident +alert_to_incident_table = sa.Table( + 'alerttoincident', + migration_metadata, + sa.Column('alert_id', UUID(as_uuid=False), sa.ForeignKey('alert.id', ondelete='CASCADE'), primary_key=True), + sa.Column('incident_id', UUID(as_uuid=False), sa.ForeignKey('incident.id', ondelete='CASCADE'), primary_key=True) +) + +# Direct table definition for Incident +incident_table = sa.Table( + 'incident', + migration_metadata, + sa.Column('id', UUID(as_uuid=False), primary_key=True), + sa.Column('alerts_count', sa.Integer, default=0), + sa.Column('affected_services', sa.JSON, default_factory=list), + sa.Column('sources', sa.JSON, default_factory=list) +) + +# Direct table definition for Alert +alert_table = sa.Table( + 'alert', + migration_metadata, + sa.Column('id', UUID(as_uuid=False), primary_key=True), + sa.Column('provider_type', sa.String), + sa.Column('event', sa.JSON) +) + class AlertDtoLocal(BaseModel): service: str | None = None source: list[str] | None = [] -def populate_db(session): +def populate_db(): + session = Session(op.get_bind()) - incidents = session.query(Incident).options(joinedload(Incident.alerts)).all() + incidents = session.execute(sa.select(incident_table)).fetchall() for incident in incidents: - alerts_dto = [AlertDtoLocal(**alert.event) for alert in incident.alerts] - - incident.sources = list( - set([source for alert_dto in alerts_dto for source in alert_dto.source]) + stmt = ( + sa.select(alert_table).select_from(alert_table) + .join(alert_to_incident_table, alert_table.c.id == alert_to_incident_table.c.alert_id) + .where(alert_to_incident_table.c.incident_id == str(incident.id)) ) - incident.affected_services = list( - set([alert.service for alert in alerts_dto if alert.service is not None]) + + alerts = session.execute(stmt).all() + alerts_dto = [AlertDtoLocal(**alert.event) for alert in alerts] + + stmt = ( + sa.update(incident_table).where(incident_table.c.id == incident.id).values( + sources=list(set([source for alert_dto in alerts_dto for source in alert_dto.source])), + affected_services=list(set([alert.service for alert in alerts_dto if alert.service is not None])), + alerts_count=len(alerts) + ) ) - incident.alerts_count = len(incident.alerts) - session.add(incident) + session.execute(stmt) session.commit() def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### op.add_column("incident", sa.Column("affected_services", sa.JSON(), nullable=True)) op.add_column("incident", sa.Column("sources", sa.JSON(), nullable=True)) op.add_column("incident", sa.Column("alerts_count", sa.Integer(), nullable=False, server_default="0")) - session = Session(op.get_bind()) - populate_db(session) - + populate_db() # ### end Alembic commands ### diff --git a/keep/api/models/db/migrations/versions/2024-07-28-16-24_8e5942040de6.py b/keep/api/models/db/migrations/versions/2024-07-28-16-24_8e5942040de6.py new file mode 100644 index 000000000..fd0c2dfbb --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-07-28-16-24_8e5942040de6.py @@ -0,0 +1,39 @@ +"""Summaries added + +Revision ID: 8e5942040de6 +Revises: 9ba0aeecd4d0 +Create Date: 2024-07-28 16:24:58.364281 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8e5942040de6" +down_revision = "9ba0aeecd4d0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "incident", + sa.Column("user_summary", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + ) + op.add_column( + "incident", + sa.Column( + "generated_summary", sqlmodel.sql.sqltypes.AutoString(), nullable=True + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("incident", "generated_summary") + op.drop_column("incident", "user_summary") + # ### end Alembic commands ### diff --git a/keep/api/models/db/migrations/versions/2024-07-29-12-51_c91b348b94f2.py b/keep/api/models/db/migrations/versions/2024-07-29-12-51_c91b348b94f2.py new file mode 100644 index 000000000..181119744 --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-07-29-12-51_c91b348b94f2.py @@ -0,0 +1,63 @@ +"""Description replaced w/ user_summary + +Revision ID: c91b348b94f2 +Revises: 8e5942040de6 +Create Date: 2024-07-29 12:51:24.496126 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Session + +# revision identifiers, used by Alembic. +revision = "c91b348b94f2" +down_revision = "8e5942040de6" +branch_labels = None +depends_on = None + + +# Define a completely separate metadata for the migration +migration_metadata = sa.MetaData() + +# Direct table definition for Incident +incident_table = sa.Table( + 'incident', + migration_metadata, + sa.Column('id', UUID(as_uuid=False), primary_key=True), + sa.Column('description', sa.String), + sa.Column('user_summary', sa.String), +) + + +def populate_db(session): + # we need to populate the user_summary field with the description + session.execute(sa.update(incident_table).values(user_summary=incident_table.c.description)) + session.commit() + + +def depopulate_db(session): + # we need to populate the description field with the user_summary + session.execute(sa.update(incident_table).values(description=incident_table.c.user_summary)) + session.commit() + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + + session = Session(op.get_bind()) + populate_db(session) + + op.drop_column("incident", "description") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("incident", sa.Column("description", sa.VARCHAR(), nullable=False, default="", server_default="")) + + session = Session(op.get_bind()) + depopulate_db(session) + + # ### end Alembic commands ### diff --git a/keep/api/models/db/migrations/versions/2024-08-09-10-53_6e353161f5a8.py b/keep/api/models/db/migrations/versions/2024-08-09-10-53_6e353161f5a8.py new file mode 100644 index 000000000..9a977d509 --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-08-09-10-53_6e353161f5a8.py @@ -0,0 +1,21 @@ +"""Merge + +Revision ID: 6e353161f5a8 +Revises: c91b348b94f2, 42098785763c +Create Date: 2024-08-09 10:53:33.363763 + +""" + +# revision identifiers, used by Alembic. +revision = "6e353161f5a8" +down_revision = ("c91b348b94f2", "42098785763c") +branch_labels = None +depends_on = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/keep/api/models/db/migrations/versions/2024-08-11-19-45_005efc57cc1c.py b/keep/api/models/db/migrations/versions/2024-08-11-19-45_005efc57cc1c.py new file mode 100644 index 000000000..efeb9d42e --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-08-11-19-45_005efc57cc1c.py @@ -0,0 +1,21 @@ +"""empty message + +Revision ID: 005efc57cc1c +Revises: 9453855f3ba0, 6e353161f5a8 +Create Date: 2024-08-11 19:45:08.308034 + +""" + +# revision identifiers, used by Alembic. +revision = "005efc57cc1c" +down_revision = ("9453855f3ba0", "6e353161f5a8") +branch_labels = None +depends_on = None + + +def upgrade() -> None: + pass + + +def downgrade() -> None: + pass diff --git a/keep/api/models/db/statistics.py b/keep/api/models/db/statistics.py new file mode 100644 index 000000000..780183850 --- /dev/null +++ b/keep/api/models/db/statistics.py @@ -0,0 +1,9 @@ +from sqlmodel import Field, SQLModel + + +class PMIMatrix(SQLModel, table=True): + tenant_id: str = Field(foreign_key="tenant.id") + fingerprint_i: str = Field(primary_key=True) + fingerprint_j: str = Field(primary_key=True) + pmi: float + \ No newline at end of file diff --git a/keep/api/routes/ai.py b/keep/api/routes/ai.py index 880128cc8..f7be43f67 100644 --- a/keep/api/routes/ai.py +++ b/keep/api/routes/ai.py @@ -8,6 +8,7 @@ from keep.api.core.dependencies import AuthenticatedEntity, AuthVerifier from keep.api.core.db import get_incidents_count, get_alerts_count, get_first_alert_datetime +from keep.api.utils.import_ee import ALGORITHM_VERBOSE_NAME router = APIRouter() @@ -26,5 +27,6 @@ def get_stats( "alerts_count": get_alerts_count(tenant_id), "first_alert_datetime": get_first_alert_datetime(tenant_id), "incidents_count": get_incidents_count(tenant_id), - "is_mining_enabled": os.environ.get("EE_ENABLED", "false") == "true" + "is_mining_enabled": os.environ.get("EE_ENABLED", "false") == "true", + "algorithm_verbose_name": str(ALGORITHM_VERBOSE_NAME) } diff --git a/keep/api/routes/alerts.py b/keep/api/routes/alerts.py index 27044864b..dfcb08664 100644 --- a/keep/api/routes/alerts.py +++ b/keep/api/routes/alerts.py @@ -52,6 +52,7 @@ ) def get_all_alerts( authenticated_entity: AuthenticatedEntity = Depends(AuthVerifier(["read:alert"])), + limit: int = 1000, ) -> list[AlertDto]: tenant_id = authenticated_entity.tenant_id logger.info( @@ -60,7 +61,7 @@ def get_all_alerts( "tenant_id": tenant_id, }, ) - db_alerts = get_last_alerts(tenant_id=tenant_id) + db_alerts = get_last_alerts(tenant_id=tenant_id, limit=limit) enriched_alerts_dto = convert_db_alerts_to_dto_alerts(db_alerts) logger.info( "Fetched alerts from DB", diff --git a/keep/api/routes/incidents.py b/keep/api/routes/incidents.py index 44bb56e37..ca094442d 100644 --- a/keep/api/routes/incidents.py +++ b/keep/api/routes/incidents.py @@ -2,7 +2,10 @@ import os import pathlib import sys +import asyncio + from typing import List +from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, Response from pusher import Pusher @@ -10,14 +13,11 @@ from keep.api.core.db import ( add_alerts_to_incident_by_incident_id, - assign_alert_to_incident, confirm_predicted_incident_by_id, - create_incident_from_dict, create_incident_from_dto, delete_incident_by_id, get_incident_alerts_by_incident_id, get_incident_by_id, - get_last_alerts, get_last_incidents, remove_alerts_to_incident_by_incident_id, update_incident_from_dto_by_id, @@ -29,6 +29,7 @@ ) from keep.api.models.alert import AlertDto, IncidentDto, IncidentDtoIn from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts +from keep.api.utils.import_ee import mine_incidents_and_create_objects from keep.api.utils.pagination import IncidentsPaginatedResultsDto, AlertPaginatedResultsDto router = APIRouter() @@ -131,7 +132,7 @@ def get_all_incidents( incidents_dto = [] for incident in incidents: - incidents_dto.append(IncidentDto.from_db_incident(incident)) + incidents_dto.append(IncidentDto.from_db_incident(incident.Incident)) logger.info( "Fetched incidents from DB", @@ -327,49 +328,39 @@ def delete_alerts_from_incident( return Response(status_code=202) - @router.post( "/mine", description="Create incidents using historical alerts", ) def mine( - authenticated_entity: AuthenticatedEntity = Depends(AuthVerifier()), - use_n_historical_alerts: int = 10000, - incident_sliding_window_size: int = 6 * 24 * 60 * 60, - statistic_sliding_window_size: int = 60 * 60, - jaccard_threshold: float = 0.0, - fingerprint_threshold: int = 1, -) -> dict: - tenant_id = authenticated_entity.tenant_id - alerts = get_last_alerts(tenant_id, limit=use_n_historical_alerts) - - if len(alerts) == 0: - return {"incidents": []} - - incidents = mine_incidents( - alerts, - incident_sliding_window_size, - statistic_sliding_window_size, - jaccard_threshold, - fingerprint_threshold, - ) - if len(incidents) == 0: - return {"incidents": []} - - for incident in incidents: - incident_id = create_incident_from_dict( - tenant_id=tenant_id, - incident_data={ - "name": "Mined using algorithm", - "description": "Candidate", - "is_predicted": True, - }, - ).id - - for alert in incident["alerts"]: - assign_alert_to_incident(alert.id, incident_id, tenant_id) - - return {"incidents": incidents} + authenticated_entity: AuthenticatedEntity = Depends(AuthVerifier(["read:alert"])), + alert_lower_timestamp: datetime = None, + alert_upper_timestamp: datetime = None, + use_n_historical_alerts: int = 10e10, + incident_lower_timestamp: datetime = None, + incident_upper_timestamp: datetime = None, + use_n_hist_incidents: int = 10e10, + pmi_threshold: float = 0.0, + knee_threshold: float = 0.8, + min_incident_size: int = 5, + incident_similarity_threshold: float = 0.8, +) -> dict: + result = asyncio.run(mine_incidents_and_create_objects( + None, + authenticated_entity.tenant_id, + alert_lower_timestamp, + alert_upper_timestamp, + use_n_historical_alerts, + incident_lower_timestamp, + incident_upper_timestamp, + use_n_hist_incidents, + pmi_threshold, + knee_threshold, + min_incident_size, + incident_similarity_threshold, + )) + + return result @router.post( diff --git a/keep/api/tasks/process_background_ai_task.py b/keep/api/tasks/process_background_ai_task.py new file mode 100644 index 000000000..aee599b60 --- /dev/null +++ b/keep/api/tasks/process_background_ai_task.py @@ -0,0 +1,71 @@ +import time +import asyncio +import logging +import datetime + +from keep.api.utils.import_ee import mine_incidents_and_create_objects, ALGORITHM_VERBOSE_NAME +from keep.api.core.db import get_tenants_configurations + +logger = logging.getLogger(__name__) + + +async def process_correlation(ctx, tenant_id:str): + await asyncio.sleep(180) + logger.info( + f"Background AI task started, {ALGORITHM_VERBOSE_NAME}", + extra={"algorithm": ALGORITHM_VERBOSE_NAME, "tenant_id": tenant_id}, + ) + start_time = datetime.datetime.now() + await mine_incidents_and_create_objects( + ctx, + tenant_id=tenant_id + ) + end_time = datetime.datetime.now() + logger.info( + f"Background AI task finished, {ALGORITHM_VERBOSE_NAME}, took {(end_time - start_time).total_seconds()} seconds", + extra={ + "algorithm": ALGORITHM_VERBOSE_NAME, + "tenant_id": tenant_id, + "duration_ms": (end_time - start_time).total_seconds() * 1000 + }, + ) + + +async def process_background_ai_task( + ctx: dict | None, # arq context + ): + """ + This job will schedule the process_correlation job for each tenant with strict ID's. + This ensures that the job is not scheduled multiple times for the same tenant. + """ + pool = ctx["redis"] + try: + all_jobs = await pool.queued_jobs() + except Exception as e: + logger.error(f"Error getting queued jobs, happens sometimes with unknown reason: {e}") + return None + + if mine_incidents_and_create_objects is not NotImplemented: + for tenant in get_tenants_configurations(): + + # Because of https://github.com/python-arq/arq/issues/432 we need to check if the job is already running + # The other option would be to twick "keep_result" but it will make debugging harder + job_prefix = 'process_correlation_tenant_id_' + str(tenant) + jobs_with_same_prefix = [job for job in all_jobs if job.job_id.startswith(job_prefix)] + if len(jobs_with_same_prefix) > 0: + logger.info( + f"No {ALGORITHM_VERBOSE_NAME} for tenant {tenant} scheduled because there is already one running", + extra={"algorithm": ALGORITHM_VERBOSE_NAME, "tenant_id": tenant}, + ) + else: + job = await pool.enqueue_job( + "process_correlation", + tenant_id=tenant, + _job_id=job_prefix + ":" + str(time.time()), # Strict ID ensures uniqueness + _job_try=1 + ) + logger.info( + f"{ALGORITHM_VERBOSE_NAME} for tenant {tenant} scheduled, job: {job}", + extra={"algorithm": ALGORITHM_VERBOSE_NAME, "tenant_id": tenant}, + ) + diff --git a/keep/api/tasks/process_event_task.py b/keep/api/tasks/process_event_task.py index 52c6af138..a3ae023d6 100644 --- a/keep/api/tasks/process_event_task.py +++ b/keep/api/tasks/process_event_task.py @@ -60,6 +60,7 @@ def __save_to_db( formatted_events: list[AlertDto], deduplicated_events: list[AlertDto], provider_id: str | None = None, + timestamp_forced: datetime.datetime | None = None, ): try: # keep raw events in the DB if the user wants to @@ -115,15 +116,21 @@ def __save_to_db( tz=datetime.timezone.utc ).isoformat() - alert = Alert( - tenant_id=tenant_id, - provider_type=( + alert_args = { + "tenant_id": tenant_id, + "provider_type": ( provider_type if provider_type else formatted_event.source[0] ), - event=formatted_event.dict(), - provider_id=provider_id, - fingerprint=formatted_event.fingerprint, - alert_hash=formatted_event.alert_hash, + "event": formatted_event.dict(), + "provider_id": provider_id, + "fingerprint": formatted_event.fingerprint, + "alert_hash": formatted_event.alert_hash, + } + if timestamp_forced is not None: + alert_args['timestamp'] = timestamp_forced + + alert = Alert( + **alert_args ) session.add(alert) audit = AlertAudit( @@ -190,6 +197,7 @@ def __handle_formatted_events( formatted_events: list[AlertDto], provider_id: str | None = None, notify_client: bool = True, + timestamp_forced: datetime.datetime | None = None, ): """ this is super important function and does five things: @@ -239,6 +247,7 @@ def __handle_formatted_events( formatted_events, deduplicated_events, provider_id, + timestamp_forced, ) # after the alert enriched and mapped, lets send it to the elasticsearch diff --git a/keep/api/utils/import_ee.py b/keep/api/utils/import_ee.py new file mode 100644 index 000000000..9dfc2e913 --- /dev/null +++ b/keep/api/utils/import_ee.py @@ -0,0 +1,21 @@ +import os +import sys +import pathlib + +EE_ENABLED = os.environ.get("EE_ENABLED", "false") == "true" +EE_PATH = os.environ.get("EE_PATH", "../ee") # Path related to the fastapi root directory + +if EE_ENABLED: + path_with_ee = ( + str(pathlib.Path(__file__).parent.resolve()) + + "/../../" + # To go to the fastapi root directory + EE_PATH + + "/../" # To go to the parent directory of the ee directory to allow imports like ee.abc.abc + ) + sys.path.insert(0, path_with_ee) + + from ee.experimental.incident_utils import mine_incidents_and_create_objects # noqa + from ee.experimental.incident_utils import ALGORITHM_VERBOSE_NAME # noqa +else: + mine_incidents_and_create_objects = NotImplemented + ALGORITHM_VERBOSE_NAME = NotImplemented diff --git a/poetry.lock b/poetry.lock index 689483c5b..57b5695ae 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1024,6 +1024,17 @@ files = [ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, ] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + [[package]] name = "dnspython" version = "2.6.1" @@ -2529,25 +2540,26 @@ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] [[package]] name = "openai" -version = "0.27.10" -description = "Python client library for the OpenAI API" +version = "1.37.1" +description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-0.27.10-py3-none-any.whl", hash = "sha256:beabd1757e3286fa166dde3b70ebb5ad8081af046876b47c14c41e203ed22a14"}, - {file = "openai-0.27.10.tar.gz", hash = "sha256:60e09edf7100080283688748c6803b7b3b52d5a55d21890f3815292a0552d83b"}, + {file = "openai-1.37.1-py3-none-any.whl", hash = "sha256:9a6adda0d6ae8fce02d235c5671c399cfa40d6a281b3628914c7ebf244888ee3"}, + {file = "openai-1.37.1.tar.gz", hash = "sha256:faf87206785a6b5d9e34555d6a3242482a6852bc802e453e2a891f68ee04ce55"}, ] [package.dependencies] -aiohttp = "*" -requests = ">=2.20" -tqdm = "*" +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.7,<5" [package.extras] -datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] -dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"] -embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] -wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] [[package]] name = "openshift-client" @@ -5057,4 +5069,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "b02155f81098680deab94ff5d9d126f696e11837f0a95442ed78b3d25bbc0554" +content-hash = "fee3645d18637ff6a6c0f4d08e20abfb556bf3d5c301111168c678ed2dbc3ad2" diff --git a/pyproject.toml b/pyproject.toml index 960f99335..821099d40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,7 @@ python-jose = "^3.3.0" jwcrypto = "^1.5.6" sqlalchemy = "1.4.41" snowflake-connector-python = "3.1.0" -openai = "^0.27.7" +openai = "1.37.1" opentelemetry-sdk = ">=1.20.0,<1.22" opentelemetry-instrumentation-fastapi = "^0.41b0" opentelemetry-instrumentation-logging = "^0.41b0" diff --git a/scripts/shoot_alerts_from_dump.py b/scripts/shoot_alerts_from_dump.py index 4e3db0476..ecb5db518 100644 --- a/scripts/shoot_alerts_from_dump.py +++ b/scripts/shoot_alerts_from_dump.py @@ -54,6 +54,7 @@ def shoot_tenants_alerts(file, tenant_id): session=session, raw_events=raw_event, formatted_events=[alert], + timestamp_forced=alert.lastReceived ) session.close()