From 0758a12230fcb0a8c85545cffdbb5aaa05836df3 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 28 Oct 2024 21:50:17 +0100 Subject: [PATCH 01/12] Rework the api_token endpoint to issue generic API tokens --- src/zenml/config/server_config.py | 7 +++ src/zenml/constants.py | 1 + src/zenml/orchestrators/utils.py | 13 ++++ .../zen_server/routers/auth_endpoints.py | 59 +++++++++++-------- 4 files changed, 55 insertions(+), 25 deletions(-) diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index 399c00f6cfe..8f521900d74 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -26,6 +26,7 @@ DEFAULT_ZENML_JWT_TOKEN_LEEWAY, DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING, DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT, + DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME, DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY, DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE, DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS, @@ -119,6 +120,8 @@ class ServerConfiguration(BaseModel): time of the JWT tokens issued to clients after they have authenticated with the ZenML server using an OAuth 2.0 device that has been marked as trusted. + generic_api_token_lifetime: The lifetime in seconds that generic + short-lived API tokens issued for automation purposes are valid. external_login_url: The login URL of an external authenticator service to use with the `EXTERNAL` authentication scheme. external_user_info_url: The user info URL of an external authenticator @@ -257,6 +260,10 @@ class ServerConfiguration(BaseModel): device_expiration_minutes: Optional[int] = None trusted_device_expiration_minutes: Optional[int] = None + generic_api_token_lifetime: int = ( + DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME + ) + external_login_url: Optional[str] = None external_user_info_url: Optional[str] = None external_cookie_name: Optional[str] = None diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 158f1c2da5d..38a2bba78a4 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -269,6 +269,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW = 60 * 48 # 48 hours DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE = 5 DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY = 1000 +DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME = 60 * 60 # 1 hour DEFAULT_ZENML_SERVER_SECURE_HEADERS_HSTS = ( "max-age=63072000; includeSubdomains" diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index 158f7302e69..ad4fd77ae5d 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -34,6 +34,8 @@ from zenml.stack import StackComponent from zenml.utils.string_utils import format_name_template +logger = get_logger(__name__) + if TYPE_CHECKING: from zenml.artifact_stores.base_artifact_store import BaseArtifactStore from zenml.models import PipelineDeploymentResponse @@ -113,6 +115,17 @@ def get_config_environment_vars( # If a schedule is given, this is a long running pipeline that # should not have an API token that expires. expires_minutes = None + logger.warning( + "An API token without an expiration time will be generated " + "and used to run this pipeline on a schedule. This is very " + "insecure because the API token cannot be revoked in case " + "of potential theft without disabling the entire user " + "accountWhen deploying a pipeline on a schedule, it is " + "strongly advised to use a service account API key to " + "authenticate to the ZenML server instead of your regular " + "user account. For more information, see " + "https://docs.zenml.io/how-to/connecting-to-zenml/connect-with-a-service-account" + ) api_token = global_config.zen_store.get_api_token( pipeline_id=pipeline_id, schedule_id=schedule_id, diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index e0661603541..e41966f2613 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -193,9 +193,12 @@ def __init__( def generate_access_token( user_id: UUID, - response: Response, + response: Optional[Response] = None, device: Optional[OAuthDeviceInternalResponse] = None, api_key: Optional[APIKeyInternalResponse] = None, + expires_in: Optional[int] = None, + pipeline_id: Optional[UUID] = None, + schedule_id: Optional[UUID] = None, ) -> OAuthTokenResponse: """Generates an access token for the given user. @@ -204,18 +207,22 @@ def generate_access_token( response: The FastAPI response object. device: The device used for authentication. api_key: The service account API key used for authentication. + expires_in: The number of seconds until the token expires. + pipeline_id: The ID of the pipeline to scope the token to. + schedule_id: The ID of the schedule to scope the token to. Returns: An authentication response with an access token. """ config = server_config() - # The JWT tokens are set to expire according to the values configured - # in the server config. Device tokens are handled separately from regular - # user tokens. + # If the expiration time is not supplied, the JWT tokens are set to expire + # according to the values configured in the server config. Device tokens are + # handled separately from regular user tokens. expires: Optional[datetime] = None - expires_in: Optional[int] = None - if device: + if expires_in: + expires = datetime.utcnow() + timedelta(seconds=expires_in) + elif device: # If a device was used for authentication, the token will expire # at the same time as the device. expires = device.expires @@ -235,7 +242,7 @@ def generate_access_token( api_key_id=api_key.id if api_key else None, ).encode(expires=expires) - if not device: + if not device and response: # Also set the access token as an HTTP only cookie in the response response.set_cookie( key=config.get_auth_cookie_name(), @@ -522,6 +529,20 @@ def api_token( detail="Not authenticated.", ) + if not token.device_id and not token.api_key_id: + config = server_config() + + # If not authenticated with a device or a service account, then a + # short-lived generic API token is returned. + return generate_access_token( + user_id=token.user_id, + expires_in=config.generic_api_token_lifetime, + ).access_token + + # Issuing workload tokens is only supported for device authenticated users + # and service accounts, because device tokens can be revoked at any time and + # service accounts can be disabled. + verify_permission( resource_type=ResourceType.PIPELINE_RUN, action=Action.CREATE ) @@ -540,21 +561,9 @@ def api_token( f"schedule {token.schedule_id}." ) - if not token.device_id and not token.api_key_id: - # If not authenticated with a device or a service account, the current - # API token is returned as is, without any modifications. Issuing - # workload tokens is only supported for device authenticated users and - # service accounts, because device tokens can be revoked at any time and - # service accounts can be disabled. - return auth_context.encoded_access_token - - # If authenticated with a device, a new API token is generated for the - # pipeline and/or schedule. - if pipeline_id: - token.pipeline_id = pipeline_id - if schedule_id: - token.schedule_id = schedule_id - expires: Optional[datetime] = None - if expires_minutes: - expires = datetime.utcnow() + timedelta(minutes=expires_minutes) - return token.encode(expires=expires) + return generate_access_token( + user_id=token.user_id, + expires_in=expires_minutes * 60 if expires_minutes else None, + pipeline_id=pipeline_id, + schedule_id=schedule_id, + ).access_token From bd372154f99689bd0952f11c761414bdf7c4564b Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 29 Oct 2024 11:49:01 +0100 Subject: [PATCH 02/12] Pass pipeline and schedule id --- src/zenml/zen_server/routers/auth_endpoints.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index e41966f2613..2e8b6e0b7de 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -240,6 +240,8 @@ def generate_access_token( user_id=user_id, device_id=device.id if device else None, api_key_id=api_key.id if api_key else None, + pipeline_id=pipeline_id, + schedule_id=schedule_id ).encode(expires=expires) if not device and response: From 538bd57ccca296c8da6a6b5e55a37aee7ac9ea65 Mon Sep 17 00:00:00 2001 From: Michael Schuster Date: Tue, 29 Oct 2024 11:50:50 +0100 Subject: [PATCH 03/12] Fix warning message --- src/zenml/orchestrators/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index ad4fd77ae5d..e9846ee5379 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -120,7 +120,7 @@ def get_config_environment_vars( "and used to run this pipeline on a schedule. This is very " "insecure because the API token cannot be revoked in case " "of potential theft without disabling the entire user " - "accountWhen deploying a pipeline on a schedule, it is " + "account. When deploying a pipeline on a schedule, it is " "strongly advised to use a service account API key to " "authenticate to the ZenML server instead of your regular " "user account. For more information, see " From 041be8884e5ce9ebb91951278ac0fa7ed1ad1197 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 11 Nov 2024 17:25:34 +0100 Subject: [PATCH 04/12] Implement proper authorization checks for workload API tokens --- src/zenml/config/server_config.py | 4 +- src/zenml/constants.py | 7 - src/zenml/enums.py | 7 + src/zenml/orchestrators/base_orchestrator.py | 13 +- src/zenml/orchestrators/step_launcher.py | 3 +- src/zenml/orchestrators/utils.py | 76 +++--- src/zenml/zen_server/auth.py | 147 ++++++++++- src/zenml/zen_server/jwt.py | 43 +++- .../zen_server/routers/auth_endpoints.py | 240 ++++++++++-------- .../zen_server/template_execution/utils.py | 25 +- src/zenml/zen_stores/rest_zen_store.py | 23 +- 11 files changed, 402 insertions(+), 186 deletions(-) diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index 8f521900d74..97e29520196 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -19,7 +19,7 @@ from typing import Any, Dict, List, Optional, Union from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator from zenml.constants import ( DEFAULT_ZENML_JWT_TOKEN_ALGORITHM, @@ -260,7 +260,7 @@ class ServerConfiguration(BaseModel): device_expiration_minutes: Optional[int] = None trusted_device_expiration_minutes: Optional[int] = None - generic_api_token_lifetime: int = ( + generic_api_token_lifetime: PositiveInt = ( DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME ) diff --git a/src/zenml/constants.py b/src/zenml/constants.py index e63ca31a4d0..3e9abe8935f 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -167,9 +167,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ENV_ZENML_ENFORCE_TYPE_ANNOTATIONS = "ZENML_ENFORCE_TYPE_ANNOTATIONS" ENV_ZENML_ENABLE_IMPLICIT_AUTH_METHODS = "ZENML_ENABLE_IMPLICIT_AUTH_METHODS" ENV_ZENML_DISABLE_STEP_LOGS_STORAGE = "ZENML_DISABLE_STEP_LOGS_STORAGE" -ENV_ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES = ( - "ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES" -) ENV_ZENML_IGNORE_FAILURE_HOOK = "ZENML_IGNORE_FAILURE_HOOK" ENV_ZENML_CUSTOM_SOURCE_ROOT = "ZENML_CUSTOM_SOURCE_ROOT" ENV_ZENML_WHEEL_PACKAGE_NAME = "ZENML_WHEEL_PACKAGE_NAME" @@ -411,10 +408,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: # orchestrator constants ORCHESTRATOR_DOCKER_IMAGE_KEY = "orchestrator" -PIPELINE_API_TOKEN_EXPIRES_MINUTES = handle_int_env_var( - ENV_ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES, - default=60 * 24, # 24 hours -) # Secret constants SECRET_VALUES = "values" diff --git a/src/zenml/enums.py b/src/zenml/enums.py index c39b39c43ea..99f55b15f4a 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -245,6 +245,13 @@ class OAuthDeviceStatus(StrEnum): LOCKED = "locked" +class APITokenType(StrEnum): + """The API token type.""" + + GENERIC = "generic" + WORKLOAD = "workload" + + class GenericFilterOps(StrEnum): """Ops for all filters for string values on list methods.""" diff --git a/src/zenml/orchestrators/base_orchestrator.py b/src/zenml/orchestrators/base_orchestrator.py index 51e501ead99..d5ab036f1d2 100644 --- a/src/zenml/orchestrators/base_orchestrator.py +++ b/src/zenml/orchestrators/base_orchestrator.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Type, cast +from uuid import UUID from pydantic import model_validator @@ -186,7 +187,17 @@ def run( """ self._prepare_run(deployment=deployment) - environment = get_config_environment_vars(deployment=deployment) + pipeline_run_id: Optional[UUID] = None + schedule_id: Optional[UUID] = None + if deployment.schedule: + schedule_id = deployment.schedule.id + if placeholder_run: + pipeline_run_id = placeholder_run.id + + environment = get_config_environment_vars( + schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + ) prevent_client_side_caching = handle_bool_env_var( ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING, default=False diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 84b0450672b..c701d90e0bb 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -422,7 +422,8 @@ def _run_step_with_step_operator( ) ) environment = orchestrator_utils.get_config_environment_vars( - deployment=self._deployment + pipeline_run_id=step_run_info.run_id, + step_run_id=step_run_info.step_run_id, ) if last_retry: environment[ENV_ZENML_IGNORE_FAILURE_HOOK] = str(False) diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index 5448e5c677f..2fb1fb78c1d 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -28,7 +28,6 @@ ENV_ZENML_DISABLE_CREDENTIALS_DISK_CACHING, ENV_ZENML_SERVER, ENV_ZENML_STORE_PREFIX, - PIPELINE_API_TOKEN_EXPIRES_MINUTES, ) from zenml.enums import AuthScheme, StackComponentType, StoreType from zenml.logger import get_logger @@ -39,7 +38,6 @@ if TYPE_CHECKING: from zenml.artifact_stores.base_artifact_store import BaseArtifactStore - from zenml.models import PipelineDeploymentResponse def get_orchestrator_run_name(pipeline_name: str) -> str: @@ -82,16 +80,23 @@ def is_setting_enabled( def get_config_environment_vars( - deployment: Optional["PipelineDeploymentResponse"] = None, + schedule_id: Optional[UUID] = None, + pipeline_run_id: Optional[UUID] = None, + step_run_id: Optional[UUID] = None, ) -> Dict[str, str]: """Gets environment variables to set for mirroring the active config. - If a pipeline deployment is given, the environment variables will be set to - include a newly generated API token valid for the duration of the pipeline - run instead of the API token from the global config. + If a schedule ID, pipeline run ID or step run ID is given, and the current + client is not authenticated to a server with an API key, the environment + variables will be updated to include a newly generated workload API token + that will be valid for the duration of the schedule, pipeline run, or step + run instead of the current API token used to authenticate the client. Args: - deployment: Optional deployment to use for the environment variables. + schedule_id: Optional schedule ID to use to generate a new API token. + pipeline_run_id: Optional pipeline run ID to use to generate a new API + token. + step_run_id: Optional step run ID to use to generate a new API token. Returns: Environment variable dict. @@ -112,42 +117,49 @@ def get_config_environment_vars( api_key = credentials_store.get_api_key(url) api_token = credentials_store.get_token(url, allow_expired=False) if api_key: + # If an API key is available, it is used to authenticate the + # pipeline run environment. environment_vars[ENV_ZENML_STORE_PREFIX + "API_KEY"] = api_key - elif deployment: - # When connected to an authenticated ZenML server, if a pipeline - # deployment is supplied, we need to fetch an API token that will be - # valid for the duration of the pipeline run. + elif schedule_id or pipeline_run_id or step_run_id: + # When connected to an authenticated ZenML server, if a schedule ID, + # pipeline run ID or step run ID is supplied, we need to fetch a new + # workload API token scoped to the schedule, pipeline run or step + # run. assert isinstance(global_config.zen_store, RestZenStore) - pipeline_id: Optional[UUID] = None - if deployment.pipeline: - pipeline_id = deployment.pipeline.id - schedule_id: Optional[UUID] = None - expires_minutes: Optional[int] = PIPELINE_API_TOKEN_EXPIRES_MINUTES - if deployment.schedule: - schedule_id = deployment.schedule.id - # If a schedule is given, this is a long running pipeline that - # should not have an API token that expires. - expires_minutes = None + + if pipeline_run_id or step_run_id: + # If a pipeline run or step run is given, the pipeline run + # or step run credentials are scoped to the pipeline run or step + # run and will only be valid for the duration of the run/step. + new_api_token = global_config.zen_store.get_api_token( + schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, + ) + else: + # If a schedule is given, the pipeline run credentials is + # configured with a token that is scoped to the given schedule. logger.warning( "An API token without an expiration time will be generated " "and used to run this pipeline on a schedule. This is very " - "insecure because the API token cannot be revoked in case " - "of potential theft without disabling the entire user " - "account. When deploying a pipeline on a schedule, it is " - "strongly advised to use a service account API key to " - "authenticate to the ZenML server instead of your regular " - "user account. For more information, see " + "insecure because the API token will be valid for the " + "entire lifetime of the schedule and can be used to access " + "your account if leaked. When deploying a pipeline on a " + "schedule, it is strongly advised to use a service account " + "API key to authenticate to the ZenML server instead of " + "your regular user account. For more information, see " "https://docs.zenml.io/how-to/connecting-to-zenml/connect-with-a-service-account" ) - new_api_token = global_config.zen_store.get_api_token( - pipeline_id=pipeline_id, - schedule_id=schedule_id, - expires_minutes=expires_minutes, - ) + new_api_token = global_config.zen_store.get_api_token( + schedule_id=schedule_id, + ) + environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = ( new_api_token ) elif api_token: + # For all other cases, the pipeline run environment is configured + # with the current access token. environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = ( api_token.access_token ) diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 4f600f1ca6f..06c2fe7f620 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -14,13 +14,13 @@ """Authentication module for ZenML server.""" from contextvars import ContextVar -from datetime import datetime +from datetime import datetime, timedelta from typing import Callable, Optional, Union from urllib.parse import urlencode from uuid import UUID import requests -from fastapi import Depends +from fastapi import Depends, Response from fastapi.security import ( HTTPBasic, HTTPBasicCredentials, @@ -37,7 +37,7 @@ LOGIN, VERSION_1, ) -from zenml.enums import AuthScheme, OAuthDeviceStatus +from zenml.enums import AuthScheme, ExecutionStatus, OAuthDeviceStatus from zenml.exceptions import ( AuthorizationException, CredentialsNotValid, @@ -51,6 +51,7 @@ ExternalUserModel, OAuthDeviceInternalResponse, OAuthDeviceInternalUpdate, + OAuthTokenResponse, UserAuthModel, UserRequest, UserResponse, @@ -329,6 +330,72 @@ def authenticate_credentials( ), ) + if decoded_token.schedule_id: + # If the token contains a schedule ID, we need to check if the + # schedule still exists in the database. + try: + zen_store().get_schedule( + decoded_token.schedule_id, hydrate=False + ) + except KeyError: + error = ( + f"Authentication error: error retrieving token schedule " + f"{decoded_token.schedule_id}" + ) + logger.error(error) + raise CredentialsNotValid(error) + + if decoded_token.pipeline_run_id: + # If the token contains a pipeline run ID, we need to check if the + # pipeline run exists in the database and the pipeline run has + # not concluded. + try: + pipeline_run = zen_store().get_run( + decoded_token.pipeline_run_id, hydrate=False + ) + except KeyError: + error = ( + f"Authentication error: error retrieving token pipeline run " + f"{decoded_token.pipeline_run_id}" + ) + logger.error(error) + raise CredentialsNotValid(error) + + if pipeline_run.status in [ + ExecutionStatus.FAILED, + ExecutionStatus.COMPLETED, + ]: + error = ( + f"The execution of pipeline run " + f"{decoded_token.pipeline_run_id} has already concluded and " + "API tokens scoped to it are no longer valid." + ) + + if decoded_token.step_run_id: + # If the token contains a step run ID, we need to check if the + # step run exists in the database and the step run has not concluded. + try: + step_run = zen_store().get_run_step( + decoded_token.step_run_id, hydrate=False + ) + except KeyError: + error = ( + f"Authentication error: error retrieving token step run " + f"{decoded_token.step_run_id}" + ) + logger.error(error) + raise CredentialsNotValid(error) + + if step_run.status in [ + ExecutionStatus.FAILED, + ExecutionStatus.COMPLETED, + ]: + error = ( + f"The execution of step run " + f"{decoded_token.step_run_id} has already concluded and " + "API tokens scoped to it are no longer valid." + ) + auth_context = AuthContext( user=user_model, access_token=decoded_token, @@ -660,6 +727,80 @@ def authenticate_api_key( return AuthContext(user=user_model, api_key=internal_api_key) +def generate_access_token( + user_id: UUID, + response: Optional[Response] = None, + device: Optional[OAuthDeviceInternalResponse] = None, + api_key: Optional[APIKeyInternalResponse] = None, + expires_in: Optional[int] = None, + schedule_id: Optional[UUID] = None, + pipeline_run_id: Optional[UUID] = None, + step_run_id: Optional[UUID] = None, +) -> OAuthTokenResponse: + """Generates an access token for the given user. + + Args: + user_id: The ID of the user. + response: The FastAPI response object. + device: The device used for authentication. + api_key: The service account API key used for authentication. + expires_in: The number of seconds until the token expires. + schedule_id: The ID of the schedule to scope the token to. + pipeline_run_id: The ID of the pipeline run to scope the token to. + step_run_id: The ID of the step run to scope the token to. + + Returns: + An authentication response with an access token. + """ + config = server_config() + + # If the expiration time is not supplied, the JWT tokens are set to expire + # according to the values configured in the server config. Device tokens are + # handled separately from regular user tokens. + expires: Optional[datetime] = None + if expires_in: + expires = datetime.utcnow() + timedelta(seconds=expires_in) + elif device: + # If a device was used for authentication, the token will expire + # at the same time as the device. + expires = device.expires + if expires: + expires_in = max( + int(expires.timestamp() - datetime.utcnow().timestamp()), 0 + ) + elif config.jwt_token_expire_minutes: + expires = datetime.utcnow() + timedelta( + minutes=config.jwt_token_expire_minutes + ) + expires_in = config.jwt_token_expire_minutes * 60 + + access_token = JWTToken( + user_id=user_id, + device_id=device.id if device else None, + api_key_id=api_key.id if api_key else None, + schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, + ).encode(expires=expires) + + if not device and response: + # Also set the access token as an HTTP only cookie in the response + response.set_cookie( + key=config.get_auth_cookie_name(), + value=access_token, + httponly=True, + samesite="lax", + max_age=config.jwt_token_expire_minutes * 60 + if config.jwt_token_expire_minutes + else None, + domain=config.auth_cookie_domain, + ) + + return OAuthTokenResponse( + access_token=access_token, expires_in=expires_in, token_type="bearer" + ) + + def http_authentication( credentials: HTTPBasicCredentials = Depends(HTTPBasic()), ) -> AuthContext: diff --git a/src/zenml/zen_server/jwt.py b/src/zenml/zen_server/jwt.py index 59d9a99d118..eafa61a784b 100644 --- a/src/zenml/zen_server/jwt.py +++ b/src/zenml/zen_server/jwt.py @@ -40,16 +40,20 @@ class JWTToken(BaseModel): device_id: The id of the authenticated device. api_key_id: The id of the authenticated API key for which this token was issued. - pipeline_id: The id of the pipeline for which the token was issued. schedule_id: The id of the schedule for which the token was issued. + pipeline_run_id: The id of the pipeline run for which the token was + issued. + step_run_id: The id of the step run for which the token was + issued. claims: The original token claims. """ user_id: UUID device_id: Optional[UUID] = None api_key_id: Optional[UUID] = None - pipeline_id: Optional[UUID] = None schedule_id: Optional[UUID] = None + pipeline_run_id: Optional[UUID] = None + step_run_id: Optional[UUID] = None claims: Dict[str, Any] = {} @classmethod @@ -122,23 +126,33 @@ def decode_token( "UUID" ) - pipeline_id: Optional[UUID] = None - if "pipeline_id" in claims: + schedule_id: Optional[UUID] = None + if "schedule_id" in claims: try: - pipeline_id = UUID(claims.pop("pipeline_id")) + schedule_id = UUID(claims.pop("schedule_id")) except ValueError: raise CredentialsNotValid( - "Invalid JWT token: the pipeline_id claim is not a valid " + "Invalid JWT token: the schedule_id claim is not a valid " "UUID" ) - schedule_id: Optional[UUID] = None - if "schedule_id" in claims: + pipeline_run_id: Optional[UUID] = None + if "pipeline_run_id" in claims: try: - schedule_id = UUID(claims.pop("schedule_id")) + pipeline_run_id = UUID(claims.pop("pipeline_run_id")) except ValueError: raise CredentialsNotValid( - "Invalid JWT token: the schedule_id claim is not a valid " + "Invalid JWT token: the pipeline_run_id claim is not a valid " + "UUID" + ) + + step_run_id: Optional[UUID] = None + if "step_run_id" in claims: + try: + step_run_id = UUID(claims.pop("step_run_id")) + except ValueError: + raise CredentialsNotValid( + "Invalid JWT token: the step_run_id claim is not a valid " "UUID" ) @@ -146,8 +160,9 @@ def decode_token( user_id=user_id, device_id=device_id, api_key_id=api_key_id, - pipeline_id=pipeline_id, schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, claims=claims, ) @@ -180,10 +195,12 @@ def encode(self, expires: Optional[datetime] = None) -> str: claims["device_id"] = str(self.device_id) if self.api_key_id: claims["api_key_id"] = str(self.api_key_id) - if self.pipeline_id: - claims["pipeline_id"] = str(self.pipeline_id) if self.schedule_id: claims["schedule_id"] = str(self.schedule_id) + if self.pipeline_run_id: + claims["pipeline_run_id"] = str(self.pipeline_run_id) + if self.step_run_id: + claims["step_run_id"] = str(self.step_run_id) return jwt.encode( claims, diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index cd50bd20aa0..f802121235a 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for authentication (login).""" -from datetime import datetime, timedelta from typing import Optional, Union from urllib.parse import urlencode from uuid import UUID @@ -40,17 +39,17 @@ VERSION_1, ) from zenml.enums import ( + APITokenType, AuthScheme, + ExecutionStatus, OAuthDeviceStatus, OAuthGrantTypes, ) from zenml.exceptions import AuthorizationException from zenml.logger import get_logger from zenml.models import ( - APIKeyInternalResponse, OAuthDeviceAuthorizationResponse, OAuthDeviceInternalRequest, - OAuthDeviceInternalResponse, OAuthDeviceInternalUpdate, OAuthDeviceUserAgentHeader, OAuthRedirectResponse, @@ -63,9 +62,9 @@ authenticate_device, authenticate_external_user, authorize, + generate_access_token, ) from zenml.zen_server.exceptions import error_response -from zenml.zen_server.jwt import JWTToken from zenml.zen_server.rate_limit import rate_limit_requests from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import verify_permission @@ -211,77 +210,6 @@ def __init__( ) -def generate_access_token( - user_id: UUID, - response: Optional[Response] = None, - device: Optional[OAuthDeviceInternalResponse] = None, - api_key: Optional[APIKeyInternalResponse] = None, - expires_in: Optional[int] = None, - pipeline_id: Optional[UUID] = None, - schedule_id: Optional[UUID] = None, -) -> OAuthTokenResponse: - """Generates an access token for the given user. - - Args: - user_id: The ID of the user. - response: The FastAPI response object. - device: The device used for authentication. - api_key: The service account API key used for authentication. - expires_in: The number of seconds until the token expires. - pipeline_id: The ID of the pipeline to scope the token to. - schedule_id: The ID of the schedule to scope the token to. - - Returns: - An authentication response with an access token. - """ - config = server_config() - - # If the expiration time is not supplied, the JWT tokens are set to expire - # according to the values configured in the server config. Device tokens are - # handled separately from regular user tokens. - expires: Optional[datetime] = None - if expires_in: - expires = datetime.utcnow() + timedelta(seconds=expires_in) - elif device: - # If a device was used for authentication, the token will expire - # at the same time as the device. - expires = device.expires - if expires: - expires_in = max( - int(expires.timestamp() - datetime.utcnow().timestamp()), 0 - ) - elif config.jwt_token_expire_minutes: - expires = datetime.utcnow() + timedelta( - minutes=config.jwt_token_expire_minutes - ) - expires_in = config.jwt_token_expire_minutes * 60 - - access_token = JWTToken( - user_id=user_id, - device_id=device.id if device else None, - api_key_id=api_key.id if api_key else None, - pipeline_id=pipeline_id, - schedule_id=schedule_id - ).encode(expires=expires) - - if not device and response: - # Also set the access token as an HTTP only cookie in the response - response.set_cookie( - key=config.get_auth_cookie_name(), - value=access_token, - httponly=True, - samesite="lax", - max_age=config.jwt_token_expire_minutes * 60 - if config.jwt_token_expire_minutes - else None, - domain=config.auth_cookie_domain, - ) - - return OAuthTokenResponse( - access_token=access_token, expires_in=expires_in, token_type="bearer" - ) - - @router.post( LOGIN, response_model=Union[OAuthTokenResponse, OAuthRedirectResponse], @@ -523,59 +451,98 @@ def device_authorization( ) @handle_exceptions def api_token( - pipeline_id: Optional[UUID] = None, + token_type: APITokenType = APITokenType.GENERIC, schedule_id: Optional[UUID] = None, - expires_minutes: Optional[int] = None, + pipeline_run_id: Optional[UUID] = None, + step_run_id: Optional[UUID] = None, auth_context: AuthContext = Security(authorize), ) -> str: - """Get a workload API token for the current user. + """Generate an API token for the current user. + + Use this endpoint to generate an API token for the current user. Two types + of API tokens are supported: + + * Generic API token: This token is short-lived and can be used for + generic automation tasks. + * Workload API token: This token is scoped to a specific pipeline run, step + run or schedule and is used by pipeline workloads to authenticate with the + server. A pipeline run ID, step run ID or schedule ID must be provided and + the generated token will only be valid for the indicated pipeline run, step + run or schedule. No time limit is imposed on the validity of the token. + + Generic API tokens can only be generated by users authenticated via the + dashboard. + + Workload API tokens can only be generated for clients that are authenticated + with an authorized device or a service account and API key and the client + must be authorized to create pipeline runs. A workload API token can be + used to authenticate and generate another workload API token, but only for + the same schedule, pipeline run ID or step run ID, in that order. Args: - pipeline_id: The ID of the pipeline to get the API token for. - schedule_id: The ID of the schedule to get the API token for. - expires_minutes: The number of minutes for which the API token should - be valid. If not provided, the API token will be valid indefinitely. + token_type: The type of API token to generate. + schedule_id: The ID of the schedule to scope the workload API token to. + pipeline_run_id: The ID of the pipeline run to scope the workload API + token to. + step_run_id: The ID of the step run to scope the workload API token to. auth_context: The authentication context. Returns: The API token. Raises: - HTTPException: If the user is not authenticated. - AuthorizationException: If trying to scope the API token to a different - pipeline/schedule than the token used to authorize this request. + AuthorizationException: If not authorized to generate the API token. + ValueError: If the request is invalid. """ token = auth_context.access_token if not token or not auth_context.encoded_access_token: # Should not happen - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated.", - ) + raise AuthorizationException("Not authenticated.") + + if token_type == APITokenType.GENERIC: + if schedule_id or pipeline_run_id or step_run_id: + raise ValueError( + "Generic API tokens cannot be scoped to a schedule, pipeline " + "run or step run." + ) + if token.device_id or token.api_key_id: + raise AuthorizationException( + "Generic API tokens can only be generated for users " + "authenticated via the ZenML dashboard." + ) - if not token.device_id and not token.api_key_id: config = server_config() - # If not authenticated with a device or a service account, then a - # short-lived generic API token is returned. return generate_access_token( user_id=token.user_id, expires_in=config.generic_api_token_lifetime, ).access_token - # Issuing workload tokens is only supported for device authenticated users - # and service accounts, because device tokens can be revoked at any time and - # service accounts can be disabled. + if ( + not token.device_id + and not token.api_key_id + and not token.schedule_id + and not token.pipeline_run_id + and not token.step_run_id + ): + raise AuthorizationException( + "Workload API tokens can only be generated by clients " + "authenticated with an authorized device, a service account API " + "key or another workload API token." + ) verify_permission( resource_type=ResourceType.PIPELINE_RUN, action=Action.CREATE ) - if pipeline_id and token.pipeline_id and pipeline_id != token.pipeline_id: - raise AuthorizationException( - f"Unable to scope API token to pipeline {pipeline_id}. The " - f"token used to authorize this request is already scoped to " - f"pipeline {token.pipeline_id}." + schedule_id = schedule_id or token.schedule_id + pipeline_run_id = pipeline_run_id or token.pipeline_run_id + step_run_id = step_run_id or token.step_run_id + + if not pipeline_run_id and not schedule_id and not step_run_id: + raise ValueError( + "Workload API tokens must be scoped to a schedule, pipeline run " + "or step run." ) if schedule_id and token.schedule_id and schedule_id != token.schedule_id: @@ -585,9 +552,78 @@ def api_token( f"schedule {token.schedule_id}." ) + if ( + pipeline_run_id + and token.pipeline_run_id + and pipeline_run_id != token.pipeline_run_id + ): + raise AuthorizationException( + f"Unable to scope API token to pipeline run {pipeline_run_id}. The " + f"token used to authorize this request is already scoped to " + f"pipeline run {token.pipeline_run_id}." + ) + + if step_run_id and token.step_run_id and step_run_id != token.step_run_id: + raise AuthorizationException( + f"Unable to scope API token to step run {step_run_id}. The " + f"token used to authorize this request is already scoped to " + f"step run {token.step_run_id}." + ) + + if schedule_id: + # The schedule must exist + try: + zen_store().get_schedule(schedule_id, hydrate=False) + except KeyError: + raise ValueError( + f"Schedule {schedule_id} does not exist and API tokens cannot " + "be generated for non-existent schedules for security reasons." + ) + + if pipeline_run_id: + # The pipeline run must exist and the run must not be concluded + try: + pipeline_run = zen_store().get_run(pipeline_run_id, hydrate=False) + except KeyError: + raise ValueError( + f"Pipeline run {pipeline_run_id} does not exist and API tokens " + "cannot be generated for non-existent pipeline runs for " + "security reasons." + ) + + if pipeline_run.status in [ + ExecutionStatus.FAILED, + ExecutionStatus.COMPLETED, + ]: + raise ValueError( + f"The execution of pipeline run {pipeline_run_id} has already " + "concluded and API tokens can no longer be generated for it " + "for security reasons." + ) + + if step_run_id: + # The step run must exist and the step must not be concluded + try: + step_run = zen_store().get_run_step(step_run_id, hydrate=False) + except KeyError: + raise ValueError( + f"Step run {step_run_id} does not exist and API tokens cannot " + "be generated for non-existent step runs for security reasons." + ) + + if step_run.status in [ + ExecutionStatus.FAILED, + ExecutionStatus.COMPLETED, + ]: + raise ValueError( + f"The execution of step run {step_run_id} has already " + "concluded and API tokens can no longer be generated for it " + "for security reasons." + ) + return generate_access_token( user_id=token.user_id, - expires_in=expires_minutes * 60 if expires_minutes else None, - pipeline_id=pipeline_id, schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, ).access_token diff --git a/src/zenml/zen_server/template_execution/utils.py b/src/zenml/zen_server/template_execution/utils.py index 34831d89029..119a0a23a14 100644 --- a/src/zenml/zen_server/template_execution/utils.py +++ b/src/zenml/zen_server/template_execution/utils.py @@ -42,7 +42,7 @@ ) from zenml.stack.flavor import Flavor from zenml.utils import dict_utils, requirements_utils, settings_utils -from zenml.zen_server.auth import AuthContext +from zenml.zen_server.auth import AuthContext, generate_access_token from zenml.zen_server.template_execution.runner_entrypoint_configuration import ( RunnerEntrypointConfiguration, ) @@ -111,17 +111,6 @@ def run_template( new_deployment = zen_store().create_deployment(deployment_request) - if auth_context.access_token: - token = auth_context.access_token - token.pipeline_id = deployment_request.pipeline - - # We create a non-expiring token to make sure its active for the entire - # duration of the pipeline run - api_token = token.encode(expires=None) - else: - assert auth_context.encoded_access_token - api_token = auth_context.encoded_access_token - server_url = server_config().server_url if not server_url: raise RuntimeError( @@ -130,6 +119,15 @@ def run_template( assert build.zenml_version zenml_version = build.zenml_version + placeholder_run = create_placeholder_run(deployment=new_deployment) + assert placeholder_run + + # We create an API token scoped to the pipeline run + api_token = generate_access_token( + user_id=auth_context.user.id, + pipeline_run_id=placeholder_run.id, + ).access_token + environment = { ENV_ZENML_ACTIVE_WORKSPACE_ID: str(new_deployment.workspace.id), ENV_ZENML_ACTIVE_STACK_ID: str(stack.id), @@ -145,9 +143,6 @@ def run_template( deployment_id=new_deployment.id ) - placeholder_run = create_placeholder_run(deployment=new_deployment) - assert placeholder_run - def _task() -> None: pypi_requirements, apt_packages = ( requirements_utils.get_requirements_for_stack(stack=stack) diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index b8288343256..d3231d65eb0 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -110,6 +110,7 @@ WORKSPACES, ) from zenml.enums import ( + APITokenType, OAuthGrantTypes, StackDeploymentProvider, StoreType, @@ -3872,17 +3873,16 @@ def delete_authorized_device(self, device_id: UUID) -> None: def get_api_token( self, - pipeline_id: Optional[UUID] = None, schedule_id: Optional[UUID] = None, - expires_minutes: Optional[int] = None, + pipeline_run_id: Optional[UUID] = None, + step_run_id: Optional[UUID] = None, ) -> str: """Get an API token for a workload. Args: - pipeline_id: The ID of the pipeline to get a token for. schedule_id: The ID of the schedule to get a token for. - expires_minutes: The number of minutes for which the token should - be valid. If not provided, the token will be valid indefinitely. + pipeline_run_id: The ID of the pipeline run to get a token for. + step_run_id: The ID of the step run to get a token for. Returns: The API token. @@ -3890,13 +3890,16 @@ def get_api_token( Raises: ValueError: if the server response is not valid. """ - params: Dict[str, Any] = {} - if pipeline_id: - params["pipeline_id"] = pipeline_id + params: Dict[str, Any] = { + # Python clients may only request workload tokens. + "token_type": APITokenType.WORKLOAD.value, + } if schedule_id: params["schedule_id"] = schedule_id - if expires_minutes: - params["expires_minutes"] = expires_minutes + if pipeline_run_id: + params["pipeline_run_id"] = pipeline_run_id + if step_run_id: + params["step_run_id"] = step_run_id response_body = self.get(API_TOKEN, params=params) if not isinstance(response_body, str): raise ValueError( From 0c2fffed77b2a89f0a8a727f6f17b9d26b331680 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 11 Nov 2024 18:22:02 +0100 Subject: [PATCH 05/12] Verify if schedule is active for schedule scoped tokens --- src/zenml/zen_server/auth.py | 10 +++++++++- src/zenml/zen_server/routers/auth_endpoints.py | 8 +++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 06c2fe7f620..df6604dd89d 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -334,7 +334,7 @@ def authenticate_credentials( # If the token contains a schedule ID, we need to check if the # schedule still exists in the database. try: - zen_store().get_schedule( + schedule = zen_store().get_schedule( decoded_token.schedule_id, hydrate=False ) except KeyError: @@ -345,6 +345,14 @@ def authenticate_credentials( logger.error(error) raise CredentialsNotValid(error) + if not schedule.active: + error = ( + f"Authentication error: schedule {decoded_token.schedule_id} " + "is not active" + ) + logger.error(error) + raise CredentialsNotValid(error) + if decoded_token.pipeline_run_id: # If the token contains a pipeline run ID, we need to check if the # pipeline run exists in the database and the pipeline run has diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index f802121235a..c4204c4cb71 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -573,13 +573,19 @@ def api_token( if schedule_id: # The schedule must exist try: - zen_store().get_schedule(schedule_id, hydrate=False) + schedule = zen_store().get_schedule(schedule_id, hydrate=False) except KeyError: raise ValueError( f"Schedule {schedule_id} does not exist and API tokens cannot " "be generated for non-existent schedules for security reasons." ) + if not schedule.active: + raise ValueError( + f"Schedule {schedule_id} is not active and API tokens cannot " + "be generated for inactive schedules for security reasons." + ) + if pipeline_run_id: # The pipeline run must exist and the run must not be concluded try: From 215baa70df6367622cd1fe8a566efb9d131f58a9 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 11 Nov 2024 20:23:03 +0100 Subject: [PATCH 06/12] Remove all restrictions concerning token types --- .../zen_server/routers/auth_endpoints.py | 30 ++----------------- 1 file changed, 3 insertions(+), 27 deletions(-) diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index c4204c4cb71..2bd78fe4ddf 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -469,15 +469,9 @@ def api_token( server. A pipeline run ID, step run ID or schedule ID must be provided and the generated token will only be valid for the indicated pipeline run, step run or schedule. No time limit is imposed on the validity of the token. - - Generic API tokens can only be generated by users authenticated via the - dashboard. - - Workload API tokens can only be generated for clients that are authenticated - with an authorized device or a service account and API key and the client - must be authorized to create pipeline runs. A workload API token can be - used to authenticate and generate another workload API token, but only for - the same schedule, pipeline run ID or step run ID, in that order. + A workload API token can be used to authenticate and generate another + workload API token, but only for the same schedule, pipeline run ID or step + run ID, in that order. Args: token_type: The type of API token to generate. @@ -505,11 +499,6 @@ def api_token( "Generic API tokens cannot be scoped to a schedule, pipeline " "run or step run." ) - if token.device_id or token.api_key_id: - raise AuthorizationException( - "Generic API tokens can only be generated for users " - "authenticated via the ZenML dashboard." - ) config = server_config() @@ -518,19 +507,6 @@ def api_token( expires_in=config.generic_api_token_lifetime, ).access_token - if ( - not token.device_id - and not token.api_key_id - and not token.schedule_id - and not token.pipeline_run_id - and not token.step_run_id - ): - raise AuthorizationException( - "Workload API tokens can only be generated by clients " - "authenticated with an authorized device, a service account API " - "key or another workload API token." - ) - verify_permission( resource_type=ResourceType.PIPELINE_RUN, action=Action.CREATE ) From c81946149f0b1ca13244bae88af7bee1a45902e6 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 11 Nov 2024 21:09:45 +0100 Subject: [PATCH 07/12] Transfer the api key and authorized device scopes to generated workload API tokens --- src/zenml/zen_server/routers/auth_endpoints.py | 3 +++ src/zenml/zen_server/template_execution/utils.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index 2bd78fe4ddf..34ba9ab1de7 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -605,6 +605,9 @@ def api_token( return generate_access_token( user_id=token.user_id, + # Keep the original API key and device token scopes + api_key=auth_context.api_key, + device=auth_context.device, schedule_id=schedule_id, pipeline_run_id=pipeline_run_id, step_run_id=step_run_id, diff --git a/src/zenml/zen_server/template_execution/utils.py b/src/zenml/zen_server/template_execution/utils.py index 119a0a23a14..56391d62334 100644 --- a/src/zenml/zen_server/template_execution/utils.py +++ b/src/zenml/zen_server/template_execution/utils.py @@ -126,6 +126,9 @@ def run_template( api_token = generate_access_token( user_id=auth_context.user.id, pipeline_run_id=placeholder_run.id, + # Keep the original API key or device scopes, if any + api_key=auth_context.api_key, + device=auth_context.device, ).access_token environment = { From c79436b0df76f8c7efd958a5ca226e2f7836ad2d Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Mon, 11 Nov 2024 22:04:32 +0100 Subject: [PATCH 08/12] Actually throw the errors for concluded pipeline runs/steps --- src/zenml/zen_server/auth.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index df6604dd89d..478974b22c2 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -378,6 +378,8 @@ def authenticate_credentials( f"{decoded_token.pipeline_run_id} has already concluded and " "API tokens scoped to it are no longer valid." ) + logger.error(error) + raise CredentialsNotValid(error) if decoded_token.step_run_id: # If the token contains a step run ID, we need to check if the @@ -403,6 +405,8 @@ def authenticate_credentials( f"{decoded_token.step_run_id} has already concluded and " "API tokens scoped to it are no longer valid." ) + logger.error(error) + raise CredentialsNotValid(error) auth_context = AuthContext( user=user_model, From b123bdc92a48283e3eaad7987dd7b3b58b7a2808 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Tue, 12 Nov 2024 11:24:02 +0100 Subject: [PATCH 09/12] Don't pass the API key in the pipeline environment anymore --- src/zenml/orchestrators/utils.py | 44 +++++++++++++------------------ tests/unit/zen_server/test_jwt.py | 9 ++++--- 2 files changed, 25 insertions(+), 28 deletions(-) diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index 2fb1fb78c1d..b1bbef60fa4 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -114,45 +114,39 @@ def get_config_environment_vars( ): credentials_store = get_credentials_store() url = global_config.store_configuration.url - api_key = credentials_store.get_api_key(url) api_token = credentials_store.get_token(url, allow_expired=False) - if api_key: - # If an API key is available, it is used to authenticate the - # pipeline run environment. - environment_vars[ENV_ZENML_STORE_PREFIX + "API_KEY"] = api_key - elif schedule_id or pipeline_run_id or step_run_id: + if schedule_id or pipeline_run_id or step_run_id: # When connected to an authenticated ZenML server, if a schedule ID, # pipeline run ID or step run ID is supplied, we need to fetch a new # workload API token scoped to the schedule, pipeline run or step # run. assert isinstance(global_config.zen_store, RestZenStore) - if pipeline_run_id or step_run_id: - # If a pipeline run or step run is given, the pipeline run - # or step run credentials are scoped to the pipeline run or step - # run and will only be valid for the duration of the run/step. - new_api_token = global_config.zen_store.get_api_token( - schedule_id=schedule_id, - pipeline_run_id=pipeline_run_id, - step_run_id=step_run_id, - ) - else: - # If a schedule is given, the pipeline run credentials is - # configured with a token that is scoped to the given schedule. + # If only a schedule is given, the pipeline run credentials will + # be valid for the entire duration of the schedule. + api_key = credentials_store.get_api_key(url) + if not api_key and not pipeline_run_id and not step_run_id: logger.warning( "An API token without an expiration time will be generated " "and used to run this pipeline on a schedule. This is very " "insecure because the API token will be valid for the " "entire lifetime of the schedule and can be used to access " - "your account if leaked. When deploying a pipeline on a " - "schedule, it is strongly advised to use a service account " - "API key to authenticate to the ZenML server instead of " - "your regular user account. For more information, see " + "your user account if accidentally leaked. When deploying " + "a pipeline on a schedule, it is strongly advised to use a " + "service account API key to authenticate to the ZenML " + "server instead of your regular user account. For more " + "information, see " "https://docs.zenml.io/how-to/connecting-to-zenml/connect-with-a-service-account" ) - new_api_token = global_config.zen_store.get_api_token( - schedule_id=schedule_id, - ) + + # The schedule, pipeline run or step run credentials are scoped to + # the schedule, pipeline run or step run and will only be valid for + # the duration of the schedule/pipeline run/step run. + new_api_token = global_config.zen_store.get_api_token( + schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, + ) environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = ( new_api_token diff --git a/tests/unit/zen_server/test_jwt.py b/tests/unit/zen_server/test_jwt.py index 2a67bba85de..480569f2925 100644 --- a/tests/unit/zen_server/test_jwt.py +++ b/tests/unit/zen_server/test_jwt.py @@ -35,8 +35,9 @@ def test_encode_decode_works(): user_id = uuid.uuid4() device_id = uuid.uuid4() api_key_id = uuid.uuid4() - pipeline_id = uuid.uuid4() schedule_id = uuid.uuid4() + pipeline_run_id = uuid.uuid4() + step_run_id = uuid.uuid4() claims = { "foo": "bar", "baz": "qux", @@ -46,8 +47,9 @@ def test_encode_decode_works(): user_id=user_id, device_id=device_id, api_key_id=api_key_id, - pipeline_id=pipeline_id, schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, claims=claims, ) @@ -57,8 +59,9 @@ def test_encode_decode_works(): assert decoded_token.user_id == user_id assert decoded_token.device_id == device_id assert decoded_token.api_key_id == api_key_id - assert decoded_token.pipeline_id == pipeline_id assert decoded_token.schedule_id == schedule_id + assert decoded_token.pipeline_run_id == pipeline_run_id + assert decoded_token.step_run_id == step_run_id # Check that the configured custom claims are included in the decoded claims assert decoded_token.claims["foo"] == "bar" assert decoded_token.claims["baz"] == "qux" From d91e326ef372404535e9c93fe8e9cf8f26f2cef3 Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Tue, 12 Nov 2024 17:03:16 +0100 Subject: [PATCH 10/12] Add in-memory cache to store token validation related objects --- src/zenml/config/server_config.py | 9 ++ src/zenml/zen_server/auth.py | 105 ++++++++++--- src/zenml/zen_server/cache.py | 195 +++++++++++++++++++++++++ src/zenml/zen_server/utils.py | 24 +++ src/zenml/zen_server/zen_server_api.py | 5 +- 5 files changed, 317 insertions(+), 21 deletions(-) create mode 100644 src/zenml/zen_server/cache.py diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index 97e29520196..b3c3bd046e4 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -233,6 +233,12 @@ class ServerConfiguration(BaseModel): deployment. max_request_body_size_in_bytes: The maximum size of the request body in bytes. If not specified, the default value of 256 Kb will be used. + memcache_max_capacity: The maximum number of entries that the memory + cache can hold. If not specified, the default value of 1000 will be + used. + memcache_default_expiry: The default expiry time in seconds for cache + entries. If not specified, the default value of 30 seconds will be + used. """ deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER @@ -328,6 +334,9 @@ class ServerConfiguration(BaseModel): DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES ) + memcache_max_capacity: int = 1000 + memcache_default_expiry: int = 30 + _deployment_id: Optional[UUID] = None @model_validator(mode="before") diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 478974b22c2..fd9adf4647b 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -57,6 +57,7 @@ UserResponse, UserUpdate, ) +from zenml.zen_server.cache import cache_result from zenml.zen_server.exceptions import http_exception_from_error from zenml.zen_server.jwt import JWTToken from zenml.zen_server.utils import server_config, zen_store @@ -332,12 +333,32 @@ def authenticate_credentials( if decoded_token.schedule_id: # If the token contains a schedule ID, we need to check if the - # schedule still exists in the database. - try: - schedule = zen_store().get_schedule( - decoded_token.schedule_id, hydrate=False - ) - except KeyError: + # schedule still exists in the database. We use a cached version + # of the schedule active status to avoid unnecessary database + # queries. + + @cache_result(expiry=30) + def get_schedule_active(schedule_id: UUID) -> Optional[bool]: + """Get the active status of a schedule. + + Args: + schedule_id: The schedule ID. + + Returns: + The schedule active status or None if the schedule does not + exist. + """ + try: + schedule = zen_store().get_schedule( + schedule_id, hydrate=False + ) + except KeyError: + return False + + return schedule.active + + schedule_active = get_schedule_active(decoded_token.schedule_id) + if schedule_active is None: error = ( f"Authentication error: error retrieving token schedule " f"{decoded_token.schedule_id}" @@ -345,7 +366,7 @@ def authenticate_credentials( logger.error(error) raise CredentialsNotValid(error) - if not schedule.active: + if not schedule_active: error = ( f"Authentication error: schedule {decoded_token.schedule_id} " "is not active" @@ -356,12 +377,35 @@ def authenticate_credentials( if decoded_token.pipeline_run_id: # If the token contains a pipeline run ID, we need to check if the # pipeline run exists in the database and the pipeline run has - # not concluded. - try: - pipeline_run = zen_store().get_run( - decoded_token.pipeline_run_id, hydrate=False - ) - except KeyError: + # not concluded. We use a cached version of the pipeline run status + # to avoid unnecessary database queries. + + @cache_result(expiry=30) + def get_pipeline_run_status( + pipeline_run_id: UUID, + ) -> Optional[ExecutionStatus]: + """Get the status of a pipeline run. + + Args: + pipeline_run_id: The pipeline run ID. + + Returns: + The pipeline run status or None if the pipeline run does not + exist. + """ + try: + pipeline_run = zen_store().get_run( + pipeline_run_id, hydrate=False + ) + except KeyError: + return None + + return pipeline_run.status + + pipeline_run_status = get_pipeline_run_status( + decoded_token.pipeline_run_id + ) + if pipeline_run_status is None: error = ( f"Authentication error: error retrieving token pipeline run " f"{decoded_token.pipeline_run_id}" @@ -369,7 +413,7 @@ def authenticate_credentials( logger.error(error) raise CredentialsNotValid(error) - if pipeline_run.status in [ + if pipeline_run_status in [ ExecutionStatus.FAILED, ExecutionStatus.COMPLETED, ]: @@ -384,11 +428,32 @@ def authenticate_credentials( if decoded_token.step_run_id: # If the token contains a step run ID, we need to check if the # step run exists in the database and the step run has not concluded. - try: - step_run = zen_store().get_run_step( - decoded_token.step_run_id, hydrate=False - ) - except KeyError: + # We use a cached version of the step run status to avoid unnecessary + # database queries. + + @cache_result(expiry=30) + def get_step_run_status( + step_run_id: UUID, + ) -> Optional[ExecutionStatus]: + """Get the status of a step run. + + Args: + step_run_id: The step run ID. + + Returns: + The step run status or None if the step run does not exist. + """ + try: + step_run = zen_store().get_run_step( + step_run_id, hydrate=False + ) + except KeyError: + return None + + return step_run.status + + step_run_status = get_step_run_status(decoded_token.step_run_id) + if step_run_status is None: error = ( f"Authentication error: error retrieving token step run " f"{decoded_token.step_run_id}" @@ -396,7 +461,7 @@ def authenticate_credentials( logger.error(error) raise CredentialsNotValid(error) - if step_run.status in [ + if step_run_status in [ ExecutionStatus.FAILED, ExecutionStatus.COMPLETED, ]: diff --git a/src/zenml/zen_server/cache.py b/src/zenml/zen_server/cache.py new file mode 100644 index 00000000000..30bbd4cb0d2 --- /dev/null +++ b/src/zenml/zen_server/cache.py @@ -0,0 +1,195 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Memory cache module for the ZenML server.""" + +import time +from collections import OrderedDict +from threading import Lock +from typing import Any, Callable, Dict, Optional +from uuid import UUID + +from zenml.logger import get_logger +from zenml.utils.singleton import SingletonMetaClass + +logger = get_logger(__name__) + + +class MemoryCacheEntry: + """Simple class to hold cache entry data.""" + + def __init__(self, value: Any, expiry: int) -> None: + """Initialize a cache entry with value and expiry time. + + Args: + value: The value to store in the cache. + expiry: The expiry time in seconds. + """ + self.value: Any = value + self.expiry: int = expiry + self.timestamp: float = time.time() + + @property + def expired(self) -> bool: + """Check if the cache entry has expired.""" + return time.time() - self.timestamp >= self.expiry + + +class MemoryCache(metaclass=SingletonMetaClass): + """Simple in-memory cache with expiry and capacity management. + + This cache is thread-safe and can be used in both synchronous and + asynchronous contexts. It uses a simple LRU (Least Recently Used) eviction + strategy to manage the cache size. + + Each cache entry has a key, value, timestamp, and expiry. The cache + automatically removes expired entries and evicts the oldest entry when + the cache reaches its maximum capacity. + + + Usage Example: + + cache = MemoryCache() + uuid_key = UUID("12345678123456781234567812345678") + + cached_or_real_object = cache.get_or_cache( + uuid_key, lambda: "sync_data", expiry=120 + ) + print(cached_or_real_object) + """ + + def __init__(self, max_capacity: int, default_expiry: int) -> None: + """Initialize the cache with a maximum capacity and default expiry time. + + Args: + max_capacity: The maximum number of entries the cache can hold. + default_expiry: The default expiry time in seconds. + """ + self.cache: Dict[UUID, MemoryCacheEntry] = OrderedDict() + self.max_capacity = max_capacity + self.default_expiry = default_expiry + self._lock = Lock() + + def set(self, key: UUID, value: Any, expiry: Optional[int] = None) -> None: + """Insert value into cache with optional custom expiry time in seconds. + + Args: + key: The key to insert the value with. + value: The value to insert into the cache. + expiry: The expiry time in seconds. If None, uses the default expiry. + """ + with self._lock: + self.cache[key] = MemoryCacheEntry( + value=value, expiry=expiry or self.default_expiry + ) + self._cleanup() + + def get(self, key: UUID) -> Optional[Any]: + """Retrieve value if it's still valid; otherwise, return None. + + Args: + key: The key to retrieve the value for. + + Returns: + The value if it's still valid; otherwise, None. + """ + with self._lock: + return self._get_internal(key) + + def _get_internal(self, key: UUID) -> Optional[Any]: + """Helper to retrieve a value without lock (internal use only). + + Args: + key: The key to retrieve the value for. + + Returns: + The value if it's still valid; otherwise, None. + """ + entry = self.cache.get(key) + if entry and not entry.expired: + return entry.value + elif entry: + del self.cache[key] # Invalidate expired entry + return None + + def _cleanup(self) -> None: + """Remove expired or excess entries.""" + # Remove expired entries + keys_to_remove = [k for k, v in self.cache.items() if v.expired] + for k in keys_to_remove: + del self.cache[k] + + # Ensure we don't exceed max capacity + while len(self.cache) > self.max_capacity: + self.cache.popitem(last=False) # type: ignore[call-arg] + + +F = Callable[[UUID], Optional[Any]] + + +def cache_result( + expiry: Optional[int] = None, +) -> Callable[[F], F]: + """A decorator to cache the result of a function based on a UUID key argument. + + Args: + expiry: Custom time in seconds for the cache entry to expire. If None, + uses the default expiry time. + + Returns: + A decorator that wraps a function, caching its results based on a UUID + key. + """ + + def decorator(func: F) -> F: + """The actual decorator that wraps the function with caching logic. + + Args: + func: The function to wrap. + + Returns: + The wrapped function with caching logic. + """ + + def wrapper(key: UUID) -> Optional[Any]: + """The wrapped function with caching logic. + + Args: + key: The key to use for caching. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + The result of the original function, either from cache or + freshly computed. + """ + from zenml.zen_server.utils import memcache + + cache = memcache() + + # Attempt to retrieve the result from cache + cached_value = cache.get(key) + if cached_value is not None: + logger.debug( + f"Memory cache hit for key: {key} and func: {func.__name__}" + ) + return cached_value + + # Call the original function and cache its result + result = func(key) + cache.set(key, result, expiry) + return result + + return wrapper + + return decorator diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index 4aea7aebfe5..152018087c4 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -43,6 +43,7 @@ from zenml.exceptions import IllegalOperationError, OAuthError from zenml.logger import get_logger from zenml.plugins.plugin_flavor_registry import PluginFlavorRegistry +from zenml.zen_server.cache import MemoryCache from zenml.zen_server.deploy.deployment import ( LocalServerDeployment, ) @@ -67,6 +68,7 @@ _feature_gate: Optional[FeatureGateInterface] = None _workload_manager: Optional[WorkloadManagerInterface] = None _plugin_flavor_registry: Optional[PluginFlavorRegistry] = None +_memcache: Optional[MemoryCache] = None def zen_store() -> "SqlZenStore": @@ -222,6 +224,28 @@ def initialize_zen_store() -> None: _zen_store = zen_store_ +def initialize_memcache(max_capacity: int, default_expiry: int) -> None: + """Initialize the memory cache. + + Args: + max_capacity: The maximum capacity of the cache. + default_expiry: The default expiry time in seconds. + """ + global _memcache + _memcache = MemoryCache(max_capacity, default_expiry) + + +def memcache() -> MemoryCache: + """Return the memory cache. + + Returns: + The memory cache. + """ + if _memcache is None: + raise RuntimeError("Memory cache not initialized") + return _memcache + + _server_config: Optional[ServerConfiguration] = None diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 1d10fd2d0e8..d93a174df0e 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -95,6 +95,7 @@ ) from zenml.zen_server.utils import ( initialize_feature_gate, + initialize_memcache, initialize_plugins, initialize_rbac, initialize_workload_manager, @@ -335,9 +336,10 @@ async def infer_source_context(request: Request, call_next: Any) -> Any: @app.on_event("startup") def initialize() -> None: """Initialize the ZenML server.""" + cfg = server_config() # Set the maximum number of worker threads to_thread.current_default_thread_limiter().total_tokens = ( - server_config().thread_pool_size + cfg.thread_pool_size ) # IMPORTANT: these need to be run before the fastapi app starts, to avoid # race conditions @@ -347,6 +349,7 @@ def initialize() -> None: initialize_workload_manager() initialize_plugins() initialize_secure_headers() + initialize_memcache(cfg.memcache_max_capacity, cfg.memcache_default_expiry) DASHBOARD_REDIRECT_URL = None From 77a591cb32c4c8c7465eb1bd83cdeaca801c04cf Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Wed, 13 Nov 2024 09:59:29 +0100 Subject: [PATCH 11/12] Fix linter issues and applied code review suggestions --- src/zenml/zen_server/cache.py | 37 +++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/zenml/zen_server/cache.py b/src/zenml/zen_server/cache.py index 30bbd4cb0d2..f6463d0484e 100644 --- a/src/zenml/zen_server/cache.py +++ b/src/zenml/zen_server/cache.py @@ -16,7 +16,8 @@ import time from collections import OrderedDict from threading import Lock -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional +from typing import OrderedDict as OrderedDictType from uuid import UUID from zenml.logger import get_logger @@ -41,7 +42,11 @@ def __init__(self, value: Any, expiry: int) -> None: @property def expired(self) -> bool: - """Check if the cache entry has expired.""" + """Check if the cache entry has expired. + + Returns: + True if the cache entry has expired; otherwise, False. + """ return time.time() - self.timestamp >= self.expiry @@ -62,10 +67,20 @@ class MemoryCache(metaclass=SingletonMetaClass): cache = MemoryCache() uuid_key = UUID("12345678123456781234567812345678") - cached_or_real_object = cache.get_or_cache( - uuid_key, lambda: "sync_data", expiry=120 - ) - print(cached_or_real_object) + if not cache.get(uuid_key): + # Get the value from the database or other source + value = get_value_from_database() + cache.set(uuid_key, value, expiry=60) + + Usage Example with decorator: + + @cache_result(expiry=60) + def get_cached_value(key: UUID) -> Any: + return get_value_from_database(key) + + uuid_key = UUID("12345678123456781234567812345678") + + value = get_cached_value(uuid_key) """ def __init__(self, max_capacity: int, default_expiry: int) -> None: @@ -75,7 +90,7 @@ def __init__(self, max_capacity: int, default_expiry: int) -> None: max_capacity: The maximum number of entries the cache can hold. default_expiry: The default expiry time in seconds. """ - self.cache: Dict[UUID, MemoryCacheEntry] = OrderedDict() + self.cache: OrderedDictType[UUID, MemoryCacheEntry] = OrderedDict() self.max_capacity = max_capacity self.default_expiry = default_expiry self._lock = Lock() @@ -131,10 +146,10 @@ def _cleanup(self) -> None: # Ensure we don't exceed max capacity while len(self.cache) > self.max_capacity: - self.cache.popitem(last=False) # type: ignore[call-arg] + self.cache.popitem(last=False) -F = Callable[[UUID], Optional[Any]] +F = Callable[[UUID], Any] def cache_result( @@ -161,13 +176,11 @@ def decorator(func: F) -> F: The wrapped function with caching logic. """ - def wrapper(key: UUID) -> Optional[Any]: + def wrapper(key: UUID) -> Any: """The wrapped function with caching logic. Args: key: The key to use for caching. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. Returns: The result of the original function, either from cache or From b14d3a1cc3fb3df2075068e9f1924f984177794e Mon Sep 17 00:00:00 2001 From: Stefan Nica Date: Wed, 13 Nov 2024 10:26:20 +0100 Subject: [PATCH 12/12] Fix docstrings --- src/zenml/zen_server/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index 152018087c4..ff96c7a640c 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -240,6 +240,9 @@ def memcache() -> MemoryCache: Returns: The memory cache. + + Raises: + RuntimeError: If the memory cache is not initialized. """ if _memcache is None: raise RuntimeError("Memory cache not initialized")