Skip to content

Commit

Permalink
document
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevc committed Jul 20, 2024
1 parent 7019d75 commit 47ede49
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 103 deletions.
200 changes: 109 additions & 91 deletions snakemake_executor_plugin_azure_batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@


import datetime
import io
import os
import shlex
import uuid
from dataclasses import dataclass, field
Expand Down Expand Up @@ -60,6 +58,7 @@
)
from snakemake_executor_plugin_azure_batch.util import (
AzureIdentityCredentialAdapter,
read_stream_as_string,
unpack_compute_node_errors,
unpack_task_failure_information,
)
Expand Down Expand Up @@ -279,68 +278,51 @@ def batch_account_name(self):

class Executor(RemoteExecutor):
def __post_init__(self):
# the snakemake/snakemake:latest container image
# the snakemake/snakemake:latest container image to run the workflow remote
self.container_image = self.workflow.remote_execution_settings.container_image
self.settings: ExecutorSettings = self.workflow.executor_settings
self.logger.debug(
f"ExecutorSettings: {pformat(self.workflow.executor_settings, indent=2)}"
)

# handle case on OSX with /var/ symlinked to /private/var/ causing
# issues with workdir not matching other workflow file dirs
dirname = os.path.dirname(self.workflow.persistence.path)
osxprefix = "/private"
if osxprefix in dirname:
dirname = dirname.removeprefix(osxprefix)

self.workdir = dirname

# Pool ids can only contain any combination of alphanumeric characters along
# with dash and underscore.
# Pool ids can only contain alphanumeric characters, dashes and underscore.
ts = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
self.pool_id = f"snakepool-{ts:s}"
self.job_id = f"snakejob-{ts:s}"

self.envvars = self.workflow.spawned_job_args_factory.envvars()
# requires managed identity resource id to be set for the azure storage plugin
if self.workflow.storage_settings.default_storage_provider == "azure":
if self.settings.managed_identity_resource_id is None:
raise WorkflowError(
"Azure Storage plugin requires a managed identity "
"resource and client ids to be set."
)

self.init_batch_client()
self.create_batch_pool()
self.create_batch_job()

# override real.py until bugfix is released
def get_envvar_declarations(self):
if self.common_settings.pass_envvar_declarations_to_cmd:
defs = " ".join(
f"{var}={repr(value)}" for var, value in self.envvars.items()
)
if defs:
return f"export {defs} &&"
else:
return ""
else:
return ""

def init_batch_client(self):
"""
Initialize the batch service client from the given credentials
Initialize the BatchServiceClient and BatchManagementClient using
DefaultAzureCredential.
Sets:
self.batch_client
self.batch_mgmt_client
- self.batch_client
- self.batch_mgmt_client
"""
try:
# alias these variables here to save space
batch_url = self.settings.account_url

# else authenticate with managed identity client id
# initialize BatchServiceClient
default_credential = DefaultAzureCredential(
exclude_managed_identity_credential=True
)
adapted_credential = AzureIdentityCredentialAdapter(
credential=default_credential, resource_id=AZURE_BATCH_RESOURCE_ENDPOINT
)

# initialize batch client with creds
self.batch_client = BatchServiceClient(adapted_credential, batch_url)
self.batch_client = BatchServiceClient(
adapted_credential, self.settings.account_url
)

# initialize BatchManagementClient
self.batch_mgmt_client = BatchManagementClient(
Expand All @@ -349,11 +331,18 @@ def init_batch_client(self):
)

except Exception as e:
raise WorkflowError("Failed to initialize batch client", e)
raise WorkflowError("Failed to initialize batch clients", e)

def shutdown(self):
# perform additional steps on shutdown
# if necessary (jobs were cancelled already)
def cleanup_resources(self):
"""Cleanup Azure Batch resources.
This method is responsible for cleaning up Azure Batch resources, including
deleting the Batch job and pool. If the `keep_pool` setting is enabled, the
pool will not be deleted.
Raises:
WorkflowError: If there is an error deleting the Batch job or pool.
"""
if not self.settings.keep_pool:
try:
self.logger.debug("Deleting AzBatch job")
Expand All @@ -369,6 +358,15 @@ def shutdown(self):
if be.error.code == "PoolBeingDeleted":
pass

