Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored workload API token management for better security and implemented generic API token dispenser #3154

Merged
merged 15 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/zenml/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DEFAULT_ZENML_JWT_TOKEN_LEEWAY,
DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING,
DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT,
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME,
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY,
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE,
DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS,
Expand Down Expand Up @@ -119,6 +120,8 @@ class ServerConfiguration(BaseModel):
time of the JWT tokens issued to clients after they have
authenticated with the ZenML server using an OAuth 2.0 device
that has been marked as trusted.
generic_api_token_lifetime: The lifetime in seconds that generic
short-lived API tokens issued for automation purposes are valid.
external_login_url: The login URL of an external authenticator service
to use with the `EXTERNAL` authentication scheme.
external_user_info_url: The user info URL of an external authenticator
Expand Down Expand Up @@ -257,6 +260,10 @@ class ServerConfiguration(BaseModel):
device_expiration_minutes: Optional[int] = None
trusted_device_expiration_minutes: Optional[int] = None

generic_api_token_lifetime: int = (
stefannica marked this conversation as resolved.
Show resolved Hide resolved
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add validation for the token lifetime value.

The generic_api_token_lifetime should be validated to ensure it's a positive integer to prevent potential security issues with invalid token lifetimes.

Add the following validation to the _validate_config method:

     @model_validator(mode="before")
     @classmethod
     @before_validator_handler
     def _validate_config(cls, data: Dict[str, Any]) -> Dict[str, Any]:
+        # Validate generic API token lifetime
+        if "generic_api_token_lifetime" in data:
+            lifetime = data["generic_api_token_lifetime"]
+            if not isinstance(lifetime, int) or lifetime <= 0:
+                raise ValueError(
+                    "generic_api_token_lifetime must be a positive integer"
+                )

Committable suggestion was skipped due to low confidence.


💡 Codebase verification

Add validation for maximum token lifetime

The current implementation allows setting arbitrary token lifetimes without an upper bound. The default value is 1 hour, which is a good security practice, but there's no enforcement to prevent setting much longer lifetimes. Consider adding a validation rule to limit the maximum token lifetime to a reasonable duration (e.g., 24 hours) to prevent potential security risks from long-lived tokens.

  • In src/zenml/config/server_config.py, add a validator for generic_api_token_lifetime to enforce a maximum value
🔗 Analysis chain

Consider adding a maximum lifetime limit.

For security best practices, consider adding an upper bound to the token lifetime to prevent the creation of tokens with excessive lifetimes.

Let's check if there are any existing token lifetime limits in the codebase:


Let me check if there are any validation rules for the token lifetime:


Let me check the default value for the token lifetime:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for token lifetime related constants and validations
rg -i "token.*lifetime|lifetime.*token" -g "!*.{log,md,txt}"

Length of output: 752


Script:

#!/bin/bash
# Check for any validation rules around token lifetime
ast-grep --pattern 'generic_api_token_lifetime'

# Also get the content of the server config file
cat src/zenml/config/server_config.py

Length of output: 25570


Script:

#!/bin/bash
# Check the default token lifetime value
cat src/zenml/constants.py | grep -A 1 "DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME"

Length of output: 157

external_login_url: Optional[str] = None
external_user_info_url: Optional[str] = None
external_cookie_name: Optional[str] = None
Expand Down
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW = 60 * 48 # 48 hours
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE = 5
DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY = 1000
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME = 60 * 60 # 1 hour

DEFAULT_ZENML_SERVER_SECURE_HEADERS_HSTS = (
"max-age=63072000; includeSubdomains"
Expand Down
13 changes: 13 additions & 0 deletions src/zenml/orchestrators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from zenml.stack import StackComponent
from zenml.utils.string_utils import format_name_template

logger = get_logger(__name__)

if TYPE_CHECKING:
from zenml.artifact_stores.base_artifact_store import BaseArtifactStore
from zenml.models import PipelineDeploymentResponse
Expand Down Expand Up @@ -113,6 +115,17 @@ def get_config_environment_vars(
# If a schedule is given, this is a long running pipeline that
# should not have an API token that expires.
expires_minutes = None
logger.warning(
stefannica marked this conversation as resolved.
Show resolved Hide resolved
"An API token without an expiration time will be generated "
"and used to run this pipeline on a schedule. This is very "
"insecure because the API token cannot be revoked in case "
"of potential theft without disabling the entire user "
"account. When deploying a pipeline on a schedule, it is "
"strongly advised to use a service account API key to "
"authenticate to the ZenML server instead of your regular "
"user account. For more information, see "
"https://docs.zenml.io/how-to/connecting-to-zenml/connect-with-a-service-account"
)
api_token = global_config.zen_store.get_api_token(
pipeline_id=pipeline_id,
schedule_id=schedule_id,
Expand Down
61 changes: 36 additions & 25 deletions src/zenml/zen_server/routers/auth_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,12 @@ def __init__(

def generate_access_token(
user_id: UUID,
response: Response,
response: Optional[Response] = None,
device: Optional[OAuthDeviceInternalResponse] = None,
api_key: Optional[APIKeyInternalResponse] = None,
expires_in: Optional[int] = None,
pipeline_id: Optional[UUID] = None,
schedule_id: Optional[UUID] = None,
Copy link
Contributor

@coderabbitai coderabbitai bot Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Unused parameters 'pipeline_id' and 'schedule_id' in generate_access_token

The parameters pipeline_id and schedule_id are added to the generate_access_token function signature but are not utilized within the function body. If these parameters are intended to scope the access token, they should be incorporated into the JWTToken or otherwise appropriately used.

Apply this diff to include pipeline_id and schedule_id in the JWTToken:

     access_token = JWTToken(
         user_id=user_id,
         device_id=device.id if device else None,
         api_key_id=api_key.id if api_key else None,
+        pipeline_id=pipeline_id,
+        schedule_id=schedule_id,
     ).encode(expires=expires)

Committable suggestion was skipped due to low confidence.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be missing indeed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@schustmi, glad we could identify this missing part. Including pipeline_id and schedule_id in the JWTToken will ensure proper token scoping.

(^‿^)/


If you found this review helpful, would you consider giving us a shout-out on X?
Thank you for using CodeRabbit!

) -> OAuthTokenResponse:
"""Generates an access token for the given user.

Expand All @@ -204,18 +207,22 @@ def generate_access_token(
response: The FastAPI response object.
device: The device used for authentication.
api_key: The service account API key used for authentication.
expires_in: The number of seconds until the token expires.
pipeline_id: The ID of the pipeline to scope the token to.
schedule_id: The ID of the schedule to scope the token to.

Returns:
An authentication response with an access token.
"""
config = server_config()

# The JWT tokens are set to expire according to the values configured
# in the server config. Device tokens are handled separately from regular
# user tokens.
# If the expiration time is not supplied, the JWT tokens are set to expire
# according to the values configured in the server config. Device tokens are
# handled separately from regular user tokens.
expires: Optional[datetime] = None
expires_in: Optional[int] = None
if device:
if expires_in:
expires = datetime.utcnow() + timedelta(seconds=expires_in)
elif device:
# If a device was used for authentication, the token will expire
# at the same time as the device.
expires = device.expires
Expand All @@ -233,9 +240,11 @@ def generate_access_token(
user_id=user_id,
device_id=device.id if device else None,
api_key_id=api_key.id if api_key else None,
pipeline_id=pipeline_id,
schedule_id=schedule_id
).encode(expires=expires)

if not device:
if not device and response:
# Also set the access token as an HTTP only cookie in the response
response.set_cookie(
key=config.get_auth_cookie_name(),
Expand Down Expand Up @@ -522,6 +531,20 @@ def api_token(
detail="Not authenticated.",
)

if not token.device_id and not token.api_key_id:
config = server_config()

# If not authenticated with a device or a service account, then a
# short-lived generic API token is returned.
return generate_access_token(
schustmi marked this conversation as resolved.
Show resolved Hide resolved
user_id=token.user_id,
expires_in=config.generic_api_token_lifetime,
).access_token

# Issuing workload tokens is only supported for device authenticated users
# and service accounts, because device tokens can be revoked at any time and
# service accounts can be disabled.

verify_permission(
resource_type=ResourceType.PIPELINE_RUN, action=Action.CREATE
)
Expand All @@ -540,21 +563,9 @@ def api_token(
f"schedule {token.schedule_id}."
)

if not token.device_id and not token.api_key_id:
# If not authenticated with a device or a service account, the current
# API token is returned as is, without any modifications. Issuing
# workload tokens is only supported for device authenticated users and
# service accounts, because device tokens can be revoked at any time and
# service accounts can be disabled.
return auth_context.encoded_access_token

# If authenticated with a device, a new API token is generated for the
# pipeline and/or schedule.
if pipeline_id:
token.pipeline_id = pipeline_id
if schedule_id:
token.schedule_id = schedule_id
expires: Optional[datetime] = None
if expires_minutes:
expires = datetime.utcnow() + timedelta(minutes=expires_minutes)
return token.encode(expires=expires)
return generate_access_token(
user_id=token.user_id,
expires_in=expires_minutes * 60 if expires_minutes else None,
pipeline_id=pipeline_id,
schedule_id=schedule_id,
).access_token
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Passing unused pipeline_id and schedule_id to generate_access_token

In this call to generate_access_token, the parameters pipeline_id and schedule_id are provided but not utilized within the function. Since generate_access_token does not handle these parameters, this could lead to confusion or unintended behavior.

Consider removing these parameters from the function call or updating generate_access_token to handle them appropriately.

Loading