Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API Server] Fix admin policy enforcement on validate and optimize #4820

Merged
merged 22 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 50 additions & 6 deletions sky/client/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import psutil
import requests

from sky import admin_policy
from sky import backends
from sky import exceptions
from sky import sky_logging
Expand Down Expand Up @@ -212,13 +213,23 @@ def list_accelerator_counts(
@annotations.client_api
def optimize(
dag: 'sky.Dag',
minimize: common.OptimizeTarget = common.OptimizeTarget.COST
minimize: common.OptimizeTarget = common.OptimizeTarget.COST,
cluster_name: Optional[str] = None,
idle_minutes_to_autostop: Optional[int] = None,
down: bool = False, # pylint: disable=redefined-outer-name
dryrun: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we emphasize that these are only required when admin_policy is specified, or more aggressively we just make this:

Suggested change
cluster_name: Optional[str] = None,
idle_minutes_to_autostop: Optional[int] = None,
down: bool = False, # pylint: disable=redefined-outer-name
dryrun: bool = False
admin_policy_request_options: Optional[admin_policy.RequestOptions] = None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep good idea, I was also thinking about it but wasn't sure if we want to expose admin_policy.RequestOptions in the SDK to end users. Should be ok.

) -> server_common.RequestId:
"""Finds the best execution plan for the given DAG.

Args:
dag: the DAG to optimize.
minimize: whether to minimize cost or time.
cluster_name: name of the cluster. Used for admin policy validation.
idle_minutes_to_autostop: autostop setting. Used for admin policy
validation.
down: whether to tear down the cluster. Used for admin policy
validation.
dryrun: whether this is a dryrun. Used for admin policy validation.

Returns:
The request ID of the optimize request.
Expand All @@ -233,7 +244,14 @@ def optimize(
"""
dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)