def shutdown(self):
"""Shutdown the executor, cleaning up resources if necessary.
This method is responsible for shutting down the executor and cleaning up any
resources that were created during the execution. If the `keep_pool` setting is
enabled, the pool will not be deleted. This method should be called before
exiting the program or when the executor is no longer needed.
"""
self.cleanup_resources()
super().shutdown()

def run_job(self, job: JobExecutorInterface):
Expand All @@ -379,12 +377,9 @@ def run_job(self, job: JobExecutorInterface):
# self.report_job_submission(job_info).
# with job_info being of type
# snakemake_interface_executor_plugins.executors.base.SubmittedJobInfo.
envsettings = []
for key, value in self.envvars.items():
try:
envsettings.append(bm.EnvironmentSetting(name=key, value=value))
except KeyError:
continue
env_settings = []
for key, value in self.envvars().items():
env_settings.append(bm.EnvironmentSetting(name=key, value=value))

exec_job = self.format_job_exec(job)
remote_command = f"/bin/bash -c {shlex.quote(exec_job)}"
Expand Down Expand Up @@ -416,7 +411,7 @@ def run_job(self, job: JobExecutorInterface):
command_line=remote_command,
container_settings=task_container_settings,
user_identity=bm.UserIdentity(auto_user=user),
environment_settings=envsettings,
environment_settings=env_settings,
)

job_info = SubmittedJobInfo(job, external_jobid=task_id)
Expand All @@ -430,17 +425,25 @@ def run_job(self, job: JobExecutorInterface):
self.report_job_submission(job_info)

def _report_pool_errors(self, job: SubmittedJobInfo):
"""report batch pool errors"""
"""Report batch pool errors.
This method is responsible for reporting any resize errors that are detected
from the Azure Batch pool.
"""
errors = []
pool = self.batch_client.pool.get(self.pool_id)
if pool.resize_errors:
for e in pool.resize_errors:
err_dict = {"code": e.code, "message": e.message}
errors.append(err_dict)
self.report_job_error(job, msg=f"Batch pool error: {e}")
self.report_job_error(job, msg=f"Batch pool error: {e}. ")

