Skip to content

Commit

Permalink
fix: Add alerts_count, affected_services and sources fields to the In…
Browse files Browse the repository at this point in the history
…cident (#1473)

Co-authored-by: Tal <[email protected]>
  • Loading branch information
VladimirFilonov and talboren authored Aug 8, 2024
1 parent d07c7c0 commit 1ea44bf
Show file tree
Hide file tree
Showing 7 changed files with 356 additions and 64 deletions.
137 changes: 119 additions & 18 deletions keep/api/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
import validators
from dotenv import find_dotenv, load_dotenv
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from sqlalchemy import and_, desc, func, null, update
from sqlalchemy import and_, desc, null, update
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.orm import joinedload, selectinload, subqueryload
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.orm.exc import StaleDataError
from sqlalchemy.sql import expression
from sqlmodel import Session, col, or_, select

from keep.api.core.db_utils import create_db_engine
from keep.api.core.db_utils import create_db_engine, get_json_extract_field

# This import is required to create the tables
from keep.api.models.alert import AlertStatus, IncidentDtoIn
Expand Down Expand Up @@ -1842,16 +1842,8 @@ def update_preset_options(tenant_id: str, preset_id: str, options: dict) -> Pres

def assign_alert_to_incident(
alert_id: UUID, incident_id: UUID, tenant_id: str
) -> AlertToIncident:
with Session(engine) as session:
assignment = AlertToIncident(
alert_id=alert_id, incident_id=incident_id, tenant_id=tenant_id
)
session.add(assignment)
session.commit()
session.refresh(assignment)

return assignment
):
return add_alerts_to_incident_by_incident_id(tenant_id, incident_id, [alert_id])


def get_incidents(tenant_id) -> List[Incident]:
Expand Down Expand Up @@ -1960,7 +1952,6 @@ def get_last_incidents(
)
.filter(Incident.tenant_id == tenant_id)
.filter(Incident.is_confirmed == is_confirmed)
.options(joinedload(Incident.alerts))
.order_by(desc(Incident.creation_time))
)