body = payloads.OptimizeBody(dag=dag_str, minimize=minimize)
body = payloads.OptimizeBody(
dag=dag_str,
minimize=minimize,
request_options=admin_policy.RequestOptions(
cluster_name=cluster_name,
idle_minutes_to_autostop=idle_minutes_to_autostop,
down=down,
dryrun=dryrun))
response = requests.post(f'{server_common.get_server_url()}/optimize',
json=json.loads(body.model_dump_json()))
return server_common.get_request_id(response)
Expand All @@ -242,7 +260,13 @@ def optimize(
@usage_lib.entrypoint
@server_common.check_server_healthy_or_start
@annotations.client_api
def validate(dag: 'sky.Dag', workdir_only: bool = False) -> None:
def validate(
dag: 'sky.Dag',
workdir_only: bool = False,
cluster_name: Optional[str] = None,
idle_minutes_to_autostop: Optional[int] = None,
down: bool = False, # pylint: disable=redefined-outer-name
dryrun: bool = False) -> None:
"""Validates the tasks.

The file paths (workdir and file_mounts) are validated on the client side
Expand All @@ -254,13 +278,25 @@ def validate(dag: 'sky.Dag', workdir_only: bool = False) -> None:
dag: the DAG to validate.
workdir_only: whether to only validate the workdir. This is used for
`exec` as it does not need other files/folders in file_mounts.
cluster_name: name of the cluster. Used for admin policy validation.
idle_minutes_to_autostop: autostop setting. Used for admin policy
validation.
down: whether to tear down the cluster. Used for admin policy
validation.
dryrun: whether this is a dryrun. Used for admin policy validation.
"""
for task in dag.tasks:
task.expand_and_validate_workdir()
if not workdir_only:
task.expand_and_validate_file_mounts()
dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
body = payloads.ValidateBody(dag=dag_str)
body = payloads.ValidateBody(
dag=dag_str,
request_options=admin_policy.RequestOptions(
cluster_name=cluster_name,
idle_minutes_to_autostop=idle_minutes_to_autostop,
down=down,
dryrun=dryrun))
response = requests.post(f'{server_common.get_server_url()}/validate',
json=json.loads(body.model_dump_json()))
if response.status_code == 400:
Expand Down Expand Up @@ -386,7 +422,11 @@ def launch(
'Please contact the SkyPilot team if you '
'need this feature at slack.skypilot.co.')
dag = dag_utils.convert_entrypoint_to_dag(task)
validate(dag)
validate(dag,
cluster_name=cluster_name,
idle_minutes_to_autostop=idle_minutes_to_autostop,
down=down,
dryrun=dryrun)

confirm_shown = False
if _need_confirmation:
Expand Down Expand Up @@ -536,7 +576,11 @@ def exec( # pylint: disable=redefined-builtin
controller that does not support this operation.
"""
dag = dag_utils.convert_entrypoint_to_dag(task)
validate(dag, workdir_only=True)
validate(dag,
workdir_only=True,
cluster_name=cluster_name,
dryrun=dryrun,
down=down)
dag = client_common.upload_mounts_to_api_server(dag, workdir_only=True)
dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag)
body = payloads.ExecBody(
Expand Down
78 changes: 74 additions & 4 deletions sky/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@

import colorama

from sky import admin_policy
from sky import backends
from sky import check as sky_check
from sky import clouds
from sky import dag
from sky import dag as dag_lib
from sky import data
from sky import exceptions
from sky import global_user_state
from sky import models
from sky import optimizer
from sky import sky_logging
from sky import task
from sky import task as task_lib
from sky.backends import backend_utils
from sky.clouds import service_catalog
from sky.jobs.server import core as managed_jobs_core
Expand All @@ -25,6 +27,7 @@
from sky.skylet import job_lib
from sky.skylet import log_lib
from sky.usage import usage_lib
from sky.utils import admin_policy_utils
from sky.utils import common
from sky.utils import common_utils
from sky.utils import controller_utils
Expand All @@ -44,6 +47,73 @@
# ======================


@usage_lib.entrypoint
def optimize(
dag: 'dag_lib.Dag',
minimize: common.OptimizeTarget = common.OptimizeTarget.COST,
blocked_resources: Optional[List['resources_lib.Resources']] = None,
quiet: bool = False,
request_options: Optional[admin_policy.RequestOptions] = None
) -> 'dag_lib.Dag':
"""Finds the best execution plan for the given DAG.

Args:
dag: the DAG to optimize.
minimize: whether to minimize cost or time.
blocked_resources: a list of resources that should not be used.
quiet: whether to suppress logging.
request_options: Request options used in enforcing admin policies.

Returns:
The optimized DAG.

Raises:
exceptions.ResourcesUnavailableError: if no resources are available
for a task.
exceptions.NoCloudAccessError: if no public clouds are enabled.
"""
# TODO: We apply the admin policy only on the first DAG optimization which
# is shown on `sky launch`. The optimizer is also invoked during failover,
# but we do not apply the admin policy there. We should apply the admin
# policy in the optimizer, but that will require some refactoring.
dag, _ = admin_policy_utils.apply(
dag,
use_mutated_config_in_current_request=True,
request_options=request_options)
return optimizer.Optimizer.optimize(dag=dag,
minimize=minimize,
blocked_resources=blocked_resources,
quiet=quiet)


@usage_lib.entrypoint
def validate_dag(
dag: 'dag_lib.Dag',
request_options: Optional[admin_policy.RequestOptions] = None) -> None:
"""Validates the specified DAG.

Args:
dag: The DAG to validate.
request_options: Request options used in enforcing admin policies.

Raises:
ValueError: if the DAG is invalid.
"""
dag, _ = admin_policy_utils.apply(
dag,
use_mutated_config_in_current_request=True,
request_options=request_options)

for task in dag.tasks:
# Will validate workdir and file_mounts in the backend, as those
# need to be validated after the files are uploaded to the SkyPilot
# API server with `upload_mounts_to_api_server`.
task.validate_name()
task.validate_run()
for r in task.resources:
r.validate()


@usage_lib.entrypoint
def status(
cluster_names: Optional[Union[str, List[str]]] = None,
Expand Down Expand Up @@ -325,8 +395,8 @@ def _start(

usage_lib.record_cluster_name_for_current_operation(cluster_name)

with dag.Dag():
dummy_task = task.Task().set_resources(handle.launched_resources)
with dag_lib.Dag():
dummy_task = task_lib.Task().set_resources(handle.launched_resources)
dummy_task.num_nodes = handle.launched_nodes
handle = backend.provision(dummy_task,
to_provision=handle.launched_resources,
Expand Down
3 changes: 3 additions & 0 deletions sky/server/requests/payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import pydantic

from sky import admin_policy
from sky import serve
from sky import sky_logging
from sky import skypilot_config
Expand Down Expand Up @@ -116,12 +117,14 @@ class CheckBody(RequestBody):
class ValidateBody(RequestBody):
"""The request body for the validate endpoint."""
dag: str
request_options: admin_policy.RequestOptions
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be made optional and we can error out on the server side if admin_policy is set but RequestOptions are not supplied?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, made optional now. When None, we will now follow the default behavior in admin_policy_utils.apply().

(Current behavior is it will silently accept request_options as None, but we should update that behavior in a future PR. That's a deeper change since serve and jobs launch also invoke admin_policy_utils.apply() without request_options)



class OptimizeBody(RequestBody):
"""The request body for the optimize endpoint."""
dag: str
minimize: common_lib.OptimizeTarget = common_lib.OptimizeTarget.COST
request_options: admin_policy.RequestOptions

def to_kwargs(self) -> Dict[str, Any]:
# Import here to avoid requirement of the whole SkyPilot dependency on
Expand Down
35 changes: 17 additions & 18 deletions sky/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@
from sky import check as sky_check
from sky import clouds
from sky import core
from sky import exceptions
from sky import execution
from sky import global_user_state
from sky import optimizer
from sky import sky_logging
from sky.clouds import service_catalog
from sky.data import storage_utils
Expand All @@ -44,7 +42,6 @@
from sky.usage import usage_lib
from sky.utils import common as common_lib
from sky.utils import common_utils
from sky.utils import dag_utils
from sky.utils import status_lib

# pylint: disable=ungrouped-imports
Expand Down Expand Up @@ -254,24 +251,26 @@ async def list_accelerator_counts(


@app.post('/validate')
async def validate(validate_body: payloads.ValidateBody) -> None:
async def validate(request: fastapi.Request,
validate_body: payloads.ValidateBody) -> None:
"""Validates the user's DAG."""
# TODO(SKY-1035): validate if existing cluster satisfies the requested
# resources, e.g. sky exec --gpus V100:8 existing-cluster-with-no-gpus

# We make the validation a separate request as it may require expensive
# network calls if an admin policy is applied.
# TODO: Our current launch process is broken down into three calls:
# validate, optimize, and launch. This requires us to apply the admin policy
# in each step, which may be an expensive operation. We should consolidate
# these into a single call or have a TTL cache for (task, admin_policy)
# pairs.
logger.debug(f'Validating tasks: {validate_body.dag}')
try:
dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag)
for task in dag.tasks:
# Will validate workdir and file_mounts in the backend, as those
# need to be validated after the files are uploaded to the SkyPilot
# API server with `upload_mounts_to_api_server`.
task.validate_name()
task.validate_run()
for r in task.resources:
r.validate()
except Exception as e: # pylint: disable=broad-except
raise fastapi.HTTPException(
status_code=400, detail=exceptions.serialize_exception(e)) from e
executor.schedule_request(request_id=request.state.request_id,
request_name='validate',
request_body=validate_body,
ignore_return_value=True,
func=core.validate_dag,
schedule_type=requests_lib.ScheduleType.SHORT)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to double check how much additional overhead we add to a normal launch with this modification. It may not worth it to slow down all the launch requests just for the admin policy, which is a less common codepath : )

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for flagging, it does add a significant penalty:

