diff --git a/CHANGELOG.md b/CHANGELOG.md index cfe5fa78f..0c4d4410a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,8 @@ have an authority field matching that of the user - CLI: list cli usage strings in alphabetical order - Helm: Fix clickhouse version - Helm: improve volumes and ingress configurations +- API: Add `RALPH_LRS_RESTRICT_BY_SCOPE` option enabling endpoint access + control by user scopes ### Fixed diff --git a/docs/api.md b/docs/api.md index 652405e2e..3e62c6a77 100644 --- a/docs/api.md +++ b/docs/api.md @@ -178,7 +178,7 @@ By default, all authenticated users have full read and write access to the serve ### Filtering results by authority (multitenancy) -In Ralph, all incoming statements are assigned an `authority` (or ownership) derived from the user that makes the call. You may restrict read access to users "own" statements (thus enabling multitenancy) by setting the following environment variable: +In Ralph LRS, all incoming statements are assigned an `authority` (or ownership) derived from the user that makes the call. You may restrict read access to users "own" statements (thus enabling multitenancy) by setting the following environment variable: ``` RALPH_LRS_RESTRICT_BY_AUTHORITY = True # Default: False @@ -190,7 +190,27 @@ NB: If not using "scopes", or for users with limited "scopes", using this option #### Scopes -(Work In Progress) +In Ralph, users are assigned scopes which may be used to restrict endpoint access or +functionalities. You may enable this option by setting the following environment variable: + +``` +RALPH_LRS_RESTRICT_BY_SCOPES = True # Default: False +``` + +Valid scopes are a slight variation on those proposed by the +[xAPI specification](https://github.com/adlnet/xAPI-Spec/blob/master/xAPI-Communication.md#details-15): + + +- statements/write +- statements/read/mine +- statements/read +- state/write +- state/read +- define +- profile/write +- profile/read +- all/read +- all ## Forwarding statements diff --git a/setup.cfg b/setup.cfg index b232ec5c4..472414dc5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -142,7 +142,7 @@ match = ^(?!(setup)\.(py)$).*\.(py)$ [isort] known_ralph=ralph sections=FUTURE,STDLIB,THIRDPARTY,RALPH,FIRSTPARTY,LOCALFOLDER -skip_glob=venv +skip_glob=venv,*/.conda/* profile=black [tool:pytest] diff --git a/src/ralph/api/auth/__init__.py b/src/ralph/api/auth/__init__.py index f5e80b737..80aa52fff 100644 --- a/src/ralph/api/auth/__init__.py +++ b/src/ralph/api/auth/__init__.py @@ -1,12 +1,11 @@ """Main module for Ralph's LRS API authentication.""" -from ralph.api.auth.basic import get_authenticated_user as get_basic_user -from ralph.api.auth.oidc import get_authenticated_user as get_oidc_user +from ralph.api.auth.basic import get_basic_auth_user +from ralph.api.auth.oidc import get_oidc_user from ralph.conf import settings # At startup, select the authentication mode that will be used -get_authenticated_user = ( - get_oidc_user - if settings.RUNSERVER_AUTH_BACKEND == settings.AuthBackends.OIDC - else get_basic_user -) +if settings.RUNSERVER_AUTH_BACKEND == settings.AuthBackends.OIDC: + get_authenticated_user = get_oidc_user +else: + get_authenticated_user = get_basic_auth_user diff --git a/src/ralph/api/auth/basic.py b/src/ralph/api/auth/basic.py index ddabb5add..04dfcce59 100644 --- a/src/ralph/api/auth/basic.py +++ b/src/ralph/api/auth/basic.py @@ -9,7 +9,7 @@ import bcrypt from cachetools import TTLCache, cached from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBasic, HTTPBasicCredentials +from fastapi.security import HTTPBasic, HTTPBasicCredentials, SecurityScopes from pydantic import BaseModel, root_validator from starlette.authentication import AuthenticationError @@ -102,15 +102,17 @@ def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials: @cached( TTLCache(maxsize=settings.AUTH_CACHE_MAX_SIZE, ttl=settings.AUTH_CACHE_TTL), lock=Lock(), - key=lambda credentials: ( + key=lambda credentials, security_scopes: ( credentials.username, credentials.password, + security_scopes.scope_str, ) if credentials is not None else None, ) -def get_authenticated_user( +def get_basic_auth_user( credentials: Union[HTTPBasicCredentials, None] = Depends(security), + security_scopes: SecurityScopes = SecurityScopes([]), ) -> AuthenticatedUser: """Checks valid auth parameters. @@ -119,13 +121,10 @@ def get_authenticated_user( Args: credentials (iterator): auth parameters from the Authorization header - - Return: - AuthenticatedUser (AuthenticatedUser) + security_scopes: scopes requested for access Raises: HTTPException - """ if not credentials: logger.error("The basic authentication mode requires a Basic Auth header") @@ -156,6 +155,7 @@ def get_authenticated_user( status_code=status.HTTP_403_FORBIDDEN, detail=str(exc) ) from exc + # Check that a password was passed if not hashed_password: # We're doing a bogus password check anyway to avoid timing attacks on # usernames @@ -168,6 +168,7 @@ def get_authenticated_user( headers={"WWW-Authenticate": "Basic"}, ) + # Check password validity if not bcrypt.checkpw( credentials.password.encode(settings.LOCALE_ENCODING), hashed_password.encode(settings.LOCALE_ENCODING), @@ -182,4 +183,15 @@ def get_authenticated_user( headers={"WWW-Authenticate": "Basic"}, ) - return AuthenticatedUser(scopes=user.scopes, agent=user.agent) + user = AuthenticatedUser(scopes=user.scopes, agent=dict(user.agent)) + + # Restrict access by scopes + if settings.LRS_RESTRICT_BY_SCOPES: + for requested_scope in security_scopes.scopes: + if not user.scopes.is_authorized(requested_scope): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f'Access not authorized to scope: "{requested_scope}".', + headers={"WWW-Authenticate": "Basic"}, + ) + return user diff --git a/src/ralph/api/auth/oidc.py b/src/ralph/api/auth/oidc.py index 423cfbb5c..d4476f3f0 100644 --- a/src/ralph/api/auth/oidc.py +++ b/src/ralph/api/auth/oidc.py @@ -2,16 +2,17 @@ import logging from functools import lru_cache -from typing import Optional, Union +from typing import Optional import requests from fastapi import Depends, HTTPException, status -from fastapi.security import OpenIdConnect +from fastapi.security import OpenIdConnect, SecurityScopes from jose import ExpiredSignatureError, JWTError, jwt from jose.exceptions import JWTClaimsError from pydantic import AnyUrl, BaseModel, Extra +from typing_extensions import Annotated -from ralph.api.auth.user import AuthenticatedUser +from ralph.api.auth.user import AuthenticatedUser, UserScopes from ralph.conf import settings OPENID_CONFIGURATION_PATH = "/.well-known/openid-configuration" @@ -92,8 +93,9 @@ def get_public_keys(jwks_uri: AnyUrl) -> dict: ) from exc -def get_authenticated_user( - auth_header: Union[str, None] = Depends(oauth2_scheme) +def get_oidc_user( + auth_header: Annotated[Optional[str], Depends(oauth2_scheme)], + security_scopes: SecurityScopes = SecurityScopes([]), ) -> AuthenticatedUser: """Decode and validate OpenId Connect ID token against issuer in config. @@ -143,7 +145,19 @@ def get_authenticated_user( id_token = IDToken.parse_obj(decoded_token) - return AuthenticatedUser( - agent={"openid": id_token.sub}, - scopes=id_token.scope.split(" ") if id_token.scope else [], + user = AuthenticatedUser( + agent={"openid": f"{id_token.iss}/{id_token.sub}"}, + scopes=UserScopes(id_token.scope.split(" ") if id_token.scope else []), ) + + # Restrict access by scopes + if settings.LRS_RESTRICT_BY_SCOPES: + for requested_scope in security_scopes.scopes: + if not user.scopes.is_authorized(requested_scope): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f'Access not authorized to scope: "{requested_scope}".', + headers={"WWW-Authenticate": "Basic"}, + ) + + return user diff --git a/src/ralph/api/auth/user.py b/src/ralph/api/auth/user.py index 6184a7611..9d61f0c4d 100644 --- a/src/ralph/api/auth/user.py +++ b/src/ralph/api/auth/user.py @@ -1,6 +1,7 @@ """Authenticated user for the Ralph API.""" -from typing import Dict, List, Literal +from functools import lru_cache +from typing import Dict, FrozenSet, Literal from pydantic import BaseModel @@ -18,6 +19,50 @@ ] +class UserScopes(FrozenSet[Scope]): + """Scopes available to users.""" + + @lru_cache(maxsize=1024) + def is_authorized(self, requested_scope: Scope): + """Check if the requested scope can be accessed based on user scopes.""" + expanded_scopes = { + "statements/read": {"statements/read/mine", "statements/read"}, + "all/read": { + "statements/read/mine", + "statements/read", + "state/read", + "profile/read", + "all/read", + }, + "all": { + "statements/write", + "statements/read/mine", + "statements/read", + "state/read", + "state/write", + "define", + "profile/read", + "profile/write", + "all/read", + "all", + }, + } + + expanded_user_scopes = set() + for scope in self: + expanded_user_scopes.update(expanded_scopes.get(scope, {scope})) + + return requested_scope in expanded_user_scopes + + @classmethod + def __get_validators__(cls): # noqa: D105 + def validate(value: FrozenSet[Scope]): + """Transform value to an instance of UserScopes.""" + return cls(value) + + yield validate + + class AuthenticatedUser(BaseModel): """Pydantic model for user authentication. @@ -27,4 +72,4 @@ class AuthenticatedUser(BaseModel): """ agent: Dict - scopes: List[Scope] + scopes: UserScopes diff --git a/src/ralph/api/routers/statements.py b/src/ralph/api/routers/statements.py index 49a3433cb..0bfb27212 100644 --- a/src/ralph/api/routers/statements.py +++ b/src/ralph/api/routers/statements.py @@ -15,6 +15,7 @@ Query, Request, Response, + Security, status, ) from fastapi.dependencies.models import Dependant @@ -101,6 +102,7 @@ def _enrich_statement_with_authority(statement: dict, current_user: Authenticate def _parse_agent_parameters(agent_obj: dict): """Parse a dict and return an AgentParameters object to use in queries.""" # Transform agent to `dict` as FastAPI cannot parse JSON (seen as string) + agent = parse_obj_as(BaseXapiAgent, agent_obj) agent_query_params = {} @@ -137,10 +139,12 @@ def strict_query_params(request: Request): @router.get("") @router.get("/") -# pylint: disable=too-many-arguments, too-many-locals async def get( request: Request, - current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)], + current_user: Annotated[ + AuthenticatedUser, + Security(get_authenticated_user, scopes=["statements/read/mine"]), + ], ### # Query string parameters defined by the LRS specification ### @@ -170,7 +174,6 @@ async def get( "of the Statement is an Activity with the specified id" ), ), - # pylint: disable=unused-argument registration: Optional[UUID] = Query( None, description=( @@ -178,7 +181,6 @@ async def get( "Filter, only return Statements matching the specified registration id" ), ), - # pylint: disable=unused-argument related_activities: Optional[bool] = Query( False, description=( @@ -189,7 +191,6 @@ async def get( "instead of that parameter's normal behaviour" ), ), - # pylint: disable=unused-argument related_agents: Optional[bool] = Query( False, description=( @@ -221,7 +222,6 @@ async def get( "0 indicates return the maximum the server will allow" ), ), - # pylint: disable=unused-argument, redefined-builtin format: Optional[Literal["ids", "exact", "canonical"]] = Query( "exact", description=( @@ -240,7 +240,6 @@ async def get( 'as in "exact" mode.' ), ), - # pylint: disable=unused-argument attachments: Optional[bool] = Query( False, description=( @@ -286,6 +285,9 @@ async def get( LRS Specification: https://github.com/adlnet/xAPI-Spec/blob/1.0.3/xAPI-Communication.md#213-get-statements """ + # pylint: disable=unused-argument,redefined-builtin,too-many-arguments + # pylint: disable=too-many-locals + # Make sure the limit does not go above max from settings limit = min(limit, settings.RUNSERVER_MAX_SEARCH_HITS_COUNT) @@ -334,14 +336,15 @@ async def get( json.loads(query_params["agent"]) ) - if settings.LRS_RESTRICT_BY_AUTHORITY: - # If using scopes, only restrict results when appropriate - if settings.LRS_RESTRICT_BY_SCOPES: - raise NotImplementedError("Scopes are not yet implemented in Ralph.") - - # Otherwise, enforce mine for all users + # mine: If using scopes, only restrict users with limited scopes + if settings.LRS_RESTRICT_BY_SCOPES: + if not current_user.scopes.is_authorized("statements/read"): + mine = True + # mine: If using only authority, always restrict (otherwise, use the default value) + elif settings.LRS_RESTRICT_BY_AUTHORITY: mine = True + # Filter by authority if using `mine` if mine: query_params["authority"] = _parse_agent_parameters(current_user.agent) @@ -399,7 +402,10 @@ async def get( @router.put("", responses=POST_PUT_RESPONSES, status_code=status.HTTP_204_NO_CONTENT) # pylint: disable=unused-argument, too-many-branches async def put( - current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)], + current_user: Annotated[ + AuthenticatedUser, + Security(get_authenticated_user, scopes=["statements/write"]), + ], statement: LaxStatement, background_tasks: BackgroundTasks, statement_id: UUID = Query(alias="statementId"), @@ -478,7 +484,10 @@ async def put( @router.post("", responses=POST_PUT_RESPONSES) # pylint: disable = too-many-branches async def post( - current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)], + current_user: Annotated[ + AuthenticatedUser, + Security(get_authenticated_user, scopes=["statements/write"]), + ], statements: Union[LaxStatement, List[LaxStatement]], background_tasks: BackgroundTasks, response: Response, diff --git a/src/ralph/conf.py b/src/ralph/conf.py index 0415a5ee3..ad91785ae 100644 --- a/src/ralph/conf.py +++ b/src/ralph/conf.py @@ -19,7 +19,10 @@ from unittest.mock import Mock get_app_dir = Mock(return_value=".") -from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra + +from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra, root_validator + +from ralph.exceptions import ConfigurationException from .utils import import_string @@ -210,5 +213,18 @@ def LOCALE_ENCODING(self) -> str: # pylint: disable=invalid-name """Returns Ralph's default locale encoding.""" return self._CORE.LOCALE_ENCODING + @root_validator(allow_reuse=True) + @classmethod + def check_restriction_compatibility(cls, values): + """Raise an error if scopes are being used without authority restriction.""" + if values.get("LRS_RESTRICT_BY_SCOPES") and not values.get( + "LRS_RESTRICT_BY_AUTHORITY" + ): + raise ConfigurationException( + "LRS_RESTRICT_BY_AUTHORITY must be set to True if using " + "LRS_RESTRICT_BY_SCOPES=True" + ) + return values + settings = Settings() diff --git a/tests/api/auth/test_basic.py b/tests/api/auth/test_basic.py index ebcbf3aea..211f5e411 100644 --- a/tests/api/auth/test_basic.py +++ b/tests/api/auth/test_basic.py @@ -6,15 +6,15 @@ import bcrypt import pytest from fastapi.exceptions import HTTPException -from fastapi.security import HTTPBasicCredentials +from fastapi.security import HTTPBasicCredentials, SecurityScopes from ralph.api.auth.basic import ( ServerUsersCredentials, UserCredentials, - get_authenticated_user, + get_basic_auth_user, get_stored_credentials, ) -from ralph.api.auth.user import AuthenticatedUser +from ralph.api.auth.user import AuthenticatedUser, UserScopes from ralph.conf import Settings, settings STORED_CREDENTIALS = json.dumps( @@ -97,18 +97,21 @@ def test_api_auth_basic_caching_credentials(fs): auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() + get_stored_credentials.cache_clear() credentials = HTTPBasicCredentials(username="ralph", password="admin") # Call function as in a first request with these credentials - get_authenticated_user(credentials) + get_basic_auth_user( + security_scopes=SecurityScopes(["profile/read"]), credentials=credentials + ) - assert get_authenticated_user.cache.popitem() == ( - ("ralph", "admin"), + assert get_basic_auth_user.cache.popitem() == ( + ("ralph", "admin", "profile/read"), AuthenticatedUser( agent={"mbox": "mailto:ralph@example.com"}, - scopes=["statements/read/mine", "statements/write"], + scopes=UserScopes(["statements/read/mine", "statements/write"]), ), ) @@ -118,13 +121,13 @@ def test_api_auth_basic_with_wrong_password(fs): auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() credentials = HTTPBasicCredentials(username="ralph", password="wrong_password") # Call function as in a first request with these credentials with pytest.raises(HTTPException): - get_authenticated_user(credentials) + get_basic_auth_user(credentials, SecurityScopes(["all"])) def test_api_auth_basic_no_credential_file_found(fs, monkeypatch): @@ -132,12 +135,12 @@ def test_api_auth_basic_no_credential_file_found(fs, monkeypatch): monkeypatch.setenv("RALPH_AUTH_FILE", "other_file") monkeypatch.setattr("ralph.api.auth.basic.settings", Settings()) - get_stored_credentials.cache_clear() + get_basic_auth_user.cache_clear() credentials = HTTPBasicCredentials(username="ralph", password="admin") with pytest.raises(HTTPException): - get_authenticated_user(credentials) + get_basic_auth_user(credentials, SecurityScopes(["all"])) def test_get_whoami_no_credentials(basic_auth_test_client): @@ -173,7 +176,7 @@ def test_get_whoami_username_not_found(basic_auth_test_client, fs): """Whoami route returns a 401 error when the username cannot be found.""" credential_bytes = base64.b64encode("john:admin".encode("utf-8")) credentials = str(credential_bytes, "utf-8") - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) @@ -195,7 +198,7 @@ def test_get_whoami_wrong_password(basic_auth_test_client, fs): auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() response = basic_auth_test_client.get( "/whoami", headers={"Authorization": f"Basic {credentials}"} @@ -217,14 +220,17 @@ def test_get_whoami_correct_credentials(basic_auth_test_client, fs): auth_file_path = settings.APP_DIR / "auth.json" fs.create_file(auth_file_path, contents=STORED_CREDENTIALS) - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() response = basic_auth_test_client.get( "/whoami", headers={"Authorization": f"Basic {credentials}"} ) assert response.status_code == 200 - assert response.json() == { - "agent": {"mbox": "mailto:ralph@example.com"}, - "scopes": ["statements/read/mine", "statements/write"], - } + + assert len(response.json().keys()) == 2 + assert response.json()["agent"] == {"mbox": "mailto:ralph@example.com"} + assert sorted(response.json()["scopes"]) == [ + "statements/read/mine", + "statements/write", + ] diff --git a/tests/api/auth/test_oidc.py b/tests/api/auth/test_oidc.py index 0c044bfe6..a0b621f01 100644 --- a/tests/api/auth/test_oidc.py +++ b/tests/api/auth/test_oidc.py @@ -1,47 +1,29 @@ """Tests for the api.auth.oidc module.""" import responses +from pydantic import parse_obj_as from ralph.api.auth.oidc import discover_provider, get_public_keys +from ralph.models.xapi.base.agents import BaseXapiAgentWithOpenId -from tests.fixtures.auth import ISSUER_URI +from tests.fixtures.auth import ISSUER_URI, mock_oidc_user @responses.activate -def test_api_auth_oidc_valid( - oidc_auth_test_client, mock_discovery_response, mock_oidc_jwks, encoded_token -): +def test_api_auth_oidc_valid(oidc_auth_test_client): """Test a valid OpenId Connect authentication.""" - # Clear LRU cache - discover_provider.cache_clear() - get_public_keys.cache_clear() - - # Mock request to get provider configuration - responses.add( - responses.GET, - f"{ISSUER_URI}/.well-known/openid-configuration", - json=mock_discovery_response, - status=200, - ) - - # Mock request to get keys - responses.add( - responses.GET, - mock_discovery_response["jwks_uri"], - json=mock_oidc_jwks, - status=200, - ) + oidc_token = mock_oidc_user(scopes=["all", "profile/read"]) response = oidc_auth_test_client.get( "/whoami", - headers={"Authorization": f"Bearer {encoded_token}"}, + headers={"Authorization": f"Bearer {oidc_token}"}, ) assert response.status_code == 200 - assert response.json() == { - "scopes": ["all", "statements/read"], - "agent": {"openid": "123|oidc"}, - } + assert len(response.json().keys()) == 2 + assert response.json()["agent"] == {"openid": "https://iss.example.com/123|oidc"} + assert parse_obj_as(BaseXapiAgentWithOpenId, response.json()["agent"]) + assert sorted(response.json()["scopes"]) == ["all", "profile/read"] @responses.activate @@ -50,25 +32,7 @@ def test_api_auth_invalid_token( ): """Test API with an invalid audience.""" - # Clear LRU cache - discover_provider.cache_clear() - get_public_keys.cache_clear() - - # Mock request to get provider configuration - responses.add( - responses.GET, - f"{ISSUER_URI}/.well-known/openid-configuration", - json=mock_discovery_response, - status=200, - ) - - # Mock request to get keys - responses.add( - responses.GET, - mock_discovery_response["jwks_uri"], - json=mock_oidc_jwks, - status=200, - ) + mock_oidc_user() response = oidc_auth_test_client.get( "/whoami", @@ -143,34 +107,14 @@ def test_api_auth_invalid_keys( @responses.activate -def test_api_auth_invalid_header( - oidc_auth_test_client, mock_discovery_response, mock_oidc_jwks, encoded_token -): +def test_api_auth_invalid_header(oidc_auth_test_client): """Test API with an invalid request header.""" - # Clear LRU cache - discover_provider.cache_clear() - get_public_keys.cache_clear() - - # Mock request to get provider configuration - responses.add( - responses.GET, - f"{ISSUER_URI}/.well-known/openid-configuration", - json=mock_discovery_response, - status=200, - ) - - # Mock request to get keys - responses.add( - responses.GET, - mock_discovery_response["jwks_uri"], - json=mock_oidc_jwks, - status=200, - ) + oidc_token = mock_oidc_user() response = oidc_auth_test_client.get( "/whoami", - headers={"Authorization": f"Wrong header {encoded_token}"}, + headers={"Authorization": f"Wrong header {oidc_token}"}, ) assert response.status_code == 401 diff --git a/tests/api/test_statements_get.py b/tests/api/test_statements_get.py index 56cdd2baf..ec8a24085 100644 --- a/tests/api/test_statements_get.py +++ b/tests/api/test_statements_get.py @@ -5,11 +5,14 @@ from urllib.parse import parse_qs, quote_plus, urlparse import pytest +import responses from elasticsearch.helpers import bulk from fastapi.testclient import TestClient from ralph.api import app -from ralph.api.auth.basic import get_authenticated_user +from ralph.api.auth import get_authenticated_user +from ralph.api.auth.basic import get_basic_auth_user +from ralph.api.auth.oidc import get_oidc_user from ralph.backends.data.base import BaseOperationType from ralph.backends.data.clickhouse import ClickHouseDataBackend from ralph.backends.data.mongo import MongoDataBackend @@ -28,7 +31,7 @@ get_mongo_test_backend, ) -from ..fixtures.auth import mock_basic_auth_user +from ..fixtures.auth import mock_basic_auth_user, mock_oidc_user from ..helpers import mock_activity, mock_agent client = TestClient(app) @@ -81,28 +84,27 @@ def insert_clickhouse_statements(statements): @pytest.fixture(params=["es", "mongo", "clickhouse"]) -# pylint: disable=unused-argument def insert_statements_and_monkeypatch_backend( request, es, mongo, clickhouse, monkeypatch ): """(Security) Return a function that inserts statements into each backend.""" - # pylint: disable=invalid-name + # pylint: disable=invalid-name,unused-argument def _insert_statements_and_monkeypatch_backend(statements): """Inserts statements once into each backend.""" - database_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" if request.param == "mongo": insert_mongo_statements(mongo, statements) - monkeypatch.setattr(database_client_class_path, get_mongo_test_backend()) + monkeypatch.setattr(backend_client_class_path, get_mongo_test_backend()) return if request.param == "clickhouse": insert_clickhouse_statements(statements) monkeypatch.setattr( - database_client_class_path, get_clickhouse_test_backend() + backend_client_class_path, get_clickhouse_test_backend() ) return insert_es_statements(es, statements) - monkeypatch.setattr(database_client_class_path, get_es_test_backend()) + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) return _insert_statements_and_monkeypatch_backend @@ -123,8 +125,7 @@ def test_api_statements_get_mine( """(Security) Test that the get statements API route, given a "mine=True" query parameter returns a list of statements filtered by authority. """ - # pylint: disable=redefined-outer-name - # pylint: disable=invalid-name + # pylint: disable=redefined-outer-name,invalid-name # Create two distinct agents if ifi == "account_same_home_page": @@ -153,7 +154,7 @@ def test_api_statements_get_mine( ) # Clear cache before each test iteration - get_authenticated_user.cache_clear() + get_basic_auth_user.cache_clear() statements = [ { @@ -233,7 +234,7 @@ def test_api_statements_get_mine( def test_api_statements_get( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route without any filters set up.""" # pylint: disable=redefined-outer-name @@ -253,7 +254,7 @@ def test_api_statements_get( # Confirm that calling this with and without the trailing slash both work for path in ("/xAPI/statements", "/xAPI/statements/"): response = client.get( - path, headers={"Authorization": f"Basic {auth_credentials}"} + path, headers={"Authorization": f"Basic {basic_auth_credentials}"} ) assert response.status_code == 200 @@ -261,7 +262,7 @@ def test_api_statements_get( def test_api_statements_get_ascending( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "ascending" query parameter, should return statements in ascending order by their timestamp. @@ -282,7 +283,7 @@ def test_api_statements_get_ascending( response = client.get( "/xAPI/statements/?ascending=true", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -290,7 +291,7 @@ def test_api_statements_get_ascending( def test_api_statements_get_by_statement_id( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a "statementId" query parameter, should return a list of statements matching the given statementId. @@ -311,7 +312,7 @@ def test_api_statements_get_by_statement_id( response = client.get( f"/xAPI/statements/?statementId={statements[1]['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -329,7 +330,7 @@ def test_api_statements_get_by_statement_id( ], ) def test_api_statements_get_by_agent( - ifi, insert_statements_and_monkeypatch_backend, auth_credentials + ifi, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "agent" query parameter, should return a list of statements filtered by the given agent. @@ -365,7 +366,7 @@ def test_api_statements_get_by_agent( response = client.get( f"/xAPI/statements/?agent={quote_plus(json.dumps(agent_1))}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -373,7 +374,7 @@ def test_api_statements_get_by_agent( def test_api_statements_get_by_verb( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a "verb" query parameter, should return a list of statements filtered by the given verb id. @@ -396,7 +397,7 @@ def test_api_statements_get_by_verb( response = client.get( "/xAPI/statements/?verb=" + quote_plus("http://adlnet.gov/expapi/verbs/played"), - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -404,7 +405,7 @@ def test_api_statements_get_by_verb( def test_api_statements_get_by_activity( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "activity" query parameter, should return a list of statements filtered by the given activity id. @@ -430,7 +431,7 @@ def test_api_statements_get_by_activity( response = client.get( f"/xAPI/statements/?activity={activity_1['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -439,7 +440,7 @@ def test_api_statements_get_by_activity( # Check that badly formated activity returns an error response = client.get( "/xAPI/statements/?activity=INVALID_IRI", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 422 @@ -447,7 +448,7 @@ def test_api_statements_get_by_activity( def test_api_statements_get_since_timestamp( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a "since" query parameter, should return a list of statements filtered by the given timestamp. @@ -469,7 +470,7 @@ def test_api_statements_get_since_timestamp( since = (datetime.now() - timedelta(minutes=30)).isoformat() response = client.get( f"/xAPI/statements/?since={since}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -477,7 +478,7 @@ def test_api_statements_get_since_timestamp( def test_api_statements_get_until_timestamp( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given an "until" query parameter, should return a list of statements filtered by the given timestamp. @@ -499,7 +500,7 @@ def test_api_statements_get_until_timestamp( until = (datetime.now() - timedelta(minutes=30)).isoformat() response = client.get( f"/xAPI/statements/?until={until}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 @@ -507,7 +508,7 @@ def test_api_statements_get_until_timestamp( def test_api_statements_get_with_pagination( - monkeypatch, insert_statements_and_monkeypatch_backend, auth_credentials + monkeypatch, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a request leading to more results than can fit on the first page, should return a list of statements non-exceeding the page @@ -546,7 +547,8 @@ def test_api_statements_get_with_pagination( # First response gets the first two results, with a "more" entry as # we have more results to return on a later page. first_response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert first_response.status_code == 200 assert first_response.json()["statements"] == [statements[4], statements[3]] @@ -558,7 +560,7 @@ def test_api_statements_get_with_pagination( # Second response gets the missing result from the first response. second_response = client.get( first_response.json()["more"], - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert second_response.status_code == 200 assert second_response.json()["statements"] == [statements[2], statements[1]] @@ -570,14 +572,14 @@ def test_api_statements_get_with_pagination( # Third response gets the missing result from the first response third_response = client.get( second_response.json()["more"], - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert third_response.status_code == 200 assert third_response.json() == {"statements": [statements[0]]} def test_api_statements_get_with_pagination_and_query( - monkeypatch, insert_statements_and_monkeypatch_backend, auth_credentials + monkeypatch, insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a request with a query parameter leading to more results than can fit on the first page, should return a list @@ -623,7 +625,7 @@ def test_api_statements_get_with_pagination_and_query( first_response = client.get( "/xAPI/statements/?verb=" + quote_plus("https://w3id.org/xapi/video/verbs/played"), - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert first_response.status_code == 200 assert first_response.json()["statements"] == [statements[2], statements[1]] @@ -635,14 +637,14 @@ def test_api_statements_get_with_pagination_and_query( # Second response gets the missing result from the first response. second_response = client.get( first_response.json()["more"], - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert second_response.status_code == 200 assert second_response.json() == {"statements": [statements[0]]} def test_api_statements_get_with_no_matching_statement( - insert_statements_and_monkeypatch_backend, auth_credentials + insert_statements_and_monkeypatch_backend, basic_auth_credentials ): """Test the get statements API route, given a query yielding no matching statement, should return an empty list. @@ -663,14 +665,16 @@ def test_api_statements_get_with_no_matching_statement( response = client.get( "/xAPI/statements/?statementId=66c81e98-1763-4730-8cfc-f5ab34f1bad5", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert response.json() == {"statements": []} -def test_api_statements_get_with_database_query_failure(auth_credentials, monkeypatch): +def test_api_statements_get_with_database_query_failure( + basic_auth_credentials, monkeypatch +): """Test the get statements API route, given a query raising a BackendException, should return an error response with HTTP code 500. """ @@ -687,14 +691,14 @@ def mock_query_statements(*_): response = client.get( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 500 assert response.json() == {"detail": "xAPI statements query failed"} @pytest.mark.parametrize("id_param", ["statementId", "voidedStatementId"]) -def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param): +def test_api_statements_get_invalid_query_parameters(basic_auth_credentials, id_param): """Test error response for invalid query parameters""" id_1 = "be67b160-d958-4f51-b8b8-1892002dbac6" @@ -703,7 +707,7 @@ def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param) # Check for 400 status code when unknown parameters are provided response = client.get( "/xAPI/statements/?mamamia=herewegoagain", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 400 assert response.json() == { @@ -713,7 +717,7 @@ def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param) # Check for 400 status code when both statementId and voidedStatementId are provided response = client.get( f"/xAPI/statements/?statementId={id_1}&voidedStatementId={id_2}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 400 @@ -725,7 +729,7 @@ def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param) ]: response = client.get( f"/xAPI/statements/?{id_param}={id_1}&{invalid_param}={value}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 400 assert response.json() == { @@ -739,6 +743,166 @@ def test_api_statements_get_invalid_query_parameters(auth_credentials, id_param) for valid_param, value in [("format", "ids"), ("attachments", "true")]: response = client.get( f"/xAPI/statements/?{id_param}={id_1}&{valid_param}={value}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code != 400 + + +@responses.activate +@pytest.mark.parametrize("auth_method", ["basic", "oidc"]) +@pytest.mark.parametrize( + "scopes,is_authorized", + [ + (["all"], True), + (["all/read"], True), + (["statements/read/mine"], True), + (["statements/read"], True), + (["profile/write", "statements/read", "all/write"], True), + (["statements/write"], False), + (["profile/read"], False), + (["all/write"], False), + ([], False), + ], +) +def test_api_statements_get_scopes( + monkeypatch, fs, es, auth_method, scopes, is_authorized +): + """Test that getting statements behaves properly according to user scopes.""" + # pylint: disable=invalid-name,too-many-locals,too-many-arguments + + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True + ) + monkeypatch.setattr("ralph.api.auth.basic.settings.LRS_RESTRICT_BY_SCOPES", True) + + if auth_method == "basic": + agent = mock_agent("mbox", 1) + credentials = mock_basic_auth_user(fs, scopes=scopes, agent=agent) + headers = {"Authorization": f"Basic {credentials}"} + + app.dependency_overrides[get_authenticated_user] = get_basic_auth_user + get_basic_auth_user.cache_clear() + + elif auth_method == "oidc": + sub = "123|oidc" + iss = "https://iss.example.com" + agent = {"openid": f"{iss}/{sub}"} + oidc_token = mock_oidc_user(sub=sub, scopes=scopes) + headers = {"Authorization": f"Bearer {oidc_token}"} + + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", + "http://providerHost:8080/auth/realms/real_name", + ) + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", + "http://clientHost:8100", + ) + + app.dependency_overrides[get_authenticated_user] = get_oidc_user + + statements = [ + { + "id": "be67b160-d958-4f51-b8b8-1892002dbac6", + "timestamp": (datetime.now() - timedelta(hours=1)).isoformat(), + "actor": agent, + "authority": agent, + }, + { + "id": "72c81e98-1763-4730-8cfc-f5ab34f1bad2", + "timestamp": datetime.now().isoformat(), + "actor": agent, + "authority": agent, + }, + ] + + # NB: scopes are not linked to statements and backends, we therefore test with ES + backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + insert_es_statements(es, statements) + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) + + response = client.get( + "/xAPI/statements/", + headers=headers, + ) + + if is_authorized: + assert response.status_code == 200 + assert response.json() == {"statements": [statements[1], statements[0]]} + else: + assert response.status_code == 401 + assert response.json() == { + "detail": 'Access not authorized to scope: "statements/read/mine".' + } + + app.dependency_overrides.pop(get_authenticated_user, None) + + +@pytest.mark.parametrize( + "scopes,read_all_access", + [ + (["all"], True), + (["all/read", "statements/read/mine"], True), + (["statements/read"], True), + (["statements/read/mine"], False), + ], +) +def test_api_statements_get_scopes_with_authority( + monkeypatch, fs, es, scopes, read_all_access +): + """Test that restricting by scope and by authority behaves properly. + Getting statements should be restricted to mine for users which only have + `statements/read/mine` scope but should not be restricted when the user + has wider scopes. + """ + # pylint: disable=invalid-name + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_AUTHORITY", True + ) + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True + ) + monkeypatch.setattr("ralph.api.auth.basic.settings.LRS_RESTRICT_BY_SCOPES", True) + + agent = mock_agent("mbox", 1) + agent_2 = mock_agent("mbox", 2) + username = "jane" + password = "janepwd" + credentials = mock_basic_auth_user(fs, username, password, scopes, agent) + headers = {"Authorization": f"Basic {credentials}"} + + get_basic_auth_user.cache_clear() + + statements = [ + { + "id": "be67b160-d958-4f51-b8b8-1892002dbac6", + "timestamp": (datetime.now() - timedelta(hours=1)).isoformat(), + "actor": agent, + "authority": agent, + }, + { + "id": "72c81e98-1763-4730-8cfc-f5ab34f1bad2", + "timestamp": datetime.now().isoformat(), + "actor": agent, + "authority": agent_2, + }, + ] + + # NB: scopes are not linked to statements and backends, we therefore test with ES + backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + insert_es_statements(es, statements) + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) + + response = client.get( + "/xAPI/statements/", + headers=headers, + ) + + assert response.status_code == 200 + + if read_all_access: + assert response.json() == {"statements": [statements[1], statements[0]]} + else: + assert response.json() == {"statements": [statements[0]]} + + app.dependency_overrides.pop(get_authenticated_user, None) diff --git a/tests/api/test_statements_post.py b/tests/api/test_statements_post.py index 350e4a11c..fe3e63691 100644 --- a/tests/api/test_statements_post.py +++ b/tests/api/test_statements_post.py @@ -4,15 +4,20 @@ from uuid import uuid4 import pytest +import responses from fastapi.testclient import TestClient from httpx import AsyncClient from ralph.api import app +from ralph.api.auth import get_authenticated_user +from ralph.api.auth.basic import get_basic_auth_user +from ralph.api.auth.oidc import get_oidc_user from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.lrs.mongo import MongoLRSBackend from ralph.conf import XapiForwardingConfigurationSettings from ralph.exceptions import BackendException +from tests.fixtures.auth import mock_basic_auth_user, mock_oidc_user from tests.fixtures.backends import ( ES_TEST_FORWARDING_INDEX, ES_TEST_HOSTS, @@ -28,6 +33,7 @@ from ..helpers import ( assert_statement_get_responses_are_equivalent, + mock_agent, mock_statement, string_is_date, string_is_uuid, @@ -36,7 +42,7 @@ client = TestClient(app) -def test_api_statements_post_invalid_parameters(auth_credentials): +def test_api_statements_post_invalid_parameters(basic_auth_credentials): """Test that using invalid parameters returns the proper status code.""" statement = mock_statement() @@ -44,7 +50,7 @@ def test_api_statements_post_invalid_parameters(auth_credentials): # Check for 400 status code when unknown parameters are provided response = client.post( "/xAPI/statements/?mamamia=herewegoagain", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 400 @@ -59,7 +65,7 @@ def test_api_statements_post_invalid_parameters(auth_credentials): ) # pylint: disable=too-many-arguments def test_api_statements_post_single_statement_directly( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with one statement.""" # pylint: disable=invalid-name,unused-argument @@ -69,7 +75,7 @@ def test_api_statements_post_single_statement_directly( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -79,7 +85,8 @@ def test_api_statements_post_single_statement_directly( es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -89,7 +96,7 @@ def test_api_statements_post_single_statement_directly( # pylint: disable=too-many-arguments def test_api_statements_post_enriching_without_existing_values( - monkeypatch, auth_credentials, es + monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when statement provides no values.""" # pylint: disable=invalid-name,unused-argument @@ -111,7 +118,7 @@ def test_api_statements_post_enriching_without_existing_values( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -120,7 +127,8 @@ def test_api_statements_post_enriching_without_existing_values( es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) statement = response.json()["statements"][0] @@ -153,7 +161,7 @@ def test_api_statements_post_enriching_without_existing_values( ) # pylint: disable=too-many-arguments def test_api_statements_post_enriching_with_existing_values( - field, value, status, monkeypatch, auth_credentials, es + field, value, status, monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when values are provided.""" # pylint: disable=invalid-name,unused-argument @@ -168,7 +176,7 @@ def test_api_statements_post_enriching_with_existing_values( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -178,7 +186,8 @@ def test_api_statements_post_enriching_with_existing_values( if status == 200: es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) statement = response.json()["statements"][0] @@ -198,7 +207,7 @@ def test_api_statements_post_enriching_with_existing_values( ) # pylint: disable=too-many-arguments def test_api_statements_post_single_statement_no_trailing_slash( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test that the statements endpoint also works without the trailing slash.""" # pylint: disable=invalid-name,unused-argument @@ -208,7 +217,7 @@ def test_api_statements_post_single_statement_no_trailing_slash( response = client.post( "/xAPI/statements", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -222,7 +231,7 @@ def test_api_statements_post_single_statement_no_trailing_slash( ) # pylint: disable=too-many-arguments def test_api_statements_post_list_of_one( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with one statement in a list.""" # pylint: disable=invalid-name,unused-argument @@ -232,7 +241,7 @@ def test_api_statements_post_list_of_one( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[statement], ) @@ -241,7 +250,8 @@ def test_api_statements_post_list_of_one( es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -255,7 +265,7 @@ def test_api_statements_post_list_of_one( ) # pylint: disable=too-many-arguments def test_api_statements_post_list( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with two statements in a list.""" # pylint: disable=invalid-name,unused-argument @@ -272,7 +282,7 @@ def test_api_statements_post_list( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statements, ) @@ -284,7 +294,8 @@ def test_api_statements_post_list( es.indices.refresh() get_response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert get_response.status_code == 200 @@ -306,7 +317,7 @@ def test_api_statements_post_list( ) # pylint: disable=too-many-arguments def test_api_statements_post_list_with_duplicates( - backend, monkeypatch, auth_credentials, es_data_stream, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es_data_stream, mongo, clickhouse ): """Test the post statements API route with duplicate statement IDs should fail.""" # pylint: disable=invalid-name,unused-argument @@ -316,7 +327,7 @@ def test_api_statements_post_list_with_duplicates( response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[statement, statement], ) @@ -327,7 +338,8 @@ def test_api_statements_post_list_with_duplicates( # The failure should imply no statement insertion. es_data_stream.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert response.json() == {"statements": []} @@ -339,7 +351,7 @@ def test_api_statements_post_list_with_duplicates( ) # pylint: disable=too-many-arguments def test_api_statements_post_list_with_duplicate_of_existing_statement( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route, given a statement that already exist in the database (has the same ID), should fail. @@ -354,7 +366,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( # Post the statement once. response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 200 @@ -366,7 +378,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( # include the ID in the response as it wasn't inserted. response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 204 @@ -376,7 +388,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( # Post the statement again, trying to change the timestamp which is not allowed. response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[dict(statement, **{"timestamp": "2023-03-15T14:07:51Z"})], ) @@ -387,7 +399,8 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( } response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -400,7 +413,7 @@ def test_api_statements_post_list_with_duplicate_of_existing_statement( [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) def test_api_statements_post_with_failure_during_storage( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with a failure happening during storage.""" # pylint: disable=invalid-name,unused-argument, too-many-arguments @@ -416,7 +429,7 @@ def write_mock(*args, **kwargs): response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -429,7 +442,7 @@ def write_mock(*args, **kwargs): [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) def test_api_statements_post_with_failure_during_id_query( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the post statements API route with a failure during query execution.""" # pylint: disable=invalid-name,unused-argument,too-many-arguments @@ -447,7 +460,7 @@ def query_statements_by_ids_mock(*args, **kwargs): response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -461,7 +474,7 @@ def query_statements_by_ids_mock(*args, **kwargs): ) # pylint: disable=too-many-arguments def test_api_statements_post_list_without_forwarding( - backend, auth_credentials, monkeypatch, es, mongo, clickhouse + backend, basic_auth_credentials, monkeypatch, es, mongo, clickhouse ): """Test the post statements API route, given an empty forwarding configuration, should not start the forwarding background task. @@ -487,7 +500,7 @@ def spy_mock_forward_xapi_statements(_): response = client.post( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -520,7 +533,7 @@ async def test_api_statements_post_list_with_forwarding( receiving_backend, forwarding_backend, monkeypatch, - auth_credentials, + basic_auth_credentials, es, es_forwarding, mongo, @@ -594,7 +607,7 @@ async def test_api_statements_post_list_with_forwarding( # The statement should be stored on the forwarding client response = await forwarding_client.get( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -605,7 +618,7 @@ async def test_api_statements_post_list_with_forwarding( async with AsyncClient() as receiving_client: response = await receiving_client.get( f"http://{RUNSERVER_TEST_HOST}:{RUNSERVER_TEST_PORT}/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -614,3 +627,74 @@ async def test_api_statements_post_list_with_forwarding( # Stop receiving LRS client await lrs_context.__aexit__(None, None, None) + + +@responses.activate +@pytest.mark.parametrize("auth_method", ["basic", "oidc"]) +@pytest.mark.parametrize( + "scopes,is_authorized", + [ + (["all"], True), + (["profile/read", "statements/write"], True), + (["all/read"], False), + (["statements/read/mine"], False), + (["profile/write"], False), + ([], False), + ], +) +def test_api_statements_post_scopes( + monkeypatch, fs, es, auth_method, scopes, is_authorized +): + """Test that posting statements behaves properly according to user scopes.""" + # pylint: disable=invalid-name,unused-argument + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True + ) + monkeypatch.setattr("ralph.api.auth.basic.settings.LRS_RESTRICT_BY_SCOPES", True) + + if auth_method == "basic": + agent = mock_agent("mbox", 1) + credentials = mock_basic_auth_user(fs, scopes=scopes, agent=agent) + headers = {"Authorization": f"Basic {credentials}"} + + app.dependency_overrides[get_authenticated_user] = get_basic_auth_user + get_basic_auth_user.cache_clear() + + elif auth_method == "oidc": + sub = "123|oidc" + agent = {"openid": sub} + oidc_token = mock_oidc_user(sub=sub, scopes=scopes) + headers = {"Authorization": f"Bearer {oidc_token}"} + + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", + "http://providerHost:8080/auth/realms/real_name", + ) + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", + "http://clientHost:8100", + ) + + app.dependency_overrides[get_authenticated_user] = get_oidc_user + + statement = mock_statement() + + # NB: scopes are not linked to statements and backends, we therefore test with ES + backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) + + response = client.post( + "/xAPI/statements/", + headers=headers, + json=statement, + ) + + if is_authorized: + assert response.status_code == 200 + else: + assert response.status_code == 401 + assert response.json() == { + "detail": 'Access not authorized to scope: "statements/write".' + } + + app.dependency_overrides.pop(get_authenticated_user, None) diff --git a/tests/api/test_statements_put.py b/tests/api/test_statements_put.py index 330bccd0f..ae30b2b73 100644 --- a/tests/api/test_statements_put.py +++ b/tests/api/test_statements_put.py @@ -1,17 +1,23 @@ """Tests for the PUT statements endpoint of the Ralph API.""" - +from importlib import reload from uuid import uuid4 import pytest +import responses from fastapi.testclient import TestClient from httpx import AsyncClient +from ralph import api from ralph.api import app +from ralph.api.auth import get_authenticated_user +from ralph.api.auth.basic import get_basic_auth_user +from ralph.api.auth.oidc import get_oidc_user from ralph.backends.lrs.es import ESLRSBackend from ralph.backends.lrs.mongo import MongoLRSBackend from ralph.conf import XapiForwardingConfigurationSettings from ralph.exceptions import BackendException +from tests.fixtures.auth import mock_basic_auth_user, mock_oidc_user from tests.fixtures.backends import ( ES_TEST_FORWARDING_INDEX, ES_TEST_HOSTS, @@ -27,21 +33,23 @@ from ..helpers import ( assert_statement_get_responses_are_equivalent, + mock_agent, mock_statement, string_is_date, ) +reload(api) client = TestClient(app) -def test_api_statements_put_invalid_parameters(auth_credentials): +def test_api_statements_put_invalid_parameters(basic_auth_credentials): """Test that using invalid parameters returns the proper status code.""" statement = mock_statement() # Check for 400 status code when unknown parameters are provided response = client.put( "/xAPI/statements/?mamamia=herewegoagain", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 400 @@ -56,7 +64,7 @@ def test_api_statements_put_invalid_parameters(auth_credentials): ) # pylint: disable=too-many-arguments def test_api_statements_put_single_statement_directly( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with one statement.""" # pylint: disable=invalid-name,unused-argument @@ -66,7 +74,7 @@ def test_api_statements_put_single_statement_directly( response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -75,7 +83,8 @@ def test_api_statements_put_single_statement_directly( es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -85,7 +94,7 @@ def test_api_statements_put_single_statement_directly( # pylint: disable=too-many-arguments def test_api_statements_put_enriching_without_existing_values( - monkeypatch, auth_credentials, es + monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when statement provides no values.""" # pylint: disable=invalid-name,unused-argument @@ -97,7 +106,7 @@ def test_api_statements_put_enriching_without_existing_values( response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 204 @@ -105,7 +114,8 @@ def test_api_statements_put_enriching_without_existing_values( es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) statement = response.json()["statements"][0] @@ -137,7 +147,7 @@ def test_api_statements_put_enriching_without_existing_values( ) # pylint: disable=too-many-arguments def test_api_statements_put_enriching_with_existing_values( - field, value, status, monkeypatch, auth_credentials, es + field, value, status, monkeypatch, basic_auth_credentials, es ): """Test that statements are properly enriched when values are provided.""" # pylint: disable=invalid-name,unused-argument @@ -152,7 +162,7 @@ def test_api_statements_put_enriching_with_existing_values( response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -162,7 +172,8 @@ def test_api_statements_put_enriching_with_existing_values( if status == 204: es.indices.refresh() response = client.get( - "/xAPI/statements/", headers={"Authorization": f"Basic {auth_credentials}"} + "/xAPI/statements/", + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) statement = response.json()["statements"][0] @@ -181,7 +192,7 @@ def test_api_statements_put_enriching_with_existing_values( ) # pylint: disable=too-many-arguments def test_api_statements_put_single_statement_no_trailing_slash( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test that the statements endpoint also works without the trailing slash.""" # pylint: disable=invalid-name,unused-argument @@ -191,7 +202,7 @@ def test_api_statements_put_single_statement_no_trailing_slash( response = client.put( f"/xAPI/statements?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -204,7 +215,7 @@ def test_api_statements_put_single_statement_no_trailing_slash( ) # pylint: disable=too-many-arguments def test_api_statements_put_id_mismatch( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): # pylint: disable=invalid-name,unused-argument """Test the put statements API route when the statementId doesn't match.""" @@ -214,7 +225,7 @@ def test_api_statements_put_id_mismatch( different_statement_id = str(uuid4()) response = client.put( f"/xAPI/statements/?statementId={different_statement_id}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -230,7 +241,7 @@ def test_api_statements_put_id_mismatch( ) # pylint: disable=too-many-arguments def test_api_statements_put_list_of_one( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): # pylint: disable=invalid-name,unused-argument """Test that we fail on PUTs with a list, even if it's one statement.""" @@ -239,7 +250,7 @@ def test_api_statements_put_list_of_one( response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=[statement], ) @@ -252,7 +263,7 @@ def test_api_statements_put_list_of_one( ) # pylint: disable=too-many-arguments def test_api_statements_put_duplicate_of_existing_statement( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route, given a statement that already exist in the database (has the same ID), should fail. @@ -265,7 +276,7 @@ def test_api_statements_put_duplicate_of_existing_statement( # Put the statement once. response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) assert response.status_code == 204 @@ -275,7 +286,7 @@ def test_api_statements_put_duplicate_of_existing_statement( # Put the statement twice, trying to change the timestamp, which is not allowed response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=dict(statement, **{"timestamp": "2023-03-15T14:07:51Z"}), ) @@ -286,7 +297,7 @@ def test_api_statements_put_duplicate_of_existing_statement( response = client.get( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -299,7 +310,7 @@ def test_api_statements_put_duplicate_of_existing_statement( [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) def test_api_statements_put_with_failure_during_storage( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with a failure happening during storage.""" # pylint: disable=invalid-name,unused-argument, too-many-arguments @@ -315,7 +326,7 @@ def write_mock(*args, **kwargs): response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -328,7 +339,7 @@ def write_mock(*args, **kwargs): [get_es_test_backend, get_clickhouse_test_backend, get_mongo_test_backend], ) def test_api_statements_put_with_a_failure_during_id_query( - backend, monkeypatch, auth_credentials, es, mongo, clickhouse + backend, monkeypatch, basic_auth_credentials, es, mongo, clickhouse ): """Test the put statements API route with a failure during query execution.""" # pylint: disable=invalid-name,unused-argument,too-many-arguments @@ -346,7 +357,7 @@ def query_statements_by_ids_mock(*args, **kwargs): response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -360,7 +371,7 @@ def query_statements_by_ids_mock(*args, **kwargs): ) # pylint: disable=too-many-arguments def test_api_statements_put_without_forwarding( - backend, auth_credentials, monkeypatch, es, mongo, clickhouse + backend, basic_auth_credentials, monkeypatch, es, mongo, clickhouse ): """Test the put statements API route, given an empty forwarding configuration, should not start the forwarding background task. @@ -386,7 +397,7 @@ def spy_mock_forward_xapi_statements(_): response = client.put( f"/xAPI/statements/?statementId={statement['id']}", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, json=statement, ) @@ -418,7 +429,7 @@ async def test_api_statements_put_with_forwarding( receiving_backend, forwarding_backend, monkeypatch, - auth_credentials, + basic_auth_credentials, es, es_forwarding, mongo, @@ -495,7 +506,7 @@ async def test_api_statements_put_with_forwarding( # The statement should be stored on the forwarding client response = await forwarding_client.get( "/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -506,7 +517,7 @@ async def test_api_statements_put_with_forwarding( async with AsyncClient() as receiving_client: response = await receiving_client.get( f"http://{RUNSERVER_TEST_HOST}:{RUNSERVER_TEST_PORT}/xAPI/statements/", - headers={"Authorization": f"Basic {auth_credentials}"}, + headers={"Authorization": f"Basic {basic_auth_credentials}"}, ) assert response.status_code == 200 assert_statement_get_responses_are_equivalent( @@ -515,3 +526,74 @@ async def test_api_statements_put_with_forwarding( # Stop receiving LRS client await lrs_context.__aexit__(None, None, None) + + +@responses.activate +@pytest.mark.parametrize("auth_method", ["basic", "oidc"]) +@pytest.mark.parametrize( + "scopes,is_authorized", + [ + (["all"], True), + (["profile/read", "statements/write"], True), + (["all/read"], False), + (["statements/read/mine"], False), + (["profile/write"], False), + ([], False), + ], +) +def test_api_statements_put_scopes( + monkeypatch, fs, es, auth_method, scopes, is_authorized +): + """Test that putting statements behaves properly according to user scopes.""" + # pylint: disable=invalid-name,unused-argument,duplicate-code + monkeypatch.setattr( + "ralph.api.routers.statements.settings.LRS_RESTRICT_BY_SCOPES", True + ) + monkeypatch.setattr("ralph.api.auth.basic.settings.LRS_RESTRICT_BY_SCOPES", True) + + if auth_method == "basic": + agent = mock_agent("mbox", 1) + credentials = mock_basic_auth_user(fs, scopes=scopes, agent=agent) + headers = {"Authorization": f"Basic {credentials}"} + + app.dependency_overrides[get_authenticated_user] = get_basic_auth_user + get_basic_auth_user.cache_clear() + + elif auth_method == "oidc": + sub = "123|oidc" + agent = {"openid": sub} + oidc_token = mock_oidc_user(sub=sub, scopes=scopes) + headers = {"Authorization": f"Bearer {oidc_token}"} + + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_ISSUER_URI", + "http://providerHost:8080/auth/realms/real_name", + ) + monkeypatch.setattr( + "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", + "http://clientHost:8100", + ) + + app.dependency_overrides[get_authenticated_user] = get_oidc_user + + statement = mock_statement() + + # NB: scopes are not linked to statements and backends, we therefore test with ES + backend_client_class_path = "ralph.api.routers.statements.BACKEND_CLIENT" + monkeypatch.setattr(backend_client_class_path, get_es_test_backend()) + + response = client.put( + f"/xAPI/statements/?statementId={statement['id']}", + headers=headers, + json=statement, + ) + + if is_authorized: + assert response.status_code == 204 + else: + assert response.status_code == 401 + assert response.json() == { + "detail": 'Access not authorized to scope: "statements/write".' + } + + app.dependency_overrides.pop(get_authenticated_user, None) diff --git a/tests/conftest.py b/tests/conftest.py index 10b819ee3..8165d3458 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from .fixtures import hypothesis_configuration # noqa: F401 from .fixtures import hypothesis_strategies # noqa: F401 from .fixtures.auth import ( # noqa: F401 - auth_credentials, + basic_auth_credentials, basic_auth_test_client, encoded_token, mock_discovery_response, diff --git a/tests/fixtures/auth.py b/tests/fixtures/auth.py index da4c83868..7e44149b3 100644 --- a/tests/fixtures/auth.py +++ b/tests/fixtures/auth.py @@ -2,9 +2,11 @@ import base64 import json import os +from typing import Optional import bcrypt import pytest +import responses from cryptography.hazmat.primitives import serialization from fastapi.testclient import TestClient from jose import jwt @@ -12,6 +14,7 @@ from ralph.api import app, get_authenticated_user from ralph.api.auth.basic import get_stored_credentials +from ralph.api.auth.oidc import discover_provider, get_public_keys from ralph.conf import settings from . import private_key, public_key @@ -24,10 +27,10 @@ def mock_basic_auth_user( fs_, - username: str, - password: str, - scopes: list, - agent: dict, + username: str = "jane", + password: str = "pwd", + scopes: Optional[list] = None, + agent: Optional[dict] = None, ): """Create a user using Basic Auth in the (fake) file system. @@ -39,6 +42,12 @@ def mock_basic_auth_user( agent (dict): an agent that represents the user and may be used as authority """ + # Default values for `scopes` and `agent` + if scopes is None: + scopes = [] + if agent is None: + agent = {"mbox": "mailto:jane@ralphlrs.com"} + # Basic HTTP auth credential_bytes = base64.b64encode(f"{username}:{password}".encode("utf-8")) credentials = str(credential_bytes, "utf-8") @@ -71,7 +80,7 @@ def mock_basic_auth_user( # pylint: disable=invalid-name @pytest.fixture -def auth_credentials(fs, user_scopes=None, agent=None): +def basic_auth_credentials(fs, user_scopes=None, agent=None): """Set up the credentials file for request authentication. Args: @@ -92,7 +101,6 @@ def auth_credentials(fs, user_scopes=None, agent=None): agent = {"mbox": "mailto:test_ralph@example.com"} credentials = mock_basic_auth_user(fs, username, password, user_scopes, agent) - return credentials @@ -101,10 +109,10 @@ def basic_auth_test_client(): """Return a TestClient with HTTP basic authentication mode.""" # pylint:disable=import-outside-toplevel from ralph.api.auth.basic import ( - get_authenticated_user as get_basic, # pylint:disable=import-outside-toplevel + get_basic_auth_user, # pylint:disable=import-outside-toplevel ) - app.dependency_overrides[get_authenticated_user] = get_basic + app.dependency_overrides[get_authenticated_user] = get_basic_auth_user with TestClient(app) as test_client: yield test_client @@ -122,15 +130,14 @@ def oidc_auth_test_client(monkeypatch): "ralph.api.auth.oidc.settings.RUNSERVER_AUTH_OIDC_AUDIENCE", AUDIENCE, ) - from ralph.api.auth.oidc import get_authenticated_user as get_oidc + from ralph.api.auth.oidc import get_oidc_user - app.dependency_overrides[get_authenticated_user] = get_oidc + app.dependency_overrides[get_authenticated_user] = get_oidc_user with TestClient(app) as test_client: yield test_client -@pytest.fixture -def mock_discovery_response(): +def _mock_discovery_response(): """Return an example discovery response.""" return { "issuer": "http://providerHost", @@ -219,6 +226,12 @@ def mock_discovery_response(): } +@pytest.fixture +def mock_discovery_response(): + """Return an example discovery response (fixture).""" + return _mock_discovery_response() + + def get_jwk(pub_key): """Return a JWK representation of the public key.""" public_numbers = pub_key.public_numbers() @@ -233,23 +246,27 @@ def get_jwk(pub_key): } -@pytest.fixture -def mock_oidc_jwks(): +def _mock_oidc_jwks(): """Mock OpenID Connect keys.""" return {"keys": [get_jwk(public_key)]} @pytest.fixture -def encoded_token(): +def mock_oidc_jwks(): + """Mock OpenID Connect keys (fixture).""" + return _mock_oidc_jwks() + + +def _create_oidc_token(sub, scopes): """Encode token with the private key.""" return jwt.encode( claims={ - "sub": "123|oidc", + "sub": sub, "iss": "https://iss.example.com", "aud": AUDIENCE, "iat": 0, # Issued the 1/1/1970 "exp": 9999999999, # Expiring in 11/20/2286 - "scope": "all statements/read", + "scope": " ".join(scopes), }, key=private_key.private_bytes( serialization.Encoding.PEM, @@ -261,3 +278,39 @@ def encoded_token(): "kid": PUBLIC_KEY_ID, }, ) + + +def mock_oidc_user(sub="123|oidc", scopes=None): + """Instantiate mock oidc user and return auth token.""" + # Default value for scope + if scopes is None: + scopes = ["all", "statements/read"] + + # Clear LRU cache + discover_provider.cache_clear() + get_public_keys.cache_clear() + + # Mock request to get provider configuration + responses.add( + responses.GET, + f"{ISSUER_URI}/.well-known/openid-configuration", + json=_mock_discovery_response(), + status=200, + ) + + # Mock request to get keys + responses.add( + responses.GET, + _mock_discovery_response()["jwks_uri"], + json=_mock_oidc_jwks(), + status=200, + ) + + oidc_token = _create_oidc_token(sub=sub, scopes=scopes) + return oidc_token + + +@pytest.fixture +def encoded_token(): + """Encode token with the private key (fixture).""" + return _create_oidc_token(sub="123|oidc", scopes=["all", "statements/read"]) diff --git a/tests/test_cli.py b/tests/test_cli.py index 6a303884c..890576027 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -212,7 +212,7 @@ def _assert_matching_basic_auth_credentials( assert "hash" in credentials if hash_: assert credentials["hash"] == hash_ - assert credentials["scopes"] == scopes + assert sorted(credentials["scopes"]) == sorted(scopes) assert "agent" in credentials if agent_name is not None: diff --git a/tests/test_conf.py b/tests/test_conf.py index 670bf5ba6..e9c681d79 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -7,6 +7,7 @@ from ralph import conf from ralph.backends.conf import BackendSettings from ralph.conf import CommaSeparatedTuple, Settings, settings +from ralph.exceptions import ConfigurationException def test_conf_settings_field_value_priority(fs, monkeypatch): @@ -73,3 +74,20 @@ def test_conf_core_settings_should_impact_settings_defaults(monkeypatch): # Defaults. assert str(conf.settings.AUTH_FILE) == "/foo/auth.json" + + +def test_conf_forbidden_scopes_without_authority(monkeypatch): + """Test that using RESTRICT_BY_SCOPES without RESTRICT_BY_AUTHORITY raises an + error.""" + + monkeypatch.setenv("RALPH_LRS_RESTRICT_BY_AUTHORITY", False) + monkeypatch.setenv("RALPH_LRS_RESTRICT_BY_SCOPES", True) + + with pytest.raises( + ConfigurationException, + match=( + "LRS_RESTRICT_BY_AUTHORITY must be set to True if using " + "LRS_RESTRICT_BY_SCOPES=True" + ), + ): + reload(conf)