Expand All @@ -1980,7 +1971,7 @@ def get_last_incidents(
return incidents, total_count


def get_incident_by_id(tenant_id: str, incident_id: str) -> Optional[Incident]:
def get_incident_by_id(tenant_id: str, incident_id: str | UUID) -> Optional[Incident]:
with Session(engine) as session:
query = session.query(
Incident,
Expand Down Expand Up @@ -2108,8 +2099,61 @@ def get_incident_alerts_by_incident_id(tenant_id: str, incident_id: str) -> List
return query.all()


def get_alerts_data_for_incident(
alert_ids: list[str | UUID],
session: Optional[Session] = None
) -> dict:

"""
Function to prepare aggregated data for incidents from the given list of alert_ids
Logic is wrapped to the inner function for better usability with an optional database session
Args:
alert_ids (list[str | UUID]): list of alert ids for aggregation
session (Optional[Session]): The database session or None
Returns: dict {sources: list[str], services: list[str], count: int}
"""

def inner(db_session: Session):

fields = (
get_json_extract_field(session, Alert.event, 'service'),
Alert.provider_type
)

alerts_data = db_session.exec(
select(
*fields
).where(
col(Alert.id).in_(alert_ids),
)
).all()

sources = []
services = []

for (service, source) in alerts_data:
if source:
sources.append(source)
if service:
services.append(service)

return {
"sources": set(sources),
"services": set(services),
"count": len(alerts_data)
}

# Ensure that we have a session to execute the query. If not - make new one
if not session:
with Session(engine) as session:
return inner(session)
return inner(session)


def add_alerts_to_incident_by_incident_id(
tenant_id: str, incident_id: str, alert_ids: List[UUID]
tenant_id: str, incident_id: str | UUID, alert_ids: List[UUID]
):
with Session(engine) as session:
incident = session.exec(
Expand All @@ -2130,21 +2174,34 @@ def add_alerts_to_incident_by_incident_id(
)
).all()

new_alert_ids = [alert_id for alert_id in alert_ids
if alert_id not in existed_alert_ids]

alerts_data_for_incident = get_alerts_data_for_incident(new_alert_ids, session)

incident.sources = list(
set(incident.sources) | set(alerts_data_for_incident["sources"])
)
incident.affected_services = list(
set(incident.affected_services) | set(alerts_data_for_incident["services"])
)
incident.alerts_count += alerts_data_for_incident["count"]

alert_to_incident_entries = [
AlertToIncident(
alert_id=alert_id, incident_id=incident.id, tenant_id=tenant_id
)
for alert_id in alert_ids
if alert_id not in existed_alert_ids
for alert_id in new_alert_ids
]

session.bulk_save_objects(alert_to_incident_entries)
session.add(incident)
session.commit()
return True


def remove_alerts_to_incident_by_incident_id(
tenant_id: str, incident_id: str, alert_ids: List[UUID]
tenant_id: str, incident_id: str | UUID, alert_ids: List[UUID]
) -> Optional[int]:
with Session(engine) as session:
incident = session.exec(
Expand All @@ -2157,6 +2214,7 @@ def remove_alerts_to_incident_by_incident_id(
if not incident:
return None

# Removing alerts-to-incident relation for provided alerts_ids
deleted = (
session.query(AlertToIncident)
.where(
Expand All @@ -2166,8 +2224,51 @@ def remove_alerts_to_incident_by_incident_id(
)
.delete()
)
session.commit()

# Getting aggregated data for incidents for alerts which just was removed
alerts_data_for_incident = get_alerts_data_for_incident(alert_ids, session)

service_field = get_json_extract_field(session, Alert.event, 'service')

# checking if services of removed alerts are still presented in alerts
# which still assigned with the incident
services_existed = session.exec(
session.query(func.distinct(service_field))
.join(AlertToIncident, Alert.id == AlertToIncident.alert_id)
.filter(
AlertToIncident.incident_id == incident_id,
service_field.in_(alerts_data_for_incident["services"])
)
).scalars()

# checking if sources (providers) of removed alerts are still presented in alerts
# which still assigned with the incident
sources_existed = session.exec(
session.query(col(Alert.provider_type).distinct())
.join(AlertToIncident, Alert.id == AlertToIncident.alert_id)
.filter(
AlertToIncident.incident_id == incident_id,
col(Alert.provider_type).in_(alerts_data_for_incident["sources"])
)
).scalars()

# Making lists of services and sources to remove from the incident
services_to_remove = [service for service in alerts_data_for_incident["services"]
if service not in services_existed]
sources_to_remove = [source for source in alerts_data_for_incident["sources"]
if source not in sources_existed]

# filtering removed entities from affected services and sources in the incident
incident.affected_services = [service for service in incident.affected_services
if service not in services_to_remove]
incident.sources = [source for source in incident.sources
if source not in sources_to_remove]

incident.alerts_count -= alerts_data_for_incident["count"]
session.add(incident)
session.commit()

return deleted


Expand Down
11 changes: 11 additions & 0 deletions keep/api/core/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pymysql
from dotenv import find_dotenv, load_dotenv
from google.cloud.sql.connector import Connector
from sqlalchemy import func
from sqlmodel import create_engine

# This import is required to create the tables
Expand Down Expand Up @@ -150,3 +151,13 @@ def create_db_engine():
echo=DB_ECHO,
)
return engine


def get_json_extract_field(session, base_field, key):

if session.bind.dialect.name == "postgresql":
return func.json_extract_path_text(base_field, key)
elif session.bind.dialect.name == "mysql":
return func.json_unquote(func.json_extract(base_field, '$.{}'.format(key)))
else:
return func.json_extract(base_field, '$.{}'.format(key))
19 changes: 6 additions & 13 deletions keep/api/models/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,8 @@ class UnEnrichAlertRequestBody(BaseModel):

class IncidentDtoIn(BaseModel):
name: str
description: str
assignee: str | None
user_summary: str | None

class Config:
extra = Extra.allow
Expand Down Expand Up @@ -364,6 +364,8 @@ class IncidentDto(IncidentDtoIn):

is_predicted: bool

generated_summary: str | None

def __str__(self) -> str:
# Convert the model instance to a dictionary
model_dict = self.dict()
Expand All @@ -381,15 +383,6 @@ class Config:
@classmethod
def from_db_incident(cls, db_incident):

alerts_dto = [AlertDto(**alert.event) for alert in db_incident.alerts]

unique_sources_list = list(
set([source for alert_dto in alerts_dto for source in alert_dto.source])
)
unique_service_list = list(
set([alert.service for alert in alerts_dto if alert.service is not None])
)

return cls(
id=db_incident.id,
name=db_incident.name,
Expand All @@ -398,9 +391,9 @@ def from_db_incident(cls, db_incident):
creation_time=db_incident.creation_time,
start_time=db_incident.start_time,
end_time=db_incident.end_time,
number_of_alerts=len(db_incident.alerts),
alert_sources=unique_sources_list,
number_of_alerts=db_incident.alerts_count,
alert_sources=db_incident.sources,
severity=IncidentSeverity.CRITICAL,
assignee=db_incident.assignee,
services=unique_service_list,
services=db_incident.affected_services,
)
4 changes: 4 additions & 0 deletions keep/api/models/db/alert.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ class Incident(SQLModel, table=True):
is_predicted: bool = Field(default=False)
is_confirmed: bool = Field(default=False)

alerts_count: int = Field(default=0)
affected_services: list = Field(sa_column=Column(JSON), default_factory=list)
sources: list = Field(sa_column=Column(JSON), default_factory=list)

def __init__(self, **kwargs):
super().__init__(**kwargs)
if "alerts" not in kwargs:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Add fields for prepopulated data from alerts
Revision ID: 67f1efb93c99
Revises: dcbd2873dcfd
Create Date: 2024-07-25 17:13:04.428633
"""

import sqlalchemy as sa
from alembic import op
from sqlalchemy.orm import Session, joinedload

from keep.api.models.alert import AlertDto
from keep.api.models.db.alert import Incident

# revision identifiers, used by Alembic.
revision = "67f1efb93c99"
down_revision = "dcbd2873dcfd"
branch_labels = None
depends_on = None


def populate_db(session):

incidents = session.query(Incident).options(joinedload(Incident.alerts)).all()

for incident in incidents:
alerts_dto = [AlertDto(**alert.event) for alert in incident.alerts]

incident.sources = list(
set([source for alert_dto in alerts_dto for source in alert_dto.source])
)
incident.affected_services = list(
set([alert.service for alert in alerts_dto if alert.service is not None])
)
incident.alerts_count = len(incident.alerts)
session.add(incident)
session.commit()


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("incident", sa.Column("affected_services", sa.JSON(), nullable=True))
op.add_column("incident", sa.Column("sources", sa.JSON(), nullable=True))
op.add_column("incident", sa.Column("alerts_count", sa.Integer(), nullable=False, server_default="0"))

session = Session(op.get_bind())
populate_db(session)

# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("incident", "alerts_count")
op.drop_column("incident", "sources")
op.drop_column("incident", "affected_services")
# ### end Alembic commands ###
Loading

0 comments on commit 1ea44bf

Please sign in to comment.