Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored workload API token management for better security and implemented generic API token dispenser #3154

Merged
merged 15 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/zenml/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,12 @@ class ServerConfiguration(BaseModel):
deployment.
max_request_body_size_in_bytes: The maximum size of the request body in
bytes. If not specified, the default value of 256 Kb will be used.
memcache_max_capacity: The maximum number of entries that the memory
cache can hold. If not specified, the default value of 1000 will be
used.
memcache_default_expiry: The default expiry time in seconds for cache
entries. If not specified, the default value of 30 seconds will be
used.
"""

deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER
Expand Down Expand Up @@ -328,6 +334,9 @@ class ServerConfiguration(BaseModel):
DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES
)

memcache_max_capacity: int = 1000
memcache_default_expiry: int = 30

_deployment_id: Optional[UUID] = None

@model_validator(mode="before")
Expand Down
105 changes: 85 additions & 20 deletions src/zenml/zen_server/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
UserResponse,
UserUpdate,
)
from zenml.zen_server.cache import cache_result
from zenml.zen_server.exceptions import http_exception_from_error
from zenml.zen_server.jwt import JWTToken
from zenml.zen_server.utils import server_config, zen_store
Expand Down Expand Up @@ -332,20 +333,40 @@ def authenticate_credentials(

if decoded_token.schedule_id:
# If the token contains a schedule ID, we need to check if the
# schedule still exists in the database.
try:
schedule = zen_store().get_schedule(
decoded_token.schedule_id, hydrate=False
)
except KeyError:
# schedule still exists in the database. We use a cached version
# of the schedule active status to avoid unnecessary database
# queries.

@cache_result(expiry=30)
def get_schedule_active(schedule_id: UUID) -> Optional[bool]:
"""Get the active status of a schedule.

Args:
schedule_id: The schedule ID.

Returns:
The schedule active status or None if the schedule does not
exist.
"""
try:
schedule = zen_store().get_schedule(
schedule_id, hydrate=False
)
except KeyError:
return False

return schedule.active

schedule_active = get_schedule_active(decoded_token.schedule_id)
if schedule_active is None:
error = (
f"Authentication error: error retrieving token schedule "
f"{decoded_token.schedule_id}"
)
logger.error(error)
raise CredentialsNotValid(error)

if not schedule.active:
if not schedule_active:
error = (
f"Authentication error: schedule {decoded_token.schedule_id} "
"is not active"
Expand All @@ -356,20 +377,43 @@ def authenticate_credentials(
if decoded_token.pipeline_run_id:
# If the token contains a pipeline run ID, we need to check if the
# pipeline run exists in the database and the pipeline run has
# not concluded.
try:
pipeline_run = zen_store().get_run(
decoded_token.pipeline_run_id, hydrate=False
)
except KeyError:
# not concluded. We use a cached version of the pipeline run status
# to avoid unnecessary database queries.

@cache_result(expiry=30)
def get_pipeline_run_status(
pipeline_run_id: UUID,
) -> Optional[ExecutionStatus]:
"""Get the status of a pipeline run.

Args:
pipeline_run_id: The pipeline run ID.

Returns:
The pipeline run status or None if the pipeline run does not
exist.
"""
try:
pipeline_run = zen_store().get_run(
pipeline_run_id, hydrate=False
)
except KeyError:
return None

return pipeline_run.status

pipeline_run_status = get_pipeline_run_status(
decoded_token.pipeline_run_id
)
if pipeline_run_status is None:
error = (
f"Authentication error: error retrieving token pipeline run "
f"{decoded_token.pipeline_run_id}"
)
logger.error(error)
raise CredentialsNotValid(error)

if pipeline_run.status in [
if pipeline_run_status in [
ExecutionStatus.FAILED,
ExecutionStatus.COMPLETED,
]:
Expand All @@ -384,19 +428,40 @@ def authenticate_credentials(
if decoded_token.step_run_id:
# If the token contains a step run ID, we need to check if the
# step run exists in the database and the step run has not concluded.
try:
step_run = zen_store().get_run_step(
decoded_token.step_run_id, hydrate=False
)
except KeyError:
# We use a cached version of the step run status to avoid unnecessary
# database queries.

@cache_result(expiry=30)
def get_step_run_status(
step_run_id: UUID,
) -> Optional[ExecutionStatus]:
"""Get the status of a step run.

Args:
step_run_id: The step run ID.

Returns:
The step run status or None if the step run does not exist.
"""
try:
step_run = zen_store().get_run_step(
step_run_id, hydrate=False
)
except KeyError:
return None

return step_run.status

step_run_status = get_step_run_status(decoded_token.step_run_id)
if step_run_status is None:
error = (
f"Authentication error: error retrieving token step run "
f"{decoded_token.step_run_id}"
)
logger.error(error)
raise CredentialsNotValid(error)

if step_run.status in [
if step_run_status in [
ExecutionStatus.FAILED,
ExecutionStatus.COMPLETED,
]:
Expand Down
195 changes: 195 additions & 0 deletions src/zenml/zen_server/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Memory cache module for the ZenML server."""

