From b0914001830ba30c4e1a085a54f0909b9fa9bb2b Mon Sep 17 00:00:00 2001 From: Tasko Olevski Date: Thu, 25 Jan 2024 16:02:50 +0100 Subject: [PATCH] squashme: minor fixes --- renku/ui/service/entrypoint.py | 2 +- renku/ui/service/serializers/headers.py | 6 +++--- renku/ui/service/utils/__init__.py | 23 ++++++++++++++++------- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/renku/ui/service/entrypoint.py b/renku/ui/service/entrypoint.py index 7449516497..4cb24d3084 100644 --- a/renku/ui/service/entrypoint.py +++ b/renku/ui/service/entrypoint.py @@ -40,8 +40,8 @@ ) from renku.ui.service.logger import service_log from renku.ui.service.serializers.headers import JWT_TOKEN_SECRET -from renku.ui.service.utils.json_encoder import SvcJSONProvider from renku.ui.service.utils import jwk_client +from renku.ui.service.utils.json_encoder import SvcJSONProvider from renku.ui.service.views import error_response from renku.ui.service.views.apispec import apispec_blueprint from renku.ui.service.views.cache import cache_blueprint diff --git a/renku/ui/service/serializers/headers.py b/renku/ui/service/serializers/headers.py index 51ef6c3851..3f8c10dc0c 100644 --- a/renku/ui/service/serializers/headers.py +++ b/renku/ui/service/serializers/headers.py @@ -20,7 +20,7 @@ from typing import cast import jwt -from flask import app +from flask import current_app from marshmallow import Schema, ValidationError, fields, post_load from werkzeug.utils import secure_filename @@ -96,9 +96,9 @@ def decode_token(token): def decode_user(data): """Extract renku user from the Keycloak ID token which is a JWT.""" try: - jwk = cast(jwt.PyJWKClient, app.config["KEYCLOAK_JWK_CLIENT"]) + jwk = cast(jwt.PyJWKClient, current_app.config["KEYCLOAK_JWK_CLIENT"]) key = jwk.get_signing_key_from_jwt(data) - decoded = jwt.decode(data, key=key, algorithms=["RS256"], audience="renku") + decoded = jwt.decode(data, key=key.key, algorithms=["RS256"], audience="renku") except jwt.PyJWTError: # NOTE: older tokens used to be signed with HS256 so use this as a backup if the validation with RS256 # above fails. We used to need HS256 because a step that is now removed was generating an ID token and diff --git a/renku/ui/service/utils/__init__.py b/renku/ui/service/utils/__init__.py index b315701c54..b84bb6bbfe 100644 --- a/renku/ui/service/utils/__init__.py +++ b/renku/ui/service/utils/__init__.py @@ -14,17 +14,18 @@ # See the License for the specific language governing permissions and # limitations under the License. """Renku service utility functions.""" +import os +import urllib from time import sleep from typing import Any, Dict, Optional, overload import requests -import urllib from jwt import PyJWKClient +from renku.core.util.requests import get from renku.ui.service.config import CACHE_PROJECTS_PATH, CACHE_UPLOADS_PATH, OIDC_URL from renku.ui.service.errors import ProgramInternalError from renku.ui.service.logger import service_log -from renku.core.util.requests import get def make_project_path(user, project): @@ -101,28 +102,36 @@ def oidc_discovery() -> Dict[str, Any]: retries = 0 max_retries = 30 sleep_seconds = 2 + renku_domain = os.environ.get("RENKU_DOMAIN") + if not renku_domain: + raise ProgramInternalError( + error_message="Cannot perform OIDC discovery without the renku domain expected " + "to be found in the RENKU_DOMAIN environment variable." + ) + full_oidc_url = f"http://{renku_domain}{OIDC_URL}" while True: retries += 1 try: - res: requests.Response = get(OIDC_URL) + res: requests.Response = get(full_oidc_url) except (requests.exceptions.HTTPError, urllib.error.HTTPError) as e: if not retries < max_retries: service_log.error("Failed to get OIDC discovery data after all retries - the server cannot start.") raise e service_log.info( - f"Failed to get OIDC discovery data from {OIDC_URL}, sleeping for {sleep_seconds} seconds and retrying" + f"Failed to get OIDC discovery data from {full_oidc_url}, " + f"sleeping for {sleep_seconds} seconds and retrying" ) sleep(sleep_seconds) else: - service_log.info(f"Successfully fetched OIDC discovery data from {OIDC_URL}") + service_log.info(f"Successfully fetched OIDC discovery data from {full_oidc_url}") return res.json() def jwk_client() -> PyJWKClient: - """Return a JWK client for Keycloak that can be used to provide JWT keys for JWT signature validation""" + """Return a JWK client for Keycloak that can be used to provide JWT keys for JWT signature validation.""" oidc_data = oidc_discovery() jwks_uri = oidc_data.get("jwks_uri") if not jwks_uri: - raise ProgramInternalError(error_message="Could not find JWK URI in the OIDC discovery data") + raise ProgramInternalError(error_message="Could not find jwks_uri in the OIDC discovery data") jwk = PyJWKClient(jwks_uri) return jwk