From e9ad537e61812b2c8c099ba76e5082ce02e04cb9 Mon Sep 17 00:00:00 2001 From: martynia Date: Mon, 6 Jan 2025 15:56:06 +0100 Subject: [PATCH] fix: refactoring pilot logging code --- diracx-routers/pyproject.toml | 4 +- .../src/diracx/routers/pilots/__init__.py | 11 ++ .../diracx/routers/pilots/access_policies.py | 63 +++++++ .../src/diracx/routers/pilots/logging.py | 157 ++++++++++++++++++ .../tests/pilots/test_access_policies.py | 61 +++++++ .../tests/pilots/test_pilot_logger.py | 100 +++++++++++ 6 files changed, 394 insertions(+), 2 deletions(-) create mode 100644 diracx-routers/src/diracx/routers/pilots/__init__.py create mode 100644 diracx-routers/src/diracx/routers/pilots/access_policies.py create mode 100644 diracx-routers/src/diracx/routers/pilots/logging.py create mode 100644 diracx-routers/tests/pilots/test_access_policies.py create mode 100644 diracx-routers/tests/pilots/test_pilot_logger.py diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index c444fc3c..e76cda92 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -48,7 +48,7 @@ types = [ ] [project.entry-points."diracx.services"] -pilotlogs = "diracx.routers.pilot_logging.remote_logger:router" +pilots = "diracx.routers.pilots:router" jobs = "diracx.routers.jobs:router" config = "diracx.routers.configuration:router" auth = "diracx.routers.auth:router" @@ -57,7 +57,7 @@ auth = "diracx.routers.auth:router" [project.entry-points."diracx.access_policies"] WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" -PilotLogsAccessPolicy = "diracx.routers.pilot_logging.access_policies:PilotLogsAccessPolicy" +PilotLogsAccessPolicy = "diracx.routers.pilots.access_policies:PilotLogsAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] diff --git a/diracx-routers/src/diracx/routers/pilots/__init__.py b/diracx-routers/src/diracx/routers/pilots/__init__.py new file mode 100644 index 00000000..3e9084bc --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from logging import getLogger + +from ..fastapi_classes import DiracxRouter +from .logging import router as logging_router + +logger = getLogger(__name__) + +router = DiracxRouter() +router.include_router(logging_router) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py new file mode 100644 index 00000000..68b2ebe7 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from enum import StrEnum, auto +from typing import Annotated, Callable + +from fastapi import Depends, HTTPException, status + +from diracx.core.properties import ( + GENERIC_PILOT, + NORMAL_USER, + OPERATOR, + PILOT, + SERVICE_ADMINISTRATOR, +) +from diracx.routers.access_policies import BaseAccessPolicy + +from ..utils.users import AuthorizedUserInfo + + +class ActionType(StrEnum): + #: Create/update pilot log records + CREATE = auto() + #: delete pilot logs + DELETE = auto() + #: Search + QUERY = auto() + + +class PilotLogsAccessPolicy(BaseAccessPolicy): + """Rules: + Only PILOT, GENERIC_PILOT, SERVICE_ADMINISTRATOR and OPERATOR can process log records. + Policies for other actions to be determined. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + ): + + if action is None: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, detail="Action is a mandatory argument" + ) + + if GENERIC_PILOT in user_info.properties and action == ActionType.CREATE: + return user_info + if PILOT in user_info.properties and action == ActionType.CREATE: + return user_info + if NORMAL_USER in user_info.properties and action == ActionType.QUERY: + return user_info + if SERVICE_ADMINISTRATOR in user_info.properties: + return user_info + if OPERATOR in user_info.properties: + return user_info + + raise HTTPException(status.HTTP_403_FORBIDDEN, detail=user_info.properties) + + +CheckPilotLogsPolicyCallable = Annotated[Callable, Depends(PilotLogsAccessPolicy.check)] diff --git a/diracx-routers/src/diracx/routers/pilots/logging.py b/diracx-routers/src/diracx/routers/pilots/logging.py new file mode 100644 index 00000000..b182890d --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/logging.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import datetime +import logging + +from fastapi import HTTPException, status +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.exc import NoResultFound + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.properties import OPERATOR, SERVICE_ADMINISTRATOR +from diracx.db.sql.pilot_agents.schema import PilotAgents +from diracx.db.sql.utils import BaseSQLDB + +from ..dependencies import PilotLogsDB +from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo +from .access_policies import ActionType, CheckPilotLogsPolicyCallable + +logger = logging.getLogger(__name__) +router = DiracxRouter() + + +class LogLine(BaseModel): + line_no: int + line: str + + +class LogMessage(BaseModel): + pilot_stamp: str + lines: list[LogLine] + vo: str + + +class DateRange(BaseModel): + min: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z") + max: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z") + + +@router.post("/") +async def send_message( + data: LogMessage, + pilot_logs_db: PilotLogsDB, + check_permissions: CheckPilotLogsPolicyCallable, +) -> int: + + logger.warning(f"Message received '{data}'") + user_info = await check_permissions(action=ActionType.CREATE) + pilot_id = 0 # need to get pilot id from pilot_stamp (via PilotAgentsDB) + # also add a timestamp to be able to select and delete logs based on pilot creation dates, even if corresponding + # pilots have been already deleted from PilotAgentsDB (so the logs can live longer than pilots). + submission_time = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) + pilot_agents_db = BaseSQLDB.available_implementations("PilotAgentsDB")[0] + url = BaseSQLDB.available_urls()["PilotAgentsDB"] + db = pilot_agents_db(url) + + try: + async with db.engine_context(): + async with db: + stmt = select(PilotAgents.pilot_id, PilotAgents.submission_time).where( + PilotAgents.pilot_stamp == data.pilot_stamp + ) + pilot_id, submission_time = (await db.conn.execute(stmt)).one() + except NoResultFound as exc: + logger.error( + f"Cannot determine PilotID for requested PilotStamp: {data.pilot_stamp}, Error: {exc}." + ) + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + + docs = [] + for line in data.lines: + docs.append( + { + "PilotStamp": data.pilot_stamp, + "PilotID": pilot_id, + "SubmissionTime": submission_time, + "VO": user_info.vo, + "LineNumber": line.line_no, + "Message": line.line, + } + ) + await pilot_logs_db.bulk_insert(pilot_logs_db.index_name(pilot_id), docs) + return pilot_id + + +@router.get("/logs") +async def get_logs( + pilot_id: int, + db: PilotLogsDB, + check_permissions: CheckPilotLogsPolicyCallable, +) -> list[dict]: + + logger.warning(f"Retrieving logs for pilot ID '{pilot_id}'") + user_info = await check_permissions(action=ActionType.QUERY) + + # here, users with privileged properties will see logs from all VOs. Is it what we want ? + search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}] + if _non_privileged(user_info): + search_params.append( + {"parameter": "VO", "operator": "eq", "value": user_info.vo} + ) + result = await db.search( + ["Message"], + search_params, + [{"parameter": "LineNumber", "direction": "asc"}], + ) + if not result: + return [{"Message": f"No logs for pilot ID = {pilot_id}"}] + return result + + +@router.delete("/logs") +async def delete( + pilot_id: int, + data: DateRange, + db: PilotLogsDB, + check_permissions: CheckPilotLogsPolicyCallable, +) -> str: + """Delete either logs for a specific PilotID or a creation date range. + Non-privileged users can only delete log files within their own VO. + """ + message = "no-op" + user_info = await check_permissions(action=ActionType.DELETE) + non_privil_params = {"parameter": "VO", "operator": "eq", "value": user_info.vo} + + # id pilot_id is provided we ignore data.min and data.max + if data.min and data.max and not pilot_id: + raise InvalidQueryError( + "This query requires a range operator definition in DiracX" + ) + + if pilot_id: + search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}] + if _non_privileged(user_info): + search_params.append(non_privil_params) + await db.delete(search_params) + message = f"Logs for pilot ID '{pilot_id}' successfully deleted" + + elif data.min: + logger.warning(f"Deleting logs for pilots with submission data >='{data.min}'") + search_params = [ + {"parameter": "SubmissionTime", "operator": "gt", "value": data.min} + ] + if _non_privileged(user_info): + search_params.append(non_privil_params) + await db.delete(search_params) + message = f"Logs for for pilots with submission data >='{data.min}' successfully deleted" + + return message + + +def _non_privileged(user_info: AuthorizedUserInfo): + return ( + SERVICE_ADMINISTRATOR not in user_info.properties + and OPERATOR not in user_info.properties + ) diff --git a/diracx-routers/tests/pilots/test_access_policies.py b/diracx-routers/tests/pilots/test_access_policies.py new file mode 100644 index 00000000..c11a26ea --- /dev/null +++ b/diracx-routers/tests/pilots/test_access_policies.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from contextlib import nullcontext +from unittest.mock import MagicMock + +import pytest +from fastapi import HTTPException + +from diracx.core.properties import ( + GENERIC_PILOT, + NORMAL_USER, + OPERATOR, + PILOT, + SERVICE_ADMINISTRATOR, +) +from diracx.routers.pilots.access_policies import ( + ActionType, + PilotLogsAccessPolicy, +) + + +@pytest.mark.parametrize( + "user, action, expectation", + [ + (PILOT, ActionType.CREATE, nullcontext()), + (PILOT, ActionType.QUERY, pytest.raises(HTTPException, match="403")), + (PILOT, ActionType.DELETE, pytest.raises(HTTPException, match="403")), + (GENERIC_PILOT, ActionType.CREATE, nullcontext()), + (GENERIC_PILOT, ActionType.QUERY, pytest.raises(HTTPException, match="403")), + (GENERIC_PILOT, ActionType.DELETE, pytest.raises(HTTPException, match="403")), + (SERVICE_ADMINISTRATOR, ActionType.CREATE, nullcontext()), + (SERVICE_ADMINISTRATOR, ActionType.QUERY, nullcontext()), + (SERVICE_ADMINISTRATOR, ActionType.DELETE, nullcontext()), + (OPERATOR, ActionType.CREATE, nullcontext()), + (OPERATOR, ActionType.QUERY, nullcontext()), + (OPERATOR, ActionType.DELETE, nullcontext()), + (NORMAL_USER, ActionType.CREATE, pytest.raises(HTTPException, match="403")), + (NORMAL_USER, ActionType.QUERY, nullcontext()), + (NORMAL_USER, ActionType.DELETE, pytest.raises(HTTPException, match="403")), + ( + "malicious_user", + ActionType.CREATE, + pytest.raises(HTTPException, match="403"), + ), + ("malicious_user", ActionType.QUERY, pytest.raises(HTTPException, match="403")), + ( + "malicious_user", + ActionType.DELETE, + pytest.raises(HTTPException, match="403"), + ), + ("any_user", None, pytest.raises(HTTPException, match="400")), + ], +) +async def test_access_policies(user, action, expectation): + user_info = MagicMock() + user_info.properties = [user] + with expectation: + ret = await PilotLogsAccessPolicy.policy( + "PilotLogsAccessPolicy", user_info, action=action + ) + assert user in ret.properties diff --git a/diracx-routers/tests/pilots/test_pilot_logger.py b/diracx-routers/tests/pilots/test_pilot_logger.py new file mode 100644 index 00000000..23a02548 --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_logger.py @@ -0,0 +1,100 @@ +from contextlib import nullcontext +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock, patch + +import pytest +from sqlalchemy import inspect, update + +from diracx.core.properties import PILOT +from diracx.db.os import PilotLogsDB +from diracx.db.sql import PilotAgentsDB +from diracx.db.sql.pilot_agents.schema import PilotAgents +from diracx.routers.pilots.logging import ( + LogLine, + LogMessage, + get_logs, + send_message, +) +from diracx.testing.mock_osdb import MockOSDBMixin + +# class PilotLogsDB(MockOSDBMixin, PilotLogsDB): +# pass +# PilotLogsDB = fake_available_osdb_implementations("PilotLogsDB", +# real_available_implementations=BaseOSDB.available_implementations)[0] + + +@pytest.fixture +async def pilot_agents_db(tmp_path) -> PilotAgentsDB: + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +@pytest.fixture +async def pilot_logs_db(): + # create a class that has sqlite backend replacing OpenSearch PilotLogsDB + m_pilot_logs_db = type("JobParametersDB", (MockOSDBMixin, PilotLogsDB), {}) + + db = m_pilot_logs_db( + connection_kwargs={"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"} + ) + async with db.client_context(): + await db.create_index_template() + yield db + + +@patch("diracx.routers.pilots.logging.BaseSQLDB.available_implementations") +@patch("diracx.routers.pilots.logging.BaseSQLDB.available_urls") +async def test_logging( + mock_url, mock_impl, pilot_logs_db: "PilotLogsDB", pilot_agents_db: PilotAgentsDB +): + + async with pilot_agents_db as db: + # Add a pilot reference + upper_limit = 6 + refs = [f"ref_{i}" for i in range(1, upper_limit)] + stamps = [f"stamp_{i}" for i in range(1, upper_limit)] + stamp_dict = dict(zip(refs, stamps)) + + await db.add_pilot_references( + refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict + ) + tables = await db.conn.run_sync( + lambda sync_conn: inspect(sync_conn).get_table_names() + ) + assert "PilotAgents" in tables + + # move submission time back in time + now = datetime.now(tz=timezone.utc) + for i in range(1, upper_limit): + sub_time = now - timedelta(hours=2 * i - 1) + stmt = ( + update(PilotAgents) + .where(PilotAgents.PilotStamp == f"stamp_{i}") + .values(SubmissionTime=sub_time) + ) + await db.conn.execute(stmt) + # 4 message records for the first pilot. + line = [{"Message": f"Message_no_{i}"} for i in range(1, 4)] + log_lines = [LogLine(line_no=i + 1, line=line[i]["Message"]) for i in range(3)] + message = LogMessage(pilot_stamp="stamp_1", lines=log_lines, vo="gridpp") + + check_permissions_mock = AsyncMock() + check_permissions_mock.return_value.vo = "gridpp" + # TODO add user properties dict return_value above + mock_url.return_value = {"PilotAgentsDB": "sqlite+aiosqlite:///:memory:"} + # use the existing context (we have a DB already): + pilot_agents_db.engine_context = nullcontext + mock_impl.return_value = [lambda x: pilot_agents_db] + # send logs for stamp_1, pilot id = 1 + pilot_id = await send_message(message, pilot_logs_db, check_permissions_mock) + assert pilot_id == 1 + # get logs for pilot_id=1 + log_records = await get_logs(pilot_id, pilot_logs_db, check_permissions_mock) + assert log_records == line + # delete logs for pilot_id = 1 + check_permissions_mock.return_value.properties = [PILOT] + # TODO: await mock_osdb delete implementation... + # res = await delete(pilot_id, DateRange(), pilot_logs_db, check_permissions_mock)