diff --git a/keep/api/core/db.py b/keep/api/core/db.py index d29b785c7..75817f455 100644 --- a/keep/api/core/db.py +++ b/keep/api/core/db.py @@ -1760,7 +1760,7 @@ def create_incident_for_grouping_rule( rule_id=rule.id, rule_fingerprint=rule_fingerprint, is_predicted=False, - is_confirmed=not rule.require_approve, + is_confirmed=rule.create_on == CreateIncidentOn.ANY.value and not rule.require_approve, ) session.add(incident) session.commit() @@ -4665,31 +4665,3 @@ def set_last_alert( transaction.commit() -def get_or_create_rule_group_by_rule_id( - tenant_id: str, - rule_id: str | UUID, - timeframe: int, - session: Optional[Session] = None -): - - with existed_or_new_session(session) as session: - group = session.query(RuleEventGroup).where( - and_( - RuleEventGroup.tenant_id == tenant_id, - RuleEventGroup.rule_id == rule_id, - RuleEventGroup.expires > datetime.utcnow(), - ) - ).first() - - if group is None: - group = RuleEventGroup( - tenant_id=tenant_id, - rule_id=rule_id, - expires=datetime.utcnow() + timedelta(seconds=timeframe), - state={} - ) - session.add(group) - session.commit() - session.refresh(group) - - return group diff --git a/keep/api/models/db/alert.py b/keep/api/models/db/alert.py index 10d1dab71..774d5eede 100644 --- a/keep/api/models/db/alert.py +++ b/keep/api/models/db/alert.py @@ -205,6 +205,10 @@ class Incident(SQLModel, table=True): class Config: arbitrary_types_allowed = True + @property + def alerts(self): + return self._alerts + class Alert(SQLModel, table=True): id: UUID = Field(default_factory=uuid4, primary_key=True) diff --git a/keep/api/models/db/migrations/versions/2024-12-07-22-18_8d4dc7d44a9c.py b/keep/api/models/db/migrations/versions/2024-12-07-22-18_8d4dc7d44a9c.py deleted file mode 100644 index d41f2a1d4..000000000 --- a/keep/api/models/db/migrations/versions/2024-12-07-22-18_8d4dc7d44a9c.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Add RuleEventGroup - -Revision ID: 8d4dc7d44a9c -Revises: c6e5594c99f8 -Create Date: 2024-12-07 22:18:23.704507 - -""" - -import sqlalchemy as sa -import sqlmodel -from alembic import op - -# revision identifiers, used by Alembic. -revision = "8d4dc7d44a9c" -down_revision = "c6e5594c99f8" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "ruleeventgroup", - sa.Column("state", sa.JSON(), nullable=True), - sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), - sa.Column("rule_id", sqlmodel.sql.sqltypes.GUID(), nullable=False), - sa.Column("tenant_id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.Column("expires", sa.DateTime(), nullable=False), - sa.ForeignKeyConstraint( - ["rule_id"], - ["rule.id"], - ), - sa.ForeignKeyConstraint( - ["tenant_id"], - ["tenant.id"], - ), - sa.PrimaryKeyConstraint("id"), - ) - with op.batch_alter_table("rule", schema=None) as batch_op: - batch_op.add_column( - sa.Column( - "create_on", - sqlmodel.sql.sqltypes.AutoString(), - nullable=False, - server_default="any", - ), - ) - - -def downgrade() -> None: - op.drop_table("ruleeventgroup") - with op.batch_alter_table("rule", schema=None) as batch_op: - batch_op.drop_column("create_on") diff --git a/keep/api/models/db/rule.py b/keep/api/models/db/rule.py index 4b626df1b..0a81d6ddf 100644 --- a/keep/api/models/db/rule.py +++ b/keep/api/models/db/rule.py @@ -52,29 +52,3 @@ class Rule(SQLModel, table=True): require_approve: bool = False resolve_on: str = ResolveOn.NEVER.value create_on: str = CreateIncidentOn.ANY.value - - -class RuleEventGroup(SQLModel, table=True): - - id: UUID = Field(default_factory=uuid4, primary_key=True) - rule_id: UUID = Field(foreign_key="rule.id") - tenant_id: str = Field(foreign_key="tenant.id") - state: Dict[str, List[UUID | str]] = Field( - sa_column=Column(JSON), - default_factory=lambda: defaultdict(list) - ) - expires: datetime - - def is_all_conditions_met(self, rule_groups: List[str]): - return all([ - len(self.state.get(condition, [])) - for condition in rule_groups - ]) - - def add_alert(self, condition, fingerprint): - self.state.setdefault(condition, []) - self.state[condition].append(fingerprint) - flag_modified(self, "state") - - def get_all_alerts(self): - return list(set(chain(*self.state.values()))) diff --git a/keep/rulesengine/rulesengine.py b/keep/rulesengine/rulesengine.py index 200097930..ce45698e6 100644 --- a/keep/rulesengine/rulesengine.py +++ b/keep/rulesengine/rulesengine.py @@ -15,16 +15,15 @@ get_incident_for_grouping_rule, create_incident_for_grouping_rule, get_or_create_rule_group_by_rule_id, - add_alerts_to_incident, is_all_alerts_resolved, is_first_incident_alert_resolved, is_last_incident_alert_resolved, - is_all_alerts_in_status, + is_all_alerts_in_status, enrich_incidents_with_alerts, ) from keep.api.core.db import get_rules as get_rules_db from keep.api.models.alert import AlertDto, AlertSeverity, IncidentDto, IncidentStatus, AlertStatus from keep.api.models.db.alert import Incident -from keep.api.models.db.rule import ResolveOn, RuleEventGroup, Rule +from keep.api.models.db.rule import ResolveOn, Rule from keep.api.utils.cel_utils import preprocess_cel_expression # Shahar: this is performance enhancment https://github.com/cloud-custodian/cel-python/issues/68 @@ -66,34 +65,33 @@ def run_rules( f"Checking if rule {rule.name} apply to event {event.id}" ) try: - rule_result, sub_rule = self._check_if_rule_apply(rule, event) + matched_rules = self._check_if_rule_apply(rule, event) except Exception: self.logger.exception( f"Failed to evaluate rule {rule.name} on event {event.id}" ) continue - if rule_result: + if matched_rules: self.logger.info( f"Rule {rule.name} on event {event.id} is relevant" ) rule_fingerprint = self._calc_rule_fingerprint(event, rule) - incident = get_incident_for_grouping_rule( - self.tenant_id, + incident = self._get_or_create_incident( rule, rule_fingerprint, + session, + ) + incident = assign_alert_to_incident( + fingerprint=event.fingerprint, + incident=incident, + tenant_id=self.tenant_id, session=session, ) - if incident: - incident = assign_alert_to_incident( - fingerprint=event.fingerprint, - incident=incident, - tenant_id=self.tenant_id, - session=session, - ) - else: + + if not incident.is_confirmed: self.logger.info( f"No existing incidents for rule {rule.name}. Checking incident creation conditions" @@ -101,20 +99,18 @@ def run_rules( rule_groups = self._extract_subrules(rule.definition_cel) - if rule.create_on == "any" or (rule.create_on == "all" and len(rule_groups) == 1): + if rule.create_on == "any" or (rule.create_on == "all" and len(rule_groups) == len(matched_rules)): self.logger.info("Single event is enough, so creating incident") - incident = self._create_incident_with_alerts( - rule, rule_fingerprint, [event.fingerprint], session=session - ) + incident.is_confirmed = True elif rule.create_on == "all": incident = self._process_event_for_history_based_rule( - event, rule, sub_rule, rule_groups, rule_fingerprint, session + incident, rule, session ) - if incident: - - incident = self._resolve_incident_if_require(rule, incident, session) - incidents_dto[incident.id] = IncidentDto.from_db_incident(incident) + incident = self._resolve_incident_if_require(rule, incident, session) + session.add(incident) + session.commit() + incidents_dto[incident.id] = IncidentDto.from_db_incident(incident) else: self.logger.info( @@ -130,62 +126,58 @@ def run_rules( return list(incidents_dto.values()) - def _create_incident_with_alerts( - self, - rule: Rule, - rule_fingerprint: str, - fingerprints: List[str], - session: Session, - ) -> Incident: - incident = create_incident_for_grouping_rule( + def _get_or_create_incident(self, rule, rule_fingerprint, session): + incident = get_incident_for_grouping_rule( self.tenant_id, rule, rule_fingerprint, session=session, ) - incident = add_alerts_to_incident( - self.tenant_id, - incident, - fingerprints, - session=session, - ) - + if not incident: + incident = create_incident_for_grouping_rule( + self.tenant_id, + rule, + rule_fingerprint, + session=session, + ) return incident def _process_event_for_history_based_rule( self, - event: AlertDto, + incident: Incident, rule: Rule, - sub_rule: str, - rule_groups: List[str], - rule_fingerprint: str, session: Session - ) -> Optional[Incident]: + ) -> Incident: self.logger.info( "Multiple events required for the incident to start" ) - rule_group = self._get_rule_group(rule, session) - rule_group.add_alert(sub_rule, event.fingerprint) + enrich_incidents_with_alerts( + tenant_id=self.tenant_id, + incidents=[incident], + session=session, + ) + + fingerprints = [alert.fingerprint for alert in incident.alerts] - fingerprints = rule_group.get_all_alerts() + is_all_conditions_met = False - incident = None - if rule_group.is_all_conditions_met(rule_groups) and is_all_alerts_in_status( - fingerprints=fingerprints, status=AlertStatus.FIRING, session=session - ): - self.logger.info("All required events are in the system, so creating incident") - incident = self._create_incident_with_alerts(rule, rule_fingerprint, fingerprints, session=session) - session.delete(rule_group) - session.commit() - - else: - self.logger.info(f"Updating state for rule `{rule.name}` events group") - # Updating rule_group expiration+ - rule_group.expires = datetime.datetime.utcnow() + datetime.timedelta(seconds=rule.timeframe) - session.add(rule_group) - session.commit() + all_sub_rules = set(self._extract_subrules(rule.definition_cel)) + matched_sub_rules = set() + + for alert in incident.alerts: + matched_sub_rules = matched_sub_rules.union(self._check_if_rule_apply(rule, AlertDto(**alert.event))) + if all_sub_rules == matched_sub_rules: + is_all_conditions_met = True + break + + if is_all_conditions_met: + all_alerts_firing = is_all_alerts_in_status( + fingerprints=fingerprints, status=AlertStatus.FIRING, session=session + ) + if all_alerts_firing: + incident.is_confirmed = True return incident @@ -214,8 +206,6 @@ def _resolve_incident_if_require(rule: Rule, incident: Incident, session: Sessio if should_resolve: incident.status = IncidentStatus.RESOLVED.value - session.add(incident) - session.commit() return incident @@ -236,7 +226,7 @@ def _extract_subrules(expression): return sub_rules # TODO: a lot of unit tests to write here - def _check_if_rule_apply(self, rule, event: AlertDto): + def _check_if_rule_apply(self, rule: Rule, event: AlertDto) -> List[str]: sub_rules = self._extract_subrules(rule.definition_cel) payload = event.dict() # workaround since source is a list @@ -246,6 +236,7 @@ def _check_if_rule_apply(self, rule, event: AlertDto): # what we do here is to compile the CEL rule and evaluate it # https://github.com/cloud-custodian/cel-python # https://github.com/google/cel-spec + sub_rules_matched = [] for sub_rule in sub_rules: ast = self.env.compile(sub_rule) prgm = self.env.program(ast) @@ -255,13 +246,13 @@ def _check_if_rule_apply(self, rule, event: AlertDto): except celpy.evaluation.CELEvalError as e: # this is ok, it means that the subrule is not relevant for this event if "no such member" in str(e): - return False, None + continue # unknown raise if r: - return True, sub_rule + sub_rules_matched.append(sub_rule) # no subrules matched - return False, None + return sub_rules_matched def _calc_rule_fingerprint(self, event: AlertDto, rule): # extract all the grouping criteria from the event @@ -383,7 +374,3 @@ def filter_alerts( filtered_alerts.append(alert) return filtered_alerts - - @staticmethod - def _get_rule_group(rule: Rule, session: Session) -> RuleEventGroup: - return get_or_create_rule_group_by_rule_id(rule.tenant_id, rule.id, rule.timeframe, session) diff --git a/tests/test_rules_engine.py b/tests/test_rules_engine.py index 7401550ce..b80337f3d 100644 --- a/tests/test_rules_engine.py +++ b/tests/test_rules_engine.py @@ -7,7 +7,7 @@ import pytest -from keep.api.core.db import create_rule as create_rule_db +from keep.api.core.db import create_rule as create_rule_db, enrich_incidents_with_alerts from keep.api.core.db import get_incident_alerts_by_incident_id, get_last_incidents, set_last_alert from keep.api.core.db import get_rules as get_rules_db from keep.api.core.dependencies import SINGLE_TENANT_UUID @@ -19,7 +19,7 @@ IncidentStatus, ) from keep.api.models.db.alert import Alert, Incident -from keep.api.models.db.rule import ResolveOn, CreateIncidentOn, RuleEventGroup +from keep.api.models.db.rule import ResolveOn, CreateIncidentOn from keep.rulesengine.rulesengine import RulesEngine @@ -73,7 +73,7 @@ def test_sanity(db_session): set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id - results = rules_engine.run_rules(alerts) + results = rules_engine.run_rules(alerts, session=db_session) # check that there are results assert len(results) > 0 @@ -121,7 +121,7 @@ def test_sanity_2(db_session): set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id - results = rules_engine.run_rules(alerts) + results = rules_engine.run_rules(alerts, session=db_session) # check that there are results assert len(results) > 0 @@ -170,7 +170,7 @@ def test_sanity_3(db_session): set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id - results = rules_engine.run_rules(alerts) + results = rules_engine.run_rules(alerts, session=db_session) # check that there are results assert len(results) > 0 @@ -219,7 +219,7 @@ def test_sanity_4(db_session): set_last_alert(SINGLE_TENANT_UUID, alert, db_session) # run the rules engine alerts[0].event_id = alert.id - results = rules_engine.run_rules(alerts) + results = rules_engine.run_rules(alerts, session=db_session) # check that there are results assert results == [] @@ -275,7 +275,7 @@ def test_incident_attributes(db_session): for i, alert in enumerate(alerts_dto): alert.event_id = alerts[i].id - results = rules_engine.run_rules([alert]) + results = rules_engine.run_rules([alert], session=db_session) # check that there are results assert results is not None assert len(results) == 1 @@ -338,7 +338,7 @@ def test_incident_severity(db_session): for i, alert in enumerate(alerts_dto): alert.event_id = alerts[i].id - results = rules_engine.run_rules(alerts_dto) + results = rules_engine.run_rules(alerts_dto, session=db_session) # check that there are results assert results is not None assert len(results) == 1 @@ -583,7 +583,7 @@ def test_incident_resolution_on_edge( assert incident.status == IncidentStatus.RESOLVED.value -def test_rule_event_groups(db_session, create_alert): +def test_rule_multiple_alerts(db_session, create_alert): create_rule_db( tenant_id=SINGLE_TENANT_UUID, @@ -609,16 +609,17 @@ def test_rule_event_groups(db_session, create_alert): ) # No incident yet - assert db_session.query(Incident).count() == 0 - # But RuleEventGroup - assert db_session.query(RuleEventGroup).count() == 1 - event_group = db_session.query(RuleEventGroup).first() + assert db_session.query(Incident).filter(Incident.is_confirmed == True).count() == 0 + # But candidate is there + assert db_session.query(Incident).filter(Incident.is_confirmed == False).count() == 1 + incident = db_session.query(Incident).first() alert_1 = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() - assert isinstance(event_group.state, dict) - assert 'severity == "critical"' in event_group.state - assert len(event_group.state['severity == "critical"']) == 1 - assert event_group.state['severity == "critical"'][0] == alert_1.fingerprint + enrich_incidents_with_alerts(SINGLE_TENANT_UUID, [incident], db_session) + + assert incident.alerts_count == 1 + assert len(incident.alerts) == 1 + assert incident.alerts[0].id == alert_1.id create_alert( "Critical Alert 2", @@ -629,19 +630,20 @@ def test_rule_event_groups(db_session, create_alert): }, ) - db_session.refresh(event_group) + db_session.refresh(incident) alert_2 = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() # Still no incident yet - assert db_session.query(Incident).count() == 0 - # And still one RuleEventGroup - assert db_session.query(RuleEventGroup).count() == 1 + assert db_session.query(Incident).filter(Incident.is_confirmed == True).count() == 0 + # And still one candidate is there + assert db_session.query(Incident).filter(Incident.is_confirmed == False).count() == 1 + + enrich_incidents_with_alerts(SINGLE_TENANT_UUID, [incident], db_session) - assert isinstance(event_group.state, dict) - assert 'severity == "critical"' in event_group.state - assert len(event_group.state['severity == "critical"']) == 2 - assert event_group.state['severity == "critical"'][0] == alert_1.fingerprint - assert event_group.state['severity == "critical"'][1] == alert_2.fingerprint + assert incident.alerts_count == 2 + assert len(incident.alerts) == 2 + assert incident.alerts[0].id == alert_1.id + assert incident.alerts[1].id == alert_2.id create_alert( "High Alert", @@ -651,15 +653,14 @@ def test_rule_event_groups(db_session, create_alert): "severity": AlertSeverity.HIGH.value, }, ) - alert_3 = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() - # RuleEventGroup was removed - assert db_session.query(RuleEventGroup).count() == 0 + alert_3 = db_session.query(Alert).order_by(Alert.timestamp.desc()).first() + enrich_incidents_with_alerts(SINGLE_TENANT_UUID, [incident], db_session) - # And incident was started - assert db_session.query(Incident).count() == 1 + # And incident was official started + assert db_session.query(Incident).filter(Incident.is_confirmed == True).count() == 1 - incident = db_session.query(Incident).first() + db_session.refresh(incident) assert incident.alerts_count == 3 alerts, alert_count = get_incident_alerts_by_incident_id( @@ -702,10 +703,10 @@ def test_rule_event_groups_expires(db_session, create_alert): }, ) - # No incident yet - assert db_session.query(Incident).count() == 0 - # One RuleEventGroup - assert db_session.query(RuleEventGroup).count() == 1 + # Still no incident yet + assert db_session.query(Incident).filter(Incident.is_confirmed == True).count() == 0 + # And still one candidate is there + assert db_session.query(Incident).filter(Incident.is_confirmed == False).count() == 1 sleep(1) @@ -718,10 +719,10 @@ def test_rule_event_groups_expires(db_session, create_alert): }, ) - # Still no incident - assert db_session.query(Incident).count() == 0 - # And now two RuleEventGroup - first one was expired - assert db_session.query(RuleEventGroup).count() == 2 + # Still no incident yet + assert db_session.query(Incident).filter(Incident.is_confirmed == True).count() == 0 + # And now two candidates is there + assert db_session.query(Incident).filter(Incident.is_confirmed == False).count() == 2