diff --git a/keep/api/core/demo_mode.py b/keep/api/core/demo_mode.py index 8f53a6992..a94571e63 100644 --- a/keep/api/core/demo_mode.py +++ b/keep/api/core/demo_mode.py @@ -7,9 +7,11 @@ from datetime import timezone from uuid import uuid4 +import aiohttp import requests from dateutil import parser from requests.models import PreparedRequest +from sqlalchemy.util import asyncio from keep.api.core.db import get_session_sync from keep.api.core.dependencies import SINGLE_TENANT_UUID @@ -24,6 +26,9 @@ logger = logging.getLogger(__name__) KEEP_LIVE_DEMO_MODE = os.environ.get("KEEP_LIVE_DEMO_MODE", "false").lower() == "true" +GENERATE_DEDUPLICATIONS = False + +REQUESTS_QUEUE = asyncio.Queue() correlation_rules_to_create = [ { @@ -412,7 +417,12 @@ def perform_demo_ai(keep_api_key, keep_api_url): response.raise_for_status() -def simulate_alerts( +def simulate_alerts(*args, **kwargs): + asyncio.create_task(simulate_alerts_worker(0, keep_api_key, 0)) + asyncio.run(simulate_alerts_async(*args, **kwargs)) + + +async def simulate_alerts_async( keep_api_url=None, keep_api_key=None, sleep_interval=5, @@ -420,11 +430,10 @@ def simulate_alerts( demo_topology=False, clean_old_incidents=False, demo_ai=False, + target_rps=0 ): logger.info("Simulating alerts...") - GENERATE_DEDUPLICATIONS = False - providers_config = [ {"type": "prometheus", "weight": 3}, {"type": "grafana", "weight": 1}, @@ -474,6 +483,7 @@ def simulate_alerts( get_or_create_topology(keep_api_key, keep_api_url) logger.info("Topology created.") + shoot = 1 while True: try: logger.info("Looping to send alerts...") @@ -486,66 +496,58 @@ def simulate_alerts( if demo_ai: perform_demo_ai(keep_api_key, keep_api_url) - send_alert_url_params = {} + # If we want to make stress-testing, we want to prepare more data for faster requesting in workers + if target_rps: + shoot = target_rps * 100 - # choose provider based on weights - provider_type = random.choices(providers, weights=normalized_weights, k=1)[ - 0 - ] - send_alert_url = "{}/alerts/event/{}".format(keep_api_url, provider_type) + for _ in range(shoot): - if provider_type in existing_providers_to_their_ids: - send_alert_url_params["provider_id"] = existing_providers_to_their_ids[ - provider_type - ] - logger.info( - f"Provider type: {provider_type}, send_alert_url_params now are: {send_alert_url_params}" - ) + send_alert_url_params = {} + + # choose provider based on weights + provider_type = random.choices(providers, weights=normalized_weights, k=1)[0] + send_alert_url = "{}/alerts/event/{}".format(keep_api_url, provider_type) - provider = provider_classes[provider_type] - alert = provider.simulate_alert() + if provider_type in existing_providers_to_their_ids: + send_alert_url_params["provider_id"] = existing_providers_to_their_ids[ + provider_type + ] + logger.info( + f"Provider type: {provider_type}, send_alert_url_params now are: {send_alert_url_params}" + ) - if provider_type in providers_to_randomize_fingerprint_for: - send_alert_url_params["fingerprint"] = str(uuid4()) + provider = provider_classes[provider_type] + alert = provider.simulate_alert() - # Determine number of times to send the same alert - num_iterations = 1 - if GENERATE_DEDUPLICATIONS: - num_iterations = random.randint(1, 3) + if provider_type in providers_to_randomize_fingerprint_for: + send_alert_url_params["fingerprint"] = str(uuid4()) - for _ in range(num_iterations): - logger.info("Sending alert: {}".format(alert)) - try: - env = random.choice(["production", "staging", "development"]) + # Determine number of times to send the same alert + num_iterations = 1 + if GENERATE_DEDUPLICATIONS: + num_iterations = random.randint(1, 3) - if "provider_id" not in send_alert_url_params: - send_alert_url_params["provider_id"] = f"{provider_type}-{env}" - else: - alert["environment"] = random.choice( - ["prod-01", "prod-02", "prod-03"] - ) + env = random.choice(["production", "staging", "development"]) + + if "provider_id" not in send_alert_url_params: + send_alert_url_params["provider_id"] = f"{provider_type}-{env}" + else: + alert["environment"] = random.choice( + ["prod-01", "prod-02", "prod-03"] + ) + + for _ in range(num_iterations): prepared_request = PreparedRequest() prepared_request.prepare_url(send_alert_url, send_alert_url_params) - logger.info( - f"Sending alert to {prepared_request.url} with url params {send_alert_url_params}" - ) + await REQUESTS_QUEUE.put((prepared_request.url, alert)) + if not target_rps: + await asyncio.sleep(sleep_interval) - response = requests.post( - prepared_request.url, - headers={"x-api-key": keep_api_key}, - json=alert, - ) - response.raise_for_status() - except requests.exceptions.RequestException as e: - logger.error("Failed to send alert: {}".format(e)) - time.sleep(sleep_interval) - continue + # Wait until almost prepopulated data was consumed + while not REQUESTS_QUEUE.empty(): + await asyncio.sleep(sleep_interval) - if not response.ok: - logger.error("Failed to send alert: {}".format(response.text)) - else: - logger.info("Alert sent successfully") except Exception as e: logger.exception( "Error in simulate_alerts", extra={"exception_str": str(e)} @@ -554,7 +556,7 @@ def simulate_alerts( logger.info( "Sleeping for {} seconds before next iteration".format(sleep_interval) ) - time.sleep(sleep_interval) + await asyncio.sleep(sleep_interval) def launch_demo_mode_thread( @@ -597,6 +599,32 @@ def launch_demo_mode_thread( return thread +async def simulate_alerts_worker(worker_id, keep_api_key, rps=1): + + headers = {"x-api-key": keep_api_key, "Content-type": "application/json"} + + async with aiohttp.ClientSession() as session: + total_start = time.time() + total_requests = 0 + while True: + start = time.time() + url, alert = await REQUESTS_QUEUE.get() + + async with session.post(url, json=alert, headers=headers) as response: + total_requests += 1 + if not response.ok: + logger.error("Failed to send alert: {}".format(response.text)) + else: + logger.info("Alert sent successfully") + + if rps: + delay = 1/rps - (time.time() - start) + if delay > 0: + logger.debug('worker %d sleeps, %f', worker_id, delay) + await asyncio.sleep(delay) + logger.info('Worker %d RPS: %.2f', worker_id, total_requests / (time.time() - total_start)) + + if __name__ == "__main__": keep_api_url = os.environ.get("KEEP_API_URL") or "http://localhost:8080" keep_api_key = os.environ.get("KEEP_READ_ONLY_BYPASS_KEY") diff --git a/scripts/simulate_alerts.py b/scripts/simulate_alerts.py index 885b0e727..29a38ca4b 100644 --- a/scripts/simulate_alerts.py +++ b/scripts/simulate_alerts.py @@ -2,7 +2,9 @@ import logging import argparse -from keep.api.core.demo_mode import simulate_alerts +import asyncio + +from keep.api.core.demo_mode import simulate_alerts, simulate_alerts_worker, simulate_alerts_async logging.basicConfig( level=logging.DEBUG, @@ -13,25 +15,34 @@ logger = logging.getLogger(__name__) -def main(): +async def main(): parser = argparse.ArgumentParser(description="Simulate alerts for Keep API.") parser.add_argument( "--full-demo", action="store_true", help="Run the full demo including correlation rules and topology.", ) + parser.add_argument("--rps", type=int, help="Base requests per second") + parser.add_argument("--workers", "-w", type=int, default=1, help="Amount of background workers to send alerts") + args = parser.parse_args() + rps = args.rps default_sleep_interval = 0.2 if args.full_demo: default_sleep_interval = 5 + rps = 0 SLEEP_INTERVAL = float( os.environ.get("SLEEP_INTERVAL", default_sleep_interval) ) keep_api_key = os.environ.get("KEEP_API_KEY") keep_api_url = os.environ.get("KEEP_API_URL") or "http://localhost:8080" - simulate_alerts( + + for i in range(args.workers): + asyncio.create_task(simulate_alerts_worker(i, keep_api_key, rps)) + + await simulate_alerts_async( keep_api_key=keep_api_key, keep_api_url=keep_api_url, sleep_interval=SLEEP_INTERVAL, @@ -39,7 +50,14 @@ def main(): demo_topology=args.full_demo, clean_old_incidents=args.full_demo, demo_ai=args.full_demo, + target_rps=rps ) + if __name__ == "__main__": - main() + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass + finally: + print("Closing Loop")