+
+ {workflow.provisioned && (
+
+ Provisioned
+
+ )}
{!!handleRunClick && WorkflowMenuSection({
onDelete: handleDeleteClick,
onRun: handleRunClick,
@@ -566,6 +574,7 @@ function WorkflowTile({ workflow }: { workflow: Workflow }) {
onBuilder: handleBuilderClick,
runButtonToolTip: message,
isRunButtonDisabled: !!isRunButtonDisabled,
+ provisioned: workflow.provisioned,
})}
@@ -862,6 +871,7 @@ export function WorkflowTileOld({ workflow }: { workflow: Workflow }) {
onBuilder: handleBuilderClick,
runButtonToolTip: message,
isRunButtonDisabled: !!isRunButtonDisabled,
+ provisioned: workflow.provisioned,
})}
diff --git a/keep/api/api.py b/keep/api/api.py
index f04d01365..bed995b86 100644
--- a/keep/api/api.py
+++ b/keep/api/api.py
@@ -59,7 +59,12 @@
from keep.event_subscriber.event_subscriber import EventSubscriber
from keep.identitymanager.identitymanagerfactory import IdentityManagerFactory
from keep.posthog.posthog import get_posthog_client
+
+# load all providers into cache
+from keep.providers.providers_factory import ProvidersFactory
+from keep.providers.providers_service import ProvidersService
from keep.workflowmanager.workflowmanager import WorkflowManager
+from keep.workflowmanager.workflowstore import WorkflowStore
load_dotenv(find_dotenv())
keep.api.logging.setup_logging()
@@ -242,15 +247,14 @@ def get_app(
@app.on_event("startup")
async def on_startup():
- # load all providers into cache
- from keep.providers.providers_factory import ProvidersFactory
- from keep.providers.providers_service import ProvidersService
-
logger.info("Loading providers into cache")
ProvidersFactory.get_all_providers()
# provision providers from env. relevant only on single tenant.
+ logger.info("Provisioning providers and workflows")
ProvidersService.provision_providers_from_env(SINGLE_TENANT_UUID)
logger.info("Providers loaded successfully")
+ WorkflowStore.provision_workflows_from_directory(SINGLE_TENANT_UUID)
+ logger.info("Workflows provisioned successfully")
# Start the services
logger.info("Starting the services")
# Start the scheduler
diff --git a/keep/api/core/db.py b/keep/api/core/db.py
index 9d7d5c40b..0e0111f54 100644
--- a/keep/api/core/db.py
+++ b/keep/api/core/db.py
@@ -278,6 +278,8 @@ def add_or_update_workflow(
interval,
workflow_raw,
is_disabled,
+ provisioned=False,
+ provisioned_file=None,
updated_by=None,
) -> Workflow:
with Session(engine, expire_on_commit=False) as session:
@@ -301,7 +303,9 @@ def add_or_update_workflow(
existing_workflow.revision += 1 # Increment the revision
existing_workflow.last_updated = datetime.now() # Update last_updated
existing_workflow.is_deleted = False
- existing_workflow.is_disabled= is_disabled
+ existing_workflow.is_disabled = is_disabled
+ existing_workflow.provisioned = provisioned
+ existing_workflow.provisioned_file = provisioned_file
else:
# Create a new workflow
@@ -313,8 +317,10 @@ def add_or_update_workflow(
created_by=created_by,
updated_by=updated_by, # Set updated_by to the provided value
interval=interval,
- is_disabled =is_disabled,
+ is_disabled=is_disabled,
workflow_raw=workflow_raw,
+ provisioned=provisioned,
+ provisioned_file=provisioned_file,
)
session.add(workflow)
@@ -461,6 +467,27 @@ def get_all_workflows(tenant_id: str) -> List[Workflow]:
return workflows
+def get_all_provisioned_workflows(tenant_id: str) -> List[Workflow]:
+ with Session(engine) as session:
+ workflows = session.exec(
+ select(Workflow)
+ .where(Workflow.tenant_id == tenant_id)
+ .where(Workflow.provisioned == True)
+ .where(Workflow.is_deleted == False)
+ ).all()
+ return workflows
+
+
+def get_all_provisioned_providers(tenant_id: str) -> List[Provider]:
+ with Session(engine) as session:
+ providers = session.exec(
+ select(Provider)
+ .where(Provider.tenant_id == tenant_id)
+ .where(Provider.provisioned == True)
+ ).all()
+ return providers
+
+
def get_all_workflows_yamls(tenant_id: str) -> List[str]:
with Session(engine) as session:
workflows = session.exec(
@@ -499,6 +526,7 @@ def get_raw_workflow(tenant_id: str, workflow_id: str) -> str:
return None
return workflow.workflow_raw
+
def update_provider_last_pull_time(tenant_id: str, provider_id: str):
extra = {"tenant_id": tenant_id, "provider_id": provider_id}
logger.info("Updating provider last pull time", extra=extra)
@@ -568,18 +596,22 @@ def finish_workflow_execution(tenant_id, workflow_id, execution_id, status, erro
session.commit()
-def get_workflow_executions(tenant_id, workflow_id, limit=50, offset=0, tab=2, status: Optional[Union[str, List[str]]] = None,
+def get_workflow_executions(
+ tenant_id,
+ workflow_id,
+ limit=50,
+ offset=0,
+ tab=2,
+ status: Optional[Union[str, List[str]]] = None,
trigger: Optional[Union[str, List[str]]] = None,
- execution_id: Optional[str] = None):
+ execution_id: Optional[str] = None,
+):
with Session(engine) as session:
- query = (
- session.query(
- WorkflowExecution,
- )
- .filter(
- WorkflowExecution.tenant_id == tenant_id,
- WorkflowExecution.workflow_id == workflow_id
- )
+ query = session.query(
+ WorkflowExecution,
+ ).filter(
+ WorkflowExecution.tenant_id == tenant_id,
+ WorkflowExecution.workflow_id == workflow_id,
)
now = datetime.now(tz=timezone.utc)
@@ -593,48 +625,51 @@ def get_workflow_executions(tenant_id, workflow_id, limit=50, offset=0, tab=2, s
start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
query = query.filter(
WorkflowExecution.started >= start_of_day,
- WorkflowExecution.started <= now
+ WorkflowExecution.started <= now,
)
if timeframe:
- query = query.filter(
- WorkflowExecution.started >= timeframe
- )
+ query = query.filter(WorkflowExecution.started >= timeframe)
if isinstance(status, str):
status = [status]
elif status is None:
- status = []
-
+ status = []
+
# Normalize trigger to a list
if isinstance(trigger, str):
trigger = [trigger]
-
if execution_id:
query = query.filter(WorkflowExecution.id == execution_id)
if status and len(status) > 0:
query = query.filter(WorkflowExecution.status.in_(status))
if trigger and len(trigger) > 0:
- conditions = [WorkflowExecution.triggered_by.like(f"{trig}%") for trig in trigger]
+ conditions = [
+ WorkflowExecution.triggered_by.like(f"{trig}%") for trig in trigger
+ ]
query = query.filter(or_(*conditions))
-
total_count = query.count()
status_count_query = query.with_entities(
- WorkflowExecution.status,
- func.count().label('count')
+ WorkflowExecution.status, func.count().label("count")
).group_by(WorkflowExecution.status)
status_counts = status_count_query.all()
statusGroupbyMap = {status: count for status, count in status_counts}
- pass_count = statusGroupbyMap.get('success', 0)
- fail_count = statusGroupbyMap.get('error', 0) + statusGroupbyMap.get('timeout', 0)
- avgDuration = query.with_entities(func.avg(WorkflowExecution.execution_time)).scalar()
+ pass_count = statusGroupbyMap.get("success", 0)
+ fail_count = statusGroupbyMap.get("error", 0) + statusGroupbyMap.get(
+ "timeout", 0
+ )
+ avgDuration = query.with_entities(
+ func.avg(WorkflowExecution.execution_time)
+ ).scalar()
avgDuration = avgDuration if avgDuration else 0.0
- query = query.order_by(desc(WorkflowExecution.started)).limit(limit).offset(offset)
-
+ query = (
+ query.order_by(desc(WorkflowExecution.started)).limit(limit).offset(offset)
+ )
+
# Execute the query
workflow_executions = query.all()
@@ -654,6 +689,19 @@ def delete_workflow(tenant_id, workflow_id):
session.commit()
+def delete_workflow_by_provisioned_file(tenant_id, provisioned_file):
+ with Session(engine) as session:
+ workflow = session.exec(
+ select(Workflow)
+ .where(Workflow.tenant_id == tenant_id)
+ .where(Workflow.provisioned_file == provisioned_file)
+ ).first()
+
+ if workflow:
+ workflow.is_deleted = True
+ session.commit()
+
+
def get_workflow_id(tenant_id, workflow_name):
with Session(engine) as session:
workflow = session.exec(
@@ -1532,10 +1580,7 @@ def get_rule_incidents_count_db(tenant_id):
query = (
session.query(Incident.rule_id, func.count(Incident.id))
.select_from(Incident)
- .filter(
- Incident.tenant_id == tenant_id,
- col(Incident.rule_id).isnot(None)
- )
+ .filter(Incident.tenant_id == tenant_id, col(Incident.rule_id).isnot(None))
.group_by(Incident.rule_id)
)
return dict(query.all())
@@ -1611,15 +1656,26 @@ def get_all_filters(tenant_id):
def get_last_alert_hash_by_fingerprint(tenant_id, fingerprint):
+ from sqlalchemy.dialects import mssql
+
# get the last alert for a given fingerprint
# to check deduplication
with Session(engine) as session:
- alert_hash = session.exec(
+ query = (
select(Alert.alert_hash)
.where(Alert.tenant_id == tenant_id)
.where(Alert.fingerprint == fingerprint)
.order_by(Alert.timestamp.desc())
- ).first()
+ .limit(1) # Add LIMIT 1 for MSSQL
+ )
+
+ # Compile the query and log it
+ compiled_query = query.compile(
+ dialect=mssql.dialect(), compile_kwargs={"literal_binds": True}
+ )
+ logger.info(f"Compiled query: {compiled_query}")
+
+ alert_hash = session.exec(query).first()
return alert_hash
@@ -2110,8 +2166,7 @@ def get_last_incidents(
# .options(joinedload(Incident.alerts))
.filter(
Incident.tenant_id == tenant_id, Incident.is_confirmed == is_confirmed
- )
- .order_by(desc(Incident.creation_time))
+ ).order_by(desc(Incident.creation_time))
)
if timeframe:
diff --git a/keep/api/models/db/migrations/versions/2024-09-13-10-48_938b1aa62d5c.py b/keep/api/models/db/migrations/versions/2024-09-13-10-48_938b1aa62d5c.py
new file mode 100644
index 000000000..d7cacc71a
--- /dev/null
+++ b/keep/api/models/db/migrations/versions/2024-09-13-10-48_938b1aa62d5c.py
@@ -0,0 +1,52 @@
+"""Provisioned
+
+Revision ID: 938b1aa62d5c
+Revises: 710b4ff1d19e
+Create Date: 2024-09-13 10:48:16.112419
+
+"""
+
+import sqlalchemy as sa
+import sqlmodel
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "938b1aa62d5c"
+down_revision = "710b4ff1d19e"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column(
+ "provider",
+ sa.Column(
+ "provisioned", sa.Boolean(), nullable=False, server_default=sa.false()
+ ),
+ )
+ op.add_column(
+ "workflow",
+ sa.Column(
+ "provisioned", sa.Boolean(), nullable=False, server_default=sa.false()
+ ),
+ )
+ op.add_column(
+ "workflow",
+ sa.Column(
+ "provisioned_file", sqlmodel.sql.sqltypes.AutoString(), nullable=True
+ ),
+ )
+
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table("workflow", schema=None) as batch_op:
+ batch_op.drop_column("provisioned")
+
+ with op.batch_alter_table("provider", schema=None) as batch_op:
+ batch_op.drop_column("provisioned")
+
+ # ### end Alembic commands ###
diff --git a/keep/api/models/db/provider.py b/keep/api/models/db/provider.py
index 78ef3c68e..37d8d05fa 100644
--- a/keep/api/models/db/provider.py
+++ b/keep/api/models/db/provider.py
@@ -21,6 +21,7 @@ class Provider(SQLModel, table=True):
) # scope name is key and value is either True if validated or string with error message, e.g: {"read": True, "write": "error message"}
consumer: bool = False
last_pull_time: Optional[datetime]
+ provisioned: bool = Field(default=False)
class Config:
orm_mode = True
diff --git a/keep/api/models/db/workflow.py b/keep/api/models/db/workflow.py
index f2c575863..3426a9560 100644
--- a/keep/api/models/db/workflow.py
+++ b/keep/api/models/db/workflow.py
@@ -19,6 +19,8 @@ class Workflow(SQLModel, table=True):
is_disabled: bool = Field(default=False)
revision: int = Field(default=1, nullable=False)
last_updated: datetime = Field(default_factory=datetime.utcnow)
+ provisioned: bool = Field(default=False)
+ provisioned_file: Optional[str] = None
class Config:
orm_mode = True
diff --git a/keep/api/models/provider.py b/keep/api/models/provider.py
index 78df4eb62..76307de20 100644
--- a/keep/api/models/provider.py
+++ b/keep/api/models/provider.py
@@ -44,3 +44,4 @@ class Provider(BaseModel):
] = []
alertsDistribution: dict[str, int] | None = None
alertExample: dict | None = None
+ provisioned: bool = False
diff --git a/keep/api/models/workflow.py b/keep/api/models/workflow.py
index 6beb16c80..68965a2fc 100644
--- a/keep/api/models/workflow.py
+++ b/keep/api/models/workflow.py
@@ -39,6 +39,8 @@ class WorkflowDTO(BaseModel):
invalid: bool = False # whether the workflow is invalid or not (for UI purposes)
last_executions: List[dict] = None
last_execution_started: datetime = None
+ provisioned: bool = False
+ provisioned_file: str = None
@property
def workflow_raw_id(self):
diff --git a/keep/api/routes/workflows.py b/keep/api/routes/workflows.py
index 47d44f6ae..621e6d843 100644
--- a/keep/api/routes/workflows.py
+++ b/keep/api/routes/workflows.py
@@ -107,30 +107,39 @@ def get_workflows(
try:
providers_dto, triggers = workflowstore.get_workflow_meta_data(
- tenant_id=tenant_id, workflow=workflow, installed_providers_by_type=installed_providers_by_type)
+ tenant_id=tenant_id,
+ workflow=workflow,
+ installed_providers_by_type=installed_providers_by_type,
+ )
except Exception as e:
logger.error(f"Error fetching workflow meta data: {e}")
providers_dto, triggers = [], [] # Default in case of failure
# create the workflow DTO
- workflow_dto = WorkflowDTO(
- id=workflow.id,
- name=workflow.name,
- description=workflow.description or "[This workflow has no description]",
- created_by=workflow.created_by,
- creation_time=workflow.creation_time,
- last_execution_time=workflow_last_run_time,
- last_execution_status=workflow_last_run_status,
- interval=workflow.interval,
- providers=providers_dto,
- triggers=triggers,
- workflow_raw=workflow.workflow_raw,
- revision=workflow.revision,
- last_updated=workflow.last_updated,
- last_executions=last_executions,
- last_execution_started=last_execution_started,
- disabled=workflow.is_disabled,
- )
+ try:
+ workflow_dto = WorkflowDTO(
+ id=workflow.id,
+ name=workflow.name,
+ description=workflow.description
+ or "[This workflow has no description]",
+ created_by=workflow.created_by,
+ creation_time=workflow.creation_time,
+ last_execution_time=workflow_last_run_time,
+ last_execution_status=workflow_last_run_status,
+ interval=workflow.interval,
+ providers=providers_dto,
+ triggers=triggers,
+ workflow_raw=workflow.workflow_raw,
+ revision=workflow.revision,
+ last_updated=workflow.last_updated,
+ last_executions=last_executions,
+ last_execution_started=last_execution_started,
+ disabled=workflow.is_disabled,
+ provisioned=workflow.provisioned,
+ )
+ except Exception as e:
+ logger.error(f"Error creating workflow DTO: {e}")
+ continue
workflows_dto.append(workflow_dto)
return workflows_dto
@@ -422,6 +431,10 @@ async def update_workflow_by_id(
extra={"tenant_id": tenant_id},
)
raise HTTPException(404, "Workflow not found")
+
+ if workflow_from_db.provisioned:
+ raise HTTPException(403, detail="Cannot update a provisioned workflow")
+
workflow = await __get_workflow_raw_data(request, None)
parser = Parser()
workflow_interval = parser.parse_interval(workflow)
@@ -543,24 +556,27 @@ def get_workflow_by_id(
workflowstore = WorkflowStore()
try:
providers_dto, triggers = workflowstore.get_workflow_meta_data(
- tenant_id=tenant_id, workflow=workflow, installed_providers_by_type=installed_providers_by_type)
+ tenant_id=tenant_id,
+ workflow=workflow,
+ installed_providers_by_type=installed_providers_by_type,
+ )
except Exception as e:
logger.error(f"Error fetching workflow meta data: {e}")
providers_dto, triggers = [], [] # Default in case of failure
-
+
final_workflow = WorkflowDTO(
- id=workflow.id,
- name=workflow.name,
- description=workflow.description or "[This workflow has no description]",
- created_by=workflow.created_by,
- creation_time=workflow.creation_time,
- interval=workflow.interval,
- providers=providers_dto,
- triggers=triggers,
- workflow_raw=workflow.workflow_raw,
- last_updated=workflow.last_updated,
- disabled=workflow.is_disabled,
- )
+ id=workflow.id,
+ name=workflow.name,
+ description=workflow.description or "[This workflow has no description]",
+ created_by=workflow.created_by,
+ creation_time=workflow.creation_time,
+ interval=workflow.interval,
+ providers=providers_dto,
+ triggers=triggers,
+ workflow_raw=workflow.workflow_raw,
+ last_updated=workflow.last_updated,
+ disabled=workflow.is_disabled,
+ )
return WorkflowExecutionsPaginatedResultsDto(
limit=limit,
offset=offset,
@@ -569,7 +585,7 @@ def get_workflow_by_id(
passCount=pass_count,
failCount=fail_count,
avgDuration=avgDuration,
- workflow=final_workflow
+ workflow=final_workflow,
)
diff --git a/keep/providers/providers_factory.py b/keep/providers/providers_factory.py
index 6e9f6a70e..5fb75e755 100644
--- a/keep/providers/providers_factory.py
+++ b/keep/providers/providers_factory.py
@@ -399,6 +399,7 @@ def get_installed_providers(
provider_copy.installed_by = p.installed_by
provider_copy.installation_time = p.installation_time
provider_copy.last_pull_time = p.last_pull_time
+ provider_copy.provisioned = p.provisioned
try:
provider_auth = {"name": p.name}
if include_details:
diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py
index 804fd17dc..9dbb0a3ba 100644
--- a/keep/providers/providers_service.py
+++ b/keep/providers/providers_service.py
@@ -8,7 +8,7 @@
from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, select
-from keep.api.core.db import engine, get_provider_by_name
+from keep.api.core.db import engine, get_all_provisioned_providers, get_provider_by_name
from keep.api.models.db.provider import Provider
from keep.api.models.provider import Provider as ProviderModel
from keep.contextmanager.contextmanager import ContextManager
@@ -45,6 +45,8 @@ def install_provider(
provider_name: str,
provider_type: str,
provider_config: Dict[str, Any],
+ provisioned: bool = False,
+ validate_scopes: bool = True,
) -> Dict[str, Any]:
provider_unique_id = uuid.uuid4().hex
logger.info(
@@ -69,7 +71,10 @@ def install_provider(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
- validated_scopes = provider.validate_scopes()
+ if validate_scopes:
+ validated_scopes = provider.validate_scopes()
+ else:
+ validated_scopes = {}
secret_manager = SecretManagerFactory.get_secret_manager(context_manager)
secret_name = f"{tenant_id}_{provider_type}_{provider_unique_id}"
@@ -89,6 +94,7 @@ def install_provider(
configuration_key=secret_name,
validatedScopes=validated_scopes,
consumer=provider.is_consumer,
+ provisioned=provisioned,
)
try:
session.add(provider_model)
@@ -129,6 +135,9 @@ def update_provider(
if not provider:
raise HTTPException(404, detail="Provider not found")
+ if provider.provisioned:
+ raise HTTPException(403, detail="Cannot update a provisioned provider")
+
provider_config = {
"authentication": provider_info,
"name": provider.name,
@@ -160,7 +169,9 @@ def update_provider(
}
@staticmethod
- def delete_provider(tenant_id: str, provider_id: str, session: Session):
+ def delete_provider(
+ tenant_id: str, provider_id: str, session: Session, allow_provisioned=False
+ ):
provider = session.exec(
select(Provider).where(
(Provider.tenant_id == tenant_id) & (Provider.id == provider_id)
@@ -170,6 +181,9 @@ def delete_provider(tenant_id: str, provider_id: str, session: Session):
if not provider:
raise HTTPException(404, detail="Provider not found")
+ if provider.provisioned and not allow_provisioned:
+ raise HTTPException(403, detail="Cannot delete a provisioned provider")
+
context_manager = ContextManager(tenant_id=tenant_id)
secret_manager = SecretManagerFactory.get_secret_manager(context_manager)
@@ -231,6 +245,18 @@ def provision_providers_from_env(tenant_id: str):
context_manager = ContextManager(tenant_id=tenant_id)
parser._parse_providers_from_env(context_manager)
env_providers = context_manager.providers_context
+
+ # first, remove any provisioned providers that are not in the env
+ prev_provisioned_providers = get_all_provisioned_providers(tenant_id)
+ for provider in prev_provisioned_providers:
+ if provider.name not in env_providers:
+ with Session(engine) as session:
+ logger.info(f"Deleting provider {provider.name}")
+ ProvidersService.delete_provider(
+ tenant_id, provider.id, session, allow_provisioned=True
+ )
+ logger.info(f"Provider {provider.name} deleted")
+
for provider_name, provider_config in env_providers.items():
logger.info(f"Provisioning provider {provider_name}")
# if its already installed, skip
@@ -238,12 +264,18 @@ def provision_providers_from_env(tenant_id: str):
logger.info(f"Provider {provider_name} already installed")
continue
logger.info(f"Installing provider {provider_name}")
- ProvidersService.install_provider(
- tenant_id=tenant_id,
- installed_by="system",
- provider_id=provider_config["type"],
- provider_name=provider_name,
- provider_type=provider_config["type"],
- provider_config=provider_config["authentication"],
- )
+ try:
+ ProvidersService.install_provider(
+ tenant_id=tenant_id,
+ installed_by="system",
+ provider_id=provider_config["type"],
+ provider_name=provider_name,
+ provider_type=provider_config["type"],
+ provider_config=provider_config["authentication"],
+ provisioned=True,
+ validate_scopes=False,
+ )
+ except Exception:
+ logger.exception(f"Failed to provision provider {provider_name}")
+ continue
logger.info(f"Provider {provider_name} provisioned")
diff --git a/keep/workflowmanager/workflowstore.py b/keep/workflowmanager/workflowstore.py
index cfeb7021d..27fd8c898 100644
--- a/keep/workflowmanager/workflowstore.py
+++ b/keep/workflowmanager/workflowstore.py
@@ -1,8 +1,8 @@
import io
import logging
import os
-import uuid
import random
+import uuid
import requests
import validators
@@ -12,20 +12,21 @@
from keep.api.core.db import (
add_or_update_workflow,
delete_workflow,
+ delete_workflow_by_provisioned_file,
+ get_all_provisioned_workflows,
get_all_workflows,
get_all_workflows_yamls,
get_raw_workflow,
+ get_workflow,
get_workflow_execution,
get_workflows_with_last_execution,
get_workflows_with_last_executions_v2,
)
from keep.api.models.db.workflow import Workflow as WorkflowModel
+from keep.api.models.workflow import ProviderDTO
from keep.parser.parser import Parser
-from keep.workflowmanager.workflow import Workflow
from keep.providers.providers_factory import ProvidersFactory
-from keep.api.models.workflow import (
- ProviderDTO,
-)
+from keep.workflowmanager.workflow import Workflow
class WorkflowStore:
@@ -62,11 +63,19 @@ def create_workflow(self, tenant_id: str, created_by, workflow: dict):
def delete_workflow(self, tenant_id, workflow_id):
self.logger.info(f"Deleting workflow {workflow_id}")
+ workflow = get_workflow(tenant_id, workflow_id)
+ if not workflow:
+ raise HTTPException(
+ status_code=404, detail=f"Workflow {workflow_id} not found"
+ )
+ if workflow.provisioned:
+ raise HTTPException(403, detail="Cannot delete a provisioned workflow")
try:
delete_workflow(tenant_id, workflow_id)
- except Exception:
+ except Exception as e:
+ self.logger.exception(f"Error deleting workflow {workflow_id}: {str(e)}")
raise HTTPException(
- status_code=404, detail=f"Workflow {workflow_id} not found"
+ status_code=500, detail=f"Failed to delete workflow {workflow_id}"
)
def _parse_workflow_to_dict(self, workflow_path: str) -> dict:
@@ -133,7 +142,9 @@ def get_all_workflows(self, tenant_id: str) -> list[WorkflowModel]:
workflows = get_all_workflows(tenant_id)
return workflows
- def get_all_workflows_with_last_execution(self, tenant_id: str, is_v2: bool = False) -> list[dict]:
+ def get_all_workflows_with_last_execution(
+ self, tenant_id: str, is_v2: bool = False
+ ) -> list[dict]:
# list all tenant's workflows
if is_v2:
workflows = get_workflows_with_last_executions_v2(tenant_id, 15)
@@ -226,6 +237,101 @@ def _get_workflows_from_directory(
)
return workflows
+ @staticmethod
+ def provision_workflows_from_directory(
+ tenant_id: str, workflows_dir: str = None
+ ) -> list[Workflow]:
+ """
+ Provision workflows from a directory.
+
+ Args:
+ tenant_id (str): The tenant ID.
+ workflows_dir (str, optional): A directory containing workflow YAML files.
+ If not provided, it will be read from the WORKFLOWS_DIR environment variable.
+
+ Returns:
+ list[Workflow]: A list of provisioned Workflow objects.
+ """
+ logger = logging.getLogger(__name__)
+ parser = Parser()
+ provisioned_workflows = []
+
+ if not workflows_dir:
+ workflows_dir = os.environ.get("KEEP_WORKFLOWS_DIRECTORY")
+ if not workflows_dir:
+ logger.info(
+ "No workflows directory provided - no provisioning will be done"
+ )
+ return []
+
+ if not os.path.isdir(workflows_dir):
+ raise FileNotFoundError(f"Directory {workflows_dir} does not exist")
+
+ # Get all existing provisioned workflows
+ provisioned_workflows = get_all_provisioned_workflows(tenant_id)
+
+ # Check for workflows that are no longer in the directory or outside the workflows_dir and delete them
+ for workflow in provisioned_workflows:
+ if (
+ not os.path.exists(workflow.provisioned_file)
+ or not os.path.commonpath([workflows_dir, workflow.provisioned_file])
+ == workflows_dir
+ ):
+ logger.info(
+ f"Deprovisioning workflow {workflow.id} as its file no longer exists or is outside the workflows directory"
+ )
+ delete_workflow_by_provisioned_file(
+ tenant_id, workflow.provisioned_file
+ )
+ logger.info(f"Workflow {workflow.id} deprovisioned successfully")
+
+ # Provision new workflows
+ for file in os.listdir(workflows_dir):
+ if file.endswith((".yaml", ".yml")):
+ logger.info(f"Provisioning workflow from {file}")
+ workflow_path = os.path.join(workflows_dir, file)
+
+ try:
+ with open(workflow_path, "r") as yaml_file:
+ workflow_yaml = yaml.safe_load(yaml_file)
+ if "workflow" in workflow_yaml:
+ workflow_yaml = workflow_yaml["workflow"]
+ # backward compatibility
+ elif "alert" in workflow_yaml:
+ workflow_yaml = workflow_yaml["alert"]
+
+ workflow_name = workflow_yaml.get("name") or workflow_yaml.get("id")
+ if not workflow_name:
+ logger.error(f"Workflow from {file} does not have a name or id")
+ continue
+ workflow_id = str(uuid.uuid4())
+ workflow_description = workflow_yaml.get("description")
+ workflow_interval = parser.parse_interval(workflow_yaml)
+ workflow_disabled = parser.parse_disabled(workflow_yaml)
+
+ add_or_update_workflow(
+ id=workflow_id,
+ name=workflow_name,
+ tenant_id=tenant_id,
+ description=workflow_description,
+ created_by="system",
+ interval=workflow_interval,
+ is_disabled=workflow_disabled,
+ workflow_raw=yaml.dump(workflow_yaml),
+ provisioned=True,
+ provisioned_file=workflow_path,
+ )
+ provisioned_workflows.append(workflow_yaml)
+
+ logger.info(f"Workflow from {file} provisioned successfully")
+ except Exception as e:
+ logger.error(
+ f"Error provisioning workflow from {file}",
+ extra={"exception": e},
+ )
+
+ return provisioned_workflows
+
def _read_workflow_from_stream(self, stream) -> dict:
"""
Parse a workflow from an IO stream.
@@ -247,7 +353,9 @@ def _read_workflow_from_stream(self, stream) -> dict:
raise e
return workflow
- def get_random_workflow_templates(self, tenant_id: str, workflows_dir: str, limit: int) -> list[dict]:
+ def get_random_workflow_templates(
+ self, tenant_id: str, workflows_dir: str, limit: int
+ ) -> list[dict]:
"""
Get random workflows from a directory.
Args:
@@ -261,7 +369,9 @@ def get_random_workflow_templates(self, tenant_id: str, workflows_dir: str, limi
if not os.path.isdir(workflows_dir):
raise FileNotFoundError(f"Directory {workflows_dir} does not exist")
- workflow_yaml_files = [f for f in os.listdir(workflows_dir) if f.endswith(('.yaml', '.yml'))]
+ workflow_yaml_files = [
+ f for f in os.listdir(workflows_dir) if f.endswith((".yaml", ".yml"))
+ ]
if not workflow_yaml_files:
raise FileNotFoundError(f"No workflows found in directory {workflows_dir}")
@@ -275,15 +385,17 @@ def get_random_workflow_templates(self, tenant_id: str, workflows_dir: str, limi
file_path = os.path.join(workflows_dir, file)
workflow_yaml = self._parse_workflow_to_dict(file_path)
if "workflow" in workflow_yaml:
- workflow_yaml['name'] = workflow_yaml['workflow']['id']
- workflow_yaml['workflow_raw'] = yaml.dump(workflow_yaml)
- workflow_yaml['workflow_raw_id'] = workflow_yaml['workflow']['id']
+ workflow_yaml["name"] = workflow_yaml["workflow"]["id"]
+ workflow_yaml["workflow_raw"] = yaml.dump(workflow_yaml)
+ workflow_yaml["workflow_raw_id"] = workflow_yaml["workflow"]["id"]
workflows.append(workflow_yaml)
count += 1
self.logger.info(f"Workflow from {file} fetched successfully")
except Exception as e:
- self.logger.error(f"Error parsing or fetching workflow from {file}: {e}")
+ self.logger.error(
+ f"Error parsing or fetching workflow from {file}: {e}"
+ )
return workflows
def group_last_workflow_executions(self, workflows: list[dict]) -> list[dict]:
@@ -294,7 +406,7 @@ def group_last_workflow_executions(self, workflows: list[dict]) -> list[dict]:
self.logger.info(f"workflow_executions: {workflows}")
workflow_dict = {}
for item in workflows:
- workflow,started,execution_time,status = item
+ workflow, started, execution_time, status = item
workflow_id = workflow.id
# Initialize the workflow if not already in the dictionary
@@ -304,14 +416,14 @@ def group_last_workflow_executions(self, workflows: list[dict]) -> list[dict]:
"workflow_last_run_started": None,
"workflow_last_run_time": None,
"workflow_last_run_status": None,
- "workflow_last_executions": []
+ "workflow_last_executions": [],
}
# Update the latest execution details if available
- if workflow_dict[workflow_id]["workflow_last_run_started"] is None :
+ if workflow_dict[workflow_id]["workflow_last_run_started"] is None:
workflow_dict[workflow_id]["workflow_last_run_status"] = status
workflow_dict[workflow_id]["workflow_last_run_started"] = started
- workflow_dict[workflow_id]["workflow_last_run_time"] = started
+ workflow_dict[workflow_id]["workflow_last_run_time"] = started
# Add the execution to the list of executions
if started is not None:
@@ -319,7 +431,7 @@ def group_last_workflow_executions(self, workflows: list[dict]) -> list[dict]:
{
"status": status,
"execution_time": execution_time,
- "started": started
+ "started": started,
}
)
# Convert the dictionary to a list of results
@@ -329,14 +441,16 @@ def group_last_workflow_executions(self, workflows: list[dict]) -> list[dict]:
"workflow_last_run_status": workflow_info["workflow_last_run_status"],
"workflow_last_run_time": workflow_info["workflow_last_run_time"],
"workflow_last_run_started": workflow_info["workflow_last_run_started"],
- "workflow_last_executions": workflow_info["workflow_last_executions"]
+ "workflow_last_executions": workflow_info["workflow_last_executions"],
}
for workflow_id, workflow_info in workflow_dict.items()
]
return results
- def get_workflow_meta_data(self, tenant_id: str, workflow: dict, installed_providers_by_type: dict):
+ def get_workflow_meta_data(
+ self, tenant_id: str, workflow: dict, installed_providers_by_type: dict
+ ):
providers_dto = []
triggers = []
@@ -354,19 +468,28 @@ def get_workflow_meta_data(self, tenant_id: str, workflow: dict, installed_provi
# Parse the workflow YAML safely
workflow_yaml = yaml.safe_load(workflow_raw_data)
if not workflow_yaml:
- self.logger.error(f"Parsed workflow_yaml is empty or invalid: {workflow_raw_data}")
+ self.logger.error(
+ f"Parsed workflow_yaml is empty or invalid: {workflow_raw_data}"
+ )
return providers_dto, triggers
providers = self.parser.get_providers_from_workflow(workflow_yaml)
except Exception as e:
# Improved logging to capture more details about the error
- self.logger.error(f"Failed to parse workflow in get_workflow_meta_data: {e}, workflow: {workflow}")
- return providers_dto, triggers # Return empty providers and triggers in case of error
+ self.logger.error(
+ f"Failed to parse workflow in get_workflow_meta_data: {e}, workflow: {workflow}"
+ )
+ return (
+ providers_dto,
+ triggers,
+ ) # Return empty providers and triggers in case of error
# Step 2: Process providers and add them to DTO
for provider in providers:
try:
- provider_data = installed_providers_by_type[provider.get("type")][provider.get("name")]
+ provider_data = installed_providers_by_type[provider.get("type")][
+ provider.get("name")
+ ]
provider_dto = ProviderDTO(
name=provider_data.name,
type=provider_data.type,
@@ -377,9 +500,13 @@ def get_workflow_meta_data(self, tenant_id: str, workflow: dict, installed_provi
except KeyError:
# Handle case where the provider is not installed
try:
- conf = ProvidersFactory.get_provider_required_config(provider.get("type"))
+ conf = ProvidersFactory.get_provider_required_config(
+ provider.get("type")
+ )
except ModuleNotFoundError:
- self.logger.warning(f"Non-existing provider in workflow: {provider.get('type')}")
+ self.logger.warning(
+ f"Non-existing provider in workflow: {provider.get('type')}"
+ )
conf = None
# Handle providers based on whether they require config
@@ -387,11 +514,13 @@ def get_workflow_meta_data(self, tenant_id: str, workflow: dict, installed_provi
name=provider.get("name"),
type=provider.get("type"),
id=None,
- installed=(conf is None), # Consider it installed if no config is required
+ installed=(
+ conf is None
+ ), # Consider it installed if no config is required
)
providers_dto.append(provider_dto)
# Step 3: Extract triggers from workflow
triggers = self.parser.get_triggers_from_workflow(workflow_yaml)
- return providers_dto, triggers
\ No newline at end of file
+ return providers_dto, triggers
diff --git a/tests/fixtures/client.py b/tests/fixtures/client.py
index bc7df1872..c49b6ece5 100644
--- a/tests/fixtures/client.py
+++ b/tests/fixtures/client.py
@@ -1,3 +1,4 @@
+import asyncio
import hashlib
import importlib
import sys
@@ -29,6 +30,11 @@ def test_app(monkeypatch, request):
for module in list(sys.modules):
if module.startswith("keep.api.routes"):
del sys.modules[module]
+
+ # this is a fucking bug in db patching ffs it ruined my saturday
+ elif module.startswith("keep.providers.providers_service"):
+ importlib.reload(sys.modules[module])
+
if "keep.api.api" in sys.modules:
importlib.reload(sys.modules["keep.api.api"])
@@ -36,6 +42,11 @@ def test_app(monkeypatch, request):
from keep.api.api import get_app
app = get_app()
+
+ # Manually trigger the startup event
+ for event_handler in app.router.on_startup:
+ asyncio.run(event_handler())
+
return app
diff --git a/tests/provision/workflows_1/provision_example_1.yml b/tests/provision/workflows_1/provision_example_1.yml
new file mode 100644
index 000000000..aeeb7f140
--- /dev/null
+++ b/tests/provision/workflows_1/provision_example_1.yml
@@ -0,0 +1,20 @@
+workflow:
+ id: aks-example
+ description: aks-example
+ triggers:
+ - type: manual
+ steps:
+ # get all pods
+ - name: get-pods
+ provider:
+ type: aks
+ config: "{{ providers.aks }}"
+ with:
+ command_type: get_pods
+ actions:
+ - name: echo-pod-status
+ foreach: "{{ steps.get-pods.results }}"
+ provider:
+ type: console
+ with:
+ alert_message: "Pod name: {{ foreach.value.metadata.name }} || Namespace: {{ foreach.value.metadata.namespace }} || Status: {{ foreach.value.status.phase }}"
diff --git a/tests/provision/workflows_1/provision_example_2.yml b/tests/provision/workflows_1/provision_example_2.yml
new file mode 100644
index 000000000..4b5518ef9
--- /dev/null
+++ b/tests/provision/workflows_1/provision_example_2.yml
@@ -0,0 +1,29 @@
+workflow:
+ id: Resend-Python-service
+ description: Python Resend Mail
+ triggers:
+ - type: manual
+ owners: []
+ services: []
+ steps:
+ - name: run-script
+ provider:
+ config: '{{ providers.default-bash }}'
+ type: bash
+ with:
+ command: python3 test.py
+ timeout: 5
+ actions:
+ - condition:
+ - assert: '{{ steps.run-script.results.return_code }} == 0'
+ name: assert-condition
+ type: assert
+ name: trigger-resend
+ provider:
+ type: resend
+ config: "{{ providers.resend-test }}"
+ with:
+ _from: "onboarding@resend.dev"
+ to: "youremail.dev@gmail.com"
+ subject: "Python test is up!"
+ html:
Python test is up!
diff --git a/tests/provision/workflows_1/provision_example_3.yml b/tests/provision/workflows_1/provision_example_3.yml
new file mode 100644
index 000000000..3e6e85cf4
--- /dev/null
+++ b/tests/provision/workflows_1/provision_example_3.yml
@@ -0,0 +1,14 @@
+workflow:
+ id: autosupress
+ strategy: parallel
+ description: demonstrates how to automatically suppress alerts
+ triggers:
+ - type: alert
+ actions:
+ - name: dismiss-alert
+ provider:
+ type: mock
+ with:
+ enrich_alert:
+ - key: dismissed
+ value: "true"
diff --git a/tests/provision/workflows_2/provision_example_1.yml b/tests/provision/workflows_2/provision_example_1.yml
new file mode 100644
index 000000000..aeeb7f140
--- /dev/null
+++ b/tests/provision/workflows_2/provision_example_1.yml
@@ -0,0 +1,20 @@
+workflow:
+ id: aks-example
+ description: aks-example
+ triggers:
+ - type: manual
+ steps:
+ # get all pods
+ - name: get-pods
+ provider:
+ type: aks
+ config: "{{ providers.aks }}"
+ with:
+ command_type: get_pods
+ actions:
+ - name: echo-pod-status
+ foreach: "{{ steps.get-pods.results }}"
+ provider:
+ type: console
+ with:
+ alert_message: "Pod name: {{ foreach.value.metadata.name }} || Namespace: {{ foreach.value.metadata.namespace }} || Status: {{ foreach.value.status.phase }}"
diff --git a/tests/provision/workflows_2/provision_example_2.yml b/tests/provision/workflows_2/provision_example_2.yml
new file mode 100644
index 000000000..4b5518ef9
--- /dev/null
+++ b/tests/provision/workflows_2/provision_example_2.yml
@@ -0,0 +1,29 @@
+workflow:
+ id: Resend-Python-service
+ description: Python Resend Mail
+ triggers:
+ - type: manual
+ owners: []
+ services: []
+ steps:
+ - name: run-script
+ provider:
+ config: '{{ providers.default-bash }}'
+ type: bash
+ with:
+ command: python3 test.py
+ timeout: 5
+ actions:
+ - condition:
+ - assert: '{{ steps.run-script.results.return_code }} == 0'
+ name: assert-condition
+ type: assert
+ name: trigger-resend
+ provider:
+ type: resend
+ config: "{{ providers.resend-test }}"
+ with:
+ _from: "onboarding@resend.dev"
+ to: "youremail.dev@gmail.com"
+ subject: "Python test is up!"
+ html:
Python test is up!
diff --git a/tests/test_auth.py b/tests/test_auth.py
index 09ce7956a..86ad7a385 100644
--- a/tests/test_auth.py
+++ b/tests/test_auth.py
@@ -53,7 +53,7 @@ def get_mock_jwt_payload(token, *args, **kwargs):
@pytest.mark.parametrize(
"test_app", ["SINGLE_TENANT", "MULTI_TENANT", "NO_AUTH"], indirect=True
)
-def test_api_key_with_header(client, db_session, test_app):
+def test_api_key_with_header(db_session, client, test_app):
"""Tests the API key authentication with the x-api-key/digest"""
auth_type = os.getenv("AUTH_TYPE")
valid_api_key = "valid_api_key"
@@ -95,7 +95,7 @@ def test_api_key_with_header(client, db_session, test_app):
@pytest.mark.parametrize(
"test_app", ["SINGLE_TENANT", "MULTI_TENANT", "NO_AUTH"], indirect=True
)
-def test_bearer_token(client, db_session, test_app):
+def test_bearer_token(db_session, client, test_app):
"""Tests the bearer token authentication"""
auth_type = os.getenv("AUTH_TYPE")
# Test bearer tokens
@@ -121,7 +121,7 @@ def test_bearer_token(client, db_session, test_app):
@pytest.mark.parametrize(
"test_app", ["SINGLE_TENANT", "MULTI_TENANT", "NO_AUTH"], indirect=True
)
-def test_webhook_api_key(client, db_session, test_app):
+def test_webhook_api_key(db_session, client, test_app):
"""Tests the webhook API key authentication"""
auth_type = os.getenv("AUTH_TYPE")
valid_api_key = "valid_api_key"
@@ -167,7 +167,7 @@ def test_webhook_api_key(client, db_session, test_app):
# sanity check with keycloak
@pytest.mark.parametrize("test_app", ["KEYCLOAK"], indirect=True)
-def test_keycloak_sanity(keycloak_client, keycloak_token, client, test_app):
+def test_keycloak_sanity(db_session, keycloak_client, keycloak_token, client, test_app):
"""Tests the keycloak sanity check"""
# Use the token to make a request to the Keep API
headers = {"Authorization": f"Bearer {keycloak_token}"}
@@ -182,7 +182,7 @@ def test_keycloak_sanity(keycloak_client, keycloak_token, client, test_app):
],
indirect=True,
)
-def test_api_key_impersonation_without_admin(client, db_session, test_app):
+def test_api_key_impersonation_without_admin(db_session, client, test_app):
"""Tests the API key impersonation with different environment settings"""
valid_api_key = "valid_admin_api_key"
@@ -207,7 +207,7 @@ def test_api_key_impersonation_without_admin(client, db_session, test_app):
],
indirect=True,
)
-def test_api_key_impersonation_without_user_provision(client, db_session, test_app):
+def test_api_key_impersonation_without_user_provision(db_session, client, test_app):
"""Tests the API key impersonation with different environment settings"""
valid_api_key = "valid_admin_api_key"
@@ -239,7 +239,7 @@ def test_api_key_impersonation_without_user_provision(client, db_session, test_a
],
indirect=True,
)
-def test_api_key_impersonation_with_user_provision(client, db_session, test_app):
+def test_api_key_impersonation_with_user_provision(db_session, client, test_app):
"""Tests the API key impersonation with different environment settings"""
valid_api_key = "valid_admin_api_key"
@@ -272,7 +272,7 @@ def test_api_key_impersonation_with_user_provision(client, db_session, test_app)
indirect=True,
)
def test_api_key_impersonation_provisioned_user_cant_login(
- client, db_session, test_app
+ db_session, client, test_app
):
"""Tests the API key impersonation with different environment settings"""
diff --git a/tests/test_enrichments.py b/tests/test_enrichments.py
index 74ea6878c..e75c99a24 100644
--- a/tests/test_enrichments.py
+++ b/tests/test_enrichments.py
@@ -457,7 +457,7 @@ def test_mapping_rule_with_elsatic(mock_session, mock_alert_dto, setup_alerts):
@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
-def test_enrichment(client, db_session, test_app, mock_alert_dto, elastic_client):
+def test_enrichment(db_session, client, test_app, mock_alert_dto, elastic_client):
# add some rule
rule = MappingRule(
id=1,
@@ -495,7 +495,7 @@ def test_enrichment(client, db_session, test_app, mock_alert_dto, elastic_client
@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
-def test_disposable_enrichment(client, db_session, test_app, mock_alert_dto):
+def test_disposable_enrichment(db_session, client, test_app, mock_alert_dto):
# SHAHAR: there is a voodoo so that you must do something with the db_session to kick it off
rule = MappingRule(
id=1,
diff --git a/tests/test_extraction_rules.py b/tests/test_extraction_rules.py
index 24759b6e2..190220a34 100644
--- a/tests/test_extraction_rules.py
+++ b/tests/test_extraction_rules.py
@@ -1,18 +1,15 @@
from time import sleep
import pytest
-
from isodate import parse_datetime
-from tests.fixtures.client import client, test_app, setup_api_key
+from tests.fixtures.client import client, setup_api_key, test_app # noqa
VALID_API_KEY = "valid_api_key"
-@pytest.mark.parametrize(
- "test_app", ["NO_AUTH"], indirect=True
-)
-def test_create_extraction_rule(client, test_app, db_session):
+@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
+def test_create_extraction_rule(db_session, client, test_app):
setup_api_key(db_session, VALID_API_KEY, role="webhook")
# Try to create invalid extraction
@@ -33,10 +30,8 @@ def test_create_extraction_rule(client, test_app, db_session):
assert response.status_code == 200
-@pytest.mark.parametrize(
- "test_app", ["NO_AUTH"], indirect=True
-)
-def test_extraction_rule_updated_at(client, test_app, db_session):
+@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
+def test_extraction_rule_updated_at(db_session, client, test_app):
setup_api_key(db_session, VALID_API_KEY, role="webhook")
rule_dict = {
@@ -66,7 +61,9 @@ def test_extraction_rule_updated_at(client, test_app, db_session):
# Without it update can happen in the same second, so we will not see any changes
sleep(1)
updated_response = client.put(
- f"/extraction/{rule_id}", json=updated_rule_dict, headers={"x-api-key": VALID_API_KEY}
+ f"/extraction/{rule_id}",
+ json=updated_rule_dict,
+ headers={"x-api-key": VALID_API_KEY},
)
assert updated_response.status_code == 200
@@ -75,7 +72,3 @@ def test_extraction_rule_updated_at(client, test_app, db_session):
new_updated_at = parse_datetime(updated_response_data["updated_at"])
assert new_updated_at > updated_at
-
-
-
-
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
index e9c400022..deb0b1621 100644
--- a/tests/test_metrics.py
+++ b/tests/test_metrics.py
@@ -1,36 +1,32 @@
-
import pytest
from keep.api.core.db import (
add_alerts_to_incident_by_incident_id,
- create_incident_from_dict
+ create_incident_from_dict,
)
-
-from tests.fixtures.client import client, setup_api_key, test_app
+from tests.fixtures.client import client, setup_api_key, test_app # noqa
-@pytest.mark.parametrize(
- "test_app", ["NO_AUTH"], indirect=True
-)
-def test_add_remove_alert_to_incidents(client, db_session, test_app, setup_stress_alerts_no_elastic):
+@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
+def test_add_remove_alert_to_incidents(
+ db_session, client, test_app, setup_stress_alerts_no_elastic
+):
alerts = setup_stress_alerts_no_elastic(14)
- incident = create_incident_from_dict("keep", {"user_generated_name": "test", "description": "test"})
+ incident = create_incident_from_dict(
+ "keep", {"user_generated_name": "test", "description": "test"}
+ )
valid_api_key = "valid_api_key"
setup_api_key(db_session, valid_api_key)
- add_alerts_to_incident_by_incident_id(
- "keep",
- incident.id,
- [a.id for a in alerts]
- )
+ add_alerts_to_incident_by_incident_id("keep", incident.id, [a.id for a in alerts])
- response = client.get(
- "/metrics",
- headers={"X-API-KEY": "valid_api_key"}
- )
+ response = client.get("/metrics", headers={"X-API-KEY": "valid_api_key"})
# Checking for alert_total metric
- assert f"alerts_total{{incident_name=\"test\" incident_id=\"{incident.id}\"}} 14" in response.text.split("\n")
+ assert (
+ f'alerts_total{{incident_name="test" incident_id="{incident.id}"}} 14'
+ in response.text.split("\n")
+ )
# Checking for open_incidents_total metric
assert "open_incidents_total 1" in response.text.split("\n")
diff --git a/tests/test_provisioning.py b/tests/test_provisioning.py
new file mode 100644
index 000000000..4b49ca854
--- /dev/null
+++ b/tests/test_provisioning.py
@@ -0,0 +1,223 @@
+import asyncio
+import importlib
+import sys
+
+import pytest
+from fastapi.testclient import TestClient
+
+from tests.fixtures.client import client, setup_api_key, test_app # noqa
+
+# Mock data for workflows
+MOCK_WORKFLOW_ID = "123e4567-e89b-12d3-a456-426614174000"
+MOCK_PROVISIONED_WORKFLOW = {
+ "id": MOCK_WORKFLOW_ID,
+ "name": "Test Workflow",
+ "description": "A provisioned test workflow",
+ "provisioned": True,
+}
+
+
+# Test for deleting a provisioned workflow
+@pytest.mark.parametrize(
+ "test_app",
+ [
+ {
+ "AUTH_TYPE": "NOAUTH",
+ "KEEP_WORKFLOWS_DIRECTORY": "./tests/provision/workflows_1",
+ },
+ ],
+ indirect=True,
+)
+def test_provisioned_workflows(db_session, client, test_app):
+ response = client.get("/workflows", headers={"x-api-key": "someapikey"})
+ assert response.status_code == 200
+ # 3 workflows and 3 provisioned workflows
+ workflows = response.json()
+ provisioned_workflows = [w for w in workflows if w.get("provisioned")]
+ assert len(provisioned_workflows) == 3
+
+
+# Test for deleting a provisioned workflow
+@pytest.mark.parametrize(
+ "test_app",
+ [
+ {
+ "AUTH_TYPE": "NOAUTH",
+ "KEEP_WORKFLOWS_DIRECTORY": "./tests/provision/workflows_2",
+ },
+ ],
+ indirect=True,
+)
+def test_provisioned_workflows_2(db_session, client, test_app):
+ response = client.get("/workflows", headers={"x-api-key": "someapikey"})
+ assert response.status_code == 200
+ # 3 workflows and 3 provisioned workflows
+ workflows = response.json()
+ provisioned_workflows = [w for w in workflows if w.get("provisioned")]
+ assert len(provisioned_workflows) == 2
+
+
+# Test for deleting a provisioned workflow
+@pytest.mark.parametrize(
+ "test_app",
+ [
+ {
+ "AUTH_TYPE": "NOAUTH",
+ "KEEP_WORKFLOWS_DIRECTORY": "./tests/provision/workflows_1",
+ },
+ ],
+ indirect=True,
+)
+def test_delete_provisioned_workflow(db_session, client, test_app):
+ response = client.get("/workflows", headers={"x-api-key": "someapikey"})
+ assert response.status_code == 200
+ # 3 workflows and 3 provisioned workflows
+ workflows = response.json()
+ provisioned_workflows = [w for w in workflows if w.get("provisioned")]
+ workflow_id = provisioned_workflows[0].get("id")
+ response = client.delete(
+ f"/workflows/{workflow_id}", headers={"x-api-key": "someapikey"}
+ )
+ # can't delete a provisioned workflow
+ assert response.status_code == 403
+ assert response.json() == {"detail": "Cannot delete a provisioned workflow"}
+
+
+@pytest.mark.parametrize(
+ "test_app",
+ [
+ {
+ "AUTH_TYPE": "NOAUTH",
+ "KEEP_WORKFLOWS_DIRECTORY": "./tests/provision/workflows_1",
+ },
+ ],
+ indirect=True,
+)
+def test_update_provisioned_workflow(db_session, client, test_app):
+ response = client.get("/workflows", headers={"x-api-key": "someapikey"})
+ assert response.status_code == 200
+ # 3 workflows and 3 provisioned workflows
+ workflows = response.json()
+ provisioned_workflows = [w for w in workflows if w.get("provisioned")]
+ workflow_id = provisioned_workflows[0].get("id")
+ response = client.put(
+ f"/workflows/{workflow_id}", headers={"x-api-key": "someapikey"}
+ )
+ # can't delete a provisioned workflow
+ assert response.status_code == 403
+ assert response.json() == {"detail": "Cannot update a provisioned workflow"}
+
+
+@pytest.mark.parametrize(
+ "test_app",
+ [
+ {
+ "AUTH_TYPE": "NOAUTH",
+ "KEEP_WORKFLOWS_DIRECTORY": "./tests/provision/workflows_1",
+ },
+ ],
+ indirect=True,
+)
+def test_reprovision_workflow(monkeypatch, db_session, client, test_app):
+ response = client.get("/workflows", headers={"x-api-key": "someapikey"})
+ assert response.status_code == 200
+ # 3 workflows and 3 provisioned workflows
+ workflows = response.json()
+ provisioned_workflows = [w for w in workflows if w.get("provisioned")]
+ assert len(provisioned_workflows) == 3
+
+ # Step 2: Change environment variables (simulating new provisioning)
+ monkeypatch.setenv("KEEP_WORKFLOWS_DIRECTORY", "./tests/provision/workflows_2")
+
+ # Reload the app to apply the new environment changes
+ importlib.reload(sys.modules["keep.api.api"])
+
+ # Reinitialize the TestClient with the new app instance
+ from keep.api.api import get_app
+
+ app = get_app()
+
+ # Manually trigger the startup event
+ for event_handler in app.router.on_startup:
+ asyncio.run(event_handler())
+
+ client = TestClient(get_app())
+
+ response = client.get("/workflows", headers={"x-api-key": "someapikey"})
+ assert response.status_code == 200
+ # 2 workflows and 2 provisioned workflows
+ workflows = response.json()
+ provisioned_workflows = [w for w in workflows if w.get("provisioned")]
+ assert len(provisioned_workflows) == 2
+
+
+@pytest.mark.parametrize(
+ "test_app",
+ [
+ {
+ "AUTH_TYPE": "NOAUTH",
+ "KEEP_PROVIDERS": '{"keepVictoriaMetrics":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort": 1234}},"keepClickhouse1":{"type":"clickhouse","authentication":{"host":"http://localhost","port":1234,"username":"keep","password":"keep","database":"keep-db"}}}',
+ },
+ ],
+ indirect=True,
+)
+def test_provision_provider(db_session, client, test_app):
+ response = client.get("/providers", headers={"x-api-key": "someapikey"})
+ assert response.status_code == 200
+ # 3 workflows and 3 provisioned workflows
+ providers = response.json()
+ provisioned_providers = [
+ p for p in providers.get("installed_providers") if p.get("provisioned")
+ ]
+ assert len(provisioned_providers) == 2
+
+
+@pytest.mark.parametrize(
+ "test_app",
+ [
+ {
+ "AUTH_TYPE": "NOAUTH",
+ "KEEP_PROVIDERS": '{"keepVictoriaMetrics":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort": 1234}},"keepClickhouse1":{"type":"clickhouse","authentication":{"host":"http://localhost","port":1234,"username":"keep","password":"keep","database":"keep-db"}}}',
+ },
+ ],
+ indirect=True,
+)
+def test_reprovision_provider(monkeypatch, db_session, client, test_app):
+ response = client.get("/providers", headers={"x-api-key": "someapikey"})
+ assert response.status_code == 200
+ # 3 workflows and 3 provisioned workflows
+ providers = response.json()
+ provisioned_providers = [
+ p for p in providers.get("installed_providers") if p.get("provisioned")
+ ]
+ assert len(provisioned_providers) == 2
+
+ # Step 2: Change environment variables (simulating new provisioning)
+ monkeypatch.setenv(
+ "KEEP_PROVIDERS",
+ '{"keepPrometheus":{"type":"prometheus","authentication":{"url":"http://localhost","port":9090}}}',
+ )
+
+ # Reload the app to apply the new environment changes
+ importlib.reload(sys.modules["keep.api.api"])
+
+ # Reinitialize the TestClient with the new app instance
+ from keep.api.api import get_app
+
+ app = get_app()
+
+ # Manually trigger the startup event
+ for event_handler in app.router.on_startup:
+ asyncio.run(event_handler())
+
+ client = TestClient(app)
+
+ # Step 3: Verify if the new provider is provisioned after reloading
+ response = client.get("/providers", headers={"x-api-key": "someapikey"})
+ assert response.status_code == 200
+ providers = response.json()
+ provisioned_providers = [
+ p for p in providers.get("installed_providers") if p.get("provisioned")
+ ]
+ assert len(provisioned_providers) == 1
+ assert provisioned_providers[0]["type"] == "prometheus"
diff --git a/tests/test_rules_api.py b/tests/test_rules_api.py
index f29e2992d..ff039950d 100644
--- a/tests/test_rules_api.py
+++ b/tests/test_rules_api.py
@@ -1,8 +1,7 @@
import pytest
-from keep.api.core.dependencies import SINGLE_TENANT_UUID
from keep.api.core.db import create_rule as create_rule_db
-
+from keep.api.core.dependencies import SINGLE_TENANT_UUID
from tests.fixtures.client import client, setup_api_key, test_app # noqa
TEST_RULE_DATA = {
@@ -18,28 +17,36 @@
"created_by": "test@keephq.dev",
}
-INVALID_DATA_STEPS = [{
- "update": {"sqlQuery": {"sql": "", "params": []}},
- "error": "SQL is required",
-}, {
- "update": {"sqlQuery": {"sql": "SELECT", "params": []}},
- "error": "Params are required",
-}, {
- "update": {"celQuery": ""},
- "error": "CEL is required",
-}, {
- "update": {"ruleName": ""},
- "error": "Rule name is required",
-}, {
- "update": {"timeframeInSeconds": 0},
- "error": "Timeframe is required",
-}, {
- "update": {"timeUnit": ""},
- "error": "Timeunit is required",
-}]
+INVALID_DATA_STEPS = [
+ {
+ "update": {"sqlQuery": {"sql": "", "params": []}},
+ "error": "SQL is required",
+ },
+ {
+ "update": {"sqlQuery": {"sql": "SELECT", "params": []}},
+ "error": "Params are required",
+ },
+ {
+ "update": {"celQuery": ""},
+ "error": "CEL is required",
+ },
+ {
+ "update": {"ruleName": ""},
+ "error": "Rule name is required",
+ },
+ {
+ "update": {"timeframeInSeconds": 0},
+ "error": "Timeframe is required",
+ },
+ {
+ "update": {"timeUnit": ""},
+ "error": "Timeunit is required",
+ },
+]
+
@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
-def test_get_rules_api(client, db_session, test_app):
+def test_get_rules_api(db_session, client, test_app):
rule = create_rule_db(**TEST_RULE_DATA)
response = client.get(
@@ -50,7 +57,7 @@ def test_get_rules_api(client, db_session, test_app):
assert response.status_code == 200
data = response.json()
assert len(data) == 1
- assert data[0]['id'] == str(rule.id)
+ assert data[0]["id"] == str(rule.id)
rule2 = create_rule_db(**TEST_RULE_DATA)
@@ -62,27 +69,26 @@ def test_get_rules_api(client, db_session, test_app):
assert response2.status_code == 200
data = response2.json()
assert len(data) == 2
- assert data[0]['id'] == str(rule.id)
- assert data[1]['id'] == str(rule2.id)
+ assert data[0]["id"] == str(rule.id)
+ assert data[1]["id"] == str(rule2.id)
@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
-def test_create_rule_api(client, db_session, test_app):
+def test_create_rule_api(db_session, client, test_app):
rule_data = {
"ruleName": "test rule",
- "sqlQuery": {"sql": "SELECT * FROM alert where severity = %s", "params": ["critical"]},
+ "sqlQuery": {
+ "sql": "SELECT * FROM alert where severity = %s",
+ "params": ["critical"],
+ },
"celQuery": "severity = 'critical'",
"timeframeInSeconds": 300,
"timeUnit": "seconds",
"requireApprove": False,
}
- response = client.post(
- "/rules",
- headers={"x-api-key": "some-key"},
- json=rule_data
- )
+ response = client.post("/rules", headers={"x-api-key": "some-key"}, json=rule_data)
assert response.status_code == 200
data = response.json()
@@ -92,9 +98,7 @@ def test_create_rule_api(client, db_session, test_app):
invalid_rule_data = {k: v for k, v in rule_data.items() if k != "ruleName"}
invalid_data_response = client.post(
- "/rules",
- headers={"x-api-key": "some-key"},
- json=invalid_rule_data
+ "/rules", headers={"x-api-key": "some-key"}, json=invalid_rule_data
)
assert invalid_data_response.status_code == 422
@@ -109,7 +113,7 @@ def test_create_rule_api(client, db_session, test_app):
invalid_data_response_2 = client.post(
"/rules",
headers={"x-api-key": "some-key"},
- json=dict(rule_data, **invalid_data_step["update"])
+ json=dict(rule_data, **invalid_data_step["update"]),
)
assert invalid_data_response_2.status_code == 400, current_step
@@ -119,7 +123,7 @@ def test_create_rule_api(client, db_session, test_app):
@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
-def test_delete_rule_api(client, db_session, test_app):
+def test_delete_rule_api(db_session, client, test_app):
rule = create_rule_db(**TEST_RULE_DATA)
response = client.delete(
@@ -143,15 +147,17 @@ def test_delete_rule_api(client, db_session, test_app):
assert data["detail"] == "Rule not found"
-
@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
-def test_update_rule_api(client, db_session, test_app):
+def test_update_rule_api(db_session, client, test_app):
rule = create_rule_db(**TEST_RULE_DATA)
rule_data = {
"ruleName": "test rule",
- "sqlQuery": {"sql": "SELECT * FROM alert where severity = %s", "params": ["critical"]},
+ "sqlQuery": {
+ "sql": "SELECT * FROM alert where severity = %s",
+ "params": ["critical"],
+ },
"celQuery": "severity = 'critical'",
"timeframeInSeconds": 300,
"timeUnit": "seconds",
@@ -159,9 +165,7 @@ def test_update_rule_api(client, db_session, test_app):
}
response = client.put(
- "/rules/{}".format(rule.id),
- headers={"x-api-key": "some-key"},
- json=rule_data
+ "/rules/{}".format(rule.id), headers={"x-api-key": "some-key"}, json=rule_data
)
assert response.status_code == 200
@@ -174,7 +178,7 @@ def test_update_rule_api(client, db_session, test_app):
invalid_data_response_2 = client.put(
"/rules/{}".format(rule.id),
headers={"x-api-key": "some-key"},
- json=dict(rule_data, **invalid_data_step["update"])
+ json=dict(rule_data, **invalid_data_step["update"]),
)
assert invalid_data_response_2.status_code == 400, current_step
diff --git a/tests/test_search_alerts.py b/tests/test_search_alerts.py
index bb45436ff..a7cc3b867 100644
--- a/tests/test_search_alerts.py
+++ b/tests/test_search_alerts.py
@@ -736,7 +736,7 @@ def test_special_characters_in_strings(db_session, setup_alerts):
# tests 10k alerts
@pytest.mark.parametrize(
- "setup_stress_alerts", [{"num_alerts": 10000}], indirect=True
+ "setup_stress_alerts", [{"num_alerts": 1000}], indirect=True
) # Generate 10,000 alerts
def test_filter_large_dataset(db_session, setup_stress_alerts):
search_query = SearchQuery(
@@ -745,7 +745,7 @@ def test_filter_large_dataset(db_session, setup_stress_alerts):
"params": {"source_1": "source_1", "severity_1": "critical"},
},
cel_query='(source == "source_1") && (severity == "critical")',
- limit=10000,
+ limit=1000,
)
# first, use elastic
os.environ["ELASTIC_ENABLED"] = "true"
@@ -764,10 +764,10 @@ def test_filter_large_dataset(db_session, setup_stress_alerts):
# compare
assert len(elastic_filtered_alerts) == len(db_filtered_alerts)
print(
- "time taken for 10k alerts with elastic: ",
+ "time taken for 1k alerts with elastic: ",
elastic_end_time - elastic_start_time,
)
- print("time taken for 10k alerts with db: ", db_end_time - db_start_time)
+ print("time taken for 1k alerts with db: ", db_end_time - db_start_time)
@pytest.mark.parametrize("setup_stress_alerts", [{"num_alerts": 10000}], indirect=True)
@@ -1312,7 +1312,7 @@ def test_severity_comparisons(
@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True)
-def test_alerts_enrichment_in_search(client, db_session, test_app, elastic_client):
+def test_alerts_enrichment_in_search(db_session, client, test_app, elastic_client):
rule = MappingRule(
id=1,
diff --git a/tests/test_search_alerts_configuration.py b/tests/test_search_alerts_configuration.py
index 7d3bb700e..bf4b1ab9a 100644
--- a/tests/test_search_alerts_configuration.py
+++ b/tests/test_search_alerts_configuration.py
@@ -12,7 +12,7 @@
@pytest.mark.parametrize("test_app", ["SINGLE_TENANT"], indirect=True)
def test_single_tenant_configuration_with_elastic(
- client, elastic_client, db_session, test_app
+ db_session, client, elastic_client, test_app
):
valid_api_key = "valid_api_key"
setup_api_key(db_session, valid_api_key)
@@ -30,7 +30,7 @@ def test_single_tenant_configuration_with_elastic(
],
indirect=True,
)
-def test_single_tenant_configuration_without_elastic(client, db_session, test_app):
+def test_single_tenant_configuration_without_elastic(db_session, client, test_app):
valid_api_key = "valid_api_key"
setup_api_key(db_session, valid_api_key)
response = client.get("/preset/feed/alerts", headers={"x-api-key": valid_api_key})
@@ -39,7 +39,7 @@ def test_single_tenant_configuration_without_elastic(client, db_session, test_ap
@pytest.mark.parametrize("test_app", ["MULTI_TENANT"], indirect=True)
def test_multi_tenant_configuration_with_elastic(
- client, elastic_client, db_session, test_app
+ db_session, client, elastic_client, test_app
):
valid_api_key = "valid_api_key"
valid_api_key_2 = "valid_api_key_2"