From 231f68a054da3753e0c657d6eca2122e29803da6 Mon Sep 17 00:00:00 2001 From: Jack P Date: Mon, 27 Nov 2023 20:15:41 -0600 Subject: [PATCH] WorkerV2 - Cloud Run V2 API (#220) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 👌 IMPROVE: started on cloud run v2 block, added foundation for all, building out abstract cmds now * 👌 IMPROVE: rough draft, still missing kill method, but ready to start on workerv2 * 🤖 TEST: testing * ‼️ BREAKING: removing memory string test * 👌 IMPROVE: creates successfully * 🐛 FIX: removed client * 🐛 FIX: fixed validation bug * 🐛 FIX: fix mispelling for var * 🐛 FIX: fixed and simplified job name generation * 🐛 FIX: fixed issue with observedGeneration keyerror * 🤖 TEST: test config job name as var * 🐛 FIX: made IDs use whole Prefect Name * 👌 IMPROVE: working cancelation functionality * 🤖 TEST: added cloud run v2 block unit tests * 📖 DOC: added some unit tests and docstrings * 🐛 FIX: applied some ci test fixes * 🐛 FIX: fixed some failing unit tests * 🐛 FIX: fix missing optional and remove | None * 🐛 FIX: added the Pydantic Version conditional import logic to cr2 worker and block * 🐛 FIX: PR feedback implemented, separate cloud run v2 block infra to models directory * 🐛 FIX: attempted fix on 3.8 test changing list to typing.List * 👌 IMPROVE: added get_prefect_image_name as default factory to image cloud run v2 * 🐛 FIX: fix ci test and hardcode default to prefect 2 latest image --------- Co-authored-by: nate nowack --- .gitignore | 3 + CHANGELOG.md | 2 + docs/cloud_run_worker_v2.md | 1 + mkdocs.yml | 2 + prefect_gcp/__init__.py | 1 + prefect_gcp/models/__init__.py | 0 prefect_gcp/models/cloud_run_v2.py | 393 +++++++++++++ prefect_gcp/workers/cloud_run_v2.py | 845 ++++++++++++++++++++++++++++ tests/test_cloud_run_v2.py | 191 +++++++ tests/test_cloud_run_worker_v2.py | 94 ++++ 10 files changed, 1532 insertions(+) create mode 100644 docs/cloud_run_worker_v2.md create mode 100644 prefect_gcp/models/__init__.py create mode 100644 prefect_gcp/models/cloud_run_v2.py create mode 100644 prefect_gcp/workers/cloud_run_v2.py create mode 100644 tests/test_cloud_run_v2.py create mode 100644 tests/test_cloud_run_worker_v2.py diff --git a/.gitignore b/.gitignore index b96a3be4..d3d36ecb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# JetBrains +.idea + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/CHANGELOG.md b/CHANGELOG.md index 75619f78..1241ff6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `CloudRunJobV2` and `CloudRunWorkerV2` for executing Prefect flows via Google Cloud Run - [#220](https://github.com/PrefectHQ/prefect-gcp/pull/220) + ### Changed ### Deprecated diff --git a/docs/cloud_run_worker_v2.md b/docs/cloud_run_worker_v2.md new file mode 100644 index 00000000..25dc3033 --- /dev/null +++ b/docs/cloud_run_worker_v2.md @@ -0,0 +1 @@ +::: prefect_gcp.workers.cloud_run_v2 diff --git a/mkdocs.yml b/mkdocs.yml index 00e8ce04..d69a0973 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -84,10 +84,12 @@ nav: - BigQuery: bigquery.md - Secret Manager: secret_manager.md - Cloud Run: cloud_run.md + - Cloud Run V2: cloud_run_v2.md - AI Platform: aiplatform.md - Deployment Steps: deployments/steps.md - Workers: - Cloud Run: cloud_run_worker.md + - Cloud Run V2: cloud_run_worker_v2.md - Vertex AI: vertex_worker.md extra: diff --git a/prefect_gcp/__init__.py b/prefect_gcp/__init__.py index 5e616175..37c97aad 100644 --- a/prefect_gcp/__init__.py +++ b/prefect_gcp/__init__.py @@ -11,6 +11,7 @@ from .secret_manager import GcpSecret # noqa from .workers.vertex import VertexAIWorker # noqa from .workers.cloud_run import CloudRunWorker # noqa +from .workers.cloud_run_v2 import CloudRunWorkerV2 # noqa register_renamed_module( "prefect_gcp.projects", "prefect_gcp.deployments", start_date="Jun 2023" diff --git a/prefect_gcp/models/__init__.py b/prefect_gcp/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/prefect_gcp/models/cloud_run_v2.py b/prefect_gcp/models/cloud_run_v2.py new file mode 100644 index 00000000..664faf9a --- /dev/null +++ b/prefect_gcp/models/cloud_run_v2.py @@ -0,0 +1,393 @@ +import time +from typing import Dict, List, Literal, Optional + +# noinspection PyProtectedMember +from googleapiclient.discovery import Resource +from prefect.infrastructure.base import InfrastructureResult +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import BaseModel +else: + from pydantic import BaseModel + + +class JobV2(BaseModel): + """ + JobV2 is a data model for a job that will be run on Cloud Run with the V2 API. + """ + + name: str + uid: str + generation: str + labels: Dict[str, str] + annotations: Dict[str, str] + createTime: str + updateTime: str + deleteTime: Optional[str] + expireTime: Optional[str] + creator: Optional[str] + lastModifier: Optional[str] + client: Optional[str] + clientVersion: Optional[str] + launchStage: Literal[ + "ALPHA", + "BETA", + "GA", + "DEPRECATED", + "EARLY_ACCESS", + "PRELAUNCH", + "UNIMPLEMENTED", + "LAUNCH_TAG_UNSPECIFIED", + ] + binaryAuthorization: Dict + template: Dict + observedGeneration: Optional[str] + terminalCondition: Dict + conditions: List[Dict] + executionCount: int + latestCreatedExecution: Dict + reconciling: bool + satisfiesPzs: bool + etag: str + + def is_ready(self) -> bool: + """ + Check if the job is ready to run. + + Returns: + Whether the job is ready to run. + """ + ready_condition = self.get_ready_condition() + + if self._is_missing_container(ready_condition=ready_condition): + raise Exception(f"{ready_condition.get('message')}") + + return ready_condition.get("state") == "CONDITION_SUCCEEDED" + + def get_ready_condition(self) -> Dict: + """ + Get the ready condition for the job. + + Returns: + The ready condition for the job. + """ + if self.terminalCondition.get("type") == "Ready": + return self.terminalCondition + + return {} + + @classmethod + def get( + cls, + cr_client: Resource, + project: str, + location: str, + job_name: str, + ): + """ + Get a job from Cloud Run with the V2 API. + + Args: + cr_client: The base client needed for interacting with GCP + Cloud Run V2 API. + project: The GCP project ID. + location: The GCP region. + job_name: The name of the job to get. + """ + # noinspection PyUnresolvedReferences + request = cr_client.jobs().get( + name=f"projects/{project}/locations/{location}/jobs/{job_name}", + ) + + response = request.execute() + + return cls( + name=response["name"], + uid=response["uid"], + generation=response["generation"], + labels=response.get("labels", {}), + annotations=response.get("annotations", {}), + createTime=response["createTime"], + updateTime=response["updateTime"], + deleteTime=response.get("deleteTime"), + expireTime=response.get("expireTime"), + creator=response.get("creator"), + lastModifier=response.get("lastModifier"), + client=response.get("client"), + clientVersion=response.get("clientVersion"), + launchStage=response.get("launchStage", "GA"), + binaryAuthorization=response.get("binaryAuthorization", {}), + template=response.get("template"), + observedGeneration=response.get("observedGeneration"), + terminalCondition=response.get("terminalCondition", {}), + conditions=response.get("conditions", []), + executionCount=response.get("executionCount", 0), + latestCreatedExecution=response["latestCreatedExecution"], + reconciling=response.get("reconciling", False), + satisfiesPzs=response.get("satisfiesPzs", False), + etag=response["etag"], + ) + + @staticmethod + def create( + cr_client: Resource, + project: str, + location: str, + job_id: str, + body: Dict, + ) -> Dict: + """ + Create a job on Cloud Run with the V2 API. + + Args: + cr_client: The base client needed for interacting with GCP + Cloud Run V2 API. + project: The GCP project ID. + location: The GCP region. + job_id: The ID of the job to create. + body: The job body. + + Returns: + The response from the Cloud Run V2 API. + """ + # noinspection PyUnresolvedReferences + request = cr_client.jobs().create( + parent=f"projects/{project}/locations/{location}", + jobId=job_id, + body=body, + ) + + response = request.execute() + + return response + + @staticmethod + def delete( + cr_client: Resource, + project: str, + location: str, + job_name: str, + ) -> Dict: + """ + Delete a job on Cloud Run with the V2 API. + + Args: + cr_client (Resource): The base client needed for interacting with GCP + Cloud Run V2 API. + project: The GCP project ID. + location: The GCP region. + job_name: The name of the job to delete. + + Returns: + Dict: The response from the Cloud Run V2 API. + """ + # noinspection PyUnresolvedReferences + list_executions_request = ( + cr_client.jobs() + .executions() + .list( + parent=f"projects/{project}/locations/{location}/jobs/{job_name}", + ) + ) + list_executions_response = list_executions_request.execute() + + for execution_to_delete in list_executions_response.get("executions", []): + # noinspection PyUnresolvedReferences + delete_execution_request = ( + cr_client.jobs() + .executions() + .delete( + name=execution_to_delete["name"], + ) + ) + delete_execution_request.execute() + + # Sleep 3 seconds so that the execution is deleted before deleting the job + time.sleep(3) + + # noinspection PyUnresolvedReferences + request = cr_client.jobs().delete( + name=f"projects/{project}/locations/{location}/jobs/{job_name}", + ) + + response = request.execute() + + return response + + @staticmethod + def run( + cr_client: Resource, + project: str, + location: str, + job_name: str, + ): + """ + Run a job on Cloud Run with the V2 API. + + Args: + cr_client: The base client needed for interacting with GCP + Cloud Run V2 API. + project: The GCP project ID. + location: The GCP region. + job_name: The name of the job to run. + """ + # noinspection PyUnresolvedReferences + request = cr_client.jobs().run( + name=f"projects/{project}/locations/{location}/jobs/{job_name}", + ) + + response = request.execute() + + return response + + @staticmethod + def _is_missing_container(ready_condition: Dict) -> bool: + """ + Check if the job is missing a container. + + Args: + ready_condition: The ready condition for the job. + + Returns: + Whether the job is missing a container. + """ + if ( + ready_condition.get("state") == "CONTAINER_FAILED" + and ready_condition.get("reason") == "ContainerMissing" + ): + return True + + return False + + +class ExecutionV2(BaseModel): + """ + ExecutionV2 is a data model for an execution of a job that will be run on + Cloud Run API v2. + """ + + name: str + uid: str + generation: str + labels: Dict[str, str] + annotations: Dict[str, str] + createTime: str + startTime: Optional[str] + completionTime: Optional[str] + deleteTime: Optional[str] + expireTime: Optional[str] + launchStage: Literal[ + "ALPHA", + "BETA", + "GA", + "DEPRECATED", + "EARLY_ACCESS", + "PRELAUNCH", + "UNIMPLEMENTED", + "LAUNCH_TAGE_UNSPECIFIED", + ] + job: str + parallelism: int + taskCount: int + template: Dict + reconciling: bool + conditions: List[Dict] + observedGeneration: Optional[str] + runningCount: Optional[int] + succeededCount: Optional[int] + failedCount: Optional[int] + cancelledCount: Optional[int] + retriedCount: Optional[int] + logUri: str + satisfiesPzs: bool + etag: str + + def is_running(self) -> bool: + """ + Return whether the execution is running. + + Returns: + Whether the execution is running. + """ + return self.completionTime is None + + def succeeded(self) -> bool: + """ + Return whether the execution succeeded. + + Returns: + Whether the execution succeeded. + """ + return True if self.condition_after_completion() else False + + def condition_after_completion(self) -> Dict: + """ + Return the condition after completion. + + Returns: + The condition after completion. + """ + if isinstance(self.conditions, List): + for condition in self.conditions: + if ( + condition["state"] == "CONDITION_SUCCEEDED" + and condition["type"] == "Completed" + ): + return condition + + return {} + + @classmethod + def get( + cls, + cr_client: Resource, + execution_id: str, + ): + """ + Get an execution from Cloud Run with the V2 API. + + Args: + cr_client: The base client needed for interacting with GCP + Cloud Run V2 API. + execution_id: The name of the execution to get, in the form of + projects/{project}/locations/{location}/jobs/{job}/executions + /{execution} + """ + # noinspection PyUnresolvedReferences + request = cr_client.jobs().executions().get(name=execution_id) + + response = request.execute() + + return cls( + name=response["name"], + uid=response["uid"], + generation=response["generation"], + labels=response.get("labels", {}), + annotations=response.get("annotations", {}), + createTime=response["createTime"], + startTime=response.get("startTime"), + completionTime=response.get("completionTime"), + deleteTime=response.get("deleteTime"), + expireTime=response.get("expireTime"), + launchStage=response.get("launchStage", "GA"), + job=response["job"], + parallelism=response["parallelism"], + taskCount=response["taskCount"], + template=response["template"], + reconciling=response.get("reconciling", False), + conditions=response.get("conditions", []), + observedGeneration=response.get("observedGeneration"), + runningCount=response.get("runningCount"), + succeededCount=response.get("succeededCount"), + failedCount=response.get("failedCount"), + cancelledCount=response.get("cancelledCount"), + retriedCount=response.get("retriedCount"), + logUri=response["logUri"], + satisfiesPzs=response.get("satisfiesPzs", False), + etag=response["etag"], + ) + + +class CloudRunJobV2Result(InfrastructureResult): + """Result from a Cloud Run Job.""" diff --git a/prefect_gcp/workers/cloud_run_v2.py b/prefect_gcp/workers/cloud_run_v2.py new file mode 100644 index 00000000..68cb8c49 --- /dev/null +++ b/prefect_gcp/workers/cloud_run_v2.py @@ -0,0 +1,845 @@ +import re +import shlex +import time +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional + +from anyio.abc import TaskStatus +from google.api_core.client_options import ClientOptions +from googleapiclient import discovery + +# noinspection PyProtectedMember +from googleapiclient.discovery import Resource +from googleapiclient.errors import HttpError +from prefect.exceptions import InfrastructureNotFound +from prefect.logging.loggers import PrefectLogAdapter +from prefect.utilities.asyncutils import run_sync_in_worker_thread +from prefect.utilities.dockerutils import get_prefect_image_name +from prefect.utilities.pydantic import JsonPatch +from prefect.workers.base import ( + BaseJobConfiguration, + BaseVariables, + BaseWorker, + BaseWorkerResult, +) +from pydantic import VERSION as PYDANTIC_VERSION + +if PYDANTIC_VERSION.startswith("2."): + from pydantic.v1 import Field, validator +else: + from pydantic import Field, validator + +from prefect_gcp.credentials import GcpCredentials +from prefect_gcp.models.cloud_run_v2 import CloudRunJobV2Result, ExecutionV2, JobV2 + +if TYPE_CHECKING: + from prefect.client.schemas import FlowRun + from prefect.server.schemas.core import Flow + from prefect.server.schemas.responses import DeploymentResponse + + +def _get_default_job_body_template() -> Dict[str, Any]: + """ + Returns the default job body template for the Cloud Run worker. + + Returns: + The default job body template. + """ + return { + "client": "prefect", + "launchStage": "{{ launch_stage }}", + "template": { + "template": { + "maxRetries": "{{ max_retries }}", + "timeout": "{{ timeout }}", + "containers": [ + { + "env": [], + "image": "{{ image }}", + "command": "{{ command }}", + "args": "{{ args }}", + "resources": { + "limits": { + "cpu": "{{ cpu }}", + "memory": "{{ memory }}", + }, + }, + }, + ], + } + }, + } + + +def _get_base_job_body() -> Dict[str, Any]: + """ + Returns the base job body for the Cloud Run worker's job body validation. + + Returns: + The base job body. + """ + return { + "template": { + "template": { + "containers": [], + }, + }, + } + + +class CloudRunWorkerJobV2Configuration(BaseJobConfiguration): + """ + The configuration for the Cloud Run worker V2. + + The schema for this class is used to populate the `job_body` section of the + default base job template. + """ + + credentials: GcpCredentials = Field( + title="GCP Credentials", + default_factory=GcpCredentials, + description=( + "The GCP Credentials used to connect to Cloud Run. " + "If not provided credentials will be inferred from " + "the local environment." + ), + ) + job_body: Dict[str, Any] = Field( + template=_get_default_job_body_template(), + ) + keep_job: bool = Field( + default=False, + title="Keep Job After Completion", + description="Keep the completed Cloud run job on Google Cloud Platform.", + ) + region: str = Field( + default="us-central1", + description="The region in which to run the Cloud Run job", + ) + timeout: int = Field( + default=600, + gt=0, + le=86400, + description=( + "The length of time that Prefect will wait for a Cloud Run Job to " + "complete before raising an exception." + ), + ) + + @property + def project(self) -> str: + """ + Returns the GCP project associated with the credentials. + + Returns: + str: The GCP project associated with the credentials. + """ + return self.credentials.project + + @property + def job_name(self): + """ + Returns the job name, if it does not exist, it creates it. + """ + pre_trim_cr_job_name = f"prefect-{self.name}" + + if len(pre_trim_cr_job_name) > 40: + pre_trim_cr_job_name = pre_trim_cr_job_name[:40] + + pre_trim_cr_job_name = pre_trim_cr_job_name.rstrip("-") + + return pre_trim_cr_job_name + + def prepare_for_flow_run( + self, + flow_run: "FlowRun", + deployment: Optional["DeploymentResponse"] = None, + flow: Optional["Flow"] = None, + ): + """ + Prepares the job configuration for a flow run. + + Ensures that necessary values are present in the job body and that the + job body is valid. + + Args: + flow_run: The flow run to prepare the job configuration for + deployment: The deployment associated with the flow run used for + preparation. + flow: The flow associated with the flow run used for preparation. + """ + super().prepare_for_flow_run( + flow_run=flow_run, + deployment=deployment, + flow=flow, + ) + + self._populate_env() + self._populate_or_format_command() + self._format_args_if_present() + self._populate_image_if_not_present() + self._populate_timeout() + + def _populate_timeout(self): + """ + Populates the job body with the timeout. + """ + self.job_body["template"]["template"]["timeout"] = f"{self.timeout}s" + + def _populate_env(self): + """ + Populates the job body with environment variables. + """ + envs = [{"name": k, "value": v} for k, v in self.env.items()] + + self.job_body["template"]["template"]["containers"][0]["env"] = envs + + def _populate_image_if_not_present(self): + """ + Populates the job body with the image if not present. + """ + if "image" not in self.job_body["template"]["template"]["containers"][0]: + self.job_body["template"]["template"]["containers"][0][ + "image" + ] = f"docker.io/{get_prefect_image_name()}" + + def _populate_or_format_command(self): + """ + Populates the job body with the command if not present. + """ + command = self.job_body["template"]["template"]["containers"][0].get("command") + + if command is None: + self.job_body["template"]["template"]["containers"][0]["command"] = [ + "python", + "-m", + "prefect.engine", + ] + elif isinstance(command, str): + self.job_body["template"]["template"]["containers"][0][ + "command" + ] = shlex.split(command) + + def _format_args_if_present(self): + """ + Formats the job body args if present. + """ + args = self.job_body["template"]["template"]["containers"][0].get("args") + + if args is not None and isinstance(args, str): + self.job_body["template"]["template"]["containers"][0][ + "args" + ] = shlex.split(args) + + # noinspection PyMethodParameters + @validator("job_body") + def _ensure_job_includes_all_required_components(cls, value: Dict[str, Any]): + """ + Ensures that the job body includes all required components. + + Args: + value: The job body to validate. + Returns: + The validated job body. + """ + patch = JsonPatch.from_diff(value, _get_base_job_body()) + + missing_paths = sorted([op["path"] for op in patch if op["op"] == "add"]) + + if missing_paths: + raise ValueError( + f"Job body is missing required components: {', '.join(missing_paths)}" + ) + + return value + + # noinspection PyMethodParameters + @validator("job_body") + def _ensure_job_has_compatible_values(cls, value: Dict[str, Any]): + """Ensure that the job body has compatible values.""" + patch = JsonPatch.from_diff(value, _get_base_job_body()) + incompatible = sorted( + [ + f"{op['path']} must have value {op['value']!r}" + for op in patch + if op["op"] == "replace" + ] + ) + if incompatible: + raise ValueError( + "Job has incompatible values for the following attributes: " + f"{', '.join(incompatible)}" + ) + return value + + +class CloudRunWorkerV2Variables(BaseVariables): + """ + Default variables for the Cloud Run worker V2. + + The schema for this class is used to populate the `variables` section of the + default base job template. + """ + + credentials: GcpCredentials = Field( + title="GCP Credentials", + default_factory=GcpCredentials, + description=( + "The GCP Credentials used to connect to Cloud Run. " + "If not provided credentials will be inferred from " + "the local environment." + ), + ) + region: str = Field( + default="us-central1", + description="The region in which to run the Cloud Run job", + ) + image: Optional[str] = Field( + default="prefecthq/prefect:2-latest", + title="Image Name", + description=( + "The image to use for the Cloud Run job. " + "If not provided the default Prefect image will be used." + ), + ) + args: List[str] = Field( + default_factory=list, + description=( + "The arguments to pass to the Cloud Run Job V2's entrypoint command." + ), + ) + keep_job: bool = Field( + default=False, + title="Keep Job After Completion", + description="Keep the completed Cloud run job on Google Cloud Platform.", + ) + launch_stage: Literal[ + "ALPHA", + "BETA", + "GA", + "DEPRECATED", + "EARLY_ACCESS", + "PRELAUNCH", + "UNIMPLEMENTED", + "LAUNCH_TAG_UNSPECIFIED", + ] = Field( + "BETA", + description=( + "The launch stage of the Cloud Run Job V2. " + "See https://cloud.google.com/run/docs/about-features-categories " + "for additional details." + ), + ) + max_retries: int = Field( + default=0, + title="Max Retries", + description="The number of times to retry the Cloud Run job.", + ) + cpu: str = Field( + default="1000m", + title="CPU", + description="The CPU to allocate to the Cloud Run job.", + ) + memory: str = Field( + default="512Mi", + title="Memory", + description=( + "The memory to allocate to the Cloud Run job along with the units, which" + "could be: G, Gi, M, Mi." + ), + example="512Mi", + pattern=r"^\d+(?:G|Gi|M|Mi)$", + ) + timeout: int = Field( + default=600, + gt=0, + le=86400, + title="Job Timeout", + description=( + "The length of time that Prefect will wait for a Cloud Run Job to " + "complete before raising an exception (maximum of 86400 seconds, 1 day)." + ), + ) + vpc_connector_name: Optional[str] = Field( + default=None, + title="VPC Connector Name", + description="The name of the VPC connector to use for the Cloud Run job.", + ) + + +class CloudRunWorkerV2Result(BaseWorkerResult): + """ + The result of a Cloud Run worker V2 job. + """ + + +class CloudRunWorkerV2(BaseWorker): + """ + The Cloud Run worker V2. + """ + + type = "cloud-run-v2" + job_configuration = CloudRunWorkerJobV2Configuration + job_configuration_variables = CloudRunWorkerV2Variables + _description = "A worker which runs flow runs on Google Cloud Run (API v2)." + _display_name = "Cloud Run Worker V2" + _documentation_url = "https://prefecthq.github.io/prefect-gcp/worker_v2/" + _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/4SpnOBvMYkHp6z939MDKP6/549a91bc1ce9afd4fb12c68db7b68106/social-icon-google-cloud-1200-630.png?h=250" # noqa + + async def run( + self, + flow_run: "FlowRun", + configuration: CloudRunWorkerJobV2Configuration, + task_status: Optional[TaskStatus] = None, + ) -> CloudRunJobV2Result: + """ + Runs the flow run on Cloud Run and waits for it to complete. + + Args: + flow_run: The flow run to run. + configuration: The configuration for the job. + task_status: The task status to update. + + Returns: + The result of the job. + """ + logger = self.get_flow_run_logger(flow_run) + + with self._get_client(configuration=configuration) as cr_client: + await run_sync_in_worker_thread( + self._create_job_and_wait_for_registration, + configuration=configuration, + cr_client=cr_client, + logger=logger, + ) + + execution = await run_sync_in_worker_thread( + self._begin_job_execution, + configuration=configuration, + cr_client=cr_client, + logger=logger, + ) + + if task_status: + task_status.started(configuration.job_name) + + result = await run_sync_in_worker_thread( + self._watch_job_execution_and_get_result, + configuration=configuration, + cr_client=cr_client, + execution=execution, + logger=logger, + ) + + return result + + async def kill_infrastructure( + self, + infrastructure_pid: str, + configuration: CloudRunWorkerJobV2Configuration, + grace_seconds: int = 30, + ): + """ + Stops the Cloud Run job. + + Args: + infrastructure_pid: The ID of the infrastructure to stop. + configuration: The configuration for the job. + grace_seconds: The number of seconds to wait before stopping the job. + """ + if grace_seconds != 30: + self._logger.warning( + f"Kill grace period of {grace_seconds}s requested, but GCP does not " + "support dynamic grace period configuration. See here for more info: " + "https://cloud.google.com/run/docs/reference/rest/v1/namespaces.jobs/delete" # noqa + ) + + with self._get_client(configuration=configuration) as cr_client: + await run_sync_in_worker_thread( + self._stop_job, + cr_client=cr_client, + configuration=configuration, + job_name=infrastructure_pid, + ) + + @staticmethod + def _get_client( + configuration: CloudRunWorkerJobV2Configuration, + ) -> ResourceWarning: + """ + Get the base client needed for interacting with GCP Cloud Run V2 API. + + Returns: + Resource: The base client needed for interacting with GCP Cloud Run V2 API. + """ + api_endpoint = "https://run.googleapis.com" + gcp_creds = configuration.credentials.get_credentials_from_service_account() + + options = ClientOptions(api_endpoint=api_endpoint) + + return ( + discovery.build( + "run", + "v2", + client_options=options, + credentials=gcp_creds, + num_retries=3, # Set to 3 in case of intermittent/connection issues + ) + .projects() + .locations() + ) + + def _create_job_and_wait_for_registration( + self, + configuration: CloudRunWorkerJobV2Configuration, + cr_client: Resource, + logger: PrefectLogAdapter, + ): + """ + Creates the Cloud Run job and waits for it to register. + + Args: + configuration: The configuration for the job. + cr_client: The Cloud Run client. + logger: The logger to use. + """ + try: + logger.info(f"Creating Cloud Run JobV2 {configuration.job_name}") + + JobV2.create( + cr_client=cr_client, + project=configuration.project, + location=configuration.region, + job_id=configuration.job_name, + body=configuration.job_body, + ) + except HttpError as exc: + self._create_job_error( + exc=exc, + configuration=configuration, + ) + + try: + self._wait_for_job_creation( + cr_client=cr_client, + configuration=configuration, + logger=logger, + ) + except Exception as exc: + logger.critical( + f"Failed to create Cloud Run JobV2 {configuration.job_name}.\n{exc}" + ) + + if not configuration.keep_job: + try: + JobV2.delete( + cr_client=cr_client, + project=configuration.project, + location=configuration.region, + job_name=configuration.job_name, + ) + except Exception as exc2: + logger.critical( + f"Failed to delete Cloud Run JobV2 {configuration.job_name}." + f"\n{exc2}" + ) + + raise + + @staticmethod + def _wait_for_job_creation( + cr_client: Resource, + configuration: CloudRunWorkerJobV2Configuration, + logger: PrefectLogAdapter, + poll_interval: int = 5, + ): + """ + Waits for the Cloud Run job to be created. + + Args: + cr_client: The Cloud Run client. + configuration: The configuration for the job. + logger: The logger to use. + poll_interval: The interval to poll the Cloud Run job, defaults to 5 + seconds. + """ + job = JobV2.get( + cr_client=cr_client, + project=configuration.project, + location=configuration.region, + job_name=configuration.job_name, + ) + + t0 = time.time() + + while not job.is_ready(): + if not (ready_condition := job.get_ready_condition()): + ready_condition = "waiting for condition update" + + logger.info(f"Current Job Condition: {ready_condition}") + + job = JobV2.get( + cr_client=cr_client, + project=configuration.project, + location=configuration.region, + job_name=configuration.job_name, + ) + + elapsed_time = time.time() - t0 + + if elapsed_time > configuration.timeout: + raise RuntimeError( + f"Timeout of {configuration.timeout} seconds reached while " + f"waiting for Cloud Run Job V2 {configuration.job_name} to be " + "created." + ) + + time.sleep(poll_interval) + + @staticmethod + def _create_job_error( + exc: HttpError, + configuration: CloudRunWorkerJobV2Configuration, + ): + """ + Creates a formatted error message for the Cloud Run V2 API errors + """ + # noinspection PyUnresolvedReferences + if exc.status_code == 404: + raise RuntimeError( + f"Failed to find resources at {exc.uri}. Confirm that region" + f" '{configuration.region}' is the correct region for your Cloud" + f" Run Job and that {configuration.project} is the correct GCP " + f" project. If your project ID is not correct, you are using a " + f"Credentials block with permissions for the wrong project." + ) from exc + + raise exc + + def _begin_job_execution( + self, + cr_client: Resource, + configuration: CloudRunWorkerJobV2Configuration, + logger: PrefectLogAdapter, + ) -> ExecutionV2: + """ + Begins the Cloud Run job execution. + + Args: + cr_client: The Cloud Run client. + configuration: The configuration for the job. + logger: The logger to use. + + Returns: + The Cloud Run job execution. + """ + try: + logger.info( + f"Submitting Cloud Run Job V2 {configuration.job_name} for execution..." + ) + + submission = JobV2.run( + cr_client=cr_client, + project=configuration.project, + location=configuration.region, + job_name=configuration.job_name, + ) + + job_execution = ExecutionV2.get( + cr_client=cr_client, + execution_id=submission["metadata"]["name"], + ) + + command = ( + " ".join(configuration.command) + if configuration.command + else "default container command" + ) + + logger.info( + f"Cloud Run Job V2 {configuration.job_name} submitted for execution " + f"with command: {command}" + ) + + return job_execution + except Exception as exc: + self._job_run_submission_error( + exc=exc, + configuration=configuration, + ) + raise + + def _watch_job_execution_and_get_result( + self, + cr_client: Resource, + configuration: CloudRunWorkerJobV2Configuration, + execution: ExecutionV2, + logger: PrefectLogAdapter, + poll_interval: int = 5, + ) -> CloudRunJobV2Result: + """ + Watch the job execution and get the result. + + Args: + cr_client (Resource): The base client needed for interacting with GCP + Cloud Run V2 API. + configuration (CloudRunWorkerJobV2Configuration): The configuration for + the job. + execution (ExecutionV2): The execution to watch. + logger (PrefectLogAdapter): The logger to use. + poll_interval (int): The number of seconds to wait between polls. + Defaults to 5 seconds. + + Returns: + The result of the job. + """ + try: + execution = self._watch_job_execution( + cr_client=cr_client, + configuration=configuration, + execution=execution, + poll_interval=poll_interval, + ) + except Exception as exc: + logger.critical( + f"Encountered an exception while waiting for job run completion - " + f"{exc}" + ) + raise + + if execution.succeeded(): + status_code = 0 + logger.info(f"Cloud Run Job V2 {configuration.job_name} succeeded") + else: + status_code = 1 + error_mg = execution.condition_after_completion().get("message") + logger.error( + f"Cloud Run Job V2 {configuration.job_name} failed - {error_mg}" + ) + + logger.info(f"Job run logs can be found on GCP at: {execution.logUri}") + + if not configuration.keep_job: + logger.info( + f"Deleting completed Cloud Run Job {configuration.job_name!r} from " + "Google Cloud Run..." + ) + + try: + JobV2.delete( + cr_client=cr_client, + project=configuration.project, + location=configuration.region, + job_name=configuration.job_name, + ) + except Exception as exc: + logger.critical( + "Received an exception while deleting the Cloud Run Job V2 " + f"- {configuration.job_name} - {exc}" + ) + + return CloudRunJobV2Result( + identifier=configuration.job_name, + status_code=status_code, + ) + + # noinspection DuplicatedCode + @staticmethod + def _watch_job_execution( + cr_client: Resource, + configuration: CloudRunWorkerJobV2Configuration, + execution: ExecutionV2, + poll_interval: int, + ) -> ExecutionV2: + """ + Update execution status until it is no longer running or timeout is reached. + + Args: + cr_client (Resource): The base client needed for interacting with GCP + Cloud Run V2 API. + configuration (CloudRunWorkerJobV2Configuration): The configuration for + the job. + execution (ExecutionV2): The execution to watch. + poll_interval (int): The number of seconds to wait between polls. + + Returns: + The execution. + """ + t0 = time.time() + + while execution.is_running(): + execution = ExecutionV2.get( + cr_client=cr_client, + execution_id=execution.name, + ) + + elapsed_time = time.time() - t0 + + if elapsed_time > configuration.timeout: + raise RuntimeError( + f"Timeout of {configuration.timeout} seconds reached while " + f"waiting for Cloud Run Job V2 {configuration.job_name} to " + "complete." + ) + + time.sleep(poll_interval) + + return execution + + @staticmethod + def _job_run_submission_error( + exc: Exception, + configuration: CloudRunWorkerJobV2Configuration, + ): + """ + Creates a formatted error message for the Cloud Run V2 API errors + + Args: + exc: The exception to format. + configuration: The configuration for the job. + """ + # noinspection PyUnresolvedReferences + if exc.status_code == 404: + pat1 = r"The requested URL [^ ]+ was not found on this server" + + if re.findall(pat1, str(exc)): + # noinspection PyUnresolvedReferences + raise RuntimeError( + f"Failed to find resources at {exc.uri}. " + f"Confirm that region '{configuration.region}' is " + f"the correct region for your Cloud Run Job " + f"and that '{configuration.project}' is the " + f"correct GCP project. If your project ID is not " + f"correct, you are using a Credentials " + f"block with permissions for the wrong project." + ) from exc + else: + raise exc + + @staticmethod + def _stop_job( + cr_client: Resource, + configuration: CloudRunWorkerJobV2Configuration, + job_name: str, + ): + """ + Stops/deletes the Cloud Run job. + + Args: + cr_client: The Cloud Run client. + configuration: The configuration for the job. + job_name: The name of the job to stop. + """ + try: + JobV2.delete( + cr_client=cr_client, + project=configuration.project, + location=configuration.region, + job_name=job_name, + ) + except Exception as exc: + if "does not exist" in str(exc): + raise InfrastructureNotFound( + f"Cannot stop Cloud Run Job; the job name {job_name!r} " + "could not be found." + ) from exc + raise diff --git a/tests/test_cloud_run_v2.py b/tests/test_cloud_run_v2.py new file mode 100644 index 00000000..092204e2 --- /dev/null +++ b/tests/test_cloud_run_v2.py @@ -0,0 +1,191 @@ +import pytest + +from prefect_gcp.models.cloud_run_v2 import JobV2 + +jobs_return_value = { + "name": "test-job-name", + "uid": "uid-123", + "generation": "1", + "labels": {}, + "createTime": "create-time", + "updateTime": "update-time", + "deleteTime": "delete-time", + "expireTime": "expire-time", + "creator": "creator", + "lastModifier": "last-modifier", + "client": "client", + "clientVersion": "client-version", + "launchStage": "BETA", + "binaryAuthorization": {}, + "template": {}, + "observedGeneration": "1", + "terminalCondition": {}, + "conditions": [], + "executionCount": 1, + "latestCreatedExecution": {}, + "reconciling": True, + "satisfiesPzs": False, + "etag": "etag-123", +} + + +class TestJobV2: + @pytest.mark.parametrize( + "state,expected", + [("CONDITION_SUCCEEDED", True), ("CONDITION_FAILED", False)], + ) + def test_is_ready(self, state, expected): + job = JobV2( + name="test-job", + uid="12345", + generation="2", + labels={}, + annotations={}, + createTime="2021-08-31T18:00:00Z", + updateTime="2021-08-31T18:00:00Z", + launchStage="BETA", + binaryAuthorization={}, + template={}, + terminalCondition={ + "type": "Ready", + "state": state, + }, + conditions=[], + executionCount=1, + latestCreatedExecution={}, + reconciling=False, + satisfiesPzs=False, + etag="etag-12345", + ) + + assert job.is_ready() == expected + + def test_is_ready_raises_exception(self): + job = JobV2( + name="test-job", + uid="12345", + generation="2", + labels={}, + annotations={}, + createTime="2021-08-31T18:00:00Z", + updateTime="2021-08-31T18:00:00Z", + launchStage="BETA", + binaryAuthorization={}, + template={}, + terminalCondition={ + "type": "Ready", + "state": "CONTAINER_FAILED", + "reason": "ContainerMissing", + }, + conditions=[], + executionCount=1, + latestCreatedExecution={}, + reconciling=False, + satisfiesPzs=False, + etag="etag-12345", + ) + + with pytest.raises(Exception): + job.is_ready() + + @pytest.mark.parametrize( + "terminal_condition,expected", + [ + ( + { + "type": "Ready", + "state": "CONDITION_SUCCEEDED", + }, + { + "type": "Ready", + "state": "CONDITION_SUCCEEDED", + }, + ), + ( + { + "type": "Failed", + "state": "CONDITION_FAILED", + }, + {}, + ), + ], + ) + def test_get_ready_condition(self, terminal_condition, expected): + job = JobV2( + name="test-job", + uid="12345", + generation="2", + labels={}, + annotations={}, + createTime="2021-08-31T18:00:00Z", + updateTime="2021-08-31T18:00:00Z", + launchStage="BETA", + binaryAuthorization={}, + template={}, + terminalCondition=terminal_condition, + conditions=[], + executionCount=1, + latestCreatedExecution={}, + reconciling=False, + satisfiesPzs=False, + etag="etag-12345", + ) + + assert job.get_ready_condition() == expected + + @pytest.mark.parametrize( + "ready_condition,expected", + [ + ( + { + "state": "CONTAINER_FAILED", + "reason": "ContainerMissing", + }, + True, + ), + ( + { + "state": "CONDITION_SUCCEEDED", + }, + False, + ), + ], + ) + def test_is_missing_container(self, ready_condition, expected): + job = JobV2( + name="test-job", + uid="12345", + generation="2", + labels={}, + annotations={}, + createTime="2021-08-31T18:00:00Z", + updateTime="2021-08-31T18:00:00Z", + launchStage="BETA", + binaryAuthorization={}, + template={}, + terminalCondition={}, + conditions=[], + executionCount=1, + latestCreatedExecution={}, + reconciling=False, + satisfiesPzs=False, + etag="etag-12345", + ) + + assert job._is_missing_container(ready_condition=ready_condition) == expected + + +def remove_server_url_from_env(env): + """ + For convenience since the testing database URL is non-deterministic. + """ + return [ + env_var + for env_var in env + if env_var["name"] + not in [ + "PREFECT_API_DATABASE_CONNECTION_URL", + "PREFECT_ORION_DATABASE_CONNECTION_URL", + "PREFECT_SERVER_DATABASE_CONNECTION_URL", + ] + ] diff --git a/tests/test_cloud_run_worker_v2.py b/tests/test_cloud_run_worker_v2.py new file mode 100644 index 00000000..de717901 --- /dev/null +++ b/tests/test_cloud_run_worker_v2.py @@ -0,0 +1,94 @@ +import pytest +from prefect.utilities.dockerutils import get_prefect_image_name + +from prefect_gcp.credentials import GcpCredentials +from prefect_gcp.workers.cloud_run_v2 import CloudRunWorkerJobV2Configuration + + +@pytest.fixture +def job_body(): + return { + "client": "prefect", + "launchStage": None, + "template": { + "template": { + "maxRetries": None, + "timeout": None, + "containers": [ + { + "env": [], + "command": None, + "args": "-m prefect.engine", + "resources": { + "limits": { + "cpu": None, + "memory": None, + }, + }, + }, + ], + } + }, + } + + +@pytest.fixture +def cloud_run_worker_v2_job_config(service_account_info, job_body): + return CloudRunWorkerJobV2Configuration( + name="my-job-name", + job_body=job_body, + credentials=GcpCredentials(service_account_info=service_account_info), + region="us-central1", + timeout=86400, + env={"ENV1": "VALUE1", "ENV2": "VALUE2"}, + ) + + +class TestCloudRunWorkerJobV2Configuration: + def test_project(self, cloud_run_worker_v2_job_config): + assert cloud_run_worker_v2_job_config.project == "my_project" + + def test_job_name(self, cloud_run_worker_v2_job_config): + assert cloud_run_worker_v2_job_config.job_name == "prefect-my-job-name" + + def test_populate_timeout(self, cloud_run_worker_v2_job_config): + cloud_run_worker_v2_job_config._populate_timeout() + + assert ( + cloud_run_worker_v2_job_config.job_body["template"]["template"]["timeout"] + == "86400s" + ) + + def test_populate_env(self, cloud_run_worker_v2_job_config): + cloud_run_worker_v2_job_config._populate_env() + + assert cloud_run_worker_v2_job_config.job_body["template"]["template"][ + "containers" + ][0]["env"] == [ + {"name": "ENV1", "value": "VALUE1"}, + {"name": "ENV2", "value": "VALUE2"}, + ] + + def test_populate_image_if_not_present(self, cloud_run_worker_v2_job_config): + cloud_run_worker_v2_job_config._populate_image_if_not_present() + + assert ( + cloud_run_worker_v2_job_config.job_body["template"]["template"][ + "containers" + ][0]["image"] + == f"docker.io/{get_prefect_image_name()}" + ) + + def test_populate_or_format_command(self, cloud_run_worker_v2_job_config): + cloud_run_worker_v2_job_config._populate_or_format_command() + + assert cloud_run_worker_v2_job_config.job_body["template"]["template"][ + "containers" + ][0]["command"] == ["python", "-m", "prefect.engine"] + + def test_format_args_if_present(self, cloud_run_worker_v2_job_config): + cloud_run_worker_v2_job_config._format_args_if_present() + + assert cloud_run_worker_v2_job_config.job_body["template"]["template"][ + "containers" + ][0]["args"] == ["-m", "prefect.engine"]