import time
from collections import OrderedDict
from threading import Lock
from typing import Any, Callable, Dict, Optional
from uuid import UUID

from zenml.logger import get_logger
from zenml.utils.singleton import SingletonMetaClass

logger = get_logger(__name__)


class MemoryCacheEntry:
"""Simple class to hold cache entry data."""

def __init__(self, value: Any, expiry: int) -> None:
"""Initialize a cache entry with value and expiry time.

Args:
value: The value to store in the cache.
expiry: The expiry time in seconds.
"""
self.value: Any = value
self.expiry: int = expiry
self.timestamp: float = time.time()

@property
def expired(self) -> bool:
"""Check if the cache entry has expired."""
return time.time() - self.timestamp >= self.expiry


class MemoryCache(metaclass=SingletonMetaClass):
"""Simple in-memory cache with expiry and capacity management.

This cache is thread-safe and can be used in both synchronous and
asynchronous contexts. It uses a simple LRU (Least Recently Used) eviction
strategy to manage the cache size.

Each cache entry has a key, value, timestamp, and expiry. The cache
automatically removes expired entries and evicts the oldest entry when
the cache reaches its maximum capacity.


Usage Example:

cache = MemoryCache()
uuid_key = UUID("12345678123456781234567812345678")

cached_or_real_object = cache.get_or_cache(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the get_or_cache method anywhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right, this was left behind from an older implementation version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

uuid_key, lambda: "sync_data", expiry=120
)
print(cached_or_real_object)
"""

def __init__(self, max_capacity: int, default_expiry: int) -> None:
"""Initialize the cache with a maximum capacity and default expiry time.

Args:
max_capacity: The maximum number of entries the cache can hold.
default_expiry: The default expiry time in seconds.
"""
self.cache: Dict[UUID, MemoryCacheEntry] = OrderedDict()
self.max_capacity = max_capacity
self.default_expiry = default_expiry
self._lock = Lock()

def set(self, key: UUID, value: Any, expiry: Optional[int] = None) -> None:
"""Insert value into cache with optional custom expiry time in seconds.

Args:
key: The key to insert the value with.
value: The value to insert into the cache.
expiry: The expiry time in seconds. If None, uses the default expiry.
"""
with self._lock:
self.cache[key] = MemoryCacheEntry(
value=value, expiry=expiry or self.default_expiry
)
self._cleanup()

def get(self, key: UUID) -> Optional[Any]:
"""Retrieve value if it's still valid; otherwise, return None.

Args:
key: The key to retrieve the value for.

Returns:
The value if it's still valid; otherwise, None.
"""
with self._lock:
return self._get_internal(key)

def _get_internal(self, key: UUID) -> Optional[Any]:
"""Helper to retrieve a value without lock (internal use only).

Args:
key: The key to retrieve the value for.

Returns:
The value if it's still valid; otherwise, None.
"""
entry = self.cache.get(key)
if entry and not entry.expired:
return entry.value
elif entry:
del self.cache[key] # Invalidate expired entry
return None

def _cleanup(self) -> None:
"""Remove expired or excess entries."""
# Remove expired entries
keys_to_remove = [k for k, v in self.cache.items() if v.expired]
for k in keys_to_remove:
del self.cache[k]

# Ensure we don't exceed max capacity
while len(self.cache) > self.max_capacity:
self.cache.popitem(last=False) # type: ignore[call-arg]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to annotate self.cache: OrderedDict[...] to get rid of this type ignore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read that OrderedDict isn't available as a type annotation in Python 3.8 and decided to use Dict instead. But then I realized we don't support Python 3.8 anymore so... yeah, sure :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



F = Callable[[UUID], Optional[Any]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think Any already includes Optional, so this is sort of redundant. Can also leave it like this if needed for clarity though, up to you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done



def cache_result(
expiry: Optional[int] = None,
) -> Callable[[F], F]:
"""A decorator to cache the result of a function based on a UUID key argument.

Args:
expiry: Custom time in seconds for the cache entry to expire. If None,
uses the default expiry time.

Returns:
A decorator that wraps a function, caching its results based on a UUID
key.
"""

def decorator(func: F) -> F:
"""The actual decorator that wraps the function with caching logic.

Args:
func: The function to wrap.

Returns:
The wrapped function with caching logic.
"""

def wrapper(key: UUID) -> Optional[Any]:
"""The wrapped function with caching logic.

Args:
key: The key to use for caching.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.

Returns:
The result of the original function, either from cache or
freshly computed.
"""
from zenml.zen_server.utils import memcache

cache = memcache()

# Attempt to retrieve the result from cache
cached_value = cache.get(key)
if cached_value is not None:
logger.debug(
f"Memory cache hit for key: {key} and func: {func.__name__}"
)
return cached_value

# Call the original function and cache its result
result = func(key)
cache.set(key, result, expiry)
return result

return wrapper

return decorator
Loading
Loading