Skip to content

Commit

Permalink
fix: Each join where fingerprint is used MUST have EXPLICIT tenant_id (
Browse files Browse the repository at this point in the history
  • Loading branch information
VladimirFilonov authored Dec 5, 2024
1 parent 86386b3 commit d08f5d1
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 19 deletions.
89 changes: 72 additions & 17 deletions keep/api/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,7 +1288,10 @@ def get_last_alerts(
cast(LastAlertToIncident.incident_id, String)
).label("incidents"),
)
.filter(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT)
.filter(
LastAlertToIncident.tenant_id == tenant_id,
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT
)
.group_by(LastAlertToIncident.fingerprint)
.subquery()
)
Expand All @@ -1302,7 +1305,10 @@ def get_last_alerts(
cast(LastAlertToIncident.incident_id, String)
).label("incidents"),
)
.filter(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT)
.filter(
LastAlertToIncident.tenant_id == tenant_id,
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT
)
.group_by(LastAlertToIncident.fingerprint)
.subquery()
)
Expand All @@ -1317,7 +1323,10 @@ def get_last_alerts(
",",
).label("incidents"),
)
.filter(LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT)
.filter(
LastAlertToIncident.tenant_id == tenant_id,
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT
)
.group_by(LastAlertToIncident.fingerprint)
.subquery()
)
Expand Down Expand Up @@ -3063,6 +3072,7 @@ def enrich_incidents_with_alerts(tenant_id: str, incidents: List[Incident], sess
select(LastAlertToIncident.incident_id, Alert)
.select_from(LastAlert)
.join(LastAlertToIncident, and_(
LastAlertToIncident.tenant_id == LastAlert.tenant_id,
LastAlertToIncident.fingerprint == LastAlert.fingerprint,
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT,
))
Expand All @@ -3089,6 +3099,7 @@ def enrich_alerts_with_incidents(tenant_id: str, alerts: List[Alert], session: O
select(LastAlertToIncident.fingerprint, Incident)
.select_from(LastAlert)
.join(LastAlertToIncident, and_(
LastAlertToIncident.tenant_id == LastAlert.tenant_id,
LastAlertToIncident.fingerprint == LastAlert.fingerprint,
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT,
))
Expand Down Expand Up @@ -3377,7 +3388,11 @@ def get_incident_alerts_and_links_by_incident_id(
LastAlertToIncident,
)
.select_from(LastAlertToIncident)
.join(LastAlert, LastAlert.fingerprint == LastAlertToIncident.fingerprint)
.join(LastAlert, and_(
LastAlert.tenant_id == LastAlertToIncident.tenant_id,
LastAlert.fingerprint == LastAlertToIncident.fingerprint
)
)
.join(Alert, LastAlert.alert_id == Alert.id)
.filter(
LastAlertToIncident.tenant_id == tenant_id,
Expand Down Expand Up @@ -3458,7 +3473,10 @@ def get_alerts_data_for_incident(
alerts_data = session.exec(
select(*fields)
.select_from(LastAlert)
.join(Alert, LastAlert.alert_id == Alert.id)
.join(Alert, and_(
LastAlert.tenant_id == Alert.tenant_id,
LastAlert.alert_id == Alert.id,
))
.where(
LastAlert.tenant_id == tenant_id,
col(LastAlert.fingerprint).in_(fingerprints),
Expand Down Expand Up @@ -3531,7 +3549,10 @@ def add_alerts_to_incident(
existing_fingerprints = set(
session.exec(
select(LastAlert.fingerprint)
.join(LastAlertToIncident, LastAlertToIncident.fingerprint == LastAlert.fingerprint)
.join(LastAlertToIncident, and_(
LastAlertToIncident.tenant_id == LastAlert.tenant_id,
LastAlertToIncident.fingerprint == LastAlert.fingerprint
))
.where(
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT,
LastAlertToIncident.tenant_id == tenant_id,
Expand Down Expand Up @@ -3595,7 +3616,10 @@ def add_alerts_to_incident(

started_at, last_seen_at = session.exec(
select(func.min(Alert.timestamp), func.max(Alert.timestamp))
.join(LastAlertToIncident, LastAlertToIncident.fingerprint == Alert.fingerprint)
.join(LastAlertToIncident, and_(
LastAlertToIncident.tenant_id == Alert.tenant_id,
LastAlertToIncident.fingerprint == Alert.fingerprint
))
.where(
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT,
LastAlertToIncident.tenant_id == tenant_id,
Expand Down Expand Up @@ -3635,7 +3659,10 @@ def get_last_alerts_for_incidents(
LastAlertToIncident.incident_id,
)
.select_from(LastAlert)
.join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint)
.join(LastAlertToIncident, and_(
LastAlert.tenant_id == LastAlertToIncident.tenant_id,
LastAlert.fingerprint == LastAlertToIncident.fingerprint
))
.join(Alert, LastAlert.alert_id == Alert.id)
.filter(
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT,
Expand Down Expand Up @@ -3696,7 +3723,10 @@ def remove_alerts_to_incident_by_incident_id(
existed_services_query = (
select(func.distinct(service_field))
.select_from(LastAlert)
.join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint)
.join(LastAlertToIncident, and_(
LastAlert.tenant_id == LastAlertToIncident.tenant_id,
LastAlert.fingerprint == LastAlertToIncident.fingerprint
))
.join(Alert, LastAlert.alert_id == Alert.id)
.filter(
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT,
Expand All @@ -3711,7 +3741,10 @@ def remove_alerts_to_incident_by_incident_id(
existed_sources_query = (
select(col(Alert.provider_type).distinct())
.select_from(LastAlert)
.join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint)
.join(LastAlertToIncident, and_(
LastAlert.tenant_id == LastAlertToIncident.tenant_id,
LastAlert.fingerprint == LastAlertToIncident.fingerprint
))
.join(Alert, LastAlert.alert_id == Alert.id)
.filter(
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT,
Expand All @@ -3736,7 +3769,10 @@ def remove_alerts_to_incident_by_incident_id(
started_at, last_seen_at = session.exec(
select(func.min(Alert.timestamp), func.max(Alert.timestamp))
.select_from(LastAlert)
.join(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint)
.join(LastAlertToIncident, and_(
LastAlert.tenant_id == LastAlertToIncident.tenant_id,
LastAlert.fingerprint == LastAlertToIncident.fingerprint,
))
.join(Alert, LastAlert.alert_id == Alert.id)
.where(
LastAlertToIncident.tenant_id == tenant_id,
Expand Down Expand Up @@ -4186,10 +4222,14 @@ def get_workflow_executions_for_incident_or_alert(
LastAlert, WorkflowToAlertExecution.alert_fingerprint == LastAlert.fingerprint
)
.join(Alert, LastAlert.alert_id == Alert.id)
.join(LastAlertToIncident, Alert.fingerprint == LastAlertToIncident.fingerprint)
.join(LastAlertToIncident, and_(
LastAlert.tenant_id == LastAlertToIncident.tenant_id,
LastAlert.fingerprint == LastAlertToIncident.fingerprint
))
.where(
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT,
LastAlertToIncident.incident_id == incident_id,
LastAlert.tenant_id == tenant_id,
)
)

Expand Down Expand Up @@ -4235,9 +4275,15 @@ def is_all_incident_alerts_resolved(
.select_from(LastAlert)
.join(Alert, LastAlert.alert_id == Alert.id)
.outerjoin(
AlertEnrichment, Alert.fingerprint == AlertEnrichment.alert_fingerprint
AlertEnrichment, and_(
Alert.tenant_id == AlertEnrichment.tenant_id,
Alert.fingerprint == AlertEnrichment.alert_fingerprint
),
)
.join(LastAlertToIncident, LastAlertToIncident.fingerprint == LastAlert.fingerprint)
.join(LastAlertToIncident, and_(
LastAlertToIncident.tenant_id == LastAlert.tenant_id,
LastAlertToIncident.fingerprint == LastAlert.fingerprint
))
.where(
LastAlertToIncident.deleted_at == NULL_FOR_DELETED_AT,
LastAlertToIncident.incident_id == incident.id,
Expand Down Expand Up @@ -4296,9 +4342,15 @@ def is_edge_incident_alert_resolved(
select(Alert.fingerprint, enriched_status_field, status_field)
.select_from(Alert)
.outerjoin(
AlertEnrichment, Alert.fingerprint == AlertEnrichment.alert_fingerprint
AlertEnrichment, and_(
Alert.tenant_id == AlertEnrichment.tenant_id,
Alert.fingerprint == AlertEnrichment.alert_fingerprint
)
)
.join(LastAlertToIncident, LastAlertToIncident.fingerprint == Alert.fingerprint)
.join(LastAlertToIncident, and_(
LastAlertToIncident.tenant_id == Alert.tenant_id,
LastAlertToIncident.fingerprint == Alert.fingerprint
))
.where(LastAlertToIncident.incident_id == incident.id)
.group_by(Alert.fingerprint)
.having(func.max(Alert.timestamp))
Expand Down Expand Up @@ -4345,7 +4397,10 @@ def get_alerts_metrics_by_provider(
*dynamic_field_sums,
)
.join(LastAlert, Alert.id == LastAlert.alert_id)
.outerjoin(LastAlertToIncident, LastAlert.fingerprint == LastAlertToIncident.fingerprint)
.outerjoin(LastAlertToIncident, and_(
LastAlert.tenant_id == LastAlertToIncident.tenant_id,
LastAlert.fingerprint == LastAlertToIncident.fingerprint
))
.filter(
Alert.tenant_id == tenant_id,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,13 +625,13 @@ def setup_stress_alerts(

@pytest.fixture
def create_alert(db_session):
def _create_alert(fingerprint, status, timestamp, details=None):
def _create_alert(fingerprint, status, timestamp, details=None, tenant_id=SINGLE_TENANT_UUID):
details = details or {}
random_name = "test-{}".format(fingerprint)
process_event(
ctx={"job_try": 1},
trace_id="test",
tenant_id=SINGLE_TENANT_UUID,
tenant_id=tenant_id,
provider_id="test",
provider_type=(
details["source"][0] if details and "source" in details and details["source"] else None
Expand Down
62 changes: 62 additions & 0 deletions tests/test_incidents.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
IncidentStatus,
)
from keep.api.models.db.alert import Alert, LastAlertToIncident
from keep.api.models.db.tenant import Tenant
from keep.api.utils.enrichment_helpers import convert_db_alerts_to_dto_alerts
from tests.fixtures.client import client, test_app # noqa

Expand Down Expand Up @@ -798,3 +799,64 @@ def test_merge_incidents_app(
assert incident_3_via_api["status"] == IncidentStatus.MERGED.value
assert incident_3_via_api["merged_into_incident_id"] == str(incident_1.id)
"""

def test_cross_tenant_exposure_issue_2768(db_session, create_alert):


tenant_data = [
Tenant(id="tenant_1", name="test-tenant-1", created_by="[email protected]"),
Tenant(id="tenant_2", name="test-tenant-2", created_by="[email protected]")
]
db_session.add_all(tenant_data)
db_session.commit()

incident_tenant_1 = create_incident_from_dict(
"tenant_1", {"user_generated_name": "test", "user_summary": "test"}
)
incident_tenant_2 = create_incident_from_dict(
"tenant_2", {"user_generated_name": "test", "user_summary": "test"}
)

create_alert(
"non-unique-fingerprint",
AlertStatus.FIRING,
datetime.utcnow(),
{},
tenant_id="tenant_1"
)

create_alert(
"non-unique-fingerprint",
AlertStatus.FIRING,
datetime.utcnow(),
{},
tenant_id="tenant_2"
)

alert_tenant_1 = db_session.query(Alert).filter(Alert.tenant_id == 'tenant_1').first()
alert_tenant_2 = db_session.query(Alert).filter(Alert.tenant_id == 'tenant_2').first()

add_alerts_to_incident_by_incident_id(
"tenant_1", incident_tenant_1.id, [alert_tenant_1.fingerprint]
)
add_alerts_to_incident_by_incident_id(
"tenant_2", incident_tenant_2.id, [alert_tenant_2.fingerprint]
)

incident_tenant_1 = get_incident_by_id("tenant_1", incident_tenant_1.id)
incident_tenant_1_alerts, total_incident_tenant_1_alerts = get_incident_alerts_by_incident_id(
tenant_id="tenant_1",
incident_id=incident_tenant_1.id,
)
assert incident_tenant_1.alerts_count == 1
assert total_incident_tenant_1_alerts == 1
assert len(incident_tenant_1_alerts) == 1

incident_tenant_2 = get_incident_by_id("tenant_2", incident_tenant_2.id)
incident_tenant_2_alerts, total_incident_tenant_2_alerts = get_incident_alerts_by_incident_id(
tenant_id="tenant_2",
incident_id=incident_tenant_2.id,
)
assert incident_tenant_2.alerts_count == 1
assert total_incident_tenant_2_alerts == 1
assert len(incident_tenant_2_alerts) == 1

0 comments on commit d08f5d1

Please sign in to comment.