Skip to content

Commit

Permalink
feat: add start_consume
Browse files Browse the repository at this point in the history
Signed-off-by: 35C4n0r <jaykumar20march@gmail.com>
  • Loading branch information
35C4n0r committed Jan 12, 2025
1 parent 6f8f504 commit 5feab3f
Showing 1 changed file with 136 additions and 58 deletions.
194 changes: 136 additions & 58 deletions keep/providers/amazonsqs_provider/amazonsqs_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from datetime import datetime
import boto3
import botocore
import logging
import inspect


import pydantic

Expand Down Expand Up @@ -54,10 +57,45 @@ class AmazonsqsProviderAuthConfig:
)


class ClientIdInjector(logging.Filter):
def filter(self, record):
# For this example, let's pretend we can obtain the client_id
# by inspecting the caller or some context. Replace the next line
# with the actual logic to get the client_id.
client_id, provider_id = self.get_client_id_from_caller()
if not hasattr(record, "extra"):
record.extra = {
"client_id": client_id,
"provider_id": provider_id,
}
return True

def get_client_id_from_caller(self):
# Here, you should implement the logic to extract client_id based on the caller.
# This can be tricky and might require you to traverse the call stack.
# Return a default or None if you can't find it.
import copy

frame = inspect.currentframe()
client_id = None
while frame:
local_vars = copy.copy(frame.f_locals)
for var_name, var_value in local_vars.items():
if isinstance(var_value, AmazonsqsProvider):
client_id = var_value.context_manager.tenant_id
provider_id = var_value.provider_id
break
if client_id:
return client_id, provider_id
frame = frame.f_back
return None, None


class AmazonsqsProvider(BaseProvider):
"""Sends and receive alerts from AmazonSQS."""

PROVIDER_CATEGORY = ["Monitoring"]
PROVIDER_CATEGORY = ["Monitoring", "Queues"]
PROVIDER_TAGS = ["queue"]

