From 77a2a8f4af074f217269c78b7de94192688d8b91 Mon Sep 17 00:00:00 2001 From: Shahar Glazner Date: Mon, 8 Jul 2024 14:25:23 +0300 Subject: [PATCH] fix: several bug fixes (#1335) --- keep/api/core/tenant_configuration.py | 6 +- keep/api/logging.py | 4 +- keep/api/routes/preset.py | 25 ++++--- keep/api/tasks/process_event_task.py | 15 ----- keep/parser/parser.py | 66 +++++++++++++------ .../ilert_provider/ilert_provider.py | 6 +- keep/workflowmanager/workflowmanager.py | 20 ++++-- keep/workflowmanager/workflowscheduler.py | 13 +++- tests/test_search_alerts.py | 7 +- 9 files changed, 103 insertions(+), 59 deletions(-) diff --git a/keep/api/core/tenant_configuration.py b/keep/api/core/tenant_configuration.py index fae8382f8..b948e597f 100644 --- a/keep/api/core/tenant_configuration.py +++ b/keep/api/core/tenant_configuration.py @@ -19,9 +19,9 @@ def __init__(self): ) def _load_tenant_configurations(self): - self.logger.info("Loading tenants configurations") + self.logger.debug("Loading tenants configurations") tenants_configuration = get_tenants_configurations() - self.logger.info( + self.logger.debug( "Tenants configurations loaded", extra={ "number_of_tenants": len(tenants_configuration), @@ -41,7 +41,7 @@ def get_configuration(self, tenant_id, config_name): # tenant_config = self.configurations.get(tenant_id, {}) tenant_config = self.configurations.get(tenant_id) if not tenant_config: - self.logger.info(f"Tenant {tenant_id} not found in memory, loading it") + self.logger.debug(f"Tenant {tenant_id} not found in memory, loading it") self.configurations = self._load_tenant_configurations() tenant_config = self.configurations.get(tenant_id, {}) diff --git a/keep/api/logging.py b/keep/api/logging.py index 36c227312..2d488b3cf 100644 --- a/keep/api/logging.py +++ b/keep/api/logging.py @@ -105,13 +105,13 @@ def dump(self): "handlers": { "default": { "level": "DEBUG", - "formatter": "json" if LOG_FORMAT == LOG_FORMAT_OPEN_TELEMETRY else None, + "formatter": "json" if LOG_FORMAT == LOG_FORMAT_OPEN_TELEMETRY else None, "class": "logging.StreamHandler", "stream": "ext://sys.stdout", }, "context": { "level": "DEBUG", - "formatter": "json" if LOG_FORMAT == LOG_FORMAT_OPEN_TELEMETRY else None, + "formatter": "json" if LOG_FORMAT == LOG_FORMAT_OPEN_TELEMETRY else None, "class": "keep.api.logging.WorkflowDBHandler", }, }, diff --git a/keep/api/routes/preset.py b/keep/api/routes/preset.py index 29097fbfa..36fb69d4a 100644 --- a/keep/api/routes/preset.py +++ b/keep/api/routes/preset.py @@ -1,6 +1,6 @@ import logging -from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request from pydantic import BaseModel from sqlmodel import Session, select @@ -20,15 +20,17 @@ logger = logging.getLogger(__name__) -async def pull_alerts_from_providers( +# SHAHAR: this function runs as background tasks as a seperate thread +# DO NOT ADD async HERE as it will run in the main thread and block the whole server +def pull_alerts_from_providers( tenant_id: str, + trace_id: str, ) -> list[AlertDto]: """ Pulls alerts from providers and record the to the DB. "Get or create logics". """ - context_manager = ContextManager( tenant_id=tenant_id, workflow_id=None, @@ -53,16 +55,15 @@ async def pull_alerts_from_providers( provider_class.get_alerts_by_fingerprint(tenant_id=tenant_id) ) for fingerprint, alert in sorted_provider_alerts_by_fingerprint.items(): - await process_event( + process_event( {}, tenant_id, - None, - None, + provider.type, + provider.id, fingerprint, None, - None, + trace_id, alert, - save_if_duplicate=False, ) @@ -194,6 +195,7 @@ def update_preset( description="Get a preset for tenant", ) async def get_preset_alerts( + request: Request, bg_tasks: BackgroundTasks, preset_name: str, authenticated_entity: AuthenticatedEntity = Depends(AuthVerifier()), @@ -201,7 +203,12 @@ async def get_preset_alerts( # Gathering alerts may take a while and we don't care if it will finish before we return the response. # In the worst case, gathered alerts will be pulled in the next request. - bg_tasks.add_task(pull_alerts_from_providers, authenticated_entity.tenant_id) + + bg_tasks.add_task( + pull_alerts_from_providers, + authenticated_entity.tenant_id, + request.state.trace_id, + ) tenant_id = authenticated_entity.tenant_id logger.info("Getting preset alerts", extra={"preset_name": preset_name}) diff --git a/keep/api/tasks/process_event_task.py b/keep/api/tasks/process_event_task.py index 879fcb18d..f92eca0d3 100644 --- a/keep/api/tasks/process_event_task.py +++ b/keep/api/tasks/process_event_task.py @@ -157,7 +157,6 @@ def __handle_formatted_events( raw_events: list[dict], formatted_events: list[AlertDto], provider_id: str | None = None, - save_if_duplicate: bool = True, ): """ this is super important function and does five things: @@ -190,18 +189,6 @@ def __handle_formatted_events( event.alert_hash = event_hash event.isDuplicate = event_deduplicated - if event.isDuplicate and not save_if_duplicate: - logger.info( - "Alert is not saved as a duplicate", - extra={ - "provider_type": provider_type, - "num_of_alerts": len(formatted_events), - "provider_id": provider_id, - "tenant_id": tenant_id, - }, - ) - return None - # filter out the deduplicated events formatted_events = list( filter(lambda event: not event.isDuplicate, formatted_events) @@ -356,7 +343,6 @@ async def process_event( event: ( AlertDto | list[AlertDto] | dict ), # the event to process, either plain (generic) or from a specific provider - save_if_duplicate: bool = True, ): extra_dict = { "tenant_id": tenant_id, @@ -395,7 +381,6 @@ async def process_event( event, event, provider_id, - save_if_duplicate, ) except Exception: logger.exception("Error processing event", extra=extra_dict) diff --git a/keep/parser/parser.py b/keep/parser/parser.py index 57765a25a..f6d31bae4 100644 --- a/keep/parser/parser.py +++ b/keep/parser/parser.py @@ -1,15 +1,16 @@ +import copy import json import logging import os import typing -import copy + import yaml +from keep.actions.actions_factory import ActionsCRUD from keep.api.core.db import get_workflow_id from keep.contextmanager.contextmanager import ContextManager from keep.providers.base.base_provider import BaseProvider from keep.providers.providers_factory import ProvidersFactory -from keep.actions.actions_factory import ActionsCRUD from keep.step.step import Step, StepType from keep.step.step_provider_parameter import StepProviderParameter from keep.workflowmanager.workflow import Workflow, WorkflowStrategy @@ -48,7 +49,11 @@ def _get_workflow_id(self, tenant_id, workflow: dict) -> str: return workflow_id def parse( - self, tenant_id, parsed_workflow_yaml: dict, providers_file: str = None, actions_file: str = None + self, + tenant_id, + parsed_workflow_yaml: dict, + providers_file: str = None, + actions_file: str = None, ) -> typing.List[Workflow]: """_summary_ @@ -68,7 +73,12 @@ def parse( ) or parsed_workflow_yaml.get("alerts") workflows = [ self._parse_workflow( - tenant_id, workflow, providers_file, workflow_providers, actions_file, workflow_actions + tenant_id, + workflow, + providers_file, + workflow_providers, + actions_file, + workflow_actions, ) for workflow in raw_workflows ] @@ -78,13 +88,23 @@ def parse( "workflow" ) or parsed_workflow_yaml.get("alert") workflow = self._parse_workflow( - tenant_id, raw_workflow, providers_file, workflow_providers, actions_file, workflow_actions + tenant_id, + raw_workflow, + providers_file, + workflow_providers, + actions_file, + workflow_actions, ) workflows = [workflow] # else, if it stored in the db, it stored without the "workflow" key else: workflow = self._parse_workflow( - tenant_id, parsed_workflow_yaml, providers_file, workflow_providers, actions_file, workflow_actions + tenant_id, + parsed_workflow_yaml, + providers_file, + workflow_providers, + actions_file, + workflow_actions, ) workflows = [workflow] return workflows @@ -113,7 +133,7 @@ def _parse_workflow( providers_file: str, workflow_providers: dict = None, actions_file: str = None, - workflow_actions: dict = None + workflow_actions: dict = None, ) -> Workflow: self.logger.debug("Parsing workflow") workflow_id = self._get_workflow_id(tenant_id, workflow) @@ -381,9 +401,9 @@ def _load_actions_config( # if the workflow file itself contain actions (mainly backward compatibility) if workflow_actions: for action in workflow_actions: - context_manager.actions_context.update({ - action.get('use') or action.get('name'): action - }) + context_manager.actions_context.update( + {action.get("use") or action.get("name"): action} + ) self._load_actions_from_db(context_manager, tenant_id) self.logger.debug("Actions parsed and loaded successfully") @@ -399,16 +419,16 @@ def _parse_actions_from_file( self.logger.exception(f"Error parsing actions file {actions_file}") raise # create a hashmap -> action - for action in actions_content.get('actions', []) : - context_manager.actions_context.update({ - action.get('use') or action.get('name'): action - }) + for action in actions_content.get("actions", []): + context_manager.actions_context.update( + {action.get("use") or action.get("name"): action} + ) def _load_actions_from_db( self, context_manager: ContextManager, tenant_id: str = None ): # If there is no tenant id, e.g. running from CLI, no db here - if not tenant_id: + if not tenant_id: return # Load actions from db actions = ActionsCRUD.get_all_actions(tenant_id) @@ -460,7 +480,7 @@ def _parse_actions( self, context_manager: ContextManager, workflow: dict ) -> typing.List[Step]: self.logger.debug("Parsing actions") - workflow_actions_raw = workflow.get("actions", []) + workflow_actions_raw = workflow.get("actions", []) workflow_actions = self._merge_action_by_use( workflow_actions=workflow_actions_raw, actions_context=context_manager.actions_context, @@ -496,7 +516,9 @@ def _load_actions_from_file( f"action defined in {actions_file} should have id as unique field" ) else: - self.logger.warning(f"No action located at {actions_file}, skip loading reusable actions") + self.logger.warning( + f"No action located at {actions_file}, skip loading reusable actions" + ) return actions_set def _merge_action_by_use( @@ -584,7 +606,13 @@ def _parse_provider_config( provider_config = context_manager.providers_context.get(config_id) if not provider_config: self.logger.warning( - f"Provider {config_id} not found in configuration, did you configure it?" + "Provider not found in configuration, did you configure it?", + extra={ + "provider_id": config_id, + "provider_type": provider_type, + "provider_config": provider_config, + "tenant_id": context_manager.tenant_id, + }, ) provider_config = {"authentication": {}} return config_id, provider_config @@ -643,7 +671,7 @@ def deep_merge(source: dict, dest: dict) -> dict: Example: source = {"deep1": {"deep2": 1}} dest = {"deep1", {"deep2": 2, "deep3": 3}} - returns -> {"deep1": {"deep2": 1, "deep3": 3}} + returns -> {"deep1": {"deep2": 1, "deep3": 3}} Returns: dict: The new object contains merged results diff --git a/keep/providers/ilert_provider/ilert_provider.py b/keep/providers/ilert_provider/ilert_provider.py index 52d4aa158..5674f8068 100644 --- a/keep/providers/ilert_provider/ilert_provider.py +++ b/keep/providers/ilert_provider/ilert_provider.py @@ -194,6 +194,10 @@ def _get_alerts(self) -> list[AlertDto]: f"Failed to get alerts: {response.status_code} {response.text}" ) + alerts = response.json() + self.logger.info( + "Got alerts from ilert", extra={"number_of_alerts": len(alerts)} + ) return [ AlertDto( id=alert["id"], @@ -211,7 +215,7 @@ def _get_alerts(self) -> list[AlertDto]: lastHistoryUpdatedAt=alert["lastHistoryUpdatedAt"], lastReceived=alert["updatedAt"], ) - for alert in response.json() + for alert in alerts ] def __create_or_update_incident( diff --git a/keep/workflowmanager/workflowmanager.py b/keep/workflowmanager/workflowmanager.py index 2c4e41612..7f10ef8ae 100644 --- a/keep/workflowmanager/workflowmanager.py +++ b/keep/workflowmanager/workflowmanager.py @@ -74,14 +74,24 @@ def insert_events(self, tenant_id, events: typing.List[AlertDto]): # the provider is not configured, hence the workflow cannot be triggered # todo - handle it better # todo2 - handle if more than one provider is not configured - except ProviderConfigurationException as e: - self.logger.warning( - f"Workflow have a provider that is not configured: {e}" + except ProviderConfigurationException: + self.logger.exception( + "Workflow have a provider that is not configured", + extra={ + "workflow_id": workflow_model.workflow_id, + "tenant_id": tenant_id, + }, ) continue - except Exception as e: + except Exception: # TODO: how to handle workflows that aren't properly parsed/configured? - self.logger.error(f"Error getting workflow: {e}") + self.logger.exception( + "Error getting workflow", + extra={ + "workflow_id": workflow_model.workflow_id, + "tenant_id": tenant_id, + }, + ) continue for trigger in workflow.workflow_triggers: # TODO: handle it better diff --git a/keep/workflowmanager/workflowscheduler.py b/keep/workflowmanager/workflowscheduler.py index c2d5e7281..f5cdc4f54 100644 --- a/keep/workflowmanager/workflowscheduler.py +++ b/keep/workflowmanager/workflowscheduler.py @@ -1,8 +1,8 @@ import enum import hashlib import logging -import threading import queue +import threading import time import typing import uuid @@ -62,8 +62,15 @@ def _handle_interval_workflows(self): tenant_id = workflow.get("tenant_id") workflow_id = workflow.get("workflow_id") workflow = self.workflow_store.get_workflow(tenant_id, workflow_id) - except ProviderConfigurationException as e: - self.logger.error(f"Provider configuration is invalid: {e}") + except ProviderConfigurationException: + self.logger.exception( + "Provider configuration is invalid", + extra={ + "workflow_id": workflow_id, + "workflow_execution_id": workflow_execution_id, + "tenant_id": tenant_id, + }, + ) self._finish_workflow_execution( tenant_id=tenant_id, workflow_id=workflow_id, diff --git a/tests/test_search_alerts.py b/tests/test_search_alerts.py index a0b96a26d..bc4f46e85 100644 --- a/tests/test_search_alerts.py +++ b/tests/test_search_alerts.py @@ -168,8 +168,11 @@ def test_search_sanity2(db_session, setup_alerts): assert len(db_filtered_alerts) == 2 # compare the results - assert elastic_filtered_alerts[0] == db_filtered_alerts[0] - assert elastic_filtered_alerts[1] == db_filtered_alerts[1] + sorted_elastic_alerts = sorted( + elastic_filtered_alerts, key=lambda x: x.lastReceived + ) + sorted_db_alerts = sorted(db_filtered_alerts, key=lambda x: x.lastReceived) + assert sorted_elastic_alerts == sorted_db_alerts @pytest.mark.parametrize(