Skip to content

Commit

Permalink
fix: refactoring pilot logging code
Browse files Browse the repository at this point in the history
  • Loading branch information
martynia committed Jan 6, 2025
1 parent f02efba commit e9ad537
Show file tree
Hide file tree
Showing 6 changed files with 394 additions and 2 deletions.
4 changes: 2 additions & 2 deletions diracx-routers/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ types = [
]

[project.entry-points."diracx.services"]
pilotlogs = "diracx.routers.pilot_logging.remote_logger:router"
pilots = "diracx.routers.pilots:router"
jobs = "diracx.routers.jobs:router"
config = "diracx.routers.configuration:router"
auth = "diracx.routers.auth:router"
Expand All @@ -57,7 +57,7 @@ auth = "diracx.routers.auth:router"
[project.entry-points."diracx.access_policies"]
WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy"
SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy"
PilotLogsAccessPolicy = "diracx.routers.pilot_logging.access_policies:PilotLogsAccessPolicy"
PilotLogsAccessPolicy = "diracx.routers.pilots.access_policies:PilotLogsAccessPolicy"

# Minimum version of the client supported
[project.entry-points."diracx.min_client_version"]
Expand Down
11 changes: 11 additions & 0 deletions diracx-routers/src/diracx/routers/pilots/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import annotations

from logging import getLogger

from ..fastapi_classes import DiracxRouter
from .logging import router as logging_router

logger = getLogger(__name__)

router = DiracxRouter()
router.include_router(logging_router)
63 changes: 63 additions & 0 deletions diracx-routers/src/diracx/routers/pilots/access_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from enum import StrEnum, auto
from typing import Annotated, Callable

from fastapi import Depends, HTTPException, status

from diracx.core.properties import (
GENERIC_PILOT,
NORMAL_USER,
OPERATOR,
PILOT,
SERVICE_ADMINISTRATOR,
)
from diracx.routers.access_policies import BaseAccessPolicy

from ..utils.users import AuthorizedUserInfo


class ActionType(StrEnum):
#: Create/update pilot log records
CREATE = auto()
#: delete pilot logs
DELETE = auto()
#: Search
QUERY = auto()


class PilotLogsAccessPolicy(BaseAccessPolicy):
"""Rules:
Only PILOT, GENERIC_PILOT, SERVICE_ADMINISTRATOR and OPERATOR can process log records.
Policies for other actions to be determined.
"""

@staticmethod
async def policy(
policy_name: str,
user_info: AuthorizedUserInfo,
/,
*,
action: ActionType | None = None,
):

if action is None:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail="Action is a mandatory argument"
)

if GENERIC_PILOT in user_info.properties and action == ActionType.CREATE:
return user_info
if PILOT in user_info.properties and action == ActionType.CREATE:
return user_info
if NORMAL_USER in user_info.properties and action == ActionType.QUERY:
return user_info
if SERVICE_ADMINISTRATOR in user_info.properties:
return user_info
if OPERATOR in user_info.properties:
return user_info

raise HTTPException(status.HTTP_403_FORBIDDEN, detail=user_info.properties)


CheckPilotLogsPolicyCallable = Annotated[Callable, Depends(PilotLogsAccessPolicy.check)]
157 changes: 157 additions & 0 deletions diracx-routers/src/diracx/routers/pilots/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from __future__ import annotations

import datetime
import logging

from fastapi import HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound

from diracx.core.exceptions import InvalidQueryError
from diracx.core.properties import OPERATOR, SERVICE_ADMINISTRATOR
from diracx.db.sql.pilot_agents.schema import PilotAgents
from diracx.db.sql.utils import BaseSQLDB

from ..dependencies import PilotLogsDB
from ..fastapi_classes import DiracxRouter
from ..utils.users import AuthorizedUserInfo
from .access_policies import ActionType, CheckPilotLogsPolicyCallable

logger = logging.getLogger(__name__)
router = DiracxRouter()


class LogLine(BaseModel):
line_no: int
line: str


class LogMessage(BaseModel):
pilot_stamp: str
lines: list[LogLine]
vo: str


class DateRange(BaseModel):
min: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z")
max: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z")


@router.post("/")
async def send_message(
data: LogMessage,
pilot_logs_db: PilotLogsDB,
check_permissions: CheckPilotLogsPolicyCallable,
) -> int:

logger.warning(f"Message received '{data}'")
user_info = await check_permissions(action=ActionType.CREATE)
pilot_id = 0 # need to get pilot id from pilot_stamp (via PilotAgentsDB)
# also add a timestamp to be able to select and delete logs based on pilot creation dates, even if corresponding
# pilots have been already deleted from PilotAgentsDB (so the logs can live longer than pilots).
submission_time = datetime.datetime.fromtimestamp(0, datetime.timezone.utc)
pilot_agents_db = BaseSQLDB.available_implementations("PilotAgentsDB")[0]
url = BaseSQLDB.available_urls()["PilotAgentsDB"]
db = pilot_agents_db(url)

try:
async with db.engine_context():
async with db:
stmt = select(PilotAgents.pilot_id, PilotAgents.submission_time).where(
PilotAgents.pilot_stamp == data.pilot_stamp
)
pilot_id, submission_time = (await db.conn.execute(stmt)).one()
except NoResultFound as exc:
logger.error(
f"Cannot determine PilotID for requested PilotStamp: {data.pilot_stamp}, Error: {exc}."
)
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc

