Skip to content

Commit

Permalink
Merge branch 'main' into feat-app-dynamics
Browse files Browse the repository at this point in the history
  • Loading branch information
talboren authored Apr 17, 2024
2 parents cecab64 + 3299e2b commit e6b83c1
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 19 deletions.
57 changes: 56 additions & 1 deletion keep/api/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import os
import uuid
from datetime import datetime, timedelta, timezone
from typing import List, Tuple
from uuid import uuid4
Expand Down Expand Up @@ -153,6 +154,26 @@ def create_db_and_tables():
f"ALTER TABLE workflowtoalertexecution DROP FOREIGN KEY {constraint_name};"
)
logger.info(f"Dropped constraint {constraint_name}")
# now add the new column
try:
if session.bind.dialect.name == "sqlite":
session.exec("ALTER TABLE workflowtoalertexecution ADD COLUMN event_id VARCHAR(255);")
elif session.bind.dialect.name == "mysql":
session.exec("ALTER TABLE workflowtoalertexecution ADD COLUMN event_id VARCHAR(255);")
elif session.bind.dialect.name == "postgresql":
session.exec("ALTER TABLE workflowtoalertexecution ADD COLUMN event_id TEXT;")
elif session.bind.dialect.name == "mssql":
session.exec("ALTER TABLE workflowtoalertexecution ADD event_id NVARCHAR(255);")
else:
raise ValueError("Unsupported database type")
except Exception as e:
# that's ok
if "Duplicate column name" in str(e):
pass
# else, log
else:
logger.exception("Failed to migrate rule table")
pass
# also add grouping_criteria to the workflow table
logger.info("Migrating Rule table")
try:
Expand Down Expand Up @@ -274,6 +295,7 @@ def create_workflow_execution(
tenant_id: str,
triggered_by: str,
execution_number: int = 1,
event_id: str = None,
fingerprint: str = None,
) -> WorkflowExecution:
with Session(engine) as session:
Expand All @@ -295,6 +317,7 @@ def create_workflow_execution(
workflow_to_alert_execution = WorkflowToAlertExecution(
workflow_execution_id=workflow_execution.id,
alert_fingerprint=fingerprint,
event_id=event_id,
)
session.add(workflow_to_alert_execution)

Expand Down Expand Up @@ -490,6 +513,26 @@ def add_or_update_workflow(
return existing_workflow if existing_workflow else workflow


def get_workflow_to_alert_execution_by_workflow_execution_id(
workflow_execution_id: str
) -> WorkflowToAlertExecution:
"""
Get the WorkflowToAlertExecution entry for a given workflow execution ID.
Args:
workflow_execution_id (str): The workflow execution ID to filter the workflow execution by.
Returns:
WorkflowToAlertExecution: The WorkflowToAlertExecution object.
"""
with Session(engine) as session:
return (
session.query(WorkflowToAlertExecution)
.filter_by(workflow_execution_id=workflow_execution_id)
.first()
)


def get_last_workflow_workflow_to_alert_executions(
session: Session, tenant_id: str
) -> list[WorkflowToAlertExecution]:
Expand Down Expand Up @@ -541,7 +584,7 @@ def get_last_workflow_workflow_to_alert_executions(


def get_last_workflow_execution_by_workflow_id(
workflow_id: str, tenant_id: str
tenant_id: str, workflow_id: str
) -> Optional[WorkflowExecution]:
with Session(engine) as session:
workflow_execution = (
Expand Down Expand Up @@ -1037,6 +1080,18 @@ def get_alerts_by_fingerprint(tenant_id: str, fingerprint: str, limit=1) -> List
return alerts


def get_alert_by_fingerprint_and_event_id(tenant_id: str, fingerprint: str, event_id: str) -> Alert:
with Session(engine) as session:
alert = (
session.query(Alert)
.filter(Alert.tenant_id == tenant_id)
.filter(Alert.fingerprint == fingerprint)
.filter(Alert.id == uuid.UUID(event_id))
.first()
)
return alert


def get_previous_alert_by_fingerprint(tenant_id: str, fingerprint: str) -> Alert:
# get the previous alert for a given fingerprint
with Session(engine) as session:
Expand Down
1 change: 1 addition & 0 deletions keep/api/models/db/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class WorkflowToAlertExecution(SQLModel, table=True):
id: Optional[int] = Field(primary_key=True, default=None)
workflow_execution_id: str = Field(foreign_key="workflowexecution.id")
alert_fingerprint: str
event_id: str | None
workflow_execution: WorkflowExecution = Relationship(
back_populates="workflow_to_alert_execution"
)
Expand Down
6 changes: 2 additions & 4 deletions keep/contextmanager/contextmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, tenant_id, workflow_id=None, workflow_execution_id=None):
if self.workflow_id:
try:
last_workflow_execution = get_last_workflow_execution_by_workflow_id(
workflow_id, tenant_id
tenant_id, workflow_id
)
if last_workflow_execution is not None:
self.last_workflow_execution_results = (
Expand Down Expand Up @@ -185,9 +185,7 @@ def set_step_context(self, step_id, results, foreach=False):
self.steps_context_size = asizeof(self.steps_context)

def get_last_workflow_run(self, workflow_id):
# TODO: fix for throttling
# no previous runs
return {}
return get_last_workflow_execution_by_workflow_id(self.tenant_id, workflow_id)

def dump(self):
self.logger.info("Dumping logs to db")
Expand Down
1 change: 0 additions & 1 deletion keep/providers/cloudwatch_provider/cloudwatch_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,6 @@ def _format_alert(
@classmethod
def simulate_alert(cls) -> dict:
# Choose a random alert type
import hashlib
import random

from keep.providers.cloudwatch_provider.alerts_mock import ALERTS
Expand Down
5 changes: 3 additions & 2 deletions keep/step/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def _check_throttling(self, action_name):
throttle = ThrottleFactory.get_instance(
self.context_manager, throttling_type, throttling_config
)
alert_id = self.context_manager.get_workflow_id()
return throttle.check_throttling(action_name, alert_id)
workflow_id = self.context_manager.get_workflow_id()
event_id = self.context_manager.event_context.event_id
return throttle.check_throttling(action_name, workflow_id, event_id)

def _get_foreach_items(self) -> list | list[list]:
"""Get the items to iterate over, when using the `foreach` attribute (see foreach.md)"""
Expand Down
5 changes: 3 additions & 2 deletions keep/throttles/base_throttle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@ def __init__(
self.context_manager = context_manager

@abc.abstractmethod
def check_throttling(self, action_name, alert_id, **kwargs) -> bool:
def check_throttling(self, action_name, workflow_id, event_id, **kwargs) -> bool:
"""
Validate provider configuration.
Args:
action_name (str): The name of the action to check throttling for.
alert_id (str): The id of the alert to check throttling for.
workflow_id (str): The id of the workflow to check throttling for.
event_id (str): The id of the event to check throttling for.
"""
raise NotImplementedError("apply() method not implemented")
24 changes: 20 additions & 4 deletions keep/throttles/one_until_resolved_throttle.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from keep.api.core.db import get_alert_by_fingerprint_and_event_id, \
get_workflow_to_alert_execution_by_workflow_execution_id
from keep.api.models.alert import AlertStatus
from keep.throttles.base_throttle import BaseThrottle
from keep.contextmanager.contextmanager import ContextManager

Expand All @@ -12,12 +15,25 @@ class OneUntilResolvedThrottle(BaseThrottle):
def __init__(self, context_manager: ContextManager, throttle_type, throttle_config):
super().__init__(context_manager=context_manager, throttle_type=throttle_type, throttle_config=throttle_config)

def check_throttling(self, action_name, alert_id, **kwargs) -> bool:
last_alert_run = self.context_manager.get_last_workflow_run(alert_id)
if not last_alert_run:
def check_throttling(self, action_name, workflow_id, event_id, **kwargs) -> bool:
last_workflow_run = self.context_manager.get_last_workflow_run(workflow_id)
if not last_workflow_run:
return False

# query workflowtoalertexecution table by workflow_id and after that get the alert by fingerprint and event_id
last_workflow_alert_execution = get_workflow_to_alert_execution_by_workflow_execution_id(last_workflow_run.id)
if not last_workflow_alert_execution:
return False

alert = get_alert_by_fingerprint_and_event_id(self.context_manager.tenant_id,
last_workflow_alert_execution.alert_fingerprint,
last_workflow_alert_execution.event_id)
if not alert:
return False

# if the last time the alert were triggered it was in resolved status, return false
if last_alert_run.get("alert_status").lower() == "resolved":
if AlertStatus(alert.event.get("status")) == AlertStatus.RESOLVED:
return False

# else, return true because its already firing
return True
2 changes: 2 additions & 0 deletions keep/workflowmanager/workflowscheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def handle_manual_event_workflow(
triggered_by=f"manually by {triggered_by_user}",
execution_number=unique_execution_number,
fingerprint=alert.fingerprint,
event_id=alert.event_id,
)
self.logger.info(f"Workflow execution id: {workflow_execution_id}")
# This is kinda WTF exception since create_workflow_execution shouldn't fail for manual
Expand Down Expand Up @@ -256,6 +257,7 @@ def _handle_event_workflows(self):
triggered_by=triggered_by,
execution_number=workflow_execution_number,
fingerprint=event.fingerprint,
event_id=event.event_id,
)
# This is kinda wtf exception since create workflow execution shouldn't fail for events other than interval
except IntegrityError:
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,21 @@ def db_session(request, mysql_container):
interval=0,
workflow_raw="test workflow raw",
),
WorkflowExecution(
id="test-execution-id-1",
workflow_id="mock_alert",
tenant_id=SINGLE_TENANT_UUID,
triggered_by="keep-test",
status="success",
execution_number=1,
results={},
),
WorkflowToAlertExecution(
id=1,
workflow_execution_id="test-execution-id-1",
alert_fingerprint="mock_alert",
event_id="mock_event_id",
),
# Add more data as needed
]
session.add_all(workflow_data)
Expand Down
21 changes: 16 additions & 5 deletions tests/test_contextmanager.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""
Test the context manager
"""

import json
import tempfile

import pytest

from keep.api.core.dependencies import SINGLE_TENANT_UUID
from keep.api.models.db.workflow import WorkflowExecution
from keep.contextmanager.contextmanager import ContextManager

STATE_FILE_MOCK_DATA = {
Expand Down Expand Up @@ -179,17 +180,27 @@ def test_context_manager_set_step_context(context_manager: ContextManager):
assert context_manager.steps_context[step_id]["results"] == results


def test_context_manager_get_last_alert_run(context_manager_with_state: ContextManager):
def test_context_manager_get_last_alert_run(context_manager_with_state: ContextManager, db_session):
alert_id = "mock_alert"
alert_context = {"mock": "mock"}
alert_status = "firing"
context_manager_with_state.tenant_id = SINGLE_TENANT_UUID
last_run = context_manager_with_state.get_last_workflow_run(alert_id)
assert last_run == {}
if last_run is None:
pytest.fail("No workflow run found with the given alert_id")
assert last_run == WorkflowExecution(
id="test-execution-id-1",
workflow_id="mock_alert",
tenant_id=SINGLE_TENANT_UUID,
started=last_run.started,
triggered_by="keep-test",
status="success",
execution_number=1,
results={},
)
context_manager_with_state.set_last_workflow_run(
alert_id, alert_context, alert_status
)
# last_run = context_manager_with_state.get_last_workflow_run(alert_id)
# assert last_run["workflow_status"] == alert_status


def test_context_manager_singleton(context_manager: ContextManager):
Expand Down

0 comments on commit e6b83c1

Please sign in to comment.