From 616bfbf5e8f3ed26a3ec30400c0abb821c725403 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 25 Feb 2025 15:50:47 -0800 Subject: [PATCH 01/21] Add admin policy to validate --- sky/client/sdk.py | 31 +++++++++++++++++++++++++++---- sky/server/requests/payloads.py | 4 ++++ sky/server/server.py | 17 +++++++++++++++++ 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index a6a3340e996..6c56f96eadd 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -242,7 +242,12 @@ def optimize( @usage_lib.entrypoint @server_common.check_server_healthy_or_start @annotations.client_api -def validate(dag: 'sky.Dag', workdir_only: bool = False) -> None: +def validate(dag: 'sky.Dag', + workdir_only: bool = False, + cluster_name: Optional[str] = None, + idle_minutes_to_autostop: Optional[int] = None, + down: bool = False, + dryrun: bool = False) -> None: """Validates the tasks. The file paths (workdir and file_mounts) are validated on the client side @@ -254,13 +259,23 @@ def validate(dag: 'sky.Dag', workdir_only: bool = False) -> None: dag: the DAG to validate. workdir_only: whether to only validate the workdir. This is used for `exec` as it does not need other files/folders in file_mounts. + cluster_name: name of the cluster to create/reuse. Used for admin policy validation. + idle_minutes_to_autostop: autostop setting. Used for admin policy validation. + down: whether to tear down the cluster. Used for admin policy validation. + dryrun: whether this is a dryrun. Used for admin policy validation. """ for task in dag.tasks: task.expand_and_validate_workdir() if not workdir_only: task.expand_and_validate_file_mounts() dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) - body = payloads.ValidateBody(dag=dag_str) + body = payloads.ValidateBody( + dag=dag_str, + cluster_name=cluster_name, + idle_minutes_to_autostop=idle_minutes_to_autostop, + down=down, + dryrun=dryrun + ) response = requests.post(f'{server_common.get_server_url()}/validate', json=json.loads(body.model_dump_json())) if response.status_code == 400: @@ -386,7 +401,11 @@ def launch( 'Please contact the SkyPilot team if you ' 'need this feature at slack.skypilot.co.') dag = dag_utils.convert_entrypoint_to_dag(task) - validate(dag) + validate(dag, + cluster_name=cluster_name, + idle_minutes_to_autostop=idle_minutes_to_autostop, + down=down, + dryrun=dryrun) confirm_shown = False if _need_confirmation: @@ -536,7 +555,11 @@ def exec( # pylint: disable=redefined-builtin controller that does not support this operation. """ dag = dag_utils.convert_entrypoint_to_dag(task) - validate(dag, workdir_only=True) + validate(dag, + workdir_only=True, + cluster_name=cluster_name, + dryrun=dryrun, + down=down) dag = client_common.upload_mounts_to_api_server(dag, workdir_only=True) dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) body = payloads.ExecBody( diff --git a/sky/server/requests/payloads.py b/sky/server/requests/payloads.py index 02f2ab8e969..c8d5c80cb0f 100644 --- a/sky/server/requests/payloads.py +++ b/sky/server/requests/payloads.py @@ -116,6 +116,10 @@ class CheckBody(RequestBody): class ValidateBody(RequestBody): """The request body for the validate endpoint.""" dag: str + cluster_name: Optional[str] = None + idle_minutes_to_autostop: Optional[int] = None + down: bool = False + dryrun: bool = False class OptimizeBody(RequestBody): diff --git a/sky/server/server.py b/sky/server/server.py index e87bba33728..382a867f23d 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -21,6 +21,7 @@ import starlette.middleware.base import sky +from sky import admin_policy from sky import check as sky_check from sky import clouds from sky import core @@ -46,6 +47,7 @@ from sky.utils import common_utils from sky.utils import dag_utils from sky.utils import status_lib +from sky.utils import admin_policy_utils # pylint: disable=ungrouped-imports if sys.version_info >= (3, 10): @@ -261,6 +263,21 @@ async def validate(validate_body: payloads.ValidateBody) -> None: logger.debug(f'Validating tasks: {validate_body.dag}') try: dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag) + + # Apply admin policies + request_options = admin_policy.RequestOptions( + cluster_name=validate_body.cluster_name, + idle_minutes_to_autostop=validate_body.idle_minutes_to_autostop, + down=validate_body.down, + dryrun=validate_body.dryrun + ) + + dag, _ = admin_policy_utils.apply( + dag, + use_mutated_config_in_current_request=True, + request_options=request_options + ) + for task in dag.tasks: # Will validate workdir and file_mounts in the backend, as those # need to be validated after the files are uploaded to the SkyPilot From 57a6ab26c4b8008babe3d469ba54b78ea366a830 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 25 Feb 2025 17:37:55 -0800 Subject: [PATCH 02/21] Add admin policy to validate --- sky/client/sdk.py | 27 +++++++++++++++++++++++---- sky/serve/client/sdk.py | 8 ++++---- sky/server/requests/payloads.py | 7 +++---- sky/server/server.py | 16 ++++++---------- 4 files changed, 36 insertions(+), 22 deletions(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index 6c56f96eadd..94647ef70a0 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -25,7 +25,7 @@ import psutil import requests -from sky import backends +from sky import backends, admin_policy from sky import exceptions from sky import sky_logging from sky import skypilot_config @@ -37,6 +37,7 @@ from sky.skylet import constants from sky.usage import usage_lib from sky.utils import annotations +from sky.utils import admin_policy_utils from sky.utils import cluster_utils from sky.utils import common from sky.utils import common_utils @@ -212,13 +213,21 @@ def list_accelerator_counts( @annotations.client_api def optimize( dag: 'sky.Dag', - minimize: common.OptimizeTarget = common.OptimizeTarget.COST + minimize: common.OptimizeTarget = common.OptimizeTarget.COST, + cluster_name: Optional[str] = None, + idle_minutes_to_autostop: Optional[int] = None, + down: bool = False, + dryrun: bool = False ) -> server_common.RequestId: """Finds the best execution plan for the given DAG. Args: dag: the DAG to optimize. minimize: whether to minimize cost or time. + cluster_name: name of the cluster to create/reuse. Used for admin policy validation. + idle_minutes_to_autostop: autostop setting. Used for admin policy validation. + down: whether to tear down the cluster. Used for admin policy validation. + dryrun: whether this is a dryrun. Used for admin policy validation. Returns: The request ID of the optimize request. @@ -233,7 +242,16 @@ def optimize( """ dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) - body = payloads.OptimizeBody(dag=dag_str, minimize=minimize) + body = payloads.OptimizeBody( + dag=dag_str, + minimize=minimize, + request_options=admin_policy.RequestOptions( + cluster_name=cluster_name, + idle_minutes_to_autostop=idle_minutes_to_autostop, + down=down, + dryrun=dryrun + ) + ) response = requests.post(f'{server_common.get_server_url()}/optimize', json=json.loads(body.model_dump_json())) return server_common.get_request_id(response) @@ -271,10 +289,11 @@ def validate(dag: 'sky.Dag', dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) body = payloads.ValidateBody( dag=dag_str, + request_options=admin_policy.RequestOptions( cluster_name=cluster_name, idle_minutes_to_autostop=idle_minutes_to_autostop, down=down, - dryrun=dryrun + dryrun=dryrun) ) response = requests.post(f'{server_common.get_server_url()}/validate', json=json.loads(body.model_dump_json())) diff --git a/sky/serve/client/sdk.py b/sky/serve/client/sdk.py index 589814d5865..9173d3de3ee 100644 --- a/sky/serve/client/sdk.py +++ b/sky/serve/client/sdk.py @@ -51,8 +51,8 @@ def up( from sky.client import sdk # pylint: disable=import-outside-toplevel dag = dag_utils.convert_entrypoint_to_dag(task) - sdk.validate(dag) - request_id = sdk.optimize(dag) + sdk.validate(dag, workdir_only=False, cluster_name=service_name) + request_id = sdk.optimize(dag, cluster_name=service_name) sdk.stream_and_get(request_id) if _need_confirmation: prompt = f'Launching a new service {service_name!r}. Proceed?' @@ -107,8 +107,8 @@ def update( from sky.client import sdk # pylint: disable=import-outside-toplevel dag = dag_utils.convert_entrypoint_to_dag(task) - sdk.validate(dag) - request_id = sdk.optimize(dag) + sdk.validate(dag, workdir_only=False, cluster_name=service_name) + request_id = sdk.optimize(dag, cluster_name=service_name) sdk.stream_and_get(request_id) if _need_confirmation: click.confirm(f'Updating service {service_name!r}. Proceed?', diff --git a/sky/server/requests/payloads.py b/sky/server/requests/payloads.py index c8d5c80cb0f..81b25118f35 100644 --- a/sky/server/requests/payloads.py +++ b/sky/server/requests/payloads.py @@ -12,6 +12,7 @@ import pydantic +from sky import admin_policy from sky import serve from sky import sky_logging from sky import skypilot_config @@ -116,16 +117,14 @@ class CheckBody(RequestBody): class ValidateBody(RequestBody): """The request body for the validate endpoint.""" dag: str - cluster_name: Optional[str] = None - idle_minutes_to_autostop: Optional[int] = None - down: bool = False - dryrun: bool = False + request_options: admin_policy.RequestOptions class OptimizeBody(RequestBody): """The request body for the optimize endpoint.""" dag: str minimize: common_lib.OptimizeTarget = common_lib.OptimizeTarget.COST + request_options: admin_policy.RequestOptions def to_kwargs(self) -> Dict[str, Any]: # Import here to avoid requirement of the whole SkyPilot dependency on diff --git a/sky/server/server.py b/sky/server/server.py index 382a867f23d..32b57918e8d 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -263,19 +263,15 @@ async def validate(validate_body: payloads.ValidateBody) -> None: logger.debug(f'Validating tasks: {validate_body.dag}') try: dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag) - - # Apply admin policies - request_options = admin_policy.RequestOptions( - cluster_name=validate_body.cluster_name, - idle_minutes_to_autostop=validate_body.idle_minutes_to_autostop, - down=validate_body.down, - dryrun=validate_body.dryrun - ) - + + # Apply admin policy since it may affect DAG validation. + # TODO: The admin policy may be a potentially expensive operation with + # network calls. Maybe this should be moved into a request run with the + # executor. dag, _ = admin_policy_utils.apply( dag, use_mutated_config_in_current_request=True, - request_options=request_options + request_options=validate_body.request_options ) for task in dag.tasks: From 714e65ce952051546f9c5018940023d4eb926cec Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 25 Feb 2025 17:51:21 -0800 Subject: [PATCH 03/21] Add admin policy to optimize --- sky/core.py | 38 ++++++++++++++++++++++++++++++++++++++ sky/server/server.py | 2 +- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sky/core.py b/sky/core.py index 5869dcdd4ca..cc321611cf9 100644 --- a/sky/core.py +++ b/sky/core.py @@ -6,6 +6,7 @@ import colorama +from sky import admin_policy from sky import backends from sky import check as sky_check from sky import clouds @@ -14,6 +15,7 @@ from sky import exceptions from sky import global_user_state from sky import models +from sky import optimizer from sky import sky_logging from sky import task from sky.backends import backend_utils @@ -25,6 +27,7 @@ from sky.skylet import job_lib from sky.skylet import log_lib from sky.usage import usage_lib +from sky.utils import admin_policy_utils from sky.utils import common from sky.utils import common_utils from sky.utils import controller_utils @@ -44,6 +47,41 @@ # ====================== +@usage_lib.entrypoint +def optimize(dag: 'dag.Dag', + minimize: common.OptimizeTarget = common.OptimizeTarget.COST, + blocked_resources: Optional[List['resources_lib.Resources']] = None, + quiet: bool = False, + request_options: admin_policy.RequestOptions = None) -> 'dag.Dag': + """Finds the best execution plan for the given DAG. + + Args: + dag: the DAG to optimize. + minimize: whether to minimize cost or time. + blocked_resources: a list of resources that should not be used. + quiet: whether to suppress logging. + + Returns: + The optimized DAG. + + Raises: + exceptions.ResourcesUnavailableError: if no resources are available + for a task. + exceptions.NoCloudAccessError: if no public clouds are enabled. + """ + dag, _ = admin_policy_utils.apply( + dag, + use_mutated_config_in_current_request=True, + request_options=request_options + ) + return optimizer.Optimizer.optimize( + dag=dag, + minimize=minimize, + blocked_resources=blocked_resources, + quiet=quiet + ) + + @usage_lib.entrypoint def status( cluster_names: Optional[Union[str, List[str]]] = None, diff --git a/sky/server/server.py b/sky/server/server.py index 32b57918e8d..2bf20c662af 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -296,7 +296,7 @@ async def optimize(optimize_body: payloads.OptimizeBody, request_name='optimize', request_body=optimize_body, ignore_return_value=True, - func=optimizer.Optimizer.optimize, + func=core.optimize, schedule_type=requests_lib.ScheduleType.SHORT, ) From 6423aad056a2d9f3e6caa831ac27516b7bd824e8 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 25 Feb 2025 17:57:45 -0800 Subject: [PATCH 04/21] docs --- sky/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sky/core.py b/sky/core.py index cc321611cf9..cc137da8dcb 100644 --- a/sky/core.py +++ b/sky/core.py @@ -60,6 +60,7 @@ def optimize(dag: 'dag.Dag', minimize: whether to minimize cost or time. blocked_resources: a list of resources that should not be used. quiet: whether to suppress logging. + request_options: Request options used in enforcing admin policies. Returns: The optimized DAG. From 5010494f798bd68f5af01502f822157d174fff1e Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 25 Feb 2025 18:00:08 -0800 Subject: [PATCH 05/21] imports --- sky/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sky/core.py b/sky/core.py index cc137da8dcb..99472f82bb3 100644 --- a/sky/core.py +++ b/sky/core.py @@ -10,7 +10,7 @@ from sky import backends from sky import check as sky_check from sky import clouds -from sky import dag +from sky import dag as dag_lib from sky import data from sky import exceptions from sky import global_user_state @@ -48,11 +48,11 @@ @usage_lib.entrypoint -def optimize(dag: 'dag.Dag', +def optimize(dag: 'dag_lib.Dag', minimize: common.OptimizeTarget = common.OptimizeTarget.COST, blocked_resources: Optional[List['resources_lib.Resources']] = None, quiet: bool = False, - request_options: admin_policy.RequestOptions = None) -> 'dag.Dag': + request_options: admin_policy.RequestOptions = None) -> 'dag_lib.Dag': """Finds the best execution plan for the given DAG. Args: @@ -364,7 +364,7 @@ def _start( usage_lib.record_cluster_name_for_current_operation(cluster_name) - with dag.Dag(): + with dag_lib.Dag(): dummy_task = task.Task().set_resources(handle.launched_resources) dummy_task.num_nodes = handle.launched_nodes handle = backend.provision(dummy_task, From 0b656170c2649048a25b691006b74a55d72beb82 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 25 Feb 2025 18:09:33 -0800 Subject: [PATCH 06/21] Move dag validation to core --- sky/core.py | 32 ++++++++++++++++++++++++++++++-- sky/server/server.py | 36 ++++++++++++------------------------ 2 files changed, 42 insertions(+), 26 deletions(-) diff --git a/sky/core.py b/sky/core.py index 99472f82bb3..e7c8a29c6c6 100644 --- a/sky/core.py +++ b/sky/core.py @@ -17,7 +17,7 @@ from sky import models from sky import optimizer from sky import sky_logging -from sky import task +from sky import task as task_lib from sky.backends import backend_utils from sky.clouds import service_catalog from sky.jobs.server import core as managed_jobs_core @@ -83,6 +83,34 @@ def optimize(dag: 'dag_lib.Dag', ) +@usage_lib.entrypoint +def validate_dag(dag: 'dag_lib.Dag', + request_options: admin_policy.RequestOptions = None) -> None: + """Validates the specified DAG. + + Args: + dag: The DAG to validate. + request_options: Request options used in enforcing admin policies. + + Raises: + ValueError: if the DAG is invalid. + """ + dag, _ = admin_policy_utils.apply( + dag, + use_mutated_config_in_current_request=True, + request_options=request_options + ) + + for task in dag.tasks: + # Will validate workdir and file_mounts in the backend, as those + # need to be validated after the files are uploaded to the SkyPilot + # API server with `upload_mounts_to_api_server`. + task.validate_name() + task.validate_run() + for r in task.resources: + r.validate() + + @usage_lib.entrypoint def status( cluster_names: Optional[Union[str, List[str]]] = None, @@ -365,7 +393,7 @@ def _start( usage_lib.record_cluster_name_for_current_operation(cluster_name) with dag_lib.Dag(): - dummy_task = task.Task().set_resources(handle.launched_resources) + dummy_task = task_lib.Task().set_resources(handle.launched_resources) dummy_task.num_nodes = handle.launched_nodes handle = backend.provision(dummy_task, to_provision=handle.launched_resources, diff --git a/sky/server/server.py b/sky/server/server.py index 2bf20c662af..d66b7c87b04 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -256,35 +256,23 @@ async def list_accelerator_counts( @app.post('/validate') -async def validate(validate_body: payloads.ValidateBody) -> None: +async def validate(request: fastapi.Request, + validate_body: payloads.ValidateBody) -> None: """Validates the user's DAG.""" # TODO(SKY-1035): validate if existing cluster satisfies the requested # resources, e.g. sky exec --gpus V100:8 existing-cluster-with-no-gpus + + # We make the validation a separate request as it may require expensive + # network calls if an admin policy is applied. logger.debug(f'Validating tasks: {validate_body.dag}') - try: - dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag) - - # Apply admin policy since it may affect DAG validation. - # TODO: The admin policy may be a potentially expensive operation with - # network calls. Maybe this should be moved into a request run with the - # executor. - dag, _ = admin_policy_utils.apply( - dag, - use_mutated_config_in_current_request=True, - request_options=validate_body.request_options + executor.schedule_request( + request_id=request.state.request_id, + request_name='validate', + request_body=validate_body, + ignore_return_value=True, + func=core.validate_dag, + schedule_type=requests_lib.ScheduleType.SHORT ) - - for task in dag.tasks: - # Will validate workdir and file_mounts in the backend, as those - # need to be validated after the files are uploaded to the SkyPilot - # API server with `upload_mounts_to_api_server`. - task.validate_name() - task.validate_run() - for r in task.resources: - r.validate() - except Exception as e: # pylint: disable=broad-except - raise fastapi.HTTPException( - status_code=400, detail=exceptions.serialize_exception(e)) from e @app.post('/optimize') From d37644baab62c1168074b77b625ac181181c2c8e Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 25 Feb 2025 18:23:16 -0800 Subject: [PATCH 07/21] Fixes --- sky/client/sdk.py | 8 ++++---- sky/jobs/client/sdk.py | 1 + sky/serve/client/sdk.py | 8 ++++---- sky/server/server.py | 5 ----- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index 94647ef70a0..2fff9b23c61 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -25,7 +25,8 @@ import psutil import requests -from sky import backends, admin_policy +from sky import admin_policy +from sky import backends from sky import exceptions from sky import sky_logging from sky import skypilot_config @@ -37,7 +38,6 @@ from sky.skylet import constants from sky.usage import usage_lib from sky.utils import annotations -from sky.utils import admin_policy_utils from sky.utils import cluster_utils from sky.utils import common from sky.utils import common_utils @@ -224,7 +224,7 @@ def optimize( Args: dag: the DAG to optimize. minimize: whether to minimize cost or time. - cluster_name: name of the cluster to create/reuse. Used for admin policy validation. + cluster_name: name of the cluster. Used for admin policy validation. idle_minutes_to_autostop: autostop setting. Used for admin policy validation. down: whether to tear down the cluster. Used for admin policy validation. dryrun: whether this is a dryrun. Used for admin policy validation. @@ -277,7 +277,7 @@ def validate(dag: 'sky.Dag', dag: the DAG to validate. workdir_only: whether to only validate the workdir. This is used for `exec` as it does not need other files/folders in file_mounts. - cluster_name: name of the cluster to create/reuse. Used for admin policy validation. + cluster_name: name of the cluster. Used for admin policy validation. idle_minutes_to_autostop: autostop setting. Used for admin policy validation. down: whether to tear down the cluster. Used for admin policy validation. dryrun: whether this is a dryrun. Used for admin policy validation. diff --git a/sky/jobs/client/sdk.py b/sky/jobs/client/sdk.py index 05fefde4a22..7ae57a86182 100644 --- a/sky/jobs/client/sdk.py +++ b/sky/jobs/client/sdk.py @@ -7,6 +7,7 @@ import click import requests +from examples.airflow.workflow_clientserver.pythonapi_test import cluster_name from sky import sky_logging from sky.client import common as client_common from sky.client import sdk diff --git a/sky/serve/client/sdk.py b/sky/serve/client/sdk.py index 9173d3de3ee..589814d5865 100644 --- a/sky/serve/client/sdk.py +++ b/sky/serve/client/sdk.py @@ -51,8 +51,8 @@ def up( from sky.client import sdk # pylint: disable=import-outside-toplevel dag = dag_utils.convert_entrypoint_to_dag(task) - sdk.validate(dag, workdir_only=False, cluster_name=service_name) - request_id = sdk.optimize(dag, cluster_name=service_name) + sdk.validate(dag) + request_id = sdk.optimize(dag) sdk.stream_and_get(request_id) if _need_confirmation: prompt = f'Launching a new service {service_name!r}. Proceed?' @@ -107,8 +107,8 @@ def update( from sky.client import sdk # pylint: disable=import-outside-toplevel dag = dag_utils.convert_entrypoint_to_dag(task) - sdk.validate(dag, workdir_only=False, cluster_name=service_name) - request_id = sdk.optimize(dag, cluster_name=service_name) + sdk.validate(dag) + request_id = sdk.optimize(dag) sdk.stream_and_get(request_id) if _need_confirmation: click.confirm(f'Updating service {service_name!r}. Proceed?', diff --git a/sky/server/server.py b/sky/server/server.py index d66b7c87b04..b89a08fcc98 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -21,14 +21,11 @@ import starlette.middleware.base import sky -from sky import admin_policy from sky import check as sky_check from sky import clouds from sky import core -from sky import exceptions from sky import execution from sky import global_user_state -from sky import optimizer from sky import sky_logging from sky.clouds import service_catalog from sky.data import storage_utils @@ -45,9 +42,7 @@ from sky.usage import usage_lib from sky.utils import common as common_lib from sky.utils import common_utils -from sky.utils import dag_utils from sky.utils import status_lib -from sky.utils import admin_policy_utils # pylint: disable=ungrouped-imports if sys.version_info >= (3, 10): From c5164dc9dfff0dab43a7c3f3c1a5916f2b1b92fc Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 25 Feb 2025 18:35:28 -0800 Subject: [PATCH 08/21] lint --- sky/client/sdk.py | 58 ++++++++++++++++++++++-------------------- sky/core.py | 35 +++++++++++++------------ sky/jobs/client/sdk.py | 1 - sky/server/server.py | 16 +++++------- 4 files changed, 54 insertions(+), 56 deletions(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index 2fff9b23c61..55fca1d2a70 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -216,7 +216,7 @@ def optimize( minimize: common.OptimizeTarget = common.OptimizeTarget.COST, cluster_name: Optional[str] = None, idle_minutes_to_autostop: Optional[int] = None, - down: bool = False, + down: bool = False, # pylint: disable=redefined-outer-name dryrun: bool = False ) -> server_common.RequestId: """Finds the best execution plan for the given DAG. @@ -225,8 +225,10 @@ def optimize( dag: the DAG to optimize. minimize: whether to minimize cost or time. cluster_name: name of the cluster. Used for admin policy validation. - idle_minutes_to_autostop: autostop setting. Used for admin policy validation. - down: whether to tear down the cluster. Used for admin policy validation. + idle_minutes_to_autostop: autostop setting. Used for admin policy + validation. + down: whether to tear down the cluster. Used for admin policy + validation. dryrun: whether this is a dryrun. Used for admin policy validation. Returns: @@ -243,15 +245,13 @@ def optimize( dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) body = payloads.OptimizeBody( - dag=dag_str, + dag=dag_str, minimize=minimize, request_options=admin_policy.RequestOptions( cluster_name=cluster_name, idle_minutes_to_autostop=idle_minutes_to_autostop, down=down, - dryrun=dryrun - ) - ) + dryrun=dryrun)) response = requests.post(f'{server_common.get_server_url()}/optimize', json=json.loads(body.model_dump_json())) return server_common.get_request_id(response) @@ -260,12 +260,13 @@ def optimize( @usage_lib.entrypoint @server_common.check_server_healthy_or_start @annotations.client_api -def validate(dag: 'sky.Dag', - workdir_only: bool = False, - cluster_name: Optional[str] = None, - idle_minutes_to_autostop: Optional[int] = None, - down: bool = False, - dryrun: bool = False) -> None: +def validate( + dag: 'sky.Dag', + workdir_only: bool = False, + cluster_name: Optional[str] = None, + idle_minutes_to_autostop: Optional[int] = None, + down: bool = False, # pylint: disable=redefined-outer-name + dryrun: bool = False) -> None: """Validates the tasks. The file paths (workdir and file_mounts) are validated on the client side @@ -278,8 +279,10 @@ def validate(dag: 'sky.Dag', workdir_only: whether to only validate the workdir. This is used for `exec` as it does not need other files/folders in file_mounts. cluster_name: name of the cluster. Used for admin policy validation. - idle_minutes_to_autostop: autostop setting. Used for admin policy validation. - down: whether to tear down the cluster. Used for admin policy validation. + idle_minutes_to_autostop: autostop setting. Used for admin policy + validation. + down: whether to tear down the cluster. Used for admin policy + validation. dryrun: whether this is a dryrun. Used for admin policy validation. """ for task in dag.tasks: @@ -290,11 +293,10 @@ def validate(dag: 'sky.Dag', body = payloads.ValidateBody( dag=dag_str, request_options=admin_policy.RequestOptions( - cluster_name=cluster_name, - idle_minutes_to_autostop=idle_minutes_to_autostop, - down=down, - dryrun=dryrun) - ) + cluster_name=cluster_name, + idle_minutes_to_autostop=idle_minutes_to_autostop, + down=down, + dryrun=dryrun)) response = requests.post(f'{server_common.get_server_url()}/validate', json=json.loads(body.model_dump_json())) if response.status_code == 400: @@ -420,10 +422,10 @@ def launch( 'Please contact the SkyPilot team if you ' 'need this feature at slack.skypilot.co.') dag = dag_utils.convert_entrypoint_to_dag(task) - validate(dag, - cluster_name=cluster_name, - idle_minutes_to_autostop=idle_minutes_to_autostop, - down=down, + validate(dag, + cluster_name=cluster_name, + idle_minutes_to_autostop=idle_minutes_to_autostop, + down=down, dryrun=dryrun) confirm_shown = False @@ -574,10 +576,10 @@ def exec( # pylint: disable=redefined-builtin controller that does not support this operation. """ dag = dag_utils.convert_entrypoint_to_dag(task) - validate(dag, - workdir_only=True, - cluster_name=cluster_name, - dryrun=dryrun, + validate(dag, + workdir_only=True, + cluster_name=cluster_name, + dryrun=dryrun, down=down) dag = client_common.upload_mounts_to_api_server(dag, workdir_only=True) dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) diff --git a/sky/core.py b/sky/core.py index e7c8a29c6c6..33a4c9c5513 100644 --- a/sky/core.py +++ b/sky/core.py @@ -48,11 +48,13 @@ @usage_lib.entrypoint -def optimize(dag: 'dag_lib.Dag', - minimize: common.OptimizeTarget = common.OptimizeTarget.COST, - blocked_resources: Optional[List['resources_lib.Resources']] = None, - quiet: bool = False, - request_options: admin_policy.RequestOptions = None) -> 'dag_lib.Dag': +def optimize( + dag: 'dag_lib.Dag', + minimize: common.OptimizeTarget = common.OptimizeTarget.COST, + blocked_resources: Optional[List['resources_lib.Resources']] = None, + quiet: bool = False, + request_options: Optional[admin_policy.RequestOptions] = None +) -> 'dag_lib.Dag': """Finds the best execution plan for the given DAG. Args: @@ -73,19 +75,17 @@ def optimize(dag: 'dag_lib.Dag', dag, _ = admin_policy_utils.apply( dag, use_mutated_config_in_current_request=True, - request_options=request_options - ) - return optimizer.Optimizer.optimize( - dag=dag, - minimize=minimize, - blocked_resources=blocked_resources, - quiet=quiet - ) + request_options=request_options) + return optimizer.Optimizer.optimize(dag=dag, + minimize=minimize, + blocked_resources=blocked_resources, + quiet=quiet) @usage_lib.entrypoint -def validate_dag(dag: 'dag_lib.Dag', - request_options: admin_policy.RequestOptions = None) -> None: +def validate_dag( + dag: 'dag_lib.Dag', + request_options: Optional[admin_policy.RequestOptions] = None) -> None: """Validates the specified DAG. Args: @@ -98,9 +98,8 @@ def validate_dag(dag: 'dag_lib.Dag', dag, _ = admin_policy_utils.apply( dag, use_mutated_config_in_current_request=True, - request_options=request_options - ) - + request_options=request_options) + for task in dag.tasks: # Will validate workdir and file_mounts in the backend, as those # need to be validated after the files are uploaded to the SkyPilot diff --git a/sky/jobs/client/sdk.py b/sky/jobs/client/sdk.py index 7ae57a86182..05fefde4a22 100644 --- a/sky/jobs/client/sdk.py +++ b/sky/jobs/client/sdk.py @@ -7,7 +7,6 @@ import click import requests -from examples.airflow.workflow_clientserver.pythonapi_test import cluster_name from sky import sky_logging from sky.client import common as client_common from sky.client import sdk diff --git a/sky/server/server.py b/sky/server/server.py index b89a08fcc98..f87e372d055 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -252,7 +252,7 @@ async def list_accelerator_counts( @app.post('/validate') async def validate(request: fastapi.Request, - validate_body: payloads.ValidateBody) -> None: + validate_body: payloads.ValidateBody) -> None: """Validates the user's DAG.""" # TODO(SKY-1035): validate if existing cluster satisfies the requested # resources, e.g. sky exec --gpus V100:8 existing-cluster-with-no-gpus @@ -260,14 +260,12 @@ async def validate(request: fastapi.Request, # We make the validation a separate request as it may require expensive # network calls if an admin policy is applied. logger.debug(f'Validating tasks: {validate_body.dag}') - executor.schedule_request( - request_id=request.state.request_id, - request_name='validate', - request_body=validate_body, - ignore_return_value=True, - func=core.validate_dag, - schedule_type=requests_lib.ScheduleType.SHORT - ) + executor.schedule_request(request_id=request.state.request_id, + request_name='validate', + request_body=validate_body, + ignore_return_value=True, + func=core.validate_dag, + schedule_type=requests_lib.ScheduleType.SHORT) @app.post('/optimize') From 41f4d26fdc0067f50384d9c13c0460b656621bf5 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 25 Feb 2025 18:50:22 -0800 Subject: [PATCH 09/21] Add comments --- sky/core.py | 4 ++++ sky/server/server.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/sky/core.py b/sky/core.py index 33a4c9c5513..7387651b947 100644 --- a/sky/core.py +++ b/sky/core.py @@ -72,6 +72,10 @@ def optimize( for a task. exceptions.NoCloudAccessError: if no public clouds are enabled. """ + # TODO: We apply the admin policy only on the first DAG optimization which + # is shown on `sky launch`. The optimizer is also invoked during failover, + # but we do not apply the admin policy there. We should apply the admin + # policy in the optimizer, but that will require some refactoring. dag, _ = admin_policy_utils.apply( dag, use_mutated_config_in_current_request=True, diff --git a/sky/server/server.py b/sky/server/server.py index f87e372d055..caa48e95087 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -259,6 +259,11 @@ async def validate(request: fastapi.Request, # We make the validation a separate request as it may require expensive # network calls if an admin policy is applied. + # TODO: Our current launch process is broken down into three calls: + # validate, optimize, and launch. This requires us to apply the admin policy + # in each step, which may be an expensive operation. We should consolidate + # these into a single call or have a TTL cache for (task, admin_policy) + # pairs. logger.debug(f'Validating tasks: {validate_body.dag}') executor.schedule_request(request_id=request.state.request_id, request_name='validate', From aa335341fe186035e0d1dd2a1d9eecd122b51621 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Tue, 25 Feb 2025 18:57:16 -0800 Subject: [PATCH 10/21] lint --- sky/core.py | 6 +++--- sky/server/server.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sky/core.py b/sky/core.py index 7387651b947..4ef91fae29b 100644 --- a/sky/core.py +++ b/sky/core.py @@ -72,9 +72,9 @@ def optimize( for a task. exceptions.NoCloudAccessError: if no public clouds are enabled. """ - # TODO: We apply the admin policy only on the first DAG optimization which - # is shown on `sky launch`. The optimizer is also invoked during failover, - # but we do not apply the admin policy there. We should apply the admin + # TODO: We apply the admin policy only on the first DAG optimization which + # is shown on `sky launch`. The optimizer is also invoked during failover, + # but we do not apply the admin policy there. We should apply the admin # policy in the optimizer, but that will require some refactoring. dag, _ = admin_policy_utils.apply( dag, diff --git a/sky/server/server.py b/sky/server/server.py index caa48e95087..33e56dd4422 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -259,10 +259,10 @@ async def validate(request: fastapi.Request, # We make the validation a separate request as it may require expensive # network calls if an admin policy is applied. - # TODO: Our current launch process is broken down into three calls: + # TODO: Our current launch process is broken down into three calls: # validate, optimize, and launch. This requires us to apply the admin policy # in each step, which may be an expensive operation. We should consolidate - # these into a single call or have a TTL cache for (task, admin_policy) + # these into a single call or have a TTL cache for (task, admin_policy) # pairs. logger.debug(f'Validating tasks: {validate_body.dag}') executor.schedule_request(request_id=request.state.request_id, From 20828be853c2a15a33993d7f22efcf86e5241d55 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 26 Feb 2025 15:32:14 -0800 Subject: [PATCH 11/21] Fixed executor based validate implementation --- sky/client/sdk.py | 59 ++++++++++++--------------------- sky/server/requests/payloads.py | 26 +++++++++------ 2 files changed, 36 insertions(+), 49 deletions(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index 55fca1d2a70..86708bca5f5 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -214,22 +214,15 @@ def list_accelerator_counts( def optimize( dag: 'sky.Dag', minimize: common.OptimizeTarget = common.OptimizeTarget.COST, - cluster_name: Optional[str] = None, - idle_minutes_to_autostop: Optional[int] = None, - down: bool = False, # pylint: disable=redefined-outer-name - dryrun: bool = False + admin_policy_request_options: Optional[admin_policy.RequestOptions] = None ) -> server_common.RequestId: """Finds the best execution plan for the given DAG. Args: dag: the DAG to optimize. minimize: whether to minimize cost or time. - cluster_name: name of the cluster. Used for admin policy validation. - idle_minutes_to_autostop: autostop setting. Used for admin policy - validation. - down: whether to tear down the cluster. Used for admin policy + admin_policy_request_options: Request options used for admin policy validation. - dryrun: whether this is a dryrun. Used for admin policy validation. Returns: The request ID of the optimize request. @@ -247,11 +240,7 @@ def optimize( body = payloads.OptimizeBody( dag=dag_str, minimize=minimize, - request_options=admin_policy.RequestOptions( - cluster_name=cluster_name, - idle_minutes_to_autostop=idle_minutes_to_autostop, - down=down, - dryrun=dryrun)) + request_options=admin_policy_request_options) response = requests.post(f'{server_common.get_server_url()}/optimize', json=json.loads(body.model_dump_json())) return server_common.get_request_id(response) @@ -263,10 +252,7 @@ def optimize( def validate( dag: 'sky.Dag', workdir_only: bool = False, - cluster_name: Optional[str] = None, - idle_minutes_to_autostop: Optional[int] = None, - down: bool = False, # pylint: disable=redefined-outer-name - dryrun: bool = False) -> None: + admin_policy_request_options: Optional[admin_policy.RequestOptions] = None) -> None: """Validates the tasks. The file paths (workdir and file_mounts) are validated on the client side @@ -278,12 +264,8 @@ def validate( dag: the DAG to validate. workdir_only: whether to only validate the workdir. This is used for `exec` as it does not need other files/folders in file_mounts. - cluster_name: name of the cluster. Used for admin policy validation. - idle_minutes_to_autostop: autostop setting. Used for admin policy + admin_policy_request_options: Request options used for admin policy validation. - down: whether to tear down the cluster. Used for admin policy - validation. - dryrun: whether this is a dryrun. Used for admin policy validation. """ for task in dag.tasks: task.expand_and_validate_workdir() @@ -292,17 +274,15 @@ def validate( dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) body = payloads.ValidateBody( dag=dag_str, - request_options=admin_policy.RequestOptions( - cluster_name=cluster_name, - idle_minutes_to_autostop=idle_minutes_to_autostop, - down=down, - dryrun=dryrun)) + request_options=admin_policy_request_options) response = requests.post(f'{server_common.get_server_url()}/validate', json=json.loads(body.model_dump_json())) if response.status_code == 400: with ux_utils.print_exception_no_traceback(): raise exceptions.deserialize_exception( response.json().get('detail')) + request_id = server_common.get_request_id(response) + get(request_id) # Make sure the request completed with no exceptions @usage_lib.entrypoint @@ -422,11 +402,13 @@ def launch( 'Please contact the SkyPilot team if you ' 'need this feature at slack.skypilot.co.') dag = dag_utils.convert_entrypoint_to_dag(task) - validate(dag, - cluster_name=cluster_name, - idle_minutes_to_autostop=idle_minutes_to_autostop, - down=down, - dryrun=dryrun) + request_options = admin_policy.RequestOptions( + cluster_name=cluster_name, + idle_minutes_to_autostop=idle_minutes_to_autostop, + down=down, + dryrun=dryrun + ) + validate(dag, admin_policy_request_options=request_options) confirm_shown = False if _need_confirmation: @@ -576,11 +558,12 @@ def exec( # pylint: disable=redefined-builtin controller that does not support this operation. """ dag = dag_utils.convert_entrypoint_to_dag(task) - validate(dag, - workdir_only=True, - cluster_name=cluster_name, - dryrun=dryrun, - down=down) + request_options = admin_policy.RequestOptions( + cluster_name=cluster_name, + down=down, + dryrun=dryrun + ) + validate(dag, workdir_only=True, admin_policy_request_options=request_options) dag = client_common.upload_mounts_to_api_server(dag, workdir_only=True) dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) body = payloads.ExecBody( diff --git a/sky/server/requests/payloads.py b/sky/server/requests/payloads.py index 81b25118f35..fe438e22bcf 100644 --- a/sky/server/requests/payloads.py +++ b/sky/server/requests/payloads.py @@ -113,18 +113,9 @@ class CheckBody(RequestBody): clouds: Optional[Tuple[str, ...]] verbose: bool - -class ValidateBody(RequestBody): - """The request body for the validate endpoint.""" - dag: str - request_options: admin_policy.RequestOptions - - -class OptimizeBody(RequestBody): - """The request body for the optimize endpoint.""" +class DagRequestBody(RequestBody): + """Request body base class for endpoints with a dag.""" dag: str - minimize: common_lib.OptimizeTarget = common_lib.OptimizeTarget.COST - request_options: admin_policy.RequestOptions def to_kwargs(self) -> Dict[str, Any]: # Import here to avoid requirement of the whole SkyPilot dependency on @@ -142,6 +133,19 @@ def to_kwargs(self) -> Dict[str, Any]: return kwargs +class ValidateBody(DagRequestBody): + """The request body for the validate endpoint.""" + dag: str + request_options: Optional[admin_policy.RequestOptions] + + +class OptimizeBody(DagRequestBody): + """The request body for the optimize endpoint.""" + dag: str + minimize: common_lib.OptimizeTarget = common_lib.OptimizeTarget.COST + request_options: Optional[admin_policy.RequestOptions] + + class LaunchBody(RequestBody): """The request body for the launch endpoint.""" task: str From 636232856e35f8e0a45cb0b75e322e176036ec64 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 26 Feb 2025 15:53:31 -0800 Subject: [PATCH 12/21] Revert executor based validate implementation --- sky/client/sdk.py | 2 -- sky/server/server.py | 24 ++++++++++++++++++------ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index 86708bca5f5..d6b2783fe41 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -281,8 +281,6 @@ def validate( with ux_utils.print_exception_no_traceback(): raise exceptions.deserialize_exception( response.json().get('detail')) - request_id = server_common.get_request_id(response) - get(request_id) # Make sure the request completed with no exceptions @usage_lib.entrypoint diff --git a/sky/server/server.py b/sky/server/server.py index 33e56dd4422..0aef832b031 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -24,9 +24,11 @@ from sky import check as sky_check from sky import clouds from sky import core +from sky import exceptions from sky import execution from sky import global_user_state from sky import sky_logging +from sky import skypilot_config from sky.clouds import service_catalog from sky.data import storage_utils from sky.jobs.server import server as jobs_rest @@ -42,7 +44,9 @@ from sky.usage import usage_lib from sky.utils import common as common_lib from sky.utils import common_utils +from sky.utils import dag_utils from sky.utils import status_lib +from sky.utils import admin_policy_utils # pylint: disable=ungrouped-imports if sys.version_info >= (3, 10): @@ -265,12 +269,20 @@ async def validate(request: fastapi.Request, # these into a single call or have a TTL cache for (task, admin_policy) # pairs. logger.debug(f'Validating tasks: {validate_body.dag}') - executor.schedule_request(request_id=request.state.request_id, - request_name='validate', - request_body=validate_body, - ignore_return_value=True, - func=core.validate_dag, - schedule_type=requests_lib.ScheduleType.SHORT) + try: + dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag) + dag, _ = admin_policy_utils.apply(dag, request_options=validate_body.request_options) + for task in dag.tasks: + # Will validate workdir and file_mounts in the backend, as those + # need to be validated after the files are uploaded to the SkyPilot + # API server with `upload_mounts_to_api_server`. + task.validate_name() + task.validate_run() + for r in task.resources: + r.validate() + except Exception as e: # pylint: disable=broad-except + raise fastapi.HTTPException( + status_code=400, detail=exceptions.serialize_exception(e)) from e @app.post('/optimize') From 6f94e3b5e22dddce1896bb63438ee82484605469 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 26 Feb 2025 15:54:49 -0800 Subject: [PATCH 13/21] lint --- sky/client/sdk.py | 35 ++++++++++++++++----------------- sky/server/requests/payloads.py | 1 + sky/server/server.py | 5 +++-- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index d6b2783fe41..1fe1913f965 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -237,10 +237,9 @@ def optimize( """ dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) - body = payloads.OptimizeBody( - dag=dag_str, - minimize=minimize, - request_options=admin_policy_request_options) + body = payloads.OptimizeBody(dag=dag_str, + minimize=minimize, + request_options=admin_policy_request_options) response = requests.post(f'{server_common.get_server_url()}/optimize', json=json.loads(body.model_dump_json())) return server_common.get_request_id(response) @@ -250,9 +249,10 @@ def optimize( @server_common.check_server_healthy_or_start @annotations.client_api def validate( - dag: 'sky.Dag', - workdir_only: bool = False, - admin_policy_request_options: Optional[admin_policy.RequestOptions] = None) -> None: + dag: 'sky.Dag', + workdir_only: bool = False, + admin_policy_request_options: Optional[admin_policy.RequestOptions] = None +) -> None: """Validates the tasks. The file paths (workdir and file_mounts) are validated on the client side @@ -272,9 +272,8 @@ def validate( if not workdir_only: task.expand_and_validate_file_mounts() dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) - body = payloads.ValidateBody( - dag=dag_str, - request_options=admin_policy_request_options) + body = payloads.ValidateBody(dag=dag_str, + request_options=admin_policy_request_options) response = requests.post(f'{server_common.get_server_url()}/validate', json=json.loads(body.model_dump_json())) if response.status_code == 400: @@ -404,8 +403,7 @@ def launch( cluster_name=cluster_name, idle_minutes_to_autostop=idle_minutes_to_autostop, down=down, - dryrun=dryrun - ) + dryrun=dryrun) validate(dag, admin_policy_request_options=request_options) confirm_shown = False @@ -556,12 +554,13 @@ def exec( # pylint: disable=redefined-builtin controller that does not support this operation. """ dag = dag_utils.convert_entrypoint_to_dag(task) - request_options = admin_policy.RequestOptions( - cluster_name=cluster_name, - down=down, - dryrun=dryrun - ) - validate(dag, workdir_only=True, admin_policy_request_options=request_options) + request_options = admin_policy.RequestOptions(cluster_name=cluster_name, + idle_minutes_to_autostop=None, + down=down, + dryrun=dryrun) + validate(dag, + workdir_only=True, + admin_policy_request_options=request_options) dag = client_common.upload_mounts_to_api_server(dag, workdir_only=True) dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) body = payloads.ExecBody( diff --git a/sky/server/requests/payloads.py b/sky/server/requests/payloads.py index fe438e22bcf..efb0311439f 100644 --- a/sky/server/requests/payloads.py +++ b/sky/server/requests/payloads.py @@ -113,6 +113,7 @@ class CheckBody(RequestBody): clouds: Optional[Tuple[str, ...]] verbose: bool + class DagRequestBody(RequestBody): """Request body base class for endpoints with a dag.""" dag: str diff --git a/sky/server/server.py b/sky/server/server.py index 0aef832b031..a1f189c996a 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -42,11 +42,11 @@ from sky.server.requests import requests as requests_lib from sky.skylet import constants from sky.usage import usage_lib +from sky.utils import admin_policy_utils from sky.utils import common as common_lib from sky.utils import common_utils from sky.utils import dag_utils from sky.utils import status_lib -from sky.utils import admin_policy_utils # pylint: disable=ungrouped-imports if sys.version_info >= (3, 10): @@ -271,7 +271,8 @@ async def validate(request: fastapi.Request, logger.debug(f'Validating tasks: {validate_body.dag}') try: dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag) - dag, _ = admin_policy_utils.apply(dag, request_options=validate_body.request_options) + dag, _ = admin_policy_utils.apply( + dag, request_options=validate_body.request_options) for task in dag.tasks: # Will validate workdir and file_mounts in the backend, as those # need to be validated after the files are uploaded to the SkyPilot From 78169ceea6d784cfab045b24919595f688101f95 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 26 Feb 2025 15:57:44 -0800 Subject: [PATCH 14/21] lint --- sky/server/server.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sky/server/server.py b/sky/server/server.py index a1f189c996a..7ffe8a4051f 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -28,7 +28,6 @@ from sky import execution from sky import global_user_state from sky import sky_logging -from sky import skypilot_config from sky.clouds import service_catalog from sky.data import storage_utils from sky.jobs.server import server as jobs_rest @@ -255,14 +254,11 @@ async def list_accelerator_counts( @app.post('/validate') -async def validate(request: fastapi.Request, - validate_body: payloads.ValidateBody) -> None: +async def validate(validate_body: payloads.ValidateBody) -> None: """Validates the user's DAG.""" # TODO(SKY-1035): validate if existing cluster satisfies the requested # resources, e.g. sky exec --gpus V100:8 existing-cluster-with-no-gpus - # We make the validation a separate request as it may require expensive - # network calls if an admin policy is applied. # TODO: Our current launch process is broken down into three calls: # validate, optimize, and launch. This requires us to apply the admin policy # in each step, which may be an expensive operation. We should consolidate @@ -271,6 +267,11 @@ async def validate(request: fastapi.Request, logger.debug(f'Validating tasks: {validate_body.dag}') try: dag = dag_utils.load_chain_dag_from_yaml_str(validate_body.dag) + # TODO: Admin policy may contain arbitrary code, which may be expensive + # to run and may block the server thread. However, moving it into the + # executor adds a ~150ms penalty on the local API server because of + # added RTTs. For now, we stick to doing the validation inline in the + # server thread. dag, _ = admin_policy_utils.apply( dag, request_options=validate_body.request_options) for task in dag.tasks: From 0d9da0de66608f81d0a8b9e94f106fb8cc4e6065 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 26 Feb 2025 16:00:15 -0800 Subject: [PATCH 15/21] Add validation during optimize --- sky/client/sdk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index 1fe1913f965..7b5f20c46fd 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -418,7 +418,7 @@ def launch( if not clusters: # Show the optimize log before the prompt if the cluster does not # exist. - request_id = optimize(dag) + request_id = optimize(dag, admin_policy_request_options=request_options) stream_and_get(request_id) else: cluster_record = clusters[0] From 478b2a4f3027ea235c11a19d13be48b598e66c38 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 26 Feb 2025 16:05:01 -0800 Subject: [PATCH 16/21] lint --- sky/client/sdk.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index 7b5f20c46fd..ecc6c97f027 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -418,7 +418,8 @@ def launch( if not clusters: # Show the optimize log before the prompt if the cluster does not # exist. - request_id = optimize(dag, admin_policy_request_options=request_options) + request_id = optimize(dag, + admin_policy_request_options=request_options) stream_and_get(request_id) else: cluster_record = clusters[0] From 37cd372f21aeb08a4dcdc3969609dd047213f9de Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 26 Feb 2025 16:10:39 -0800 Subject: [PATCH 17/21] Remove validate from core --- sky/core.py | 28 ---------------------------- 1 file changed, 28 deletions(-) diff --git a/sky/core.py b/sky/core.py index 4ef91fae29b..618ed23f71d 100644 --- a/sky/core.py +++ b/sky/core.py @@ -86,34 +86,6 @@ def optimize( quiet=quiet) -@usage_lib.entrypoint -def validate_dag( - dag: 'dag_lib.Dag', - request_options: Optional[admin_policy.RequestOptions] = None) -> None: - """Validates the specified DAG. - - Args: - dag: The DAG to validate. - request_options: Request options used in enforcing admin policies. - - Raises: - ValueError: if the DAG is invalid. - """ - dag, _ = admin_policy_utils.apply( - dag, - use_mutated_config_in_current_request=True, - request_options=request_options) - - for task in dag.tasks: - # Will validate workdir and file_mounts in the backend, as those - # need to be validated after the files are uploaded to the SkyPilot - # API server with `upload_mounts_to_api_server`. - task.validate_name() - task.validate_run() - for r in task.resources: - r.validate() - - @usage_lib.entrypoint def status( cluster_names: Optional[Union[str, List[str]]] = None, From 0c7c078fdad86f9ac3b6c4569109b8e5b06370f7 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Wed, 26 Feb 2025 16:14:18 -0800 Subject: [PATCH 18/21] Remove admin policy apply when validating dag for exec --- sky/client/sdk.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index ecc6c97f027..f83b3c8f628 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -555,13 +555,7 @@ def exec( # pylint: disable=redefined-builtin controller that does not support this operation. """ dag = dag_utils.convert_entrypoint_to_dag(task) - request_options = admin_policy.RequestOptions(cluster_name=cluster_name, - idle_minutes_to_autostop=None, - down=down, - dryrun=dryrun) - validate(dag, - workdir_only=True, - admin_policy_request_options=request_options) + validate(dag, workdir_only=True) dag = client_common.upload_mounts_to_api_server(dag, workdir_only=True) dag_str = dag_utils.dump_chain_dag_to_yaml_str(dag) body = payloads.ExecBody( From 07f2f3d9f3dcd15e65b0d8dccb3c285d0504b2dc Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Fri, 28 Feb 2025 12:13:03 -0800 Subject: [PATCH 19/21] comments --- sky/client/sdk.py | 6 ++++-- sky/core.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sky/client/sdk.py b/sky/client/sdk.py index f83b3c8f628..85757e783e4 100644 --- a/sky/client/sdk.py +++ b/sky/client/sdk.py @@ -222,7 +222,8 @@ def optimize( dag: the DAG to optimize. minimize: whether to minimize cost or time. admin_policy_request_options: Request options used for admin policy - validation. + validation. This is only required when a admin policy is in use, + see: https://docs.skypilot.co/en/latest/cloud-setup/policy.html Returns: The request ID of the optimize request. @@ -265,7 +266,8 @@ def validate( workdir_only: whether to only validate the workdir. This is used for `exec` as it does not need other files/folders in file_mounts. admin_policy_request_options: Request options used for admin policy - validation. + validation. This is only required when a admin policy is in use, + see: https://docs.skypilot.co/en/latest/cloud-setup/policy.html """ for task in dag.tasks: task.expand_and_validate_workdir() diff --git a/sky/core.py b/sky/core.py index 618ed23f71d..db654d19489 100644 --- a/sky/core.py +++ b/sky/core.py @@ -63,7 +63,8 @@ def optimize( blocked_resources: a list of resources that should not be used. quiet: whether to suppress logging. request_options: Request options used in enforcing admin policies. - + This is only required when a admin policy is in use, + see: https://docs.skypilot.co/en/latest/cloud-setup/policy.html Returns: The optimized DAG. From ff27bccf1a01b450857ca1c923825a7ed0654bb4 Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Fri, 28 Feb 2025 12:59:52 -0800 Subject: [PATCH 20/21] Bump API version --- sky/server/constants.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/server/constants.py b/sky/server/constants.py index 979d8bafbe8..304684f7b3b 100644 --- a/sky/server/constants.py +++ b/sky/server/constants.py @@ -3,7 +3,7 @@ # API server version, whenever there is a change in API server that requires a # restart of the local API server or error out when the client does not match # the server version. -API_VERSION = '2' +API_VERSION = '3' # Prefix for API request names. REQUEST_NAME_PREFIX = 'sky.' From 6691dc143bab35a0eb3261a030e406ee34be12be Mon Sep 17 00:00:00 2001 From: Romil Bhardwaj Date: Mon, 3 Mar 2025 12:23:45 -0800 Subject: [PATCH 21/21] comments --- sky/server/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/server/server.py b/sky/server/server.py index 16d0c3a74da..51fda3de9e5 100644 --- a/sky/server/server.py +++ b/sky/server/server.py @@ -259,7 +259,7 @@ async def validate(validate_body: payloads.ValidateBody) -> None: # TODO(SKY-1035): validate if existing cluster satisfies the requested # resources, e.g. sky exec --gpus V100:8 existing-cluster-with-no-gpus - # TODO: Our current launch process is broken down into three calls: + # TODO: Our current launch process is split into three calls: # validate, optimize, and launch. This requires us to apply the admin policy # in each step, which may be an expensive operation. We should consolidate # these into a single call or have a TTL cache for (task, admin_policy)