diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 1b0fa1f5..a41a3d80 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta from datetime import timezone as dt_timezone -from typing import Optional +from typing import Callable, Optional, Union from urllib.parse import parse_qsl, urlparse from django.apps import apps @@ -734,6 +734,7 @@ class DeviceCodeResponse: user_code: int device_code: str interval: int + verification_uri_complete: Optional[Union[str, Callable]] = None def create_device(device_request: DeviceRequest, device_response: DeviceCodeResponse) -> Device: diff --git a/oauth2_provider/settings.py b/oauth2_provider/settings.py index 5216c806..4f282624 100644 --- a/oauth2_provider/settings.py +++ b/oauth2_provider/settings.py @@ -24,7 +24,7 @@ from django.utils.module_loading import import_string from oauthlib.common import Request -from oauth2_provider.utils import user_code_generator +from oauth2_provider.utils import set_oauthlib_user_to_device_request_user, user_code_generator USER_SETTINGS = getattr(settings, "OAUTH2_PROVIDER", None) @@ -43,7 +43,9 @@ "CLIENT_SECRET_HASHER": "default", "ACCESS_TOKEN_GENERATOR": None, "OAUTH_DEVICE_VERIFICATION_URI": None, + "OAUTH_DEVICE_VERIFICATION_URI_COMPLETE": None, "OAUTH_DEVICE_USER_CODE_GENERATOR": user_code_generator, + "OAUTH_PRE_TOKEN_VALIDATION": [set_oauthlib_user_to_device_request_user], "REFRESH_TOKEN_GENERATOR": None, "EXTRA_SERVER_KWARGS": {}, "OAUTH2_SERVER_CLASS": "oauthlib.oauth2.Server", @@ -276,8 +278,10 @@ def server_kwargs(self): ("token_generator", "ACCESS_TOKEN_GENERATOR"), ("refresh_token_generator", "REFRESH_TOKEN_GENERATOR"), ("verification_uri", "OAUTH_DEVICE_VERIFICATION_URI"), + ("verification_uri_complete", "OAUTH_DEVICE_VERIFICATION_URI_COMPLETE"), ("interval", "DEVICE_FLOW_INTERVAL"), ("user_code_generator", "OAUTH_DEVICE_USER_CODE_GENERATOR"), + ("pre_token","OAUTH_PRE_TOKEN_VALIDATION") ] } kwargs.update(self.EXTRA_SERVER_KWARGS) diff --git a/oauth2_provider/utils.py b/oauth2_provider/utils.py index ef213dca..36d1de4b 100644 --- a/oauth2_provider/utils.py +++ b/oauth2_provider/utils.py @@ -3,6 +3,7 @@ from django.conf import settings from jwcrypto import jwk +from oauthlib.common import Request @functools.lru_cache() @@ -75,3 +76,24 @@ def user_code_generator(user_code_length: int = 8) -> str: user_code[i] = random.choice(character_space) return "".join(user_code) + + +def set_oauthlib_user_to_device_request_user(request: Request) -> None: + """ + The user isn't known when the device flow is initiated by a device. + All we know is the client_id. + + However, when the user logins in order to submit the user code + from the device we now know which user is trying to authenticate + their device. We update the device user field at this point + and save it in the db. + + This function is added to the pre_token stage during the device code grant's + create_token_response where we have the oauthlib Request object which is what's used + to populate the user field in the device model + """ + # Since this function is used in the settings module, it will lead to circular imports + # since django isn't fully initialised yet when settings run + from oauth2_provider.models import Device, get_device_model + device: Device = get_device_model().objects.get(device_code=request._params["device_code"]) + request.user = device.user diff --git a/oauth2_provider/views/device.py b/oauth2_provider/views/device.py index 676dfbf9..075b9bc1 100644 --- a/oauth2_provider/views/device.py +++ b/oauth2_provider/views/device.py @@ -59,6 +59,9 @@ def device_user_code_view(request): user_code: str = form.cleaned_data["user_code"] device: Device = get_device_model().objects.get(user_code=user_code) + device.user = request.user + device.save(update_fields=["user"]) + if device is None: form.add_error("user_code", "Incorrect user code") return render(request, "oauth2_provider/device/user_code.html", {"form": form}) diff --git a/tests/app/idp/idp/settings.py b/tests/app/idp/idp/settings.py index f92ba2b5..67940760 100644 --- a/tests/app/idp/idp/settings.py +++ b/tests/app/idp/idp/settings.py @@ -15,7 +15,7 @@ import environ -from oauth2_provider.utils import user_code_generator +from oauth2_provider.utils import set_oauthlib_user_to_device_request_user, user_code_generator # Build paths inside the project like this: BASE_DIR / 'subdir'. @@ -202,7 +202,9 @@ OAUTH2_PROVIDER = { "OAUTH2_VALIDATOR_CLASS": "idp.oauth.CustomOAuth2Validator", "OAUTH_DEVICE_VERIFICATION_URI": "http://127.0.0.1:8000/o/device", + "OAUTH_PRE_TOKEN_VALIDATION": [set_oauthlib_user_to_device_request_user], "OAUTH_DEVICE_USER_CODE_GENERATOR": user_code_generator, + "OAUTH_DEVICE_VERIFICATION_URI_COMPLETE": lambda x: f"http://127.0.0.1:8000/o/device?user_code={x}", "OIDC_ENABLED": env("OAUTH2_PROVIDER_OIDC_ENABLED"), "OIDC_RP_INITIATED_LOGOUT_ENABLED": env("OAUTH2_PROVIDER_OIDC_RP_INITIATED_LOGOUT_ENABLED"), # this key is just for out test app, you should never store a key like this in a production environment. diff --git a/tests/test_device.py b/tests/test_device.py index 345552a2..18f1fce3 100644 --- a/tests/test_device.py +++ b/tests/test_device.py @@ -8,7 +8,13 @@ from django.urls import reverse import oauth2_provider.models -from oauth2_provider.models import get_access_token_model, get_application_model, get_device_model +from oauth2_provider.models import ( + get_access_token_model, + get_application_model, + get_device_model, + get_refresh_token_model, +) +from oauth2_provider.utils import set_oauthlib_user_to_device_request_user from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase @@ -16,6 +22,7 @@ Application = get_application_model() AccessToken = get_access_token_model() +RefreshToken = get_refresh_token_model() UserModel = get_user_model() DeviceModel: oauth2_provider.models.Device = get_device_model() @@ -122,6 +129,8 @@ def test_device_flow_authorization_user_code_confirm_and_access_token(self): # ----------------------- self.oauth2_settings.OAUTH_DEVICE_VERIFICATION_URI = "example.com/device" self.oauth2_settings.OAUTH_DEVICE_USER_CODE_GENERATOR = lambda: "xyz" + self.oauth2_settings.OAUTH_DEVICE_USER_CODE_GENERATOR = lambda: "xyz" + self.oauth2_settings.OAUTH_PRE_TOKEN_VALIDATION = [set_oauthlib_user_to_device_request_user] request_data: dict[str, str] = { "client_id": self.application.client_id, @@ -193,6 +202,7 @@ def test_device_flow_authorization_user_code_confirm_and_access_token(self): "client_id": self.application.client_id, "grant_type": "urn:ietf:params:oauth:grant-type:device_code", } + token_response = self.client.post( "/o/token/", data=urlencode(token_payload), @@ -207,6 +217,17 @@ def test_device_flow_authorization_user_code_confirm_and_access_token(self): assert token_data["token_type"].lower() == "bearer" assert "scope" in token_data + # ensure the access token and refresh token have the same user as the device that just authenticated + access_token: oauth2_provider.models.AccessToken = AccessToken.objects.get( + token=token_data["access_token"] + ) + assert access_token.user == device.user + + refresh_token: oauth2_provider.models.RefreshToken = RefreshToken.objects.get( + token=token_data["refresh_token"] + ) + assert refresh_token.user == device.user + @mock.patch( "oauthlib.oauth2.rfc8628.endpoints.device_authorization.generate_token", lambda: "abc",