-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
394 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from __future__ import annotations | ||
|
||
from logging import getLogger | ||
|
||
from ..fastapi_classes import DiracxRouter | ||
from .logging import router as logging_router | ||
|
||
logger = getLogger(__name__) | ||
|
||
router = DiracxRouter() | ||
router.include_router(logging_router) |
63 changes: 63 additions & 0 deletions
63
diracx-routers/src/diracx/routers/pilots/access_policies.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import annotations | ||
|
||
from enum import StrEnum, auto | ||
from typing import Annotated, Callable | ||
|
||
from fastapi import Depends, HTTPException, status | ||
|
||
from diracx.core.properties import ( | ||
GENERIC_PILOT, | ||
NORMAL_USER, | ||
OPERATOR, | ||
PILOT, | ||
SERVICE_ADMINISTRATOR, | ||
) | ||
from diracx.routers.access_policies import BaseAccessPolicy | ||
|
||
from ..utils.users import AuthorizedUserInfo | ||
|
||
|
||
class ActionType(StrEnum): | ||
#: Create/update pilot log records | ||
CREATE = auto() | ||
#: delete pilot logs | ||
DELETE = auto() | ||
#: Search | ||
QUERY = auto() | ||
|
||
|
||
class PilotLogsAccessPolicy(BaseAccessPolicy): | ||
"""Rules: | ||
Only PILOT, GENERIC_PILOT, SERVICE_ADMINISTRATOR and OPERATOR can process log records. | ||
Policies for other actions to be determined. | ||
""" | ||
|
||
@staticmethod | ||
async def policy( | ||
policy_name: str, | ||
user_info: AuthorizedUserInfo, | ||
/, | ||
*, | ||
action: ActionType | None = None, | ||
): | ||
|
||
if action is None: | ||
raise HTTPException( | ||
status.HTTP_400_BAD_REQUEST, detail="Action is a mandatory argument" | ||
) | ||
|
||
if GENERIC_PILOT in user_info.properties and action == ActionType.CREATE: | ||
return user_info | ||
if PILOT in user_info.properties and action == ActionType.CREATE: | ||
return user_info | ||
if NORMAL_USER in user_info.properties and action == ActionType.QUERY: | ||
return user_info | ||
if SERVICE_ADMINISTRATOR in user_info.properties: | ||
return user_info | ||
if OPERATOR in user_info.properties: | ||
return user_info | ||
|
||
raise HTTPException(status.HTTP_403_FORBIDDEN, detail=user_info.properties) | ||
|
||
|
||
CheckPilotLogsPolicyCallable = Annotated[Callable, Depends(PilotLogsAccessPolicy.check)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
from __future__ import annotations | ||
|
||
import datetime | ||
import logging | ||
|
||
from fastapi import HTTPException, status | ||
from pydantic import BaseModel | ||
from sqlalchemy import select | ||
from sqlalchemy.exc import NoResultFound | ||
|
||
from diracx.core.exceptions import InvalidQueryError | ||
from diracx.core.properties import OPERATOR, SERVICE_ADMINISTRATOR | ||
from diracx.db.sql.pilot_agents.schema import PilotAgents | ||
from diracx.db.sql.utils import BaseSQLDB | ||
|
||
from ..dependencies import PilotLogsDB | ||
from ..fastapi_classes import DiracxRouter | ||
from ..utils.users import AuthorizedUserInfo | ||
from .access_policies import ActionType, CheckPilotLogsPolicyCallable | ||
|
||
logger = logging.getLogger(__name__) | ||
router = DiracxRouter() | ||
|
||
|
||
class LogLine(BaseModel): | ||
line_no: int | ||
line: str | ||
|
||
|
||
class LogMessage(BaseModel): | ||
pilot_stamp: str | ||
lines: list[LogLine] | ||
vo: str | ||
|
||
|
||
class DateRange(BaseModel): | ||
min: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z") | ||
max: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z") | ||
|
||
|
||
@router.post("/") | ||
async def send_message( | ||
data: LogMessage, | ||
pilot_logs_db: PilotLogsDB, | ||
check_permissions: CheckPilotLogsPolicyCallable, | ||
) -> int: | ||
|
||
logger.warning(f"Message received '{data}'") | ||
user_info = await check_permissions(action=ActionType.CREATE) | ||
pilot_id = 0 # need to get pilot id from pilot_stamp (via PilotAgentsDB) | ||
# also add a timestamp to be able to select and delete logs based on pilot creation dates, even if corresponding | ||
# pilots have been already deleted from PilotAgentsDB (so the logs can live longer than pilots). | ||
submission_time = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) | ||
pilot_agents_db = BaseSQLDB.available_implementations("PilotAgentsDB")[0] | ||
url = BaseSQLDB.available_urls()["PilotAgentsDB"] | ||
db = pilot_agents_db(url) | ||
|
||
try: | ||
async with db.engine_context(): | ||
async with db: | ||
stmt = select(PilotAgents.pilot_id, PilotAgents.submission_time).where( | ||
PilotAgents.pilot_stamp == data.pilot_stamp | ||
) | ||
pilot_id, submission_time = (await db.conn.execute(stmt)).one() | ||
except NoResultFound as exc: | ||
logger.error( | ||
f"Cannot determine PilotID for requested PilotStamp: {data.pilot_stamp}, Error: {exc}." | ||
) | ||
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc | ||
|
||
docs = [] | ||
for line in data.lines: | ||
docs.append( | ||
{ | ||
"PilotStamp": data.pilot_stamp, | ||
"PilotID": pilot_id, | ||
"SubmissionTime": submission_time, | ||
"VO": user_info.vo, | ||
"LineNumber": line.line_no, | ||
"Message": line.line, | ||
} | ||
) | ||
await pilot_logs_db.bulk_insert(pilot_logs_db.index_name(pilot_id), docs) | ||
return pilot_id | ||
|
||
|
||
@router.get("/logs") | ||
async def get_logs( | ||
pilot_id: int, | ||
db: PilotLogsDB, | ||
check_permissions: CheckPilotLogsPolicyCallable, | ||
) -> list[dict]: | ||
|
||
logger.warning(f"Retrieving logs for pilot ID '{pilot_id}'") | ||
user_info = await check_permissions(action=ActionType.QUERY) | ||
|
||
# here, users with privileged properties will see logs from all VOs. Is it what we want ? | ||
search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}] | ||
if _non_privileged(user_info): | ||
search_params.append( | ||
{"parameter": "VO", "operator": "eq", "value": user_info.vo} | ||
) | ||
result = await db.search( | ||
["Message"], | ||
search_params, | ||
[{"parameter": "LineNumber", "direction": "asc"}], | ||
) | ||
if not result: | ||
return [{"Message": f"No logs for pilot ID = {pilot_id}"}] | ||
return result | ||
|
||
|
||
@router.delete("/logs") | ||
async def delete( | ||
pilot_id: int, | ||
data: DateRange, | ||
db: PilotLogsDB, | ||
check_permissions: CheckPilotLogsPolicyCallable, | ||
) -> str: | ||
"""Delete either logs for a specific PilotID or a creation date range. | ||
Non-privileged users can only delete log files within their own VO. | ||
""" | ||
message = "no-op" | ||
user_info = await check_permissions(action=ActionType.DELETE) | ||
non_privil_params = {"parameter": "VO", "operator": "eq", "value": user_info.vo} | ||
|
||
# id pilot_id is provided we ignore data.min and data.max | ||
if data.min and data.max and not pilot_id: | ||
raise InvalidQueryError( | ||
"This query requires a range operator definition in DiracX" | ||
) | ||
|
||
if pilot_id: | ||
search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}] | ||
if _non_privileged(user_info): | ||
search_params.append(non_privil_params) | ||
await db.delete(search_params) | ||
message = f"Logs for pilot ID '{pilot_id}' successfully deleted" | ||
|
||
elif data.min: | ||
logger.warning(f"Deleting logs for pilots with submission data >='{data.min}'") | ||
search_params = [ | ||
{"parameter": "SubmissionTime", "operator": "gt", "value": data.min} | ||
] | ||
if _non_privileged(user_info): | ||
search_params.append(non_privil_params) | ||
await db.delete(search_params) | ||
message = f"Logs for for pilots with submission data >='{data.min}' successfully deleted" | ||
|
||
return message | ||
|
||
|
||
def _non_privileged(user_info: AuthorizedUserInfo): | ||
return ( | ||
SERVICE_ADMINISTRATOR not in user_info.properties | ||
and OPERATOR not in user_info.properties | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from __future__ import annotations | ||
|
||
from contextlib import nullcontext | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
from fastapi import HTTPException | ||
|
||
from diracx.core.properties import ( | ||
GENERIC_PILOT, | ||
NORMAL_USER, | ||
OPERATOR, | ||
PILOT, | ||
SERVICE_ADMINISTRATOR, | ||
) | ||
from diracx.routers.pilots.access_policies import ( | ||
ActionType, | ||
PilotLogsAccessPolicy, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"user, action, expectation", | ||
[ | ||
(PILOT, ActionType.CREATE, nullcontext()), | ||
(PILOT, ActionType.QUERY, pytest.raises(HTTPException, match="403")), | ||
(PILOT, ActionType.DELETE, pytest.raises(HTTPException, match="403")), | ||
(GENERIC_PILOT, ActionType.CREATE, nullcontext()), | ||
(GENERIC_PILOT, ActionType.QUERY, pytest.raises(HTTPException, match="403")), | ||
(GENERIC_PILOT, ActionType.DELETE, pytest.raises(HTTPException, match="403")), | ||
(SERVICE_ADMINISTRATOR, ActionType.CREATE, nullcontext()), | ||
(SERVICE_ADMINISTRATOR, ActionType.QUERY, nullcontext()), | ||
(SERVICE_ADMINISTRATOR, ActionType.DELETE, nullcontext()), | ||
(OPERATOR, ActionType.CREATE, nullcontext()), | ||
(OPERATOR, ActionType.QUERY, nullcontext()), | ||
(OPERATOR, ActionType.DELETE, nullcontext()), | ||
(NORMAL_USER, ActionType.CREATE, pytest.raises(HTTPException, match="403")), | ||
(NORMAL_USER, ActionType.QUERY, nullcontext()), | ||
(NORMAL_USER, ActionType.DELETE, pytest.raises(HTTPException, match="403")), | ||
( | ||
"malicious_user", | ||
ActionType.CREATE, | ||
pytest.raises(HTTPException, match="403"), | ||
), | ||
("malicious_user", ActionType.QUERY, pytest.raises(HTTPException, match="403")), | ||
( | ||
"malicious_user", | ||
ActionType.DELETE, | ||
pytest.raises(HTTPException, match="403"), | ||
), | ||
("any_user", None, pytest.raises(HTTPException, match="400")), | ||
], | ||
) | ||
async def test_access_policies(user, action, expectation): | ||
user_info = MagicMock() | ||
user_info.properties = [user] | ||
with expectation: | ||
ret = await PilotLogsAccessPolicy.policy( | ||
"PilotLogsAccessPolicy", user_info, action=action | ||
) | ||
assert user in ret.properties |
Oops, something went wrong.