diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index 399c00f6cfe..b3c3bd046e4 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -19,13 +19,14 @@ from typing import Any, Dict, List, Optional, Union from uuid import UUID -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator from zenml.constants import ( DEFAULT_ZENML_JWT_TOKEN_ALGORITHM, DEFAULT_ZENML_JWT_TOKEN_LEEWAY, DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING, DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT, + DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME, DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY, DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE, DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS, @@ -119,6 +120,8 @@ class ServerConfiguration(BaseModel): time of the JWT tokens issued to clients after they have authenticated with the ZenML server using an OAuth 2.0 device that has been marked as trusted. + generic_api_token_lifetime: The lifetime in seconds that generic + short-lived API tokens issued for automation purposes are valid. external_login_url: The login URL of an external authenticator service to use with the `EXTERNAL` authentication scheme. external_user_info_url: The user info URL of an external authenticator @@ -230,6 +233,12 @@ class ServerConfiguration(BaseModel): deployment. max_request_body_size_in_bytes: The maximum size of the request body in bytes. If not specified, the default value of 256 Kb will be used. + memcache_max_capacity: The maximum number of entries that the memory + cache can hold. If not specified, the default value of 1000 will be + used. + memcache_default_expiry: The default expiry time in seconds for cache + entries. If not specified, the default value of 30 seconds will be + used. """ deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER @@ -257,6 +266,10 @@ class ServerConfiguration(BaseModel): device_expiration_minutes: Optional[int] = None trusted_device_expiration_minutes: Optional[int] = None + generic_api_token_lifetime: PositiveInt = ( + DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME + ) + external_login_url: Optional[str] = None external_user_info_url: Optional[str] = None external_cookie_name: Optional[str] = None @@ -321,6 +334,9 @@ class ServerConfiguration(BaseModel): DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES ) + memcache_max_capacity: int = 1000 + memcache_default_expiry: int = 30 + _deployment_id: Optional[UUID] = None @model_validator(mode="before") diff --git a/src/zenml/constants.py b/src/zenml/constants.py index e8fb57a078a..3e9abe8935f 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -167,9 +167,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ENV_ZENML_ENFORCE_TYPE_ANNOTATIONS = "ZENML_ENFORCE_TYPE_ANNOTATIONS" ENV_ZENML_ENABLE_IMPLICIT_AUTH_METHODS = "ZENML_ENABLE_IMPLICIT_AUTH_METHODS" ENV_ZENML_DISABLE_STEP_LOGS_STORAGE = "ZENML_DISABLE_STEP_LOGS_STORAGE" -ENV_ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES = ( - "ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES" -) ENV_ZENML_IGNORE_FAILURE_HOOK = "ZENML_IGNORE_FAILURE_HOOK" ENV_ZENML_CUSTOM_SOURCE_ROOT = "ZENML_CUSTOM_SOURCE_ROOT" ENV_ZENML_WHEEL_PACKAGE_NAME = "ZENML_WHEEL_PACKAGE_NAME" @@ -270,6 +267,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW = 60 * 48 # 48 hours DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE = 5 DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY = 1000 +DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME = 60 * 60 # 1 hour DEFAULT_ZENML_SERVER_SECURE_HEADERS_HSTS = ( "max-age=63072000; includeSubdomains" @@ -410,10 +408,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: # orchestrator constants ORCHESTRATOR_DOCKER_IMAGE_KEY = "orchestrator" -PIPELINE_API_TOKEN_EXPIRES_MINUTES = handle_int_env_var( - ENV_ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES, - default=60 * 24, # 24 hours -) # Secret constants SECRET_VALUES = "values" diff --git a/src/zenml/enums.py b/src/zenml/enums.py index a3a1a7fec56..2e8e77f2dbb 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -245,6 +245,13 @@ class OAuthDeviceStatus(StrEnum): LOCKED = "locked" +class APITokenType(StrEnum): + """The API token type.""" + + GENERIC = "generic" + WORKLOAD = "workload" + + class GenericFilterOps(StrEnum): """Ops for all filters for string values on list methods.""" diff --git a/src/zenml/orchestrators/base_orchestrator.py b/src/zenml/orchestrators/base_orchestrator.py index 51e501ead99..d5ab036f1d2 100644 --- a/src/zenml/orchestrators/base_orchestrator.py +++ b/src/zenml/orchestrators/base_orchestrator.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Type, cast +from uuid import UUID from pydantic import model_validator @@ -186,7 +187,17 @@ def run( """ self._prepare_run(deployment=deployment) - environment = get_config_environment_vars(deployment=deployment) + pipeline_run_id: Optional[UUID] = None + schedule_id: Optional[UUID] = None + if deployment.schedule: + schedule_id = deployment.schedule.id + if placeholder_run: + pipeline_run_id = placeholder_run.id + + environment = get_config_environment_vars( + schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + ) prevent_client_side_caching = handle_bool_env_var( ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING, default=False diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index 84b0450672b..c701d90e0bb 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -422,7 +422,8 @@ def _run_step_with_step_operator( ) ) environment = orchestrator_utils.get_config_environment_vars( - deployment=self._deployment + pipeline_run_id=step_run_info.run_id, + step_run_id=step_run_info.step_run_id, ) if last_retry: environment[ENV_ZENML_IGNORE_FAILURE_HOOK] = str(False) diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index 8f028198440..b1bbef60fa4 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -28,16 +28,16 @@ ENV_ZENML_DISABLE_CREDENTIALS_DISK_CACHING, ENV_ZENML_SERVER, ENV_ZENML_STORE_PREFIX, - PIPELINE_API_TOKEN_EXPIRES_MINUTES, ) from zenml.enums import AuthScheme, StackComponentType, StoreType from zenml.logger import get_logger from zenml.stack import StackComponent from zenml.utils.string_utils import format_name_template +logger = get_logger(__name__) + if TYPE_CHECKING: from zenml.artifact_stores.base_artifact_store import BaseArtifactStore - from zenml.models import PipelineDeploymentResponse def get_orchestrator_run_name(pipeline_name: str) -> str: @@ -80,16 +80,23 @@ def is_setting_enabled( def get_config_environment_vars( - deployment: Optional["PipelineDeploymentResponse"] = None, + schedule_id: Optional[UUID] = None, + pipeline_run_id: Optional[UUID] = None, + step_run_id: Optional[UUID] = None, ) -> Dict[str, str]: """Gets environment variables to set for mirroring the active config. - If a pipeline deployment is given, the environment variables will be set to - include a newly generated API token valid for the duration of the pipeline - run instead of the API token from the global config. + If a schedule ID, pipeline run ID or step run ID is given, and the current + client is not authenticated to a server with an API key, the environment + variables will be updated to include a newly generated workload API token + that will be valid for the duration of the schedule, pipeline run, or step + run instead of the current API token used to authenticate the client. Args: - deployment: Optional deployment to use for the environment variables. + schedule_id: Optional schedule ID to use to generate a new API token. + pipeline_run_id: Optional pipeline run ID to use to generate a new API + token. + step_run_id: Optional step run ID to use to generate a new API token. Returns: Environment variable dict. @@ -107,34 +114,46 @@ def get_config_environment_vars( ): credentials_store = get_credentials_store() url = global_config.store_configuration.url - api_key = credentials_store.get_api_key(url) api_token = credentials_store.get_token(url, allow_expired=False) - if api_key: - environment_vars[ENV_ZENML_STORE_PREFIX + "API_KEY"] = api_key - elif deployment: - # When connected to an authenticated ZenML server, if a pipeline - # deployment is supplied, we need to fetch an API token that will be - # valid for the duration of the pipeline run. + if schedule_id or pipeline_run_id or step_run_id: + # When connected to an authenticated ZenML server, if a schedule ID, + # pipeline run ID or step run ID is supplied, we need to fetch a new + # workload API token scoped to the schedule, pipeline run or step + # run. assert isinstance(global_config.zen_store, RestZenStore) - pipeline_id: Optional[UUID] = None - if deployment.pipeline: - pipeline_id = deployment.pipeline.id - schedule_id: Optional[UUID] = None - expires_minutes: Optional[int] = PIPELINE_API_TOKEN_EXPIRES_MINUTES - if deployment.schedule: - schedule_id = deployment.schedule.id - # If a schedule is given, this is a long running pipeline that - # should not have an API token that expires. - expires_minutes = None + + # If only a schedule is given, the pipeline run credentials will + # be valid for the entire duration of the schedule. + api_key = credentials_store.get_api_key(url) + if not api_key and not pipeline_run_id and not step_run_id: + logger.warning( + "An API token without an expiration time will be generated " + "and used to run this pipeline on a schedule. This is very " + "insecure because the API token will be valid for the " + "entire lifetime of the schedule and can be used to access " + "your user account if accidentally leaked. When deploying " + "a pipeline on a schedule, it is strongly advised to use a " + "service account API key to authenticate to the ZenML " + "server instead of your regular user account. For more " + "information, see " + "https://docs.zenml.io/how-to/connecting-to-zenml/connect-with-a-service-account" + ) + + # The schedule, pipeline run or step run credentials are scoped to + # the schedule, pipeline run or step run and will only be valid for + # the duration of the schedule/pipeline run/step run. new_api_token = global_config.zen_store.get_api_token( - pipeline_id=pipeline_id, schedule_id=schedule_id, - expires_minutes=expires_minutes, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, ) + environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = ( new_api_token ) elif api_token: + # For all other cases, the pipeline run environment is configured + # with the current access token. environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = ( api_token.access_token ) diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index 4f600f1ca6f..fd9adf4647b 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -14,13 +14,13 @@ """Authentication module for ZenML server.""" from contextvars import ContextVar -from datetime import datetime +from datetime import datetime, timedelta from typing import Callable, Optional, Union from urllib.parse import urlencode from uuid import UUID import requests -from fastapi import Depends +from fastapi import Depends, Response from fastapi.security import ( HTTPBasic, HTTPBasicCredentials, @@ -37,7 +37,7 @@ LOGIN, VERSION_1, ) -from zenml.enums import AuthScheme, OAuthDeviceStatus +from zenml.enums import AuthScheme, ExecutionStatus, OAuthDeviceStatus from zenml.exceptions import ( AuthorizationException, CredentialsNotValid, @@ -51,11 +51,13 @@ ExternalUserModel, OAuthDeviceInternalResponse, OAuthDeviceInternalUpdate, + OAuthTokenResponse, UserAuthModel, UserRequest, UserResponse, UserUpdate, ) +from zenml.zen_server.cache import cache_result from zenml.zen_server.exceptions import http_exception_from_error from zenml.zen_server.jwt import JWTToken from zenml.zen_server.utils import server_config, zen_store @@ -329,6 +331,148 @@ def authenticate_credentials( ), ) + if decoded_token.schedule_id: + # If the token contains a schedule ID, we need to check if the + # schedule still exists in the database. We use a cached version + # of the schedule active status to avoid unnecessary database + # queries. + + @cache_result(expiry=30) + def get_schedule_active(schedule_id: UUID) -> Optional[bool]: + """Get the active status of a schedule. + + Args: + schedule_id: The schedule ID. + + Returns: + The schedule active status or None if the schedule does not + exist. + """ + try: + schedule = zen_store().get_schedule( + schedule_id, hydrate=False + ) + except KeyError: + return False + + return schedule.active + + schedule_active = get_schedule_active(decoded_token.schedule_id) + if schedule_active is None: + error = ( + f"Authentication error: error retrieving token schedule " + f"{decoded_token.schedule_id}" + ) + logger.error(error) + raise CredentialsNotValid(error) + + if not schedule_active: + error = ( + f"Authentication error: schedule {decoded_token.schedule_id} " + "is not active" + ) + logger.error(error) + raise CredentialsNotValid(error) + + if decoded_token.pipeline_run_id: + # If the token contains a pipeline run ID, we need to check if the + # pipeline run exists in the database and the pipeline run has + # not concluded. We use a cached version of the pipeline run status + # to avoid unnecessary database queries. + + @cache_result(expiry=30) + def get_pipeline_run_status( + pipeline_run_id: UUID, + ) -> Optional[ExecutionStatus]: + """Get the status of a pipeline run. + + Args: + pipeline_run_id: The pipeline run ID. + + Returns: + The pipeline run status or None if the pipeline run does not + exist. + """ + try: + pipeline_run = zen_store().get_run( + pipeline_run_id, hydrate=False + ) + except KeyError: + return None + + return pipeline_run.status + + pipeline_run_status = get_pipeline_run_status( + decoded_token.pipeline_run_id + ) + if pipeline_run_status is None: + error = ( + f"Authentication error: error retrieving token pipeline run " + f"{decoded_token.pipeline_run_id}" + ) + logger.error(error) + raise CredentialsNotValid(error) + + if pipeline_run_status in [ + ExecutionStatus.FAILED, + ExecutionStatus.COMPLETED, + ]: + error = ( + f"The execution of pipeline run " + f"{decoded_token.pipeline_run_id} has already concluded and " + "API tokens scoped to it are no longer valid." + ) + logger.error(error) + raise CredentialsNotValid(error) + + if decoded_token.step_run_id: + # If the token contains a step run ID, we need to check if the + # step run exists in the database and the step run has not concluded. + # We use a cached version of the step run status to avoid unnecessary + # database queries. + + @cache_result(expiry=30) + def get_step_run_status( + step_run_id: UUID, + ) -> Optional[ExecutionStatus]: + """Get the status of a step run. + + Args: + step_run_id: The step run ID. + + Returns: + The step run status or None if the step run does not exist. + """ + try: + step_run = zen_store().get_run_step( + step_run_id, hydrate=False + ) + except KeyError: + return None + + return step_run.status + + step_run_status = get_step_run_status(decoded_token.step_run_id) + if step_run_status is None: + error = ( + f"Authentication error: error retrieving token step run " + f"{decoded_token.step_run_id}" + ) + logger.error(error) + raise CredentialsNotValid(error) + + if step_run_status in [ + ExecutionStatus.FAILED, + ExecutionStatus.COMPLETED, + ]: + error = ( + f"The execution of step run " + f"{decoded_token.step_run_id} has already concluded and " + "API tokens scoped to it are no longer valid." + ) + logger.error(error) + raise CredentialsNotValid(error) + auth_context = AuthContext( user=user_model, access_token=decoded_token, @@ -660,6 +804,80 @@ def authenticate_api_key( return AuthContext(user=user_model, api_key=internal_api_key) +def generate_access_token( + user_id: UUID, + response: Optional[Response] = None, + device: Optional[OAuthDeviceInternalResponse] = None, + api_key: Optional[APIKeyInternalResponse] = None, + expires_in: Optional[int] = None, + schedule_id: Optional[UUID] = None, + pipeline_run_id: Optional[UUID] = None, + step_run_id: Optional[UUID] = None, +) -> OAuthTokenResponse: + """Generates an access token for the given user. + + Args: + user_id: The ID of the user. + response: The FastAPI response object. + device: The device used for authentication. + api_key: The service account API key used for authentication. + expires_in: The number of seconds until the token expires. + schedule_id: The ID of the schedule to scope the token to. + pipeline_run_id: The ID of the pipeline run to scope the token to. + step_run_id: The ID of the step run to scope the token to. + + Returns: + An authentication response with an access token. + """ + config = server_config() + + # If the expiration time is not supplied, the JWT tokens are set to expire + # according to the values configured in the server config. Device tokens are + # handled separately from regular user tokens. + expires: Optional[datetime] = None + if expires_in: + expires = datetime.utcnow() + timedelta(seconds=expires_in) + elif device: + # If a device was used for authentication, the token will expire + # at the same time as the device. + expires = device.expires + if expires: + expires_in = max( + int(expires.timestamp() - datetime.utcnow().timestamp()), 0 + ) + elif config.jwt_token_expire_minutes: + expires = datetime.utcnow() + timedelta( + minutes=config.jwt_token_expire_minutes + ) + expires_in = config.jwt_token_expire_minutes * 60 + + access_token = JWTToken( + user_id=user_id, + device_id=device.id if device else None, + api_key_id=api_key.id if api_key else None, + schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, + ).encode(expires=expires) + + if not device and response: + # Also set the access token as an HTTP only cookie in the response + response.set_cookie( + key=config.get_auth_cookie_name(), + value=access_token, + httponly=True, + samesite="lax", + max_age=config.jwt_token_expire_minutes * 60 + if config.jwt_token_expire_minutes + else None, + domain=config.auth_cookie_domain, + ) + + return OAuthTokenResponse( + access_token=access_token, expires_in=expires_in, token_type="bearer" + ) + + def http_authentication( credentials: HTTPBasicCredentials = Depends(HTTPBasic()), ) -> AuthContext: diff --git a/src/zenml/zen_server/cache.py b/src/zenml/zen_server/cache.py new file mode 100644 index 00000000000..f6463d0484e --- /dev/null +++ b/src/zenml/zen_server/cache.py @@ -0,0 +1,208 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Memory cache module for the ZenML server.""" + +import time +from collections import OrderedDict +from threading import Lock +from typing import Any, Callable, Optional +from typing import OrderedDict as OrderedDictType +from uuid import UUID + +from zenml.logger import get_logger +from zenml.utils.singleton import SingletonMetaClass + +logger = get_logger(__name__) + + +class MemoryCacheEntry: + """Simple class to hold cache entry data.""" + + def __init__(self, value: Any, expiry: int) -> None: + """Initialize a cache entry with value and expiry time. + + Args: + value: The value to store in the cache. + expiry: The expiry time in seconds. + """ + self.value: Any = value + self.expiry: int = expiry + self.timestamp: float = time.time() + + @property + def expired(self) -> bool: + """Check if the cache entry has expired. + + Returns: + True if the cache entry has expired; otherwise, False. + """ + return time.time() - self.timestamp >= self.expiry + + +class MemoryCache(metaclass=SingletonMetaClass): + """Simple in-memory cache with expiry and capacity management. + + This cache is thread-safe and can be used in both synchronous and + asynchronous contexts. It uses a simple LRU (Least Recently Used) eviction + strategy to manage the cache size. + + Each cache entry has a key, value, timestamp, and expiry. The cache + automatically removes expired entries and evicts the oldest entry when + the cache reaches its maximum capacity. + + + Usage Example: + + cache = MemoryCache() + uuid_key = UUID("12345678123456781234567812345678") + + if not cache.get(uuid_key): + # Get the value from the database or other source + value = get_value_from_database() + cache.set(uuid_key, value, expiry=60) + + Usage Example with decorator: + + @cache_result(expiry=60) + def get_cached_value(key: UUID) -> Any: + return get_value_from_database(key) + + uuid_key = UUID("12345678123456781234567812345678") + + value = get_cached_value(uuid_key) + """ + + def __init__(self, max_capacity: int, default_expiry: int) -> None: + """Initialize the cache with a maximum capacity and default expiry time. + + Args: + max_capacity: The maximum number of entries the cache can hold. + default_expiry: The default expiry time in seconds. + """ + self.cache: OrderedDictType[UUID, MemoryCacheEntry] = OrderedDict() + self.max_capacity = max_capacity + self.default_expiry = default_expiry + self._lock = Lock() + + def set(self, key: UUID, value: Any, expiry: Optional[int] = None) -> None: + """Insert value into cache with optional custom expiry time in seconds. + + Args: + key: The key to insert the value with. + value: The value to insert into the cache. + expiry: The expiry time in seconds. If None, uses the default expiry. + """ + with self._lock: + self.cache[key] = MemoryCacheEntry( + value=value, expiry=expiry or self.default_expiry + ) + self._cleanup() + + def get(self, key: UUID) -> Optional[Any]: + """Retrieve value if it's still valid; otherwise, return None. + + Args: + key: The key to retrieve the value for. + + Returns: + The value if it's still valid; otherwise, None. + """ + with self._lock: + return self._get_internal(key) + + def _get_internal(self, key: UUID) -> Optional[Any]: + """Helper to retrieve a value without lock (internal use only). + + Args: + key: The key to retrieve the value for. + + Returns: + The value if it's still valid; otherwise, None. + """ + entry = self.cache.get(key) + if entry and not entry.expired: + return entry.value + elif entry: + del self.cache[key] # Invalidate expired entry + return None + + def _cleanup(self) -> None: + """Remove expired or excess entries.""" + # Remove expired entries + keys_to_remove = [k for k, v in self.cache.items() if v.expired] + for k in keys_to_remove: + del self.cache[k] + + # Ensure we don't exceed max capacity + while len(self.cache) > self.max_capacity: + self.cache.popitem(last=False) + + +F = Callable[[UUID], Any] + + +def cache_result( + expiry: Optional[int] = None, +) -> Callable[[F], F]: + """A decorator to cache the result of a function based on a UUID key argument. + + Args: + expiry: Custom time in seconds for the cache entry to expire. If None, + uses the default expiry time. + + Returns: + A decorator that wraps a function, caching its results based on a UUID + key. + """ + + def decorator(func: F) -> F: + """The actual decorator that wraps the function with caching logic. + + Args: + func: The function to wrap. + + Returns: + The wrapped function with caching logic. + """ + + def wrapper(key: UUID) -> Any: + """The wrapped function with caching logic. + + Args: + key: The key to use for caching. + + Returns: + The result of the original function, either from cache or + freshly computed. + """ + from zenml.zen_server.utils import memcache + + cache = memcache() + + # Attempt to retrieve the result from cache + cached_value = cache.get(key) + if cached_value is not None: + logger.debug( + f"Memory cache hit for key: {key} and func: {func.__name__}" + ) + return cached_value + + # Call the original function and cache its result + result = func(key) + cache.set(key, result, expiry) + return result + + return wrapper + + return decorator diff --git a/src/zenml/zen_server/jwt.py b/src/zenml/zen_server/jwt.py index 59d9a99d118..eafa61a784b 100644 --- a/src/zenml/zen_server/jwt.py +++ b/src/zenml/zen_server/jwt.py @@ -40,16 +40,20 @@ class JWTToken(BaseModel): device_id: The id of the authenticated device. api_key_id: The id of the authenticated API key for which this token was issued. - pipeline_id: The id of the pipeline for which the token was issued. schedule_id: The id of the schedule for which the token was issued. + pipeline_run_id: The id of the pipeline run for which the token was + issued. + step_run_id: The id of the step run for which the token was + issued. claims: The original token claims. """ user_id: UUID device_id: Optional[UUID] = None api_key_id: Optional[UUID] = None - pipeline_id: Optional[UUID] = None schedule_id: Optional[UUID] = None + pipeline_run_id: Optional[UUID] = None + step_run_id: Optional[UUID] = None claims: Dict[str, Any] = {} @classmethod @@ -122,23 +126,33 @@ def decode_token( "UUID" ) - pipeline_id: Optional[UUID] = None - if "pipeline_id" in claims: + schedule_id: Optional[UUID] = None + if "schedule_id" in claims: try: - pipeline_id = UUID(claims.pop("pipeline_id")) + schedule_id = UUID(claims.pop("schedule_id")) except ValueError: raise CredentialsNotValid( - "Invalid JWT token: the pipeline_id claim is not a valid " + "Invalid JWT token: the schedule_id claim is not a valid " "UUID" ) - schedule_id: Optional[UUID] = None - if "schedule_id" in claims: + pipeline_run_id: Optional[UUID] = None + if "pipeline_run_id" in claims: try: - schedule_id = UUID(claims.pop("schedule_id")) + pipeline_run_id = UUID(claims.pop("pipeline_run_id")) except ValueError: raise CredentialsNotValid( - "Invalid JWT token: the schedule_id claim is not a valid " + "Invalid JWT token: the pipeline_run_id claim is not a valid " + "UUID" + ) + + step_run_id: Optional[UUID] = None + if "step_run_id" in claims: + try: + step_run_id = UUID(claims.pop("step_run_id")) + except ValueError: + raise CredentialsNotValid( + "Invalid JWT token: the step_run_id claim is not a valid " "UUID" ) @@ -146,8 +160,9 @@ def decode_token( user_id=user_id, device_id=device_id, api_key_id=api_key_id, - pipeline_id=pipeline_id, schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, claims=claims, ) @@ -180,10 +195,12 @@ def encode(self, expires: Optional[datetime] = None) -> str: claims["device_id"] = str(self.device_id) if self.api_key_id: claims["api_key_id"] = str(self.api_key_id) - if self.pipeline_id: - claims["pipeline_id"] = str(self.pipeline_id) if self.schedule_id: claims["schedule_id"] = str(self.schedule_id) + if self.pipeline_run_id: + claims["pipeline_run_id"] = str(self.pipeline_run_id) + if self.step_run_id: + claims["step_run_id"] = str(self.step_run_id) return jwt.encode( claims, diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index a4db4b12762..34ba9ab1de7 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for authentication (login).""" -from datetime import datetime, timedelta from typing import Optional, Union from urllib.parse import urlencode from uuid import UUID @@ -40,17 +39,17 @@ VERSION_1, ) from zenml.enums import ( + APITokenType, AuthScheme, + ExecutionStatus, OAuthDeviceStatus, OAuthGrantTypes, ) from zenml.exceptions import AuthorizationException from zenml.logger import get_logger from zenml.models import ( - APIKeyInternalResponse, OAuthDeviceAuthorizationResponse, OAuthDeviceInternalRequest, - OAuthDeviceInternalResponse, OAuthDeviceInternalUpdate, OAuthDeviceUserAgentHeader, OAuthRedirectResponse, @@ -63,9 +62,9 @@ authenticate_device, authenticate_external_user, authorize, + generate_access_token, ) from zenml.zen_server.exceptions import error_response -from zenml.zen_server.jwt import JWTToken from zenml.zen_server.rate_limit import rate_limit_requests from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import verify_permission @@ -211,68 +210,6 @@ def __init__( ) -def generate_access_token( - user_id: UUID, - response: Response, - device: Optional[OAuthDeviceInternalResponse] = None, - api_key: Optional[APIKeyInternalResponse] = None, -) -> OAuthTokenResponse: - """Generates an access token for the given user. - - Args: - user_id: The ID of the user. - response: The FastAPI response object. - device: The device used for authentication. - api_key: The service account API key used for authentication. - - Returns: - An authentication response with an access token. - """ - config = server_config() - - # The JWT tokens are set to expire according to the values configured - # in the server config. Device tokens are handled separately from regular - # user tokens. - expires: Optional[datetime] = None - expires_in: Optional[int] = None - if device: - # If a device was used for authentication, the token will expire - # at the same time as the device. - expires = device.expires - if expires: - expires_in = max( - int(expires.timestamp() - datetime.utcnow().timestamp()), 0 - ) - elif config.jwt_token_expire_minutes: - expires = datetime.utcnow() + timedelta( - minutes=config.jwt_token_expire_minutes - ) - expires_in = config.jwt_token_expire_minutes * 60 - - access_token = JWTToken( - user_id=user_id, - device_id=device.id if device else None, - api_key_id=api_key.id if api_key else None, - ).encode(expires=expires) - - if not device: - # Also set the access token as an HTTP only cookie in the response - response.set_cookie( - key=config.get_auth_cookie_name(), - value=access_token, - httponly=True, - samesite="lax", - max_age=config.jwt_token_expire_minutes * 60 - if config.jwt_token_expire_minutes - else None, - domain=config.auth_cookie_domain, - ) - - return OAuthTokenResponse( - access_token=access_token, expires_in=expires_in, token_type="bearer" - ) - - @router.post( LOGIN, response_model=Union[OAuthTokenResponse, OAuthRedirectResponse], @@ -514,45 +451,74 @@ def device_authorization( ) @handle_exceptions def api_token( - pipeline_id: Optional[UUID] = None, + token_type: APITokenType = APITokenType.GENERIC, schedule_id: Optional[UUID] = None, - expires_minutes: Optional[int] = None, + pipeline_run_id: Optional[UUID] = None, + step_run_id: Optional[UUID] = None, auth_context: AuthContext = Security(authorize), ) -> str: - """Get a workload API token for the current user. + """Generate an API token for the current user. + + Use this endpoint to generate an API token for the current user. Two types + of API tokens are supported: + + * Generic API token: This token is short-lived and can be used for + generic automation tasks. + * Workload API token: This token is scoped to a specific pipeline run, step + run or schedule and is used by pipeline workloads to authenticate with the + server. A pipeline run ID, step run ID or schedule ID must be provided and + the generated token will only be valid for the indicated pipeline run, step + run or schedule. No time limit is imposed on the validity of the token. + A workload API token can be used to authenticate and generate another + workload API token, but only for the same schedule, pipeline run ID or step + run ID, in that order. Args: - pipeline_id: The ID of the pipeline to get the API token for. - schedule_id: The ID of the schedule to get the API token for. - expires_minutes: The number of minutes for which the API token should - be valid. If not provided, the API token will be valid indefinitely. + token_type: The type of API token to generate. + schedule_id: The ID of the schedule to scope the workload API token to. + pipeline_run_id: The ID of the pipeline run to scope the workload API + token to. + step_run_id: The ID of the step run to scope the workload API token to. auth_context: The authentication context. Returns: The API token. Raises: - HTTPException: If the user is not authenticated. - AuthorizationException: If trying to scope the API token to a different - pipeline/schedule than the token used to authorize this request. + AuthorizationException: If not authorized to generate the API token. + ValueError: If the request is invalid. """ token = auth_context.access_token if not token or not auth_context.encoded_access_token: # Should not happen - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not authenticated.", - ) + raise AuthorizationException("Not authenticated.") + + if token_type == APITokenType.GENERIC: + if schedule_id or pipeline_run_id or step_run_id: + raise ValueError( + "Generic API tokens cannot be scoped to a schedule, pipeline " + "run or step run." + ) + + config = server_config() + + return generate_access_token( + user_id=token.user_id, + expires_in=config.generic_api_token_lifetime, + ).access_token verify_permission( resource_type=ResourceType.PIPELINE_RUN, action=Action.CREATE ) - if pipeline_id and token.pipeline_id and pipeline_id != token.pipeline_id: - raise AuthorizationException( - f"Unable to scope API token to pipeline {pipeline_id}. The " - f"token used to authorize this request is already scoped to " - f"pipeline {token.pipeline_id}." + schedule_id = schedule_id or token.schedule_id + pipeline_run_id = pipeline_run_id or token.pipeline_run_id + step_run_id = step_run_id or token.step_run_id + + if not pipeline_run_id and not schedule_id and not step_run_id: + raise ValueError( + "Workload API tokens must be scoped to a schedule, pipeline run " + "or step run." ) if schedule_id and token.schedule_id and schedule_id != token.schedule_id: @@ -562,21 +528,87 @@ def api_token( f"schedule {token.schedule_id}." ) - if not token.device_id and not token.api_key_id: - # If not authenticated with a device or a service account, the current - # API token is returned as is, without any modifications. Issuing - # workload tokens is only supported for device authenticated users and - # service accounts, because device tokens can be revoked at any time and - # service accounts can be disabled. - return auth_context.encoded_access_token - - # If authenticated with a device, a new API token is generated for the - # pipeline and/or schedule. - if pipeline_id: - token.pipeline_id = pipeline_id + if ( + pipeline_run_id + and token.pipeline_run_id + and pipeline_run_id != token.pipeline_run_id + ): + raise AuthorizationException( + f"Unable to scope API token to pipeline run {pipeline_run_id}. The " + f"token used to authorize this request is already scoped to " + f"pipeline run {token.pipeline_run_id}." + ) + + if step_run_id and token.step_run_id and step_run_id != token.step_run_id: + raise AuthorizationException( + f"Unable to scope API token to step run {step_run_id}. The " + f"token used to authorize this request is already scoped to " + f"step run {token.step_run_id}." + ) + if schedule_id: - token.schedule_id = schedule_id - expires: Optional[datetime] = None - if expires_minutes: - expires = datetime.utcnow() + timedelta(minutes=expires_minutes) - return token.encode(expires=expires) + # The schedule must exist + try: + schedule = zen_store().get_schedule(schedule_id, hydrate=False) + except KeyError: + raise ValueError( + f"Schedule {schedule_id} does not exist and API tokens cannot " + "be generated for non-existent schedules for security reasons." + ) + + if not schedule.active: + raise ValueError( + f"Schedule {schedule_id} is not active and API tokens cannot " + "be generated for inactive schedules for security reasons." + ) + + if pipeline_run_id: + # The pipeline run must exist and the run must not be concluded + try: + pipeline_run = zen_store().get_run(pipeline_run_id, hydrate=False) + except KeyError: + raise ValueError( + f"Pipeline run {pipeline_run_id} does not exist and API tokens " + "cannot be generated for non-existent pipeline runs for " + "security reasons." + ) + + if pipeline_run.status in [ + ExecutionStatus.FAILED, + ExecutionStatus.COMPLETED, + ]: + raise ValueError( + f"The execution of pipeline run {pipeline_run_id} has already " + "concluded and API tokens can no longer be generated for it " + "for security reasons." + ) + + if step_run_id: + # The step run must exist and the step must not be concluded + try: + step_run = zen_store().get_run_step(step_run_id, hydrate=False) + except KeyError: + raise ValueError( + f"Step run {step_run_id} does not exist and API tokens cannot " + "be generated for non-existent step runs for security reasons." + ) + + if step_run.status in [ + ExecutionStatus.FAILED, + ExecutionStatus.COMPLETED, + ]: + raise ValueError( + f"The execution of step run {step_run_id} has already " + "concluded and API tokens can no longer be generated for it " + "for security reasons." + ) + + return generate_access_token( + user_id=token.user_id, + # Keep the original API key and device token scopes + api_key=auth_context.api_key, + device=auth_context.device, + schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, + ).access_token diff --git a/src/zenml/zen_server/template_execution/utils.py b/src/zenml/zen_server/template_execution/utils.py index cdf77e14891..0bc5620cf48 100644 --- a/src/zenml/zen_server/template_execution/utils.py +++ b/src/zenml/zen_server/template_execution/utils.py @@ -42,7 +42,7 @@ ) from zenml.stack.flavor import Flavor from zenml.utils import dict_utils, requirements_utils, settings_utils -from zenml.zen_server.auth import AuthContext +from zenml.zen_server.auth import AuthContext, generate_access_token from zenml.zen_server.template_execution.runner_entrypoint_configuration import ( RunnerEntrypointConfiguration, ) @@ -111,17 +111,6 @@ def run_template( new_deployment = zen_store().create_deployment(deployment_request) - if auth_context.access_token: - token = auth_context.access_token - token.pipeline_id = deployment_request.pipeline - - # We create a non-expiring token to make sure its active for the entire - # duration of the pipeline run - api_token = token.encode(expires=None) - else: - assert auth_context.encoded_access_token - api_token = auth_context.encoded_access_token - server_url = server_config().server_url if not server_url: raise RuntimeError( @@ -130,6 +119,18 @@ def run_template( assert build.zenml_version zenml_version = build.zenml_version + placeholder_run = create_placeholder_run(deployment=new_deployment) + assert placeholder_run + + # We create an API token scoped to the pipeline run + api_token = generate_access_token( + user_id=auth_context.user.id, + pipeline_run_id=placeholder_run.id, + # Keep the original API key or device scopes, if any + api_key=auth_context.api_key, + device=auth_context.device, + ).access_token + environment = { ENV_ZENML_ACTIVE_WORKSPACE_ID: str(new_deployment.workspace.id), ENV_ZENML_ACTIVE_STACK_ID: str(stack.id), @@ -145,9 +146,6 @@ def run_template( deployment_id=new_deployment.id ) - placeholder_run = create_placeholder_run(deployment=new_deployment) - assert placeholder_run - def _task() -> None: pypi_requirements, apt_packages = ( requirements_utils.get_requirements_for_stack(stack=stack) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index 4aea7aebfe5..ff96c7a640c 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -43,6 +43,7 @@ from zenml.exceptions import IllegalOperationError, OAuthError from zenml.logger import get_logger from zenml.plugins.plugin_flavor_registry import PluginFlavorRegistry +from zenml.zen_server.cache import MemoryCache from zenml.zen_server.deploy.deployment import ( LocalServerDeployment, ) @@ -67,6 +68,7 @@ _feature_gate: Optional[FeatureGateInterface] = None _workload_manager: Optional[WorkloadManagerInterface] = None _plugin_flavor_registry: Optional[PluginFlavorRegistry] = None +_memcache: Optional[MemoryCache] = None def zen_store() -> "SqlZenStore": @@ -222,6 +224,31 @@ def initialize_zen_store() -> None: _zen_store = zen_store_ +def initialize_memcache(max_capacity: int, default_expiry: int) -> None: + """Initialize the memory cache. + + Args: + max_capacity: The maximum capacity of the cache. + default_expiry: The default expiry time in seconds. + """ + global _memcache + _memcache = MemoryCache(max_capacity, default_expiry) + + +def memcache() -> MemoryCache: + """Return the memory cache. + + Returns: + The memory cache. + + Raises: + RuntimeError: If the memory cache is not initialized. + """ + if _memcache is None: + raise RuntimeError("Memory cache not initialized") + return _memcache + + _server_config: Optional[ServerConfiguration] = None diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 1d10fd2d0e8..d93a174df0e 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -95,6 +95,7 @@ ) from zenml.zen_server.utils import ( initialize_feature_gate, + initialize_memcache, initialize_plugins, initialize_rbac, initialize_workload_manager, @@ -335,9 +336,10 @@ async def infer_source_context(request: Request, call_next: Any) -> Any: @app.on_event("startup") def initialize() -> None: """Initialize the ZenML server.""" + cfg = server_config() # Set the maximum number of worker threads to_thread.current_default_thread_limiter().total_tokens = ( - server_config().thread_pool_size + cfg.thread_pool_size ) # IMPORTANT: these need to be run before the fastapi app starts, to avoid # race conditions @@ -347,6 +349,7 @@ def initialize() -> None: initialize_workload_manager() initialize_plugins() initialize_secure_headers() + initialize_memcache(cfg.memcache_max_capacity, cfg.memcache_default_expiry) DASHBOARD_REDIRECT_URL = None diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index b8288343256..d3231d65eb0 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -110,6 +110,7 @@ WORKSPACES, ) from zenml.enums import ( + APITokenType, OAuthGrantTypes, StackDeploymentProvider, StoreType, @@ -3872,17 +3873,16 @@ def delete_authorized_device(self, device_id: UUID) -> None: def get_api_token( self, - pipeline_id: Optional[UUID] = None, schedule_id: Optional[UUID] = None, - expires_minutes: Optional[int] = None, + pipeline_run_id: Optional[UUID] = None, + step_run_id: Optional[UUID] = None, ) -> str: """Get an API token for a workload. Args: - pipeline_id: The ID of the pipeline to get a token for. schedule_id: The ID of the schedule to get a token for. - expires_minutes: The number of minutes for which the token should - be valid. If not provided, the token will be valid indefinitely. + pipeline_run_id: The ID of the pipeline run to get a token for. + step_run_id: The ID of the step run to get a token for. Returns: The API token. @@ -3890,13 +3890,16 @@ def get_api_token( Raises: ValueError: if the server response is not valid. """ - params: Dict[str, Any] = {} - if pipeline_id: - params["pipeline_id"] = pipeline_id + params: Dict[str, Any] = { + # Python clients may only request workload tokens. + "token_type": APITokenType.WORKLOAD.value, + } if schedule_id: params["schedule_id"] = schedule_id - if expires_minutes: - params["expires_minutes"] = expires_minutes + if pipeline_run_id: + params["pipeline_run_id"] = pipeline_run_id + if step_run_id: + params["step_run_id"] = step_run_id response_body = self.get(API_TOKEN, params=params) if not isinstance(response_body, str): raise ValueError( diff --git a/tests/unit/zen_server/test_jwt.py b/tests/unit/zen_server/test_jwt.py index 2a67bba85de..480569f2925 100644 --- a/tests/unit/zen_server/test_jwt.py +++ b/tests/unit/zen_server/test_jwt.py @@ -35,8 +35,9 @@ def test_encode_decode_works(): user_id = uuid.uuid4() device_id = uuid.uuid4() api_key_id = uuid.uuid4() - pipeline_id = uuid.uuid4() schedule_id = uuid.uuid4() + pipeline_run_id = uuid.uuid4() + step_run_id = uuid.uuid4() claims = { "foo": "bar", "baz": "qux", @@ -46,8 +47,9 @@ def test_encode_decode_works(): user_id=user_id, device_id=device_id, api_key_id=api_key_id, - pipeline_id=pipeline_id, schedule_id=schedule_id, + pipeline_run_id=pipeline_run_id, + step_run_id=step_run_id, claims=claims, ) @@ -57,8 +59,9 @@ def test_encode_decode_works(): assert decoded_token.user_id == user_id assert decoded_token.device_id == device_id assert decoded_token.api_key_id == api_key_id - assert decoded_token.pipeline_id == pipeline_id assert decoded_token.schedule_id == schedule_id + assert decoded_token.pipeline_run_id == pipeline_run_id + assert decoded_token.step_run_id == step_run_id # Check that the configured custom claims are included in the decoded claims assert decoded_token.claims["foo"] == "bar" assert decoded_token.claims["baz"] == "qux"