Skip to content

Commit

Permalink
Refactored workload API token management for better security and impl…
Browse files Browse the repository at this point in the history
…emented generic API token dispenser (#3154)

* Rework the api_token endpoint to issue generic API tokens

* Pass pipeline and schedule id

* Fix warning message

* Implement proper authorization checks for workload API tokens

* Verify if schedule is active for schedule scoped tokens

* Remove all restrictions concerning token types

* Transfer the api key and authorized device scopes to generated workload API tokens

* Actually throw the errors for concluded pipeline runs/steps

* Don't pass the API key in the pipeline environment anymore

* Add in-memory cache to store token validation related objects

* Fix linter issues and applied code review suggestions

* Fix docstrings

---------

Co-authored-by: Michael Schuster <[email protected]>
  • Loading branch information
stefannica and schustmi authored Nov 13, 2024
1 parent 496a0d5 commit 0cf59cb
Show file tree
Hide file tree
Showing 15 changed files with 740 additions and 183 deletions.
18 changes: 17 additions & 1 deletion src/zenml/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from typing import Any, Dict, List, Optional, Union
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, model_validator

from zenml.constants import (
DEFAULT_ZENML_JWT_TOKEN_ALGORITHM,
DEFAULT_ZENML_JWT_TOKEN_LEEWAY,
DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING,
DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT,
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME,
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY,
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE,
DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS,
Expand Down Expand Up @@ -119,6 +120,8 @@ class ServerConfiguration(BaseModel):
time of the JWT tokens issued to clients after they have
authenticated with the ZenML server using an OAuth 2.0 device
that has been marked as trusted.
generic_api_token_lifetime: The lifetime in seconds that generic
short-lived API tokens issued for automation purposes are valid.
external_login_url: The login URL of an external authenticator service
to use with the `EXTERNAL` authentication scheme.
external_user_info_url: The user info URL of an external authenticator
Expand Down Expand Up @@ -230,6 +233,12 @@ class ServerConfiguration(BaseModel):
deployment.
max_request_body_size_in_bytes: The maximum size of the request body in
bytes. If not specified, the default value of 256 Kb will be used.
memcache_max_capacity: The maximum number of entries that the memory
cache can hold. If not specified, the default value of 1000 will be
used.
memcache_default_expiry: The default expiry time in seconds for cache
entries. If not specified, the default value of 30 seconds will be
used.
"""

deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER
Expand Down Expand Up @@ -257,6 +266,10 @@ class ServerConfiguration(BaseModel):
device_expiration_minutes: Optional[int] = None
trusted_device_expiration_minutes: Optional[int] = None

generic_api_token_lifetime: PositiveInt = (
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME
)

external_login_url: Optional[str] = None
external_user_info_url: Optional[str] = None
external_cookie_name: Optional[str] = None
Expand Down Expand Up @@ -321,6 +334,9 @@ class ServerConfiguration(BaseModel):
DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES
)

memcache_max_capacity: int = 1000
memcache_default_expiry: int = 30

_deployment_id: Optional[UUID] = None

@model_validator(mode="before")
Expand Down
8 changes: 1 addition & 7 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
ENV_ZENML_ENFORCE_TYPE_ANNOTATIONS = "ZENML_ENFORCE_TYPE_ANNOTATIONS"
ENV_ZENML_ENABLE_IMPLICIT_AUTH_METHODS = "ZENML_ENABLE_IMPLICIT_AUTH_METHODS"
ENV_ZENML_DISABLE_STEP_LOGS_STORAGE = "ZENML_DISABLE_STEP_LOGS_STORAGE"
ENV_ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES = (
"ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES"
)
ENV_ZENML_IGNORE_FAILURE_HOOK = "ZENML_IGNORE_FAILURE_HOOK"
ENV_ZENML_CUSTOM_SOURCE_ROOT = "ZENML_CUSTOM_SOURCE_ROOT"
ENV_ZENML_WHEEL_PACKAGE_NAME = "ZENML_WHEEL_PACKAGE_NAME"
Expand Down Expand Up @@ -270,6 +267,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW = 60 * 48 # 48 hours
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE = 5
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY = 1000
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME = 60 * 60 # 1 hour

DEFAULT_ZENML_SERVER_SECURE_HEADERS_HSTS = (
"max-age=63072000; includeSubdomains"
Expand Down Expand Up @@ -410,10 +408,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int:

# orchestrator constants
ORCHESTRATOR_DOCKER_IMAGE_KEY = "orchestrator"
PIPELINE_API_TOKEN_EXPIRES_MINUTES = handle_int_env_var(
ENV_ZENML_PIPELINE_API_TOKEN_EXPIRES_MINUTES,
default=60 * 24, # 24 hours
)

# Secret constants
SECRET_VALUES = "values"
Expand Down
7 changes: 7 additions & 0 deletions src/zenml/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,13 @@ class OAuthDeviceStatus(StrEnum):
LOCKED = "locked"


class APITokenType(StrEnum):
"""The API token type."""

GENERIC = "generic"
WORKLOAD = "workload"


class GenericFilterOps(StrEnum):
"""Ops for all filters for string values on list methods."""

Expand Down
13 changes: 12 additions & 1 deletion src/zenml/orchestrators/base_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Type, cast
from uuid import UUID

from pydantic import model_validator

Expand Down Expand Up @@ -186,7 +187,17 @@ def run(
"""
self._prepare_run(deployment=deployment)

environment = get_config_environment_vars(deployment=deployment)
pipeline_run_id: Optional[UUID] = None
schedule_id: Optional[UUID] = None
if deployment.schedule:
schedule_id = deployment.schedule.id
if placeholder_run:
pipeline_run_id = placeholder_run.id

environment = get_config_environment_vars(
schedule_id=schedule_id,
pipeline_run_id=pipeline_run_id,
)

prevent_client_side_caching = handle_bool_env_var(
ENV_ZENML_PREVENT_CLIENT_SIDE_CACHING, default=False
Expand Down
3 changes: 2 additions & 1 deletion src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ def _run_step_with_step_operator(
)
)
environment = orchestrator_utils.get_config_environment_vars(
deployment=self._deployment
pipeline_run_id=step_run_info.run_id,
step_run_id=step_run_info.step_run_id,
)
if last_retry:
environment[ENV_ZENML_IGNORE_FAILURE_HOOK] = str(False)
Expand Down
71 changes: 45 additions & 26 deletions src/zenml/orchestrators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@
ENV_ZENML_DISABLE_CREDENTIALS_DISK_CACHING,
ENV_ZENML_SERVER,
ENV_ZENML_STORE_PREFIX,
PIPELINE_API_TOKEN_EXPIRES_MINUTES,
)
from zenml.enums import AuthScheme, StackComponentType, StoreType
from zenml.logger import get_logger
from zenml.stack import StackComponent
from zenml.utils.string_utils import format_name_template

logger = get_logger(__name__)

if TYPE_CHECKING:
from zenml.artifact_stores.base_artifact_store import BaseArtifactStore
from zenml.models import PipelineDeploymentResponse


def get_orchestrator_run_name(pipeline_name: str) -> str:
Expand Down Expand Up @@ -80,16 +80,23 @@ def is_setting_enabled(


def get_config_environment_vars(
deployment: Optional["PipelineDeploymentResponse"] = None,
schedule_id: Optional[UUID] = None,
pipeline_run_id: Optional[UUID] = None,
step_run_id: Optional[UUID] = None,
) -> Dict[str, str]:
"""Gets environment variables to set for mirroring the active config.
If a pipeline deployment is given, the environment variables will be set to
include a newly generated API token valid for the duration of the pipeline
run instead of the API token from the global config.
If a schedule ID, pipeline run ID or step run ID is given, and the current
client is not authenticated to a server with an API key, the environment
variables will be updated to include a newly generated workload API token
that will be valid for the duration of the schedule, pipeline run, or step
run instead of the current API token used to authenticate the client.
Args:
deployment: Optional deployment to use for the environment variables.
schedule_id: Optional schedule ID to use to generate a new API token.
pipeline_run_id: Optional pipeline run ID to use to generate a new API
token.
step_run_id: Optional step run ID to use to generate a new API token.
Returns:
Environment variable dict.
Expand All @@ -107,34 +114,46 @@ def get_config_environment_vars(
):
credentials_store = get_credentials_store()
url = global_config.store_configuration.url
api_key = credentials_store.get_api_key(url)
api_token = credentials_store.get_token(url, allow_expired=False)
if api_key:
environment_vars[ENV_ZENML_STORE_PREFIX + "API_KEY"] = api_key
elif deployment:
# When connected to an authenticated ZenML server, if a pipeline
# deployment is supplied, we need to fetch an API token that will be
# valid for the duration of the pipeline run.
if schedule_id or pipeline_run_id or step_run_id:
# When connected to an authenticated ZenML server, if a schedule ID,
# pipeline run ID or step run ID is supplied, we need to fetch a new
# workload API token scoped to the schedule, pipeline run or step
# run.
assert isinstance(global_config.zen_store, RestZenStore)
pipeline_id: Optional[UUID] = None
if deployment.pipeline:
pipeline_id = deployment.pipeline.id
schedule_id: Optional[UUID] = None
expires_minutes: Optional[int] = PIPELINE_API_TOKEN_EXPIRES_MINUTES
if deployment.schedule:
schedule_id = deployment.schedule.id
# If a schedule is given, this is a long running pipeline that
# should not have an API token that expires.
expires_minutes = None

# If only a schedule is given, the pipeline run credentials will
# be valid for the entire duration of the schedule.
api_key = credentials_store.get_api_key(url)
if not api_key and not pipeline_run_id and not step_run_id:
logger.warning(
"An API token without an expiration time will be generated "
"and used to run this pipeline on a schedule. This is very "
"insecure because the API token will be valid for the "
"entire lifetime of the schedule and can be used to access "
"your user account if accidentally leaked. When deploying "
"a pipeline on a schedule, it is strongly advised to use a "
"service account API key to authenticate to the ZenML "
"server instead of your regular user account. For more "
"information, see "
"https://docs.zenml.io/how-to/connecting-to-zenml/connect-with-a-service-account"
)

# The schedule, pipeline run or step run credentials are scoped to
# the schedule, pipeline run or step run and will only be valid for
# the duration of the schedule/pipeline run/step run.
new_api_token = global_config.zen_store.get_api_token(
pipeline_id=pipeline_id,
schedule_id=schedule_id,
expires_minutes=expires_minutes,
pipeline_run_id=pipeline_run_id,
step_run_id=step_run_id,
)

environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = (
new_api_token
)
elif api_token:
# For all other cases, the pipeline run environment is configured
# with the current access token.
environment_vars[ENV_ZENML_STORE_PREFIX + "API_TOKEN"] = (
api_token.access_token
)
Expand Down
Loading

0 comments on commit 0cf59cb

Please sign in to comment.