alert_severity_dict = {
"critical": AlertSeverity.CRITICAL,
Expand Down Expand Up @@ -93,7 +131,16 @@ def __init__(
self, context_manager: ContextManager, provider_id: str, config: ProviderConfig
):
super().__init__(context_manager, provider_id, config)
self._client = None
self.consume = False
self.consumer = None
self.err = ""
# patch all AmazonSQS loggers to contain the tenant_id
for logger_name in logging.Logger.manager.loggerDict:
if logger_name.startswith("amazonsqs"):
logger = logging.getLogger(logger_name)
if not any(isinstance(f, ClientIdInjector) for f in logger.filters):
self.logger.info(f"Patching amazonsqs logger {logger_name}")
logger.addFilter(ClientIdInjector())

def dispose(self):
"""
Expand All @@ -112,14 +159,14 @@ def validate_config(self):

@property
def __get_sqs_client(self):
if self._client is None:
self._client = boto3.client(
if self.consumer is None:
self.consumer = boto3.client(
"sqs",
region_name=self.authentication_config.region_name,
aws_access_key_id=self.authentication_config.access_key_id,
aws_secret_access_key=self.authentication_config.secret_access_key,
)
return self._client
return self.consumer

def validate_scopes(self) -> dict[str, bool | str]:
self.logger.info("Validating user scopes for AmazonSQS provider")
Expand All @@ -136,28 +183,42 @@ def validate_scopes(self) -> dict[str, bool | str]:
)
try:
sts.get_caller_identity()
self.logger.info("User identity fetched successfully, user is authenticated.")
self.logger.info(
"User identity fetched successfully, user is authenticated."
)
scopes["authenticated"] = True
except botocore.exceptions.ClientError as e:
self.logger.error("Error while getting user identity, authentication failed", extra={"exception": str(e)})
self.logger.error(
"Error while getting user identity, authentication failed",
extra={"exception": str(e)},
)
scopes["authenticated"] = str(e)
return scopes

try:
self.__write_to_queue(message="KEEP_SCOPE_TEST_MSG_PLEASE_IGNORE", dedup_id=str(uuid.uuid4()), group_id="keep")
self.__write_to_queue(
message="KEEP_SCOPE_TEST_MSG_PLEASE_IGNORE",
dedup_id=str(uuid.uuid4()),
group_id="keep",
)
self.logger.info("All scopes verified successfully")
scopes["sqs::write"] = True
scopes["sqs::read"] = True
except botocore.exceptions.ClientError as e:
self.logger.error("User does not have permission to write to SQS queue", extra={"exception": str(e)})
self.logger.error(
"User does not have permission to write to SQS queue",
extra={"exception": str(e)},
)
scopes["sqs::write"] = str(e)
try:
self.__read_from_queue()
self.logger.info("User has permission to read from SQS Queue")
scopes["sqs::read"] = True
except botocore.exceptions.ClientError as e:
self.logger.error("User does not have permission to read from SQS queue",
extra={"exception": str(e)})
self.logger.error(
"User does not have permission to read from SQS queue",
extra={"exception": str(e)},
)
scopes["sqs::read"] = str(e)
return scopes

Expand All @@ -172,7 +233,9 @@ def __read_from_queue(self):
WaitTimeSeconds=10,
)
except Exception as e:
self.logger.error("Error while reading from SQS Queue", extra={"exception": str(e)})
self.logger.error(
"Error while reading from SQS Queue", extra={"exception": str(e)}
)

def __write_to_queue(self, message, group_id, dedup_id, **kwargs):
try:
Expand All @@ -183,11 +246,14 @@ def __write_to_queue(self, message, group_id, dedup_id, **kwargs):
is_fifo = self.authentication_config.sqs_queue_url.endswith(".fifo")
self.logger.info("Building MessageAttributes")
msg_attrs = {
key: {"StringValue": kwargs[key], "DataType": "String"} for key in kwargs
key: {"StringValue": kwargs[key], "DataType": "String"}
for key in kwargs
}
if is_fifo:
if not dedup_id or group_id:
self.logger.error()
if not dedup_id or not group_id:
self.logger.error(
"Mandatory to provide dedup_id (Message deduplication ID) & group_id (Message group ID) when pushing to fifo queue"
)
raise Exception(
"Mandatory to provide dedup_id (Message deduplication ID) & group_id (Message group ID) when pushing to fifo queue"
)
Expand All @@ -206,11 +272,14 @@ def __write_to_queue(self, message, group_id, dedup_id, **kwargs):
)

self.logger.info(
"Successfully pushed the message to SQS", extra={"response": str(response)}
"Successfully pushed the message to SQS",
extra={"response": str(response)},
)
return response
except Exception as e:
self.logger.error("Error while writing to SQS queue", extra={"exception": str(e)})
self.logger.error(
"Error while writing to SQS queue", extra={"exception": str(e)}
)
raise e

def __delete_from_queue(self, receipt: str):
Expand All @@ -221,7 +290,10 @@ def __delete_from_queue(self, receipt: str):
)
self.logger.info("Successfully deleted message from SQS Queue")
except Exception as e:
self.logger.error("Error while deleting message from SQS queue", extra={"exception": str(e)})
self.logger.error(
"Error while deleting message from SQS queue",
extra={"exception": str(e)},
)
raise e

@staticmethod
Expand All @@ -233,47 +305,53 @@ def get_status_or_default(status_value):
# If not, return the default AlertStatus.FIRING
return AlertStatus.FIRING

def _get_alerts(self) -> list[AlertDto]:
self.logger.info("Getting Alerts from Amazon SQS provider")
alerts = []
response = self.__read_from_queue()
messages = response.get("Messages", [])
if not messages:
self.logger.info("No messages found. Queue is empty!")

for message in messages:
labels = {}
attrs = message.get("MessageAttributes", {})
for msg_attr in attrs:
labels[msg_attr.lower()] = attrs[msg_attr].get(
"StringValue", attrs[msg_attr].get("BinaryValue", "")
)
alerts.append(
AlertDto(
id=message["MessageId"],
name=labels.get("name", message["Body"]),
description=labels.get("description", message["Body"]),
message=message["Body"],
status=AmazonsqsProvider.get_status_or_default(
labels.get("status", "firing")
),
severity=self.alert_severity_dict.get(
labels.get("severity", "high"), AlertSeverity.HIGH
),
lastReceived=datetime.fromtimestamp(
float(message["Attributes"]["SentTimestamp"]) / 1000
).isoformat(),
firingStartTime=datetime.fromtimestamp(
float(message["Attributes"]["SentTimestamp"]) / 1000
).isoformat(),
labels=labels,
source=["amazonsqs"]
)
)
self.__delete_from_queue(receipt=message["ReceiptHandle"])
return alerts

def _notify(self, message, group_id, dedup_id, **kwargs):
return self.__write_to_queue(
message=message, group_id=group_id, dedup_id=dedup_id, **kwargs
)

def start_consume(self):
self.consume = True
while self.consume:
response = self.__read_from_queue()
messages = response.get("Messages", [])
if not messages:
self.logger.info("No messages found. Queue is empty!")

for message in messages:
try:
labels = {}
attrs = message.get("MessageAttributes", {})
for msg_attr in attrs:
labels[msg_attr.lower()] = attrs[msg_attr].get(
"StringValue", attrs[msg_attr].get("BinaryValue", "")
)

alert_dict = {
"id": message["MessageId"],
"name": labels.get("name", message["Body"]),
"description": labels.get("description", message["Body"]),
"message": message["Body"],
"status": AmazonsqsProvider.get_status_or_default(
labels.get("status", "firing")
),
"severity": self.alert_severity_dict.get(
labels.get("severity", "high"), AlertSeverity.HIGH
),
"lastReceived": datetime.fromtimestamp(
float(message["Attributes"]["SentTimestamp"]) / 1000
).isoformat(),
"firingStartTime": datetime.fromtimestamp(
float(message["Attributes"]["SentTimestamp"]) / 1000
).isoformat(),
"labels": labels,
"source": ["amazonsqs"],
}
self._push_alert(alert_dict)
self.__delete_from_queue(receipt=message["ReceiptHandle"])
except Exception as e:
self.logger.error(f"Error processing message: {e}")
self.logger.info("Consuming stopped")

def stop_consume(self):
self.consume = False

0 comments on commit 5feab3f

Please sign in to comment.