diff --git a/src/engine/src/core/ioc/IOCContainerFactory.py b/src/engine/src/core/ioc/IOCContainerFactory.py index ec6d2308..aedaef18 100644 --- a/src/engine/src/core/ioc/IOCContainerFactory.py +++ b/src/engine/src/core/ioc/IOCContainerFactory.py @@ -1,5 +1,3 @@ -from functools import partial - from core.ioc import IOCContainer from core.state import ReactiveState from core.daos import ( @@ -24,6 +22,7 @@ TaskRepository, TemplateRepository ) +from core.tasks.TaskInputFileStagingService import TaskInputFileStagingService from core.workflows import ( GraphValidator, ValueFromService @@ -119,6 +118,12 @@ def build(self): # ) # ) + container.register("TaskInputFileStagingService", + lambda: TaskInputFileStagingService( + container.load("ValueFromService") + ) + ) + container.register("TaskOutputRepository", lambda: TaskOutputRepository( container.load("TaskOutputMapper") diff --git a/src/engine/src/core/tasks/TaskInputFileStagingService.py b/src/engine/src/core/tasks/TaskInputFileStagingService.py new file mode 100644 index 00000000..5e165026 --- /dev/null +++ b/src/engine/src/core/tasks/TaskInputFileStagingService.py @@ -0,0 +1,68 @@ +import os + +from core.workflows import ValueFromService +from owe_python_sdk.schema import Task +from errors.tasks import TaskInputStagingError + + +class TaskInputFileStagingService: + """Responsible for creating the actual files that will be used as inputs + during a task execution. Possible sources of data for the input files may be + one or all of the following: + 1) The value property of an input specification + 2) An argument passed as part of a pipeline invocation + 2) A variable from the envrionment with which the pipeline was invoked + 2) An output file from another task upon which the staging task is dependent + """ + def __init__( + self, + value_from_service: ValueFromService + ): + self.value_from_service = value_from_service + + def stage(self, task: Task): + """Iterates over all of the items in the task input dictionary, fetches + the values from their sources, then creates the files in the task's + working directory""" + for input_id, input_ in task.input.items(): + if input_.value != None: + self._create_input_(task, input_id, input_.value) + + value_from = input_.value_from + key = list(value_from.keys())[0] # NOTE Should only have 1 key + if key == "task_output": + try: + value = self._value_from_service.get_task_output_value_by_id( + task_id=value_from[key].task_id, + _id=value_from[key].output_id + ) + except Exception: + raise TaskInputStagingError(f"No output found for task '{value_from[key].task_id}' with output id of '{value_from[key].output_id}'") + if key == "args": + try: + value = self._value_from_service.get_arg_value_by_key( + value_from[key] + ) + except Exception: + raise TaskInputStagingError(f"Error attempting to fetch value from args at key '{key}'") + if key == "env": + try: + value = self._value_from_service.get_env_value_by_key( + value_from[key] + ) + except Exception: + raise TaskInputStagingError(f"Error attempting to fetch value from env at key '{key}'") + + self._create_input_(task, input_id, value) + + def _create_input_(self, task, input_id, value): + try: + with open(os.path.join(task.input_dir, input_id), mode="w") as file: + if value == None: value == "" + file.write(str(value)) + except Exception as e: + raise TaskInputStagingError(f"Error while staging input: {e}") + + + + \ No newline at end of file diff --git a/src/engine/src/core/tasks/executors/Function.py b/src/engine/src/core/tasks/executors/Function.py index cbaebc69..66b63530 100644 --- a/src/engine/src/core/tasks/executors/Function.py +++ b/src/engine/src/core/tasks/executors/Function.py @@ -15,7 +15,6 @@ from utils.k8s import flavor_to_k8s_resource_reqs, input_to_k8s_env_vars, gen_resource_name from core.tasks import function_bootstrap from core.repositories import GitCacheRepository -from errors import WorkflowTerminated class ContainerDetails: diff --git a/src/engine/src/core/workflows/WorkflowExecutor.py b/src/engine/src/core/workflows/WorkflowExecutor.py index b149ef68..81aee2b1 100644 --- a/src/engine/src/core/workflows/WorkflowExecutor.py +++ b/src/engine/src/core/workflows/WorkflowExecutor.py @@ -40,8 +40,8 @@ server_logger = logging.getLogger("server") -def interceptable(rollback=None): # Decorator factory - def interceptable_decorator(fn): # Decorator +def interruptable(rollback=None): # Decorator factory + def interruptable_decorator(fn): # Decorator def wrapper(self, *args, **kwargs): # Wrapper rollback_fn = getattr(self, (rollback or ""), None) try: @@ -56,14 +56,14 @@ def wrapper(self, *args, **kwargs): # Wrapper server_logger.debug(f"Workflow Termination Signal Detected: Terminating:{self.state.terminating}/Terminated:{self.state.terminated}") if self.state.terminating or self.state.terminated: # Run the rollback function by the name provided in the - # interceptable decorator factory args + # interruptable decorator factory args rollback_fn and rollback_fn() return raise e return wrapper - return interceptable_decorator + return interruptable_decorator class WorkflowExecutor(Worker, EventPublisher): """The Workflow Executor is responsible for processing and executing tasks for @@ -141,7 +141,7 @@ def __init__(self, _id=None, plugins=[]): def p_str(self, status): return f"{lbuf('[PIPELINE]')} {self.state.ctx.idempotency_key} {status} {self.state.ctx.pipeline.id}" def t_str(self, task, status): return f"{lbuf('[TASK]')} {self.state.ctx.idempotency_key} {status} {self.state.ctx.pipeline.id}.{task.id}" - @interceptable() + @interruptable() def start(self, ctx, threads): """This method is the entrypoint for a workflow exection. It's invoked by the main Server instance when a workflow submission is @@ -178,7 +178,7 @@ def start(self, ctx, threads): # Trigger the terminal state callback. self._on_pipeline_terminal_state(event=PIPELINE_FAILED, message=str(e)) - @interceptable() + @interruptable() def _staging(self, ctx): # Resets the workflow executor to its initial state self._set_initial_state() @@ -219,7 +219,7 @@ def _staging(self, ctx): except Exception as e: self._on_pipeline_terminal_state(PIPELINE_FAILED, message=str(e)) - @interceptable() + @interruptable() def _prepare_tasks(self): """This function adds information about the pipeline context to the task objects, prepares the file system for each task execution, handles task templating, @@ -279,7 +279,7 @@ def _prepare_tasks(self): # Register the task executor self._register_executor(self.state.ctx.pipeline_run.uuid, task, executor) - @interceptable() + @interruptable() def _prepare_task_fs(self, task): # Create the base directory for all files and output created during this task execution os.makedirs(task.work_dir, exist_ok=True) @@ -297,7 +297,7 @@ def _prepare_task_fs(self, task): Path(task.stdout).touch() Path(task.stderr).touch() - @interceptable() + @interruptable() def _start_task(self, task): # Check if any of the previous tasks were skipped. If yes, and the current # task's dependency specifies a can_skip == False for any of the skipped tasks, @@ -324,6 +324,12 @@ def _start_task(self, task): # Execute the task if not skip and not expression_error: + # Stage task inputs + task_input_file_staging_service = self.container.load( + "TaskInputFileStagingService" + ) + task_input_file_staging_service.stage(task) + # Log the task status self.state.ctx.logger.info(self.t_str(task, "ACTIVE")) @@ -332,7 +338,10 @@ def _start_task(self, task): try: # Fetch the executor - executor: TaskExecutor = self._get_executor(self.state.ctx.pipeline_run.uuid, task) + executor: TaskExecutor = self._get_executor( + self.state.ctx.pipeline_run.uuid, + task + ) # Run the task executor and get the task result task_result = executor.execute() @@ -355,7 +364,7 @@ def _start_task(self, task): # NOTE Triggers hook _on_change_ready_task self.state.ready_tasks += unstarted_threads - @interceptable() + @interruptable() def _on_task_terminal_state(self, task, task_result): # Determine the correct callback to use. callback = self._on_task_completed @@ -389,7 +398,7 @@ def _on_task_terminal_state(self, task, task_result): # Execute all possible queued tasks return self._fetch_ready_tasks() - @interceptable() + @interruptable() def _on_pipeline_terminal_state(self, event=None, message=""): # No event was provided. Determine if complete or failed from number # of failed tasks @@ -419,7 +428,7 @@ def _on_pipeline_terminal_state(self, event=None, message=""): self._set_initial_state() - @interceptable() + @interruptable() def _on_task_completed(self, task, task_result): # Log the completion self.state.ctx.logger.info(self.t_str(task, "COMPLETED")) @@ -432,7 +441,7 @@ def _on_task_completed(self, task, task_result): self.state.finished.append(task.id) self.state.succeeded.append(task.id) - @interceptable() + @interruptable() def _on_task_skipped(self, task, _): # Log the task active self.state.ctx.logger.info(self.t_str(task, "SKIPPED")) @@ -444,7 +453,7 @@ def _on_task_skipped(self, task, _): self.state.finished.append(task.id) self.state.skipped.append(task.id) - @interceptable() + @interruptable() def _on_task_failed(self, task, task_result): # Log the failure self.state.ctx.logger.info(self.t_str(task, f"FAILED: {task_result.errors}")) @@ -456,7 +465,7 @@ def _on_task_failed(self, task, task_result): self.state.finished.append(task.id) self.state.failed.append(task.id) - @interceptable() + @interruptable() def _get_initial_tasks(self, tasks): initial_tasks = [task for task in tasks if len(task.depends_on) == 0] @@ -467,7 +476,7 @@ def _get_initial_tasks(self, tasks): return initial_tasks - @interceptable() + @interruptable() def _set_tasks(self, tasks): # Create a list of the ids of the tasks task_ids = [task.id for task in tasks] @@ -542,36 +551,14 @@ def _set_tasks(self, tasks): # Add all tasks to the queue self.state.queue = [ task for task in self.state.tasks ] - @interceptable() + @interruptable() def _prepare_pipeline(self): # Create all of the directories needed for the pipeline to run and persist results and cache self._prepare_pipeline_fs() - - # Persist each arg to files in the pipeline file system - arg_value_file_repo = self.container.load("ArgValueFileRepository") - for key in self.state.ctx.args: - arg_value_file_repo.save( - self.state.ctx.pipeline.args_dir + key, - self.state.ctx.args[key].value - ) - - # Persist each arg to files in the pipeline file system - env_var_value_file_repo = self.container.load("EnvVarValueFileRepository") - for key in self.state.ctx.pipeline.env: - env_var_value_file_repo.save( - self.state.ctx.pipeline.env_dir + key, - self.state.ctx.env[key].value - ) + # TODO Perform template mapping at the pipeline level here. - # template_mapper = TemplateMapper(cache_dir=self.state.ctx.pipeline.git_cache_dir) - # if self.state.ctx.pipeline.uses != None: - # self.state.ctx.pipeline = template_mapper.map( - # self.state.ctx.pipeline, - # self.state.ctx.pipeline.uses - # ) - - @interceptable() + @interruptable() def _prepare_pipeline_fs(self): """Creates all of the directories necessary to run the pipeline, store temp files, and cache data""" @@ -627,7 +614,25 @@ def _prepare_pipeline_fs(self): # the state is reset. (Which means that ther will be no self.state.ctx.pipeline.work_dir) self.work_dir = self.state.ctx.pipeline.work_dir - @interceptable() + # Persist each arg to files in the pipeline file system + arg_value_file_repo = self.container.load("ArgValueFileRepository") + arg_repo = self.container.load("ArgRepository") + for key in self.state.ctx.args: + arg_value_file_repo.save( + self.state.ctx.pipeline.args_dir + key, + arg_repo.get_value_by_key(key) + ) + + # Persist each arg to files in the pipeline file system + env_var_value_file_repo = self.container.load("EnvVarValueFileRepository") + env_repo = self.container.load("EnvRepository") + for key in self.state.ctx.pipeline.env: + env_var_value_file_repo.save( + self.state.ctx.pipeline.env_dir + key, + env_repo.get_value_by_key(key) + ) + + @interruptable() def _fetch_ready_tasks(self): ready_tasks = [] threads = [] @@ -645,7 +650,7 @@ def _fetch_ready_tasks(self): return threads - @interceptable() + @interruptable() def _task_is_ready(self, task): # All tasks without dependencies are ready immediately if len(task.depends_on) == 0: return True @@ -663,23 +668,23 @@ def _task_is_ready(self, task): return True - @interceptable() + @interruptable() def _get_task_by_id(self, task_id): return next(filter(lambda t: t.id == task_id,self.state.tasks), None) - @interceptable() + @interruptable() def _remove_from_queue(self, task): len(self.state.queue) == 0 or self.state.queue.pop(self.state.queue.index(task)) - @interceptable() + @interruptable() def _register_executor(self, run_uuid, task, executor): self.state.executors[f"{run_uuid}.{task.id}"] = executor - @interceptable() + @interruptable() def _get_executor(self, run_uuid, task, default=None): return self.state.executors.get(f"{run_uuid}.{task.id}", None) - @interceptable() + @interruptable() def _deregister_executor(self, run_uuid, task): # Clean up the resources created by the task executor executor = self._get_executor(run_uuid, task) @@ -688,7 +693,7 @@ def _deregister_executor(self, run_uuid, task): # TODO use server logger below # self.state.ctx.logger.debug(self.t_str(task, "EXECUTOR DEREGISTERED")) - @interceptable() + @interruptable() def _get_executor(self, run_uuid, task): return self.state.executors[f"{run_uuid}.{task.id}"] @@ -709,7 +714,7 @@ def reset(self, terminated=False): if terminated: self.state.terminated = True - @interceptable() + @interruptable() def _setup_loggers(self): # Create the logger. NOTE Directly instantiating a Logger object # is recommended against in the documentation, however it makes sense @@ -724,7 +729,7 @@ def _setup_loggers(self): run_logger.addHandler(handler) self.state.ctx.logger = CompositeLogger([server_logger, run_logger]) - @interceptable(rollback="_reset_event_exchange") + @interruptable(rollback="_reset_event_exchange") def _initialize_notification_handlers(self): self.state.ctx.logger.debug(self.p_str("INITIALIZING NOTIFICATION HANDLERS")) # Initialize the notification event handlers from plugins. Notification event handlers are used to persist updates about the @@ -744,7 +749,7 @@ def _initialize_notification_handlers(self): middleware.subscriptions ) - @interceptable(rollback="_reset_event_exchange") + @interruptable(rollback="_reset_event_exchange") def _initialize_archivers(self): # No archivers specified. Return if len(self.state.ctx.archives) < 1: return @@ -798,7 +803,7 @@ def _set_initial_state(self): self.work_dir = None self.can_start = False - @interceptable() + @interruptable() def _set_context(self, ctx): self.state.ctx = ctx diff --git a/src/engine/src/errors/tasks.py b/src/engine/src/errors/tasks.py index 38df6aeb..ad896dcc 100644 --- a/src/engine/src/errors/tasks.py +++ b/src/engine/src/errors/tasks.py @@ -25,3 +25,6 @@ class OperandResolutionError(WorkflowsBaseException): class ConditionalExpressionEvalError(WorkflowsBaseException): pass + +class TaskInputStagingError(WorkflowsBaseException): + pass