[Master] validate without executor: 0.007498 seconds
[This branch] validate with executor: 0.150482 seconds

I'm going to revert this change in favor of keeping the codepath fast for the common case when there's no admin policy.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark script:

from sky.client import sdk
from sky import Task, Dag, Resources
from sky.utils import dag_utils
import timeit

task = Task(name='test', run="echo 'Hello, World!'")
# resources = Resources(labels={'invalid'*50: 'bar'})
# task.set_resources(resources)
dag = dag_utils.convert_entrypoint_to_dag(task)

# First run once to get the result and warm up any caches
print("Running initial validation...")
sdk.validate(dag)
print("Initial validation complete.")

# Now benchmark with timeit
print("\nStarting benchmark...")
number_of_runs = 10
setup = "from __main__ import sdk, dag"

# Run the benchmark
print(f"Running sdk.validate() {number_of_runs} times...")
execution_time = timeit.timeit("sdk.validate(dag)", setup=setup, number=number_of_runs)
average_time = execution_time / number_of_runs

# Print results
print("\nBenchmark Results:")
print(f"Total time for {number_of_runs} runs: {execution_time:.6f} seconds")
print(f"Average time per run: {average_time:.6f} seconds")



@app.post('/optimize')
Expand All @@ -283,7 +282,7 @@ async def optimize(optimize_body: payloads.OptimizeBody,
request_name='optimize',
request_body=optimize_body,
ignore_return_value=True,
func=optimizer.Optimizer.optimize,
func=core.optimize,
schedule_type=requests_lib.ScheduleType.SHORT,
)

Expand Down