def _report_task_status(self, job: SubmittedJobInfo):
"""report batch task status. Return True if still running, False if not"""
"""Report batch task status.
Returns:
bool: True if the task is still running, False otherwise.
"""
try:
task: bm.CloudTask = self.batch_client.task.get(
job_id=self.job_id, task_id=job.external_jobid
Expand All @@ -450,9 +453,11 @@ def _report_task_status(self, job: SubmittedJobInfo):
return True

self.logger.debug(
f"task {task.id}: "
f"creation_time={task.creation_time} "
f"state={task.state} node_info={task.node_info}\n"
{
"task": task.id,
"state": str(task.state),
"creation_time": str(task.creation_time),
}
)

if task.state == bm.TaskState.completed:
Expand All @@ -469,14 +474,6 @@ def _report_task_status(self, job: SubmittedJobInfo):
self.report_job_error(job, msg=msg, stderr=stderr, stdout=stdout)
elif ei.result == bm.TaskExecutionResult.success:
self.report_job_success(job)
else:
msg = f"\nUnknown task execution result: {ei.__dict__}\n"
self.report_job_error(
job,
msg=msg,
stderr=stderr,
stdout=stdout,
)
return False
else:
return True
Expand Down Expand Up @@ -504,15 +501,15 @@ def _report_node_errors(self):
stderr_file = self.batch_client.file.get_from_compute_node(
self.pool_id, n.id, "/startup/stderr.txt"
)
stderr_stream = self._read_stream_as_string(stderr_file, "utf-8")
stderr_stream = read_stream_as_string(stderr_file, "utf-8")
except Exception:
stderr_stream = ""

try:
stdout_file = self.batch_client.file.get_from_compute_node(
self.pool_id, n.id, "/startup/stdout.txt"
)
stdout_stream = self._read_stream_as_string(stdout_file, "utf-8")
stdout_stream = read_stream_as_string(stdout_file, "utf-8")
except Exception:
stdout_stream = ""

Expand Down Expand Up @@ -563,7 +560,16 @@ def cancel_jobs(self, active_jobs: List[SubmittedJobInfo]):
self.batch_client.task.terminate(self.job_id, task.id)

def create_batch_pool(self):
"""Creates a pool of compute nodes"""
"""Creates a pool of compute nodes.
This method is responsible for creating a pool of compute nodes in Azure Batch.
Returns:
None
Raises:
WorkflowError: If there is an error creating the pool.
"""

image_ref = ImageReference(
publisher=self.settings.pool_image_publisher,
Expand Down Expand Up @@ -615,9 +621,11 @@ def create_batch_pool(self):
else:
raise WorkflowError(
"No container registry authentication scheme set. Please set the "
"BATCH_CONTAINER_REGISTRY_USER and BATCH_CONTAINER_REGISTRY_PASS "
"or set MANAGED_IDENTITY_CLIENT_ID and "
"MANAGED_IDENTITY_RESOURCE_ID."
"SNAKEMAKE_AZURE_BATCH_CONTAINER_REGISTRY_USER and "
"SNAKEMAKE_AZURE_BATCH_CONTAINER_REGISTRY_PASS "
"or set SNAKEMAKE_AZURE_BATCH_MANAGED_IDENTITY_CLIENT_ID and "
"SNAKEMAKE_AZURE_BATCH_MANAGED_IDENTITY_RESOURCE_ID "
"and Grant it permissions to the Azure Container Registry."
)

registry_conf = [
Expand Down Expand Up @@ -699,7 +707,7 @@ def create_batch_pool(self):
target_node_communication_mode=NodeCommunicationMode.CLASSIC,
)
try:
self.logger.info(f"Creating pool: {self.pool_id}")
self.logger.info(f"Creating Batch Pool: {self.pool_id}")
# we use the azure.mgmt.batch client to create the pool here because if you
# configure a managed identity for the batch nodes, the azure.batch client
# does not correctly apply it to the pool
Expand All @@ -719,9 +727,20 @@ def create_batch_pool(self):
self.logger.info(f"Pool {self.pool_id} exists.")

def create_batch_job(self):
"""Creates a job with the specified ID, associated with the specified pool"""
"""Creates a job with the specified ID, associated with the specified pool.
self.logger.info(f"Creating batch job {self.job_id}")
Args:
job_id (str): The ID of the job.
pool_id (str): The ID of the pool associated with the job.
Returns:
None
Raises:
WorkflowError: If there is an error creating the job.
"""

self.logger.info(f"Creating Batch Job: {self.job_id}")

try:
self.batch_client.job.add(
Expand All @@ -734,33 +753,32 @@ def create_batch_job(self):
except bm.BatchErrorException as e:
raise WorkflowError("Error adding batch job", e)

# from https://github.com/Azure-Samples/batch-python-quickstart/blob/master/src/python_quickstart_client.py # noqa
@staticmethod
def _read_stream_as_string(stream, encoding):
"""Read stream as string
:param stream: input stream generator
:param str encoding: The encoding of the file. The default is utf-8.
:return: The file content.
:rtype: str
"""
output = io.BytesIO()
try:
for data in stream:
output.write(data)
if encoding is None:
encoding = "utf-8"
return output.getvalue().decode(encoding)
finally:
output.close()

# adopted from
# https://github.com/Azure-Samples/batch-python-quickstart/blob/master/src/python_quickstart_client.py # noqa
def _get_task_output(self, job_id, task_id, stdout_or_stderr, encoding=None):
"""
Retrieves the content of the specified task's stdout or stderr file.
Args:
job_id (str): The ID of the job that contains the task.
task_id (str): The ID of the task.
stdout_or_stderr (str): Specifies whether to retrieve the stdout or stderr
file content. Must be either "stdout" or "stderr".
encoding (str, optional): The encoding to use when reading the file content.
Defaults to None.
Returns:
str: The content of the specified stdout or stderr file, or an empty string
if the file does not exist.
Raises:
Exception: If an error occurs while retrieving the file content.
"""
assert stdout_or_stderr in ["stdout", "stderr"]
fname = stdout_or_stderr + ".txt"
try:
stream = self.batch_client.file.get_from_task(job_id, task_id, fname)
content = self._read_stream_as_string(stream, encoding)
content = read_stream_as_string(stream, encoding)
except Exception:
content = ""

Expand Down
Loading

0 comments on commit 47ede49

Please sign in to comment.