diff --git a/docs/deployment/provision/overview.mdx b/docs/deployment/provision/overview.mdx new file mode 100644 index 000000000..f48d4d81e --- /dev/null +++ b/docs/deployment/provision/overview.mdx @@ -0,0 +1,29 @@ +--- +title: "Overview" +--- + +Keep supports various deployment and provisioning strategies to accommodate different environments and use cases, from development setups to production deployments. + +### Provisioning Options + +Keep offers two main provisioning options: + +1. [**Provider Provisioning**](/deployment/provision/provider) - Set up and manage data providers for Keep. +2. [**Workflow Provisioning**](/deployment/provision/workflow) - Configure and manage workflows within Keep. + +Choosing the right provisioning strategy depends on your specific use case, deployment environment, and scalability requirements. You can read more about each provisioning option in their respective sections. + +### How To Configure Provisioning + + +Some provisioning options require additional environment variables. These will be covered in detail on the specific provisioning pages. + + +Provisioning in Keep is controlled through environment variables and configuration files. The main environment variables for provisioning are: + +| Provisioning Type | Environment Variable | Purpose | +|-------------------|----------------------|---------| +| **Provider** | `KEEP_PROVIDERS` | JSON string containing provider configurations | +| **Workflow** | `KEEP_WORKFLOWS_DIRECTORY` | Directory path containing workflow configuration files | + +For more details on each provisioning strategy, including setup instructions and implications, refer to the respective sections. diff --git a/docs/deployment/provision/provider.mdx b/docs/deployment/provision/provider.mdx new file mode 100644 index 000000000..f6993aabf --- /dev/null +++ b/docs/deployment/provision/provider.mdx @@ -0,0 +1,58 @@ +--- +title: "Providers Provisioning" +--- + +For any questions or issues related to provider provisioning, please join our [Slack](https://slack.keephq.dev) community. + +Provider provisioning in Keep allows you to set up and manage data providers dynamically. This feature enables you to configure various data sources that Keep can interact with, such as monitoring systems, databases, or other services. + +### Configuring Providers + +To provision providers, set the `KEEP_PROVIDERS` environment variable with a JSON string containing the provider configurations. Here's an example: + +```json +{ + "keepVictoriaMetrics": { + "type": "victoriametrics", + "authentication": { + "VMAlertHost": "http://localhost", + "VMAlertPort": 1234 + } + }, + "keepClickhouse1": { + "type": "clickhouse", + "authentication": { + "host": "http://localhost", + "port": 1234, + "username": "keep", + "password": "keep", + "database": "keep-db" + } + } +} +``` + +Spin up Keep with this `KEEP_PROVIDERS` value: +```json +# ENV +KEEP_PROVIDERS={"keepVictoriaMetrics":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort": 1234}},"keepClickhouse1":{"type":"clickhouse","authentication":{"host":"http://localhost","port":"4321","username":"keep","password":"1234","database":"keepdb"}}} +``` + +### Supported Providers + +Keep supports a wide range of provider types. Each provider type has its own specific configuration requirements. +To see the full list of supported providers and their detailed configuration options, please refer to our comprehensive provider documentation. + + +### Update Provisioned Providers + +Provider configurations can be updated dynamically by changing the `KEEP_PROVIDERS` environment variable. + +On every restart, Keep reads this environment variable and determines which providers need to be added or removed. + +This process allows for flexible management of data sources without requiring manual intervention. By simply updating the `KEEP_PROVIDERS` variable and restarting the application, you can efficiently add new providers, remove existing ones, or modify their configurations. + +The high-level provisioning mechanism: +1. Keep reads the `KEEP_PROVIDERS` value. +2. Keep checks if there are any provisioned providers that are no longer in the `KEEP_PROVIDERS` value, and deletes them. +3. Keep installs all providers from the `KEEP_PROVIDERS` value. diff --git a/docs/deployment/provision/workflow.mdx b/docs/deployment/provision/workflow.mdx new file mode 100644 index 000000000..134704dc4 --- /dev/null +++ b/docs/deployment/provision/workflow.mdx @@ -0,0 +1,36 @@ +--- +title: "Workflow Provisioning" +--- + +For any questions or issues related to workflow provisioning, please join our [Slack](https://slack.keephq.dev) community. + +Workflow provisioning in Keep allows you to set up and manage workflows dynamically. This feature enables you to configure various automated processes and tasks within your Keep deployment. + +### Configuring Workflows + +To provision workflows, follow these steps: + +1. Set the `KEEP_WORKFLOWS_DIRECTORY` environment variable to the path of your workflow configuration directory. +2. Create workflow configuration files in the specified directory. + +Example directory structure: +``` +/path/to/workflows/ +├── workflow1.yaml +├── workflow2.yaml +└── workflow3.yaml +``` +### Update Provisioned Workflows + +On every restart, Keep reads the `KEEP_WORKFLOWS_DIRECTORY` environment variable and determines which workflows need to be added, removed, or updated. + +This process allows for flexible management of workflows without requiring manual intervention. By simply updating the workflow files in the `KEEP_WORKFLOWS_DIRECTORY` and restarting the application, you can efficiently add new workflows, remove existing ones, or modify their configurations. + +The high-level provisioning mechanism: +1. Keep reads the `KEEP_WORKFLOWS_DIRECTORY` value. +2. Keep lists all workflow files under the `KEEP_WORKFLOWS_DIRECTORY` directory. +3. Keep compares the current workflow files with the previously provisioned workflows: + - New workflow files are provisioned. + - Missing workflow files are deprovisioned. + - Updated workflow files are re-provisioned with the new configuration. +4. Keep updates its internal state to reflect the current set of provisioned workflows. diff --git a/docs/mint.json b/docs/mint.json index 99e9a8b8a..1f874c84e 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -65,6 +65,14 @@ "deployment/authentication/keycloak-auth" ] }, + { + "group": "Provision", + "pages": [ + "deployment/provision/overview", + "deployment/provision/provider", + "deployment/provision/workflow" + ] + }, "deployment/secret-manager", "deployment/docker", "deployment/kubernetes", diff --git a/keep-ui/app/providers/provider-form.tsx b/keep-ui/app/providers/provider-form.tsx index 3e2da8ea2..f3f8e1ce6 100644 --- a/keep-ui/app/providers/provider-form.tsx +++ b/keep-ui/app/providers/provider-form.tsx @@ -26,7 +26,7 @@ import { Accordion, AccordionHeader, AccordionBody, - + Badge, } from "@tremor/react"; import { ExclamationCircleIcon, @@ -520,6 +520,7 @@ const ProviderForm = ({ onChange={(value) => handleInputChange({ target: { name: configKey, value } })} placeholder={method.placeholder || `Select ${configKey}`} error={Object.keys(inputErrors).includes(configKey)} + disabled={provider.provisioned} > {method.options.map((option) => ( @@ -541,6 +542,7 @@ const ProviderForm = ({ color="orange" size="xs" onClick={addEntry(configKey)} + disabled={provider.provisioned} > Add Entry @@ -550,6 +552,7 @@ const ProviderForm = ({ value={formValues[configKey] || []} onChange={(value) => handleDictInputChange(configKey, value)} error={Object.keys(inputErrors).includes(configKey)} + disabled={provider.provisioned} /> ); @@ -565,6 +568,7 @@ const ProviderForm = ({ inputFileRef.current.click(); }} icon={ArrowDownOnSquareIcon} + disabled={provider.provisioned} > {selectedFile ? `File Chosen: ${selectedFile}` : `Upload a ${method.name}`} @@ -581,6 +585,7 @@ const ProviderForm = ({ } handleInputChange(e); }} + disabled={provider.provisioned} /> ); @@ -597,6 +602,7 @@ const ProviderForm = ({ autoComplete="off" error={Object.keys(inputErrors).includes(configKey)} placeholder={method.placeholder || `Enter ${configKey}`} + disabled={provider.provisioned} /> ); @@ -694,6 +700,13 @@ const ProviderForm = ({
Connect to {provider.display_name} + {/* Display the Provisioned Badge if the provider is provisioned */} + {provider.provisioned && ( + + Provisioned + + )} +
+ {provider.provisioned && +
+ + + Editing provisioned providers is not possible from UI. + + +
+ } + {provider.provider_description && ( {provider.provider_description} )} @@ -885,7 +912,7 @@ const ProviderForm = ({ variant="secondary" color="orange" className="mt-2.5" - disabled={!installOrUpdateWebhookEnabled} + disabled={!installOrUpdateWebhookEnabled || provider.provisioned} tooltip={ !installOrUpdateWebhookEnabled ? "Fix required webhook scopes and refresh scopes to enable" @@ -928,16 +955,20 @@ const ProviderForm = ({ {installedProvidersMode && Object.keys(provider.config).length > 0 && ( <> - - +
+ +
)} {!installedProvidersMode && Object.keys(provider.config).length > 0 && ( diff --git a/keep-ui/app/providers/provider-tile.tsx b/keep-ui/app/providers/provider-tile.tsx index 51d0a4865..cd59f0790 100644 --- a/keep-ui/app/providers/provider-tile.tsx +++ b/keep-ui/app/providers/provider-tile.tsx @@ -19,6 +19,7 @@ import { import "./provider-tile.css"; import moment from "moment"; import ImageWithFallback from "@/components/ImageWithFallback"; +import { FaCode } from "react-icons/fa"; interface Props { provider: Provider; @@ -200,6 +201,15 @@ export default function ProviderTile({ provider, onClick }: Props) { Linked ) : null} + {provider.provisioned ? ( + + ) : null}
diff --git a/keep-ui/app/providers/providers.tsx b/keep-ui/app/providers/providers.tsx index a75229645..ed060ead2 100644 --- a/keep-ui/app/providers/providers.tsx +++ b/keep-ui/app/providers/providers.tsx @@ -90,6 +90,7 @@ export interface Provider { tags: TProviderLabels[]; alertsDistribution?: AlertDistritbuionData[]; alertExample?: { [key: string]: string }; + provisioned?: boolean; } export type Providers = Provider[]; diff --git a/keep-ui/app/workflows/models.tsx b/keep-ui/app/workflows/models.tsx index 7b548a518..5e77ddcfa 100644 --- a/keep-ui/app/workflows/models.tsx +++ b/keep-ui/app/workflows/models.tsx @@ -44,6 +44,7 @@ export type Workflow = { WorkflowExecution, "execution_time" | "status" | "started" >[]; + provisioned?: boolean; }; export type MockProvider = { diff --git a/keep-ui/app/workflows/workflow-menu.tsx b/keep-ui/app/workflows/workflow-menu.tsx index d13e85aa7..2362fa094 100644 --- a/keep-ui/app/workflows/workflow-menu.tsx +++ b/keep-ui/app/workflows/workflow-menu.tsx @@ -12,7 +12,8 @@ interface WorkflowMenuProps { onDownload?: () => void; onBuilder?: () => void; isRunButtonDisabled: boolean; - runButtonToolTip?: string; + runButtonToolTip?: string; + provisioned?: boolean; } @@ -24,6 +25,7 @@ export default function WorkflowMenu({ onBuilder, isRunButtonDisabled, runButtonToolTip, + provisioned, }: WorkflowMenuProps) { const stopPropagation = (e: React.MouseEvent<HTMLButtonElement>) => { e.stopPropagation(); @@ -115,15 +117,23 @@ export default function WorkflowMenu({ </Menu.Item> <Menu.Item> {({ active }) => ( - <button - onClick={(e) => { stopPropagation(e); onDelete?.(); }} - className={`${ - active ? "bg-slate-200" : "text-gray-900" - } group flex w-full items-center rounded-md px-2 py-2 text-xs`} - > - <TrashIcon className="mr-2 h-4 w-4" aria-hidden="true" /> - Delete - </button> + <div className="relative group"> + <button + disabled={provisioned} + onClick={(e) => { stopPropagation(e); onDelete?.(); }} + className={`${ + active ? 'bg-slate-200' : 'text-gray-900' + } flex w-full items-center rounded-md px-2 py-2 text-xs ${provisioned ? 'cursor-not-allowed opacity-50' : ''}`} + > + <TrashIcon className="mr-2 h-4 w-4" aria-hidden="true" /> + Delete + </button> + {provisioned && ( + <div className="absolute bottom-full transform -translate-x-1/2 bg-black text-white text-xs rounded px-4 py-1 z-10 opacity-0 group-hover:opacity-100"> + Cannot delete a provisioned workflow + </div> + )} + </div> )} </Menu.Item> </div> diff --git a/keep-ui/app/workflows/workflow-tile.tsx b/keep-ui/app/workflows/workflow-tile.tsx index c462454d1..73da76321 100644 --- a/keep-ui/app/workflows/workflow-tile.tsx +++ b/keep-ui/app/workflows/workflow-tile.tsx @@ -50,6 +50,7 @@ function WorkflowMenuSection({ onBuilder, isRunButtonDisabled, runButtonToolTip, + provisioned, }: { onDelete: () => Promise<void>; onRun: () => Promise<void>; @@ -58,6 +59,7 @@ function WorkflowMenuSection({ onBuilder: () => void; isRunButtonDisabled: boolean; runButtonToolTip?: string; + provisioned?: boolean; }) { // Determine if all providers are installed @@ -70,6 +72,7 @@ function WorkflowMenuSection({ onBuilder={onBuilder} isRunButtonDisabled={isRunButtonDisabled} runButtonToolTip={runButtonToolTip} + provisioned={provisioned} /> ); } @@ -547,7 +550,7 @@ function WorkflowTile({ workflow }: { workflow: Workflow }) { <Loading /> </div> )} - <Card + <Card className="relative flex flex-col justify-between bg-white rounded shadow p-2 h-full hover:border-orange-400 hover:border-2" onClick={(e)=>{ e.stopPropagation(); @@ -557,7 +560,12 @@ function WorkflowTile({ workflow }: { workflow: Workflow }) { } }} > - <div className="absolute top-0 right-0 mt-2 mr-2 mb-2"> + <div className="absolute top-0 right-0 mt-2 mr-2 mb-2 flex items-center"> + {workflow.provisioned && ( + <Badge color="orange" size="xs" className="mr-2"> + Provisioned + </Badge> + )} {!!handleRunClick && WorkflowMenuSection({ onDelete: handleDeleteClick, onRun: handleRunClick, @@ -566,6 +574,7 @@ function WorkflowTile({ workflow }: { workflow: Workflow }) { onBuilder: handleBuilderClick, runButtonToolTip: message, isRunButtonDisabled: !!isRunButtonDisabled, + provisioned: workflow.provisioned, })} </div> <div className="m-2 flex flex-col justify-around item-start flex-wrap"> @@ -862,6 +871,7 @@ export function WorkflowTileOld({ workflow }: { workflow: Workflow }) { onBuilder: handleBuilderClick, runButtonToolTip: message, isRunButtonDisabled: !!isRunButtonDisabled, + provisioned: workflow.provisioned, })} </div> diff --git a/keep/api/api.py b/keep/api/api.py index f04d01365..bed995b86 100644 --- a/keep/api/api.py +++ b/keep/api/api.py @@ -59,7 +59,12 @@ from keep.event_subscriber.event_subscriber import EventSubscriber from keep.identitymanager.identitymanagerfactory import IdentityManagerFactory from keep.posthog.posthog import get_posthog_client + +# load all providers into cache +from keep.providers.providers_factory import ProvidersFactory +from keep.providers.providers_service import ProvidersService from keep.workflowmanager.workflowmanager import WorkflowManager +from keep.workflowmanager.workflowstore import WorkflowStore load_dotenv(find_dotenv()) keep.api.logging.setup_logging() @@ -242,15 +247,14 @@ def get_app( @app.on_event("startup") async def on_startup(): - # load all providers into cache - from keep.providers.providers_factory import ProvidersFactory - from keep.providers.providers_service import ProvidersService - logger.info("Loading providers into cache") ProvidersFactory.get_all_providers() # provision providers from env. relevant only on single tenant. + logger.info("Provisioning providers and workflows") ProvidersService.provision_providers_from_env(SINGLE_TENANT_UUID) logger.info("Providers loaded successfully") + WorkflowStore.provision_workflows_from_directory(SINGLE_TENANT_UUID) + logger.info("Workflows provisioned successfully") # Start the services logger.info("Starting the services") # Start the scheduler diff --git a/keep/api/core/db.py b/keep/api/core/db.py index 9d7d5c40b..0e0111f54 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -278,6 +278,8 @@ def add_or_update_workflow( interval, workflow_raw, is_disabled, + provisioned=False, + provisioned_file=None, updated_by=None, ) -> Workflow: with Session(engine, expire_on_commit=False) as session: @@ -301,7 +303,9 @@ def add_or_update_workflow( existing_workflow.revision += 1 # Increment the revision existing_workflow.last_updated = datetime.now() # Update last_updated existing_workflow.is_deleted = False - existing_workflow.is_disabled= is_disabled + existing_workflow.is_disabled = is_disabled + existing_workflow.provisioned = provisioned + existing_workflow.provisioned_file = provisioned_file else: # Create a new workflow @@ -313,8 +317,10 @@ def add_or_update_workflow( created_by=created_by, updated_by=updated_by, # Set updated_by to the provided value interval=interval, - is_disabled =is_disabled, + is_disabled=is_disabled, workflow_raw=workflow_raw, + provisioned=provisioned, + provisioned_file=provisioned_file, ) session.add(workflow) @@ -461,6 +467,27 @@ def get_all_workflows(tenant_id: str) -> List[Workflow]: return workflows +def get_all_provisioned_workflows(tenant_id: str) -> List[Workflow]: + with Session(engine) as session: + workflows = session.exec( + select(Workflow) + .where(Workflow.tenant_id == tenant_id) + .where(Workflow.provisioned == True) + .where(Workflow.is_deleted == False) + ).all() + return workflows + + +def get_all_provisioned_providers(tenant_id: str) -> List[Provider]: + with Session(engine) as session: + providers = session.exec( + select(Provider) + .where(Provider.tenant_id == tenant_id) + .where(Provider.provisioned == True) + ).all() + return providers + + def get_all_workflows_yamls(tenant_id: str) -> List[str]: with Session(engine) as session: workflows = session.exec( @@ -499,6 +526,7 @@ def get_raw_workflow(tenant_id: str, workflow_id: str) -> str: return None return workflow.workflow_raw + def update_provider_last_pull_time(tenant_id: str, provider_id: str): extra = {"tenant_id": tenant_id, "provider_id": provider_id} logger.info("Updating provider last pull time", extra=extra) @@ -568,18 +596,22 @@ def finish_workflow_execution(tenant_id, workflow_id, execution_id, status, erro session.commit() -def get_workflow_executions(tenant_id, workflow_id, limit=50, offset=0, tab=2, status: Optional[Union[str, List[str]]] = None, +def get_workflow_executions( + tenant_id, + workflow_id, + limit=50, + offset=0, + tab=2, + status: Optional[Union[str, List[str]]] = None, trigger: Optional[Union[str, List[str]]] = None, - execution_id: Optional[str] = None): + execution_id: Optional[str] = None, +): with Session(engine) as session: - query = ( - session.query( - WorkflowExecution, - ) - .filter( - WorkflowExecution.tenant_id == tenant_id, - WorkflowExecution.workflow_id == workflow_id - ) + query = session.query( + WorkflowExecution, + ).filter( + WorkflowExecution.tenant_id == tenant_id, + WorkflowExecution.workflow_id == workflow_id, ) now = datetime.now(tz=timezone.utc) @@ -593,48 +625,51 @@ def get_workflow_executions(tenant_id, workflow_id, limit=50, offset=0, tab=2, s start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0) query = query.filter( WorkflowExecution.started >= start_of_day, - WorkflowExecution.started <= now + WorkflowExecution.started <= now, ) if timeframe: - query = query.filter( - WorkflowExecution.started >= timeframe - ) + query = query.filter(WorkflowExecution.started >= timeframe) if isinstance(status, str): status = [status] elif status is None: - status = [] - + status = [] + # Normalize trigger to a list if isinstance(trigger, str): trigger = [trigger] - if execution_id: query = query.filter(WorkflowExecution.id == execution_id) if status and len(status) > 0: query = query.filter(WorkflowExecution.status.in_(status)) if trigger and len(trigger) > 0: - conditions = [WorkflowExecution.triggered_by.like(f"{trig}%") for trig in trigger] + conditions = [ + WorkflowExecution.triggered_by.like(f"{trig}%") for trig in trigger + ] query = query.filter(or_(*conditions)) - total_count = query.count() status_count_query = query.with_entities( - WorkflowExecution.status, - func.count().label('count') + WorkflowExecution.status, func.count().label("count") ).group_by(WorkflowExecution.status) status_counts = status_count_query.all() statusGroupbyMap = {status: count for status, count in status_counts} - pass_count = statusGroupbyMap.get('success', 0) - fail_count = statusGroupbyMap.get('error', 0) + statusGroupbyMap.get('timeout', 0) - avgDuration = query.with_entities(func.avg(WorkflowExecution.execution_time)).scalar() + pass_count = statusGroupbyMap.get("success", 0) + fail_count = statusGroupbyMap.get("error", 0) + statusGroupbyMap.get( + "timeout", 0 + ) + avgDuration = query.with_entities( + func.avg(WorkflowExecution.execution_time) + ).scalar() avgDuration = avgDuration if avgDuration else 0.0 - query = query.order_by(desc(WorkflowExecution.started)).limit(limit).offset(offset) - + query = ( + query.order_by(desc(WorkflowExecution.started)).limit(limit).offset(offset) + ) + # Execute the query workflow_executions = query.all() @@ -654,6 +689,19 @@ def delete_workflow(tenant_id, workflow_id): session.commit() +def delete_workflow_by_provisioned_file(tenant_id, provisioned_file): + with Session(engine) as session: + workflow = session.exec( + select(Workflow) + .where(Workflow.tenant_id == tenant_id) + .where(Workflow.provisioned_file == provisioned_file) + ).first() + + if workflow: + workflow.is_deleted = True + session.commit() + + def get_workflow_id(tenant_id, workflow_name): with Session(engine) as session: workflow = session.exec( @@ -1532,10 +1580,7 @@ def get_rule_incidents_count_db(tenant_id): query = ( session.query(Incident.rule_id, func.count(Incident.id)) .select_from(Incident) - .filter( - Incident.tenant_id == tenant_id, - col(Incident.rule_id).isnot(None) - ) + .filter(Incident.tenant_id == tenant_id, col(Incident.rule_id).isnot(None)) .group_by(Incident.rule_id) ) return dict(query.all()) @@ -1611,15 +1656,26 @@ def get_all_filters(tenant_id): def get_last_alert_hash_by_fingerprint(tenant_id, fingerprint): + from sqlalchemy.dialects import mssql + # get the last alert for a given fingerprint # to check deduplication with Session(engine) as session: - alert_hash = session.exec( + query = ( select(Alert.alert_hash) .where(Alert.tenant_id == tenant_id) .where(Alert.fingerprint == fingerprint) .order_by(Alert.timestamp.desc()) - ).first() + .limit(1) # Add LIMIT 1 for MSSQL + ) + + # Compile the query and log it + compiled_query = query.compile( + dialect=mssql.dialect(), compile_kwargs={"literal_binds": True} + ) + logger.info(f"Compiled query: {compiled_query}") + + alert_hash = session.exec(query).first() return alert_hash @@ -2110,8 +2166,7 @@ def get_last_incidents( # .options(joinedload(Incident.alerts)) .filter( Incident.tenant_id == tenant_id, Incident.is_confirmed == is_confirmed - ) - .order_by(desc(Incident.creation_time)) + ).order_by(desc(Incident.creation_time)) ) if timeframe: diff --git a/keep/api/models/db/migrations/versions/2024-09-13-10-48_938b1aa62d5c.py b/keep/api/models/db/migrations/versions/2024-09-13-10-48_938b1aa62d5c.py new file mode 100644 index 000000000..d7cacc71a --- /dev/null +++ b/keep/api/models/db/migrations/versions/2024-09-13-10-48_938b1aa62d5c.py @@ -0,0 +1,52 @@ +"""Provisioned + +Revision ID: 938b1aa62d5c +Revises: 710b4ff1d19e +Create Date: 2024-09-13 10:48:16.112419 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "938b1aa62d5c" +down_revision = "710b4ff1d19e" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "provider", + sa.Column( + "provisioned", sa.Boolean(), nullable=False, server_default=sa.false() + ), + ) + op.add_column( + "workflow", + sa.Column( + "provisioned", sa.Boolean(), nullable=False, server_default=sa.false() + ), + ) + op.add_column( + "workflow", + sa.Column( + "provisioned_file", sqlmodel.sql.sqltypes.AutoString(), nullable=True + ), + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("workflow", schema=None) as batch_op: + batch_op.drop_column("provisioned") + + with op.batch_alter_table("provider", schema=None) as batch_op: + batch_op.drop_column("provisioned") + + # ### end Alembic commands ### diff --git a/keep/api/models/db/provider.py b/keep/api/models/db/provider.py index 78ef3c68e..37d8d05fa 100644 --- a/keep/api/models/db/provider.py +++ b/keep/api/models/db/provider.py @@ -21,6 +21,7 @@ class Provider(SQLModel, table=True): ) # scope name is key and value is either True if validated or string with error message, e.g: {"read": True, "write": "error message"} consumer: bool = False last_pull_time: Optional[datetime] + provisioned: bool = Field(default=False) class Config: orm_mode = True diff --git a/keep/api/models/db/workflow.py b/keep/api/models/db/workflow.py index f2c575863..3426a9560 100644 --- a/keep/api/models/db/workflow.py +++ b/keep/api/models/db/workflow.py @@ -19,6 +19,8 @@ class Workflow(SQLModel, table=True): is_disabled: bool = Field(default=False) revision: int = Field(default=1, nullable=False) last_updated: datetime = Field(default_factory=datetime.utcnow) + provisioned: bool = Field(default=False) + provisioned_file: Optional[str] = None class Config: orm_mode = True diff --git a/keep/api/models/provider.py b/keep/api/models/provider.py index 78df4eb62..76307de20 100644 --- a/keep/api/models/provider.py +++ b/keep/api/models/provider.py @@ -44,3 +44,4 @@ class Provider(BaseModel): ] = [] alertsDistribution: dict[str, int] | None = None alertExample: dict | None = None + provisioned: bool = False diff --git a/keep/api/models/workflow.py b/keep/api/models/workflow.py index 6beb16c80..68965a2fc 100644 --- a/keep/api/models/workflow.py +++ b/keep/api/models/workflow.py @@ -39,6 +39,8 @@ class WorkflowDTO(BaseModel): invalid: bool = False # whether the workflow is invalid or not (for UI purposes) last_executions: List[dict] = None last_execution_started: datetime = None + provisioned: bool = False + provisioned_file: str = None @property def workflow_raw_id(self): diff --git a/keep/api/routes/workflows.py b/keep/api/routes/workflows.py index 47d44f6ae..621e6d843 100644 --- a/keep/api/routes/workflows.py +++ b/keep/api/routes/workflows.py @@ -107,30 +107,39 @@ def get_workflows( try: providers_dto, triggers = workflowstore.get_workflow_meta_data( - tenant_id=tenant_id, workflow=workflow, installed_providers_by_type=installed_providers_by_type) + tenant_id=tenant_id, + workflow=workflow, + installed_providers_by_type=installed_providers_by_type, + ) except Exception as e: logger.error(f"Error fetching workflow meta data: {e}") providers_dto, triggers = [], [] # Default in case of failure # create the workflow DTO - workflow_dto = WorkflowDTO( - id=workflow.id, - name=workflow.name, - description=workflow.description or "[This workflow has no description]", - created_by=workflow.created_by, - creation_time=workflow.creation_time, - last_execution_time=workflow_last_run_time, - last_execution_status=workflow_last_run_status, - interval=workflow.interval, - providers=providers_dto, - triggers=triggers, - workflow_raw=workflow.workflow_raw, - revision=workflow.revision, - last_updated=workflow.last_updated, - last_executions=last_executions, - last_execution_started=last_execution_started, - disabled=workflow.is_disabled, - ) + try: + workflow_dto = WorkflowDTO( + id=workflow.id, + name=workflow.name, + description=workflow.description + or "[This workflow has no description]", + created_by=workflow.created_by, + creation_time=workflow.creation_time, + last_execution_time=workflow_last_run_time, + last_execution_status=workflow_last_run_status, + interval=workflow.interval, + providers=providers_dto, + triggers=triggers, + workflow_raw=workflow.workflow_raw, + revision=workflow.revision, + last_updated=workflow.last_updated, + last_executions=last_executions, + last_execution_started=last_execution_started, + disabled=workflow.is_disabled, + provisioned=workflow.provisioned, + ) + except Exception as e: + logger.error(f"Error creating workflow DTO: {e}") + continue workflows_dto.append(workflow_dto) return workflows_dto @@ -422,6 +431,10 @@ async def update_workflow_by_id( extra={"tenant_id": tenant_id}, ) raise HTTPException(404, "Workflow not found") + + if workflow_from_db.provisioned: + raise HTTPException(403, detail="Cannot update a provisioned workflow") + workflow = await __get_workflow_raw_data(request, None) parser = Parser() workflow_interval = parser.parse_interval(workflow) @@ -543,24 +556,27 @@ def get_workflow_by_id( workflowstore = WorkflowStore() try: providers_dto, triggers = workflowstore.get_workflow_meta_data( - tenant_id=tenant_id, workflow=workflow, installed_providers_by_type=installed_providers_by_type) + tenant_id=tenant_id, + workflow=workflow, + installed_providers_by_type=installed_providers_by_type, + ) except Exception as e: logger.error(f"Error fetching workflow meta data: {e}") providers_dto, triggers = [], [] # Default in case of failure - + final_workflow = WorkflowDTO( - id=workflow.id, - name=workflow.name, - description=workflow.description or "[This workflow has no description]", - created_by=workflow.created_by, - creation_time=workflow.creation_time, - interval=workflow.interval, - providers=providers_dto, - triggers=triggers, - workflow_raw=workflow.workflow_raw, - last_updated=workflow.last_updated, - disabled=workflow.is_disabled, - ) + id=workflow.id, + name=workflow.name, + description=workflow.description or "[This workflow has no description]", + created_by=workflow.created_by, + creation_time=workflow.creation_time, + interval=workflow.interval, + providers=providers_dto, + triggers=triggers, + workflow_raw=workflow.workflow_raw, + last_updated=workflow.last_updated, + disabled=workflow.is_disabled, + ) return WorkflowExecutionsPaginatedResultsDto( limit=limit, offset=offset, @@ -569,7 +585,7 @@ def get_workflow_by_id( passCount=pass_count, failCount=fail_count, avgDuration=avgDuration, - workflow=final_workflow + workflow=final_workflow, ) diff --git a/keep/providers/providers_factory.py b/keep/providers/providers_factory.py index 6e9f6a70e..5fb75e755 100644 --- a/keep/providers/providers_factory.py +++ b/keep/providers/providers_factory.py @@ -399,6 +399,7 @@ def get_installed_providers( provider_copy.installed_by = p.installed_by provider_copy.installation_time = p.installation_time provider_copy.last_pull_time = p.last_pull_time + provider_copy.provisioned = p.provisioned try: provider_auth = {"name": p.name} if include_details: diff --git a/keep/providers/providers_service.py b/keep/providers/providers_service.py index 804fd17dc..9dbb0a3ba 100644 --- a/keep/providers/providers_service.py +++ b/keep/providers/providers_service.py @@ -8,7 +8,7 @@ from sqlalchemy.exc import IntegrityError from sqlmodel import Session, select -from keep.api.core.db import engine, get_provider_by_name +from keep.api.core.db import engine, get_all_provisioned_providers, get_provider_by_name from keep.api.models.db.provider import Provider from keep.api.models.provider import Provider as ProviderModel from keep.contextmanager.contextmanager import ContextManager @@ -45,6 +45,8 @@ def install_provider( provider_name: str, provider_type: str, provider_config: Dict[str, Any], + provisioned: bool = False, + validate_scopes: bool = True, ) -> Dict[str, Any]: provider_unique_id = uuid.uuid4().hex logger.info( @@ -69,7 +71,10 @@ def install_provider( except Exception as e: raise HTTPException(status_code=400, detail=str(e)) - validated_scopes = provider.validate_scopes() + if validate_scopes: + validated_scopes = provider.validate_scopes() + else: + validated_scopes = {} secret_manager = SecretManagerFactory.get_secret_manager(context_manager) secret_name = f"{tenant_id}_{provider_type}_{provider_unique_id}" @@ -89,6 +94,7 @@ def install_provider( configuration_key=secret_name, validatedScopes=validated_scopes, consumer=provider.is_consumer, + provisioned=provisioned, ) try: session.add(provider_model) @@ -129,6 +135,9 @@ def update_provider( if not provider: raise HTTPException(404, detail="Provider not found") + if provider.provisioned: + raise HTTPException(403, detail="Cannot update a provisioned provider") + provider_config = { "authentication": provider_info, "name": provider.name, @@ -160,7 +169,9 @@ def update_provider( } @staticmethod - def delete_provider(tenant_id: str, provider_id: str, session: Session): + def delete_provider( + tenant_id: str, provider_id: str, session: Session, allow_provisioned=False + ): provider = session.exec( select(Provider).where( (Provider.tenant_id == tenant_id) & (Provider.id == provider_id) @@ -170,6 +181,9 @@ def delete_provider(tenant_id: str, provider_id: str, session: Session): if not provider: raise HTTPException(404, detail="Provider not found") + if provider.provisioned and not allow_provisioned: + raise HTTPException(403, detail="Cannot delete a provisioned provider") + context_manager = ContextManager(tenant_id=tenant_id) secret_manager = SecretManagerFactory.get_secret_manager(context_manager) @@ -231,6 +245,18 @@ def provision_providers_from_env(tenant_id: str): context_manager = ContextManager(tenant_id=tenant_id) parser._parse_providers_from_env(context_manager) env_providers = context_manager.providers_context + + # first, remove any provisioned providers that are not in the env + prev_provisioned_providers = get_all_provisioned_providers(tenant_id) + for provider in prev_provisioned_providers: + if provider.name not in env_providers: + with Session(engine) as session: + logger.info(f"Deleting provider {provider.name}") + ProvidersService.delete_provider( + tenant_id, provider.id, session, allow_provisioned=True + ) + logger.info(f"Provider {provider.name} deleted") + for provider_name, provider_config in env_providers.items(): logger.info(f"Provisioning provider {provider_name}") # if its already installed, skip @@ -238,12 +264,18 @@ def provision_providers_from_env(tenant_id: str): logger.info(f"Provider {provider_name} already installed") continue logger.info(f"Installing provider {provider_name}") - ProvidersService.install_provider( - tenant_id=tenant_id, - installed_by="system", - provider_id=provider_config["type"], - provider_name=provider_name, - provider_type=provider_config["type"], - provider_config=provider_config["authentication"], - ) + try: + ProvidersService.install_provider( + tenant_id=tenant_id, + installed_by="system", + provider_id=provider_config["type"], + provider_name=provider_name, + provider_type=provider_config["type"], + provider_config=provider_config["authentication"], + provisioned=True, + validate_scopes=False, + ) + except Exception: + logger.exception(f"Failed to provision provider {provider_name}") + continue logger.info(f"Provider {provider_name} provisioned") diff --git a/keep/workflowmanager/workflowstore.py b/keep/workflowmanager/workflowstore.py index cfeb7021d..27fd8c898 100644 --- a/keep/workflowmanager/workflowstore.py +++ b/keep/workflowmanager/workflowstore.py @@ -1,8 +1,8 @@ import io import logging import os -import uuid import random +import uuid import requests import validators @@ -12,20 +12,21 @@ from keep.api.core.db import ( add_or_update_workflow, delete_workflow, + delete_workflow_by_provisioned_file, + get_all_provisioned_workflows, get_all_workflows, get_all_workflows_yamls, get_raw_workflow, + get_workflow, get_workflow_execution, get_workflows_with_last_execution, get_workflows_with_last_executions_v2, ) from keep.api.models.db.workflow import Workflow as WorkflowModel +from keep.api.models.workflow import ProviderDTO from keep.parser.parser import Parser -from keep.workflowmanager.workflow import Workflow from keep.providers.providers_factory import ProvidersFactory -from keep.api.models.workflow import ( - ProviderDTO, -) +from keep.workflowmanager.workflow import Workflow class WorkflowStore: @@ -62,11 +63,19 @@ def create_workflow(self, tenant_id: str, created_by, workflow: dict): def delete_workflow(self, tenant_id, workflow_id): self.logger.info(f"Deleting workflow {workflow_id}") + workflow = get_workflow(tenant_id, workflow_id) + if not workflow: + raise HTTPException( + status_code=404, detail=f"Workflow {workflow_id} not found" + ) + if workflow.provisioned: + raise HTTPException(403, detail="Cannot delete a provisioned workflow") try: delete_workflow(tenant_id, workflow_id) - except Exception: + except Exception as e: + self.logger.exception(f"Error deleting workflow {workflow_id}: {str(e)}") raise HTTPException( - status_code=404, detail=f"Workflow {workflow_id} not found" + status_code=500, detail=f"Failed to delete workflow {workflow_id}" ) def _parse_workflow_to_dict(self, workflow_path: str) -> dict: @@ -133,7 +142,9 @@ def get_all_workflows(self, tenant_id: str) -> list[WorkflowModel]: workflows = get_all_workflows(tenant_id) return workflows - def get_all_workflows_with_last_execution(self, tenant_id: str, is_v2: bool = False) -> list[dict]: + def get_all_workflows_with_last_execution( + self, tenant_id: str, is_v2: bool = False + ) -> list[dict]: # list all tenant's workflows if is_v2: workflows = get_workflows_with_last_executions_v2(tenant_id, 15) @@ -226,6 +237,101 @@ def _get_workflows_from_directory( ) return workflows + @staticmethod + def provision_workflows_from_directory( + tenant_id: str, workflows_dir: str = None + ) -> list[Workflow]: + """ + Provision workflows from a directory. + + Args: + tenant_id (str): The tenant ID. + workflows_dir (str, optional): A directory containing workflow YAML files. + If not provided, it will be read from the WORKFLOWS_DIR environment variable. + + Returns: + list[Workflow]: A list of provisioned Workflow objects. + """ + logger = logging.getLogger(__name__) + parser = Parser() + provisioned_workflows = [] + + if not workflows_dir: + workflows_dir = os.environ.get("KEEP_WORKFLOWS_DIRECTORY") + if not workflows_dir: + logger.info( + "No workflows directory provided - no provisioning will be done" + ) + return [] + + if not os.path.isdir(workflows_dir): + raise FileNotFoundError(f"Directory {workflows_dir} does not exist") + + # Get all existing provisioned workflows + provisioned_workflows = get_all_provisioned_workflows(tenant_id) + + # Check for workflows that are no longer in the directory or outside the workflows_dir and delete them + for workflow in provisioned_workflows: + if ( + not os.path.exists(workflow.provisioned_file) + or not os.path.commonpath([workflows_dir, workflow.provisioned_file]) + == workflows_dir + ): + logger.info( + f"Deprovisioning workflow {workflow.id} as its file no longer exists or is outside the workflows directory" + ) + delete_workflow_by_provisioned_file( + tenant_id, workflow.provisioned_file + ) + logger.info(f"Workflow {workflow.id} deprovisioned successfully") + + # Provision new workflows + for file in os.listdir(workflows_dir): + if file.endswith((".yaml", ".yml")): + logger.info(f"Provisioning workflow from {file}") + workflow_path = os.path.join(workflows_dir, file) + + try: + with open(workflow_path, "r") as yaml_file: + workflow_yaml = yaml.safe_load(yaml_file) + if "workflow" in workflow_yaml: + workflow_yaml = workflow_yaml["workflow"] + # backward compatibility + elif "alert" in workflow_yaml: + workflow_yaml = workflow_yaml["alert"] + + workflow_name = workflow_yaml.get("name") or workflow_yaml.get("id") + if not workflow_name: + logger.error(f"Workflow from {file} does not have a name or id") + continue + workflow_id = str(uuid.uuid4()) + workflow_description = workflow_yaml.get("description") + workflow_interval = parser.parse_interval(workflow_yaml) + workflow_disabled = parser.parse_disabled(workflow_yaml) + + add_or_update_workflow( + id=workflow_id, + name=workflow_name, + tenant_id=tenant_id, + description=workflow_description, + created_by="system", + interval=workflow_interval, + is_disabled=workflow_disabled, + workflow_raw=yaml.dump(workflow_yaml), + provisioned=True, + provisioned_file=workflow_path, + ) + provisioned_workflows.append(workflow_yaml) + + logger.info(f"Workflow from {file} provisioned successfully") + except Exception as e: + logger.error( + f"Error provisioning workflow from {file}", + extra={"exception": e}, + ) + + return provisioned_workflows + def _read_workflow_from_stream(self, stream) -> dict: """ Parse a workflow from an IO stream. @@ -247,7 +353,9 @@ def _read_workflow_from_stream(self, stream) -> dict: raise e return workflow - def get_random_workflow_templates(self, tenant_id: str, workflows_dir: str, limit: int) -> list[dict]: + def get_random_workflow_templates( + self, tenant_id: str, workflows_dir: str, limit: int + ) -> list[dict]: """ Get random workflows from a directory. Args: @@ -261,7 +369,9 @@ def get_random_workflow_templates(self, tenant_id: str, workflows_dir: str, limi if not os.path.isdir(workflows_dir): raise FileNotFoundError(f"Directory {workflows_dir} does not exist") - workflow_yaml_files = [f for f in os.listdir(workflows_dir) if f.endswith(('.yaml', '.yml'))] + workflow_yaml_files = [ + f for f in os.listdir(workflows_dir) if f.endswith((".yaml", ".yml")) + ] if not workflow_yaml_files: raise FileNotFoundError(f"No workflows found in directory {workflows_dir}") @@ -275,15 +385,17 @@ def get_random_workflow_templates(self, tenant_id: str, workflows_dir: str, limi file_path = os.path.join(workflows_dir, file) workflow_yaml = self._parse_workflow_to_dict(file_path) if "workflow" in workflow_yaml: - workflow_yaml['name'] = workflow_yaml['workflow']['id'] - workflow_yaml['workflow_raw'] = yaml.dump(workflow_yaml) - workflow_yaml['workflow_raw_id'] = workflow_yaml['workflow']['id'] + workflow_yaml["name"] = workflow_yaml["workflow"]["id"] + workflow_yaml["workflow_raw"] = yaml.dump(workflow_yaml) + workflow_yaml["workflow_raw_id"] = workflow_yaml["workflow"]["id"] workflows.append(workflow_yaml) count += 1 self.logger.info(f"Workflow from {file} fetched successfully") except Exception as e: - self.logger.error(f"Error parsing or fetching workflow from {file}: {e}") + self.logger.error( + f"Error parsing or fetching workflow from {file}: {e}" + ) return workflows def group_last_workflow_executions(self, workflows: list[dict]) -> list[dict]: @@ -294,7 +406,7 @@ def group_last_workflow_executions(self, workflows: list[dict]) -> list[dict]: self.logger.info(f"workflow_executions: {workflows}") workflow_dict = {} for item in workflows: - workflow,started,execution_time,status = item + workflow, started, execution_time, status = item workflow_id = workflow.id # Initialize the workflow if not already in the dictionary @@ -304,14 +416,14 @@ def group_last_workflow_executions(self, workflows: list[dict]) -> list[dict]: "workflow_last_run_started": None, "workflow_last_run_time": None, "workflow_last_run_status": None, - "workflow_last_executions": [] + "workflow_last_executions": [], } # Update the latest execution details if available - if workflow_dict[workflow_id]["workflow_last_run_started"] is None : + if workflow_dict[workflow_id]["workflow_last_run_started"] is None: workflow_dict[workflow_id]["workflow_last_run_status"] = status workflow_dict[workflow_id]["workflow_last_run_started"] = started - workflow_dict[workflow_id]["workflow_last_run_time"] = started + workflow_dict[workflow_id]["workflow_last_run_time"] = started # Add the execution to the list of executions if started is not None: @@ -319,7 +431,7 @@ def group_last_workflow_executions(self, workflows: list[dict]) -> list[dict]: { "status": status, "execution_time": execution_time, - "started": started + "started": started, } ) # Convert the dictionary to a list of results @@ -329,14 +441,16 @@ def group_last_workflow_executions(self, workflows: list[dict]) -> list[dict]: "workflow_last_run_status": workflow_info["workflow_last_run_status"], "workflow_last_run_time": workflow_info["workflow_last_run_time"], "workflow_last_run_started": workflow_info["workflow_last_run_started"], - "workflow_last_executions": workflow_info["workflow_last_executions"] + "workflow_last_executions": workflow_info["workflow_last_executions"], } for workflow_id, workflow_info in workflow_dict.items() ] return results - def get_workflow_meta_data(self, tenant_id: str, workflow: dict, installed_providers_by_type: dict): + def get_workflow_meta_data( + self, tenant_id: str, workflow: dict, installed_providers_by_type: dict + ): providers_dto = [] triggers = [] @@ -354,19 +468,28 @@ def get_workflow_meta_data(self, tenant_id: str, workflow: dict, installed_provi # Parse the workflow YAML safely workflow_yaml = yaml.safe_load(workflow_raw_data) if not workflow_yaml: - self.logger.error(f"Parsed workflow_yaml is empty or invalid: {workflow_raw_data}") + self.logger.error( + f"Parsed workflow_yaml is empty or invalid: {workflow_raw_data}" + ) return providers_dto, triggers providers = self.parser.get_providers_from_workflow(workflow_yaml) except Exception as e: # Improved logging to capture more details about the error - self.logger.error(f"Failed to parse workflow in get_workflow_meta_data: {e}, workflow: {workflow}") - return providers_dto, triggers # Return empty providers and triggers in case of error + self.logger.error( + f"Failed to parse workflow in get_workflow_meta_data: {e}, workflow: {workflow}" + ) + return ( + providers_dto, + triggers, + ) # Return empty providers and triggers in case of error # Step 2: Process providers and add them to DTO for provider in providers: try: - provider_data = installed_providers_by_type[provider.get("type")][provider.get("name")] + provider_data = installed_providers_by_type[provider.get("type")][ + provider.get("name") + ] provider_dto = ProviderDTO( name=provider_data.name, type=provider_data.type, @@ -377,9 +500,13 @@ def get_workflow_meta_data(self, tenant_id: str, workflow: dict, installed_provi except KeyError: # Handle case where the provider is not installed try: - conf = ProvidersFactory.get_provider_required_config(provider.get("type")) + conf = ProvidersFactory.get_provider_required_config( + provider.get("type") + ) except ModuleNotFoundError: - self.logger.warning(f"Non-existing provider in workflow: {provider.get('type')}") + self.logger.warning( + f"Non-existing provider in workflow: {provider.get('type')}" + ) conf = None # Handle providers based on whether they require config @@ -387,11 +514,13 @@ def get_workflow_meta_data(self, tenant_id: str, workflow: dict, installed_provi name=provider.get("name"), type=provider.get("type"), id=None, - installed=(conf is None), # Consider it installed if no config is required + installed=( + conf is None + ), # Consider it installed if no config is required ) providers_dto.append(provider_dto) # Step 3: Extract triggers from workflow triggers = self.parser.get_triggers_from_workflow(workflow_yaml) - return providers_dto, triggers \ No newline at end of file + return providers_dto, triggers diff --git a/tests/fixtures/client.py b/tests/fixtures/client.py index bc7df1872..c49b6ece5 100644 --- a/tests/fixtures/client.py +++ b/tests/fixtures/client.py @@ -1,3 +1,4 @@ +import asyncio import hashlib import importlib import sys @@ -29,6 +30,11 @@ def test_app(monkeypatch, request): for module in list(sys.modules): if module.startswith("keep.api.routes"): del sys.modules[module] + + # this is a fucking bug in db patching ffs it ruined my saturday + elif module.startswith("keep.providers.providers_service"): + importlib.reload(sys.modules[module]) + if "keep.api.api" in sys.modules: importlib.reload(sys.modules["keep.api.api"]) @@ -36,6 +42,11 @@ def test_app(monkeypatch, request): from keep.api.api import get_app app = get_app() + + # Manually trigger the startup event + for event_handler in app.router.on_startup: + asyncio.run(event_handler()) + return app diff --git a/tests/provision/workflows_1/provision_example_1.yml b/tests/provision/workflows_1/provision_example_1.yml new file mode 100644 index 000000000..aeeb7f140 --- /dev/null +++ b/tests/provision/workflows_1/provision_example_1.yml @@ -0,0 +1,20 @@ +workflow: + id: aks-example + description: aks-example + triggers: + - type: manual + steps: + # get all pods + - name: get-pods + provider: + type: aks + config: "{{ providers.aks }}" + with: + command_type: get_pods + actions: + - name: echo-pod-status + foreach: "{{ steps.get-pods.results }}" + provider: + type: console + with: + alert_message: "Pod name: {{ foreach.value.metadata.name }} || Namespace: {{ foreach.value.metadata.namespace }} || Status: {{ foreach.value.status.phase }}" diff --git a/tests/provision/workflows_1/provision_example_2.yml b/tests/provision/workflows_1/provision_example_2.yml new file mode 100644 index 000000000..4b5518ef9 --- /dev/null +++ b/tests/provision/workflows_1/provision_example_2.yml @@ -0,0 +1,29 @@ +workflow: + id: Resend-Python-service + description: Python Resend Mail + triggers: + - type: manual + owners: [] + services: [] + steps: + - name: run-script + provider: + config: '{{ providers.default-bash }}' + type: bash + with: + command: python3 test.py + timeout: 5 + actions: + - condition: + - assert: '{{ steps.run-script.results.return_code }} == 0' + name: assert-condition + type: assert + name: trigger-resend + provider: + type: resend + config: "{{ providers.resend-test }}" + with: + _from: "onboarding@resend.dev" + to: "youremail.dev@gmail.com" + subject: "Python test is up!" + html: <p>Python test is up!</p> diff --git a/tests/provision/workflows_1/provision_example_3.yml b/tests/provision/workflows_1/provision_example_3.yml new file mode 100644 index 000000000..3e6e85cf4 --- /dev/null +++ b/tests/provision/workflows_1/provision_example_3.yml @@ -0,0 +1,14 @@ +workflow: + id: autosupress + strategy: parallel + description: demonstrates how to automatically suppress alerts + triggers: + - type: alert + actions: + - name: dismiss-alert + provider: + type: mock + with: + enrich_alert: + - key: dismissed + value: "true" diff --git a/tests/provision/workflows_2/provision_example_1.yml b/tests/provision/workflows_2/provision_example_1.yml new file mode 100644 index 000000000..aeeb7f140 --- /dev/null +++ b/tests/provision/workflows_2/provision_example_1.yml @@ -0,0 +1,20 @@ +workflow: + id: aks-example + description: aks-example + triggers: + - type: manual + steps: + # get all pods + - name: get-pods + provider: + type: aks + config: "{{ providers.aks }}" + with: + command_type: get_pods + actions: + - name: echo-pod-status + foreach: "{{ steps.get-pods.results }}" + provider: + type: console + with: + alert_message: "Pod name: {{ foreach.value.metadata.name }} || Namespace: {{ foreach.value.metadata.namespace }} || Status: {{ foreach.value.status.phase }}" diff --git a/tests/provision/workflows_2/provision_example_2.yml b/tests/provision/workflows_2/provision_example_2.yml new file mode 100644 index 000000000..4b5518ef9 --- /dev/null +++ b/tests/provision/workflows_2/provision_example_2.yml @@ -0,0 +1,29 @@ +workflow: + id: Resend-Python-service + description: Python Resend Mail + triggers: + - type: manual + owners: [] + services: [] + steps: + - name: run-script + provider: + config: '{{ providers.default-bash }}' + type: bash + with: + command: python3 test.py + timeout: 5 + actions: + - condition: + - assert: '{{ steps.run-script.results.return_code }} == 0' + name: assert-condition + type: assert + name: trigger-resend + provider: + type: resend + config: "{{ providers.resend-test }}" + with: + _from: "onboarding@resend.dev" + to: "youremail.dev@gmail.com" + subject: "Python test is up!" + html: <p>Python test is up!</p> diff --git a/tests/test_auth.py b/tests/test_auth.py index 09ce7956a..86ad7a385 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -53,7 +53,7 @@ def get_mock_jwt_payload(token, *args, **kwargs): @pytest.mark.parametrize( "test_app", ["SINGLE_TENANT", "MULTI_TENANT", "NO_AUTH"], indirect=True ) -def test_api_key_with_header(client, db_session, test_app): +def test_api_key_with_header(db_session, client, test_app): """Tests the API key authentication with the x-api-key/digest""" auth_type = os.getenv("AUTH_TYPE") valid_api_key = "valid_api_key" @@ -95,7 +95,7 @@ def test_api_key_with_header(client, db_session, test_app): @pytest.mark.parametrize( "test_app", ["SINGLE_TENANT", "MULTI_TENANT", "NO_AUTH"], indirect=True ) -def test_bearer_token(client, db_session, test_app): +def test_bearer_token(db_session, client, test_app): """Tests the bearer token authentication""" auth_type = os.getenv("AUTH_TYPE") # Test bearer tokens @@ -121,7 +121,7 @@ def test_bearer_token(client, db_session, test_app): @pytest.mark.parametrize( "test_app", ["SINGLE_TENANT", "MULTI_TENANT", "NO_AUTH"], indirect=True ) -def test_webhook_api_key(client, db_session, test_app): +def test_webhook_api_key(db_session, client, test_app): """Tests the webhook API key authentication""" auth_type = os.getenv("AUTH_TYPE") valid_api_key = "valid_api_key" @@ -167,7 +167,7 @@ def test_webhook_api_key(client, db_session, test_app): # sanity check with keycloak @pytest.mark.parametrize("test_app", ["KEYCLOAK"], indirect=True) -def test_keycloak_sanity(keycloak_client, keycloak_token, client, test_app): +def test_keycloak_sanity(db_session, keycloak_client, keycloak_token, client, test_app): """Tests the keycloak sanity check""" # Use the token to make a request to the Keep API headers = {"Authorization": f"Bearer {keycloak_token}"} @@ -182,7 +182,7 @@ def test_keycloak_sanity(keycloak_client, keycloak_token, client, test_app): ], indirect=True, ) -def test_api_key_impersonation_without_admin(client, db_session, test_app): +def test_api_key_impersonation_without_admin(db_session, client, test_app): """Tests the API key impersonation with different environment settings""" valid_api_key = "valid_admin_api_key" @@ -207,7 +207,7 @@ def test_api_key_impersonation_without_admin(client, db_session, test_app): ], indirect=True, ) -def test_api_key_impersonation_without_user_provision(client, db_session, test_app): +def test_api_key_impersonation_without_user_provision(db_session, client, test_app): """Tests the API key impersonation with different environment settings""" valid_api_key = "valid_admin_api_key" @@ -239,7 +239,7 @@ def test_api_key_impersonation_without_user_provision(client, db_session, test_a ], indirect=True, ) -def test_api_key_impersonation_with_user_provision(client, db_session, test_app): +def test_api_key_impersonation_with_user_provision(db_session, client, test_app): """Tests the API key impersonation with different environment settings""" valid_api_key = "valid_admin_api_key" @@ -272,7 +272,7 @@ def test_api_key_impersonation_with_user_provision(client, db_session, test_app) indirect=True, ) def test_api_key_impersonation_provisioned_user_cant_login( - client, db_session, test_app + db_session, client, test_app ): """Tests the API key impersonation with different environment settings""" diff --git a/tests/test_enrichments.py b/tests/test_enrichments.py index 74ea6878c..e75c99a24 100644 --- a/tests/test_enrichments.py +++ b/tests/test_enrichments.py @@ -457,7 +457,7 @@ def test_mapping_rule_with_elsatic(mock_session, mock_alert_dto, setup_alerts): @pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) -def test_enrichment(client, db_session, test_app, mock_alert_dto, elastic_client): +def test_enrichment(db_session, client, test_app, mock_alert_dto, elastic_client): # add some rule rule = MappingRule( id=1, @@ -495,7 +495,7 @@ def test_enrichment(client, db_session, test_app, mock_alert_dto, elastic_client @pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) -def test_disposable_enrichment(client, db_session, test_app, mock_alert_dto): +def test_disposable_enrichment(db_session, client, test_app, mock_alert_dto): # SHAHAR: there is a voodoo so that you must do something with the db_session to kick it off rule = MappingRule( id=1, diff --git a/tests/test_extraction_rules.py b/tests/test_extraction_rules.py index 24759b6e2..190220a34 100644 --- a/tests/test_extraction_rules.py +++ b/tests/test_extraction_rules.py @@ -1,18 +1,15 @@ from time import sleep import pytest - from isodate import parse_datetime -from tests.fixtures.client import client, test_app, setup_api_key +from tests.fixtures.client import client, setup_api_key, test_app # noqa VALID_API_KEY = "valid_api_key" -@pytest.mark.parametrize( - "test_app", ["NO_AUTH"], indirect=True -) -def test_create_extraction_rule(client, test_app, db_session): +@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) +def test_create_extraction_rule(db_session, client, test_app): setup_api_key(db_session, VALID_API_KEY, role="webhook") # Try to create invalid extraction @@ -33,10 +30,8 @@ def test_create_extraction_rule(client, test_app, db_session): assert response.status_code == 200 -@pytest.mark.parametrize( - "test_app", ["NO_AUTH"], indirect=True -) -def test_extraction_rule_updated_at(client, test_app, db_session): +@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) +def test_extraction_rule_updated_at(db_session, client, test_app): setup_api_key(db_session, VALID_API_KEY, role="webhook") rule_dict = { @@ -66,7 +61,9 @@ def test_extraction_rule_updated_at(client, test_app, db_session): # Without it update can happen in the same second, so we will not see any changes sleep(1) updated_response = client.put( - f"/extraction/{rule_id}", json=updated_rule_dict, headers={"x-api-key": VALID_API_KEY} + f"/extraction/{rule_id}", + json=updated_rule_dict, + headers={"x-api-key": VALID_API_KEY}, ) assert updated_response.status_code == 200 @@ -75,7 +72,3 @@ def test_extraction_rule_updated_at(client, test_app, db_session): new_updated_at = parse_datetime(updated_response_data["updated_at"]) assert new_updated_at > updated_at - - - - diff --git a/tests/test_metrics.py b/tests/test_metrics.py index e9c400022..deb0b1621 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,36 +1,32 @@ - import pytest from keep.api.core.db import ( add_alerts_to_incident_by_incident_id, - create_incident_from_dict + create_incident_from_dict, ) - -from tests.fixtures.client import client, setup_api_key, test_app +from tests.fixtures.client import client, setup_api_key, test_app # noqa -@pytest.mark.parametrize( - "test_app", ["NO_AUTH"], indirect=True -) -def test_add_remove_alert_to_incidents(client, db_session, test_app, setup_stress_alerts_no_elastic): +@pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) +def test_add_remove_alert_to_incidents( + db_session, client, test_app, setup_stress_alerts_no_elastic +): alerts = setup_stress_alerts_no_elastic(14) - incident = create_incident_from_dict("keep", {"user_generated_name": "test", "description": "test"}) + incident = create_incident_from_dict( + "keep", {"user_generated_name": "test", "description": "test"} + ) valid_api_key = "valid_api_key" setup_api_key(db_session, valid_api_key) - add_alerts_to_incident_by_incident_id( - "keep", - incident.id, - [a.id for a in alerts] - ) + add_alerts_to_incident_by_incident_id("keep", incident.id, [a.id for a in alerts]) - response = client.get( - "/metrics", - headers={"X-API-KEY": "valid_api_key"} - ) + response = client.get("/metrics", headers={"X-API-KEY": "valid_api_key"}) # Checking for alert_total metric - assert f"alerts_total{{incident_name=\"test\" incident_id=\"{incident.id}\"}} 14" in response.text.split("\n") + assert ( + f'alerts_total{{incident_name="test" incident_id="{incident.id}"}} 14' + in response.text.split("\n") + ) # Checking for open_incidents_total metric assert "open_incidents_total 1" in response.text.split("\n") diff --git a/tests/test_provisioning.py b/tests/test_provisioning.py new file mode 100644 index 000000000..4b49ca854 --- /dev/null +++ b/tests/test_provisioning.py @@ -0,0 +1,223 @@ +import asyncio +import importlib +import sys + +import pytest +from fastapi.testclient import TestClient + +from tests.fixtures.client import client, setup_api_key, test_app # noqa + +# Mock data for workflows +MOCK_WORKFLOW_ID = "123e4567-e89b-12d3-a456-426614174000" +MOCK_PROVISIONED_WORKFLOW = { + "id": MOCK_WORKFLOW_ID, + "name": "Test Workflow", + "description": "A provisioned test workflow", + "provisioned": True, +} + + +# Test for deleting a provisioned workflow +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_WORKFLOWS_DIRECTORY": "./tests/provision/workflows_1", + }, + ], + indirect=True, +) +def test_provisioned_workflows(db_session, client, test_app): + response = client.get("/workflows", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + # 3 workflows and 3 provisioned workflows + workflows = response.json() + provisioned_workflows = [w for w in workflows if w.get("provisioned")] + assert len(provisioned_workflows) == 3 + + +# Test for deleting a provisioned workflow +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_WORKFLOWS_DIRECTORY": "./tests/provision/workflows_2", + }, + ], + indirect=True, +) +def test_provisioned_workflows_2(db_session, client, test_app): + response = client.get("/workflows", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + # 3 workflows and 3 provisioned workflows + workflows = response.json() + provisioned_workflows = [w for w in workflows if w.get("provisioned")] + assert len(provisioned_workflows) == 2 + + +# Test for deleting a provisioned workflow +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_WORKFLOWS_DIRECTORY": "./tests/provision/workflows_1", + }, + ], + indirect=True, +) +def test_delete_provisioned_workflow(db_session, client, test_app): + response = client.get("/workflows", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + # 3 workflows and 3 provisioned workflows + workflows = response.json() + provisioned_workflows = [w for w in workflows if w.get("provisioned")] + workflow_id = provisioned_workflows[0].get("id") + response = client.delete( + f"/workflows/{workflow_id}", headers={"x-api-key": "someapikey"} + ) + # can't delete a provisioned workflow + assert response.status_code == 403 + assert response.json() == {"detail": "Cannot delete a provisioned workflow"} + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_WORKFLOWS_DIRECTORY": "./tests/provision/workflows_1", + }, + ], + indirect=True, +) +def test_update_provisioned_workflow(db_session, client, test_app): + response = client.get("/workflows", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + # 3 workflows and 3 provisioned workflows + workflows = response.json() + provisioned_workflows = [w for w in workflows if w.get("provisioned")] + workflow_id = provisioned_workflows[0].get("id") + response = client.put( + f"/workflows/{workflow_id}", headers={"x-api-key": "someapikey"} + ) + # can't delete a provisioned workflow + assert response.status_code == 403 + assert response.json() == {"detail": "Cannot update a provisioned workflow"} + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_WORKFLOWS_DIRECTORY": "./tests/provision/workflows_1", + }, + ], + indirect=True, +) +def test_reprovision_workflow(monkeypatch, db_session, client, test_app): + response = client.get("/workflows", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + # 3 workflows and 3 provisioned workflows + workflows = response.json() + provisioned_workflows = [w for w in workflows if w.get("provisioned")] + assert len(provisioned_workflows) == 3 + + # Step 2: Change environment variables (simulating new provisioning) + monkeypatch.setenv("KEEP_WORKFLOWS_DIRECTORY", "./tests/provision/workflows_2") + + # Reload the app to apply the new environment changes + importlib.reload(sys.modules["keep.api.api"]) + + # Reinitialize the TestClient with the new app instance + from keep.api.api import get_app + + app = get_app() + + # Manually trigger the startup event + for event_handler in app.router.on_startup: + asyncio.run(event_handler()) + + client = TestClient(get_app()) + + response = client.get("/workflows", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + # 2 workflows and 2 provisioned workflows + workflows = response.json() + provisioned_workflows = [w for w in workflows if w.get("provisioned")] + assert len(provisioned_workflows) == 2 + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"keepVictoriaMetrics":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort": 1234}},"keepClickhouse1":{"type":"clickhouse","authentication":{"host":"http://localhost","port":1234,"username":"keep","password":"keep","database":"keep-db"}}}', + }, + ], + indirect=True, +) +def test_provision_provider(db_session, client, test_app): + response = client.get("/providers", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + # 3 workflows and 3 provisioned workflows + providers = response.json() + provisioned_providers = [ + p for p in providers.get("installed_providers") if p.get("provisioned") + ] + assert len(provisioned_providers) == 2 + + +@pytest.mark.parametrize( + "test_app", + [ + { + "AUTH_TYPE": "NOAUTH", + "KEEP_PROVIDERS": '{"keepVictoriaMetrics":{"type":"victoriametrics","authentication":{"VMAlertHost":"http://localhost","VMAlertPort": 1234}},"keepClickhouse1":{"type":"clickhouse","authentication":{"host":"http://localhost","port":1234,"username":"keep","password":"keep","database":"keep-db"}}}', + }, + ], + indirect=True, +) +def test_reprovision_provider(monkeypatch, db_session, client, test_app): + response = client.get("/providers", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + # 3 workflows and 3 provisioned workflows + providers = response.json() + provisioned_providers = [ + p for p in providers.get("installed_providers") if p.get("provisioned") + ] + assert len(provisioned_providers) == 2 + + # Step 2: Change environment variables (simulating new provisioning) + monkeypatch.setenv( + "KEEP_PROVIDERS", + '{"keepPrometheus":{"type":"prometheus","authentication":{"url":"http://localhost","port":9090}}}', + ) + + # Reload the app to apply the new environment changes + importlib.reload(sys.modules["keep.api.api"]) + + # Reinitialize the TestClient with the new app instance + from keep.api.api import get_app + + app = get_app() + + # Manually trigger the startup event + for event_handler in app.router.on_startup: + asyncio.run(event_handler()) + + client = TestClient(app) + + # Step 3: Verify if the new provider is provisioned after reloading + response = client.get("/providers", headers={"x-api-key": "someapikey"}) + assert response.status_code == 200 + providers = response.json() + provisioned_providers = [ + p for p in providers.get("installed_providers") if p.get("provisioned") + ] + assert len(provisioned_providers) == 1 + assert provisioned_providers[0]["type"] == "prometheus" diff --git a/tests/test_rules_api.py b/tests/test_rules_api.py index f29e2992d..ff039950d 100644 --- a/tests/test_rules_api.py +++ b/tests/test_rules_api.py @@ -1,8 +1,7 @@ import pytest -from keep.api.core.dependencies import SINGLE_TENANT_UUID from keep.api.core.db import create_rule as create_rule_db - +from keep.api.core.dependencies import SINGLE_TENANT_UUID from tests.fixtures.client import client, setup_api_key, test_app # noqa TEST_RULE_DATA = { @@ -18,28 +17,36 @@ "created_by": "test@keephq.dev", } -INVALID_DATA_STEPS = [{ - "update": {"sqlQuery": {"sql": "", "params": []}}, - "error": "SQL is required", -}, { - "update": {"sqlQuery": {"sql": "SELECT", "params": []}}, - "error": "Params are required", -}, { - "update": {"celQuery": ""}, - "error": "CEL is required", -}, { - "update": {"ruleName": ""}, - "error": "Rule name is required", -}, { - "update": {"timeframeInSeconds": 0}, - "error": "Timeframe is required", -}, { - "update": {"timeUnit": ""}, - "error": "Timeunit is required", -}] +INVALID_DATA_STEPS = [ + { + "update": {"sqlQuery": {"sql": "", "params": []}}, + "error": "SQL is required", + }, + { + "update": {"sqlQuery": {"sql": "SELECT", "params": []}}, + "error": "Params are required", + }, + { + "update": {"celQuery": ""}, + "error": "CEL is required", + }, + { + "update": {"ruleName": ""}, + "error": "Rule name is required", + }, + { + "update": {"timeframeInSeconds": 0}, + "error": "Timeframe is required", + }, + { + "update": {"timeUnit": ""}, + "error": "Timeunit is required", + }, +] + @pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) -def test_get_rules_api(client, db_session, test_app): +def test_get_rules_api(db_session, client, test_app): rule = create_rule_db(**TEST_RULE_DATA) response = client.get( @@ -50,7 +57,7 @@ def test_get_rules_api(client, db_session, test_app): assert response.status_code == 200 data = response.json() assert len(data) == 1 - assert data[0]['id'] == str(rule.id) + assert data[0]["id"] == str(rule.id) rule2 = create_rule_db(**TEST_RULE_DATA) @@ -62,27 +69,26 @@ def test_get_rules_api(client, db_session, test_app): assert response2.status_code == 200 data = response2.json() assert len(data) == 2 - assert data[0]['id'] == str(rule.id) - assert data[1]['id'] == str(rule2.id) + assert data[0]["id"] == str(rule.id) + assert data[1]["id"] == str(rule2.id) @pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) -def test_create_rule_api(client, db_session, test_app): +def test_create_rule_api(db_session, client, test_app): rule_data = { "ruleName": "test rule", - "sqlQuery": {"sql": "SELECT * FROM alert where severity = %s", "params": ["critical"]}, + "sqlQuery": { + "sql": "SELECT * FROM alert where severity = %s", + "params": ["critical"], + }, "celQuery": "severity = 'critical'", "timeframeInSeconds": 300, "timeUnit": "seconds", "requireApprove": False, } - response = client.post( - "/rules", - headers={"x-api-key": "some-key"}, - json=rule_data - ) + response = client.post("/rules", headers={"x-api-key": "some-key"}, json=rule_data) assert response.status_code == 200 data = response.json() @@ -92,9 +98,7 @@ def test_create_rule_api(client, db_session, test_app): invalid_rule_data = {k: v for k, v in rule_data.items() if k != "ruleName"} invalid_data_response = client.post( - "/rules", - headers={"x-api-key": "some-key"}, - json=invalid_rule_data + "/rules", headers={"x-api-key": "some-key"}, json=invalid_rule_data ) assert invalid_data_response.status_code == 422 @@ -109,7 +113,7 @@ def test_create_rule_api(client, db_session, test_app): invalid_data_response_2 = client.post( "/rules", headers={"x-api-key": "some-key"}, - json=dict(rule_data, **invalid_data_step["update"]) + json=dict(rule_data, **invalid_data_step["update"]), ) assert invalid_data_response_2.status_code == 400, current_step @@ -119,7 +123,7 @@ def test_create_rule_api(client, db_session, test_app): @pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) -def test_delete_rule_api(client, db_session, test_app): +def test_delete_rule_api(db_session, client, test_app): rule = create_rule_db(**TEST_RULE_DATA) response = client.delete( @@ -143,15 +147,17 @@ def test_delete_rule_api(client, db_session, test_app): assert data["detail"] == "Rule not found" - @pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) -def test_update_rule_api(client, db_session, test_app): +def test_update_rule_api(db_session, client, test_app): rule = create_rule_db(**TEST_RULE_DATA) rule_data = { "ruleName": "test rule", - "sqlQuery": {"sql": "SELECT * FROM alert where severity = %s", "params": ["critical"]}, + "sqlQuery": { + "sql": "SELECT * FROM alert where severity = %s", + "params": ["critical"], + }, "celQuery": "severity = 'critical'", "timeframeInSeconds": 300, "timeUnit": "seconds", @@ -159,9 +165,7 @@ def test_update_rule_api(client, db_session, test_app): } response = client.put( - "/rules/{}".format(rule.id), - headers={"x-api-key": "some-key"}, - json=rule_data + "/rules/{}".format(rule.id), headers={"x-api-key": "some-key"}, json=rule_data ) assert response.status_code == 200 @@ -174,7 +178,7 @@ def test_update_rule_api(client, db_session, test_app): invalid_data_response_2 = client.put( "/rules/{}".format(rule.id), headers={"x-api-key": "some-key"}, - json=dict(rule_data, **invalid_data_step["update"]) + json=dict(rule_data, **invalid_data_step["update"]), ) assert invalid_data_response_2.status_code == 400, current_step diff --git a/tests/test_search_alerts.py b/tests/test_search_alerts.py index bb45436ff..a7cc3b867 100644 --- a/tests/test_search_alerts.py +++ b/tests/test_search_alerts.py @@ -736,7 +736,7 @@ def test_special_characters_in_strings(db_session, setup_alerts): # tests 10k alerts @pytest.mark.parametrize( - "setup_stress_alerts", [{"num_alerts": 10000}], indirect=True + "setup_stress_alerts", [{"num_alerts": 1000}], indirect=True ) # Generate 10,000 alerts def test_filter_large_dataset(db_session, setup_stress_alerts): search_query = SearchQuery( @@ -745,7 +745,7 @@ def test_filter_large_dataset(db_session, setup_stress_alerts): "params": {"source_1": "source_1", "severity_1": "critical"}, }, cel_query='(source == "source_1") && (severity == "critical")', - limit=10000, + limit=1000, ) # first, use elastic os.environ["ELASTIC_ENABLED"] = "true" @@ -764,10 +764,10 @@ def test_filter_large_dataset(db_session, setup_stress_alerts): # compare assert len(elastic_filtered_alerts) == len(db_filtered_alerts) print( - "time taken for 10k alerts with elastic: ", + "time taken for 1k alerts with elastic: ", elastic_end_time - elastic_start_time, ) - print("time taken for 10k alerts with db: ", db_end_time - db_start_time) + print("time taken for 1k alerts with db: ", db_end_time - db_start_time) @pytest.mark.parametrize("setup_stress_alerts", [{"num_alerts": 10000}], indirect=True) @@ -1312,7 +1312,7 @@ def test_severity_comparisons( @pytest.mark.parametrize("test_app", ["NO_AUTH"], indirect=True) -def test_alerts_enrichment_in_search(client, db_session, test_app, elastic_client): +def test_alerts_enrichment_in_search(db_session, client, test_app, elastic_client): rule = MappingRule( id=1, diff --git a/tests/test_search_alerts_configuration.py b/tests/test_search_alerts_configuration.py index 7d3bb700e..bf4b1ab9a 100644 --- a/tests/test_search_alerts_configuration.py +++ b/tests/test_search_alerts_configuration.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("test_app", ["SINGLE_TENANT"], indirect=True) def test_single_tenant_configuration_with_elastic( - client, elastic_client, db_session, test_app + db_session, client, elastic_client, test_app ): valid_api_key = "valid_api_key" setup_api_key(db_session, valid_api_key) @@ -30,7 +30,7 @@ def test_single_tenant_configuration_with_elastic( ], indirect=True, ) -def test_single_tenant_configuration_without_elastic(client, db_session, test_app): +def test_single_tenant_configuration_without_elastic(db_session, client, test_app): valid_api_key = "valid_api_key" setup_api_key(db_session, valid_api_key) response = client.get("/preset/feed/alerts", headers={"x-api-key": valid_api_key}) @@ -39,7 +39,7 @@ def test_single_tenant_configuration_without_elastic(client, db_session, test_ap @pytest.mark.parametrize("test_app", ["MULTI_TENANT"], indirect=True) def test_multi_tenant_configuration_with_elastic( - client, elastic_client, db_session, test_app + db_session, client, elastic_client, test_app ): valid_api_key = "valid_api_key" valid_api_key_2 = "valid_api_key_2"