docs = []
for line in data.lines:
docs.append(
{
"PilotStamp": data.pilot_stamp,
"PilotID": pilot_id,
"SubmissionTime": submission_time,
"VO": user_info.vo,
"LineNumber": line.line_no,
"Message": line.line,
}
)
await pilot_logs_db.bulk_insert(pilot_logs_db.index_name(pilot_id), docs)
return pilot_id


@router.get("/logs")
async def get_logs(
pilot_id: int,
db: PilotLogsDB,
check_permissions: CheckPilotLogsPolicyCallable,
) -> list[dict]:

logger.warning(f"Retrieving logs for pilot ID '{pilot_id}'")
user_info = await check_permissions(action=ActionType.QUERY)

# here, users with privileged properties will see logs from all VOs. Is it what we want ?
search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}]
if _non_privileged(user_info):
search_params.append(
{"parameter": "VO", "operator": "eq", "value": user_info.vo}
)
result = await db.search(
["Message"],
search_params,
[{"parameter": "LineNumber", "direction": "asc"}],
)
if not result:
return [{"Message": f"No logs for pilot ID = {pilot_id}"}]
return result


@router.delete("/logs")
async def delete(
pilot_id: int,
data: DateRange,
db: PilotLogsDB,
check_permissions: CheckPilotLogsPolicyCallable,
) -> str:
"""Delete either logs for a specific PilotID or a creation date range.
Non-privileged users can only delete log files within their own VO.
"""
message = "no-op"
user_info = await check_permissions(action=ActionType.DELETE)
non_privil_params = {"parameter": "VO", "operator": "eq", "value": user_info.vo}

# id pilot_id is provided we ignore data.min and data.max
if data.min and data.max and not pilot_id:
raise InvalidQueryError(
"This query requires a range operator definition in DiracX"
)

if pilot_id:
search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}]
if _non_privileged(user_info):
search_params.append(non_privil_params)
await db.delete(search_params)
message = f"Logs for pilot ID '{pilot_id}' successfully deleted"

elif data.min:
logger.warning(f"Deleting logs for pilots with submission data >='{data.min}'")
search_params = [
{"parameter": "SubmissionTime", "operator": "gt", "value": data.min}
]
if _non_privileged(user_info):
search_params.append(non_privil_params)
await db.delete(search_params)
message = f"Logs for for pilots with submission data >='{data.min}' successfully deleted"

return message


def _non_privileged(user_info: AuthorizedUserInfo):
return (
SERVICE_ADMINISTRATOR not in user_info.properties
and OPERATOR not in user_info.properties
)
61 changes: 61 additions & 0 deletions diracx-routers/tests/pilots/test_access_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

from contextlib import nullcontext
from unittest.mock import MagicMock

import pytest
from fastapi import HTTPException

from diracx.core.properties import (
GENERIC_PILOT,
NORMAL_USER,
OPERATOR,
PILOT,
SERVICE_ADMINISTRATOR,
)
from diracx.routers.pilots.access_policies import (
ActionType,
PilotLogsAccessPolicy,
)


@pytest.mark.parametrize(
"user, action, expectation",
[
(PILOT, ActionType.CREATE, nullcontext()),
(PILOT, ActionType.QUERY, pytest.raises(HTTPException, match="403")),
(PILOT, ActionType.DELETE, pytest.raises(HTTPException, match="403")),
(GENERIC_PILOT, ActionType.CREATE, nullcontext()),
(GENERIC_PILOT, ActionType.QUERY, pytest.raises(HTTPException, match="403")),
(GENERIC_PILOT, ActionType.DELETE, pytest.raises(HTTPException, match="403")),
(SERVICE_ADMINISTRATOR, ActionType.CREATE, nullcontext()),
(SERVICE_ADMINISTRATOR, ActionType.QUERY, nullcontext()),
(SERVICE_ADMINISTRATOR, ActionType.DELETE, nullcontext()),
(OPERATOR, ActionType.CREATE, nullcontext()),
(OPERATOR, ActionType.QUERY, nullcontext()),
(OPERATOR, ActionType.DELETE, nullcontext()),
(NORMAL_USER, ActionType.CREATE, pytest.raises(HTTPException, match="403")),
(NORMAL_USER, ActionType.QUERY, nullcontext()),
(NORMAL_USER, ActionType.DELETE, pytest.raises(HTTPException, match="403")),
(
"malicious_user",
ActionType.CREATE,
pytest.raises(HTTPException, match="403"),
),
("malicious_user", ActionType.QUERY, pytest.raises(HTTPException, match="403")),
(
"malicious_user",
ActionType.DELETE,
pytest.raises(HTTPException, match="403"),
),
("any_user", None, pytest.raises(HTTPException, match="400")),
],
)
async def test_access_policies(user, action, expectation):
user_info = MagicMock()
user_info.properties = [user]
with expectation:
ret = await PilotLogsAccessPolicy.policy(
"PilotLogsAccessPolicy", user_info, action=action
)
assert user in ret.properties
Loading

0 comments on commit e9ad537

Please sign in to comment.