From 7410c9e525890d902a467a12a0042769c18de830 Mon Sep 17 00:00:00 2001
From: Nate Coraor <nate@bx.psu.edu>
Date: Fri, 26 Apr 2024 18:03:21 -0400
Subject: [PATCH] Do not attempt to complete pre- or post-process if jobs are
 cancelled in the middle of either stage.

---
 pulsar/managers/staging/post.py | 18 ++++++++++++------
 pulsar/managers/staging/pre.py  |  5 ++++-
 pulsar/managers/stateful.py     | 10 ++++++++--
 3 files changed, 24 insertions(+), 9 deletions(-)

diff --git a/pulsar/managers/staging/post.py b/pulsar/managers/staging/post.py
index ebe4bc6a..6ca59390 100644
--- a/pulsar/managers/staging/post.py
+++ b/pulsar/managers/staging/post.py
@@ -13,27 +13,27 @@
 log = logging.getLogger(__name__)
 
 
-def postprocess(job_directory, action_executor):
+def postprocess(job_directory, action_executor, was_cancelled):
     # Returns True if outputs were collected.
     try:
         if job_directory.has_metadata("launch_config"):
             staging_config = job_directory.load_metadata("launch_config").get("remote_staging", None)
         else:
             staging_config = None
-        collected = __collect_outputs(job_directory, staging_config, action_executor)
+        collected = __collect_outputs(job_directory, staging_config, action_executor, was_cancelled)
         return collected
     finally:
         job_directory.write_file("postprocessed", "")
     return False
 
 
-def __collect_outputs(job_directory, staging_config, action_executor):
+def __collect_outputs(job_directory, staging_config, action_executor, was_cancelled):
     collected = True
     if "action_mapper" in staging_config:
         file_action_mapper = action_mapper.FileActionMapper(config=staging_config["action_mapper"])
         client_outputs = staging.ClientOutputs.from_dict(staging_config["client_outputs"])
         pulsar_outputs = __pulsar_outputs(job_directory)
-        output_collector = PulsarServerOutputCollector(job_directory, action_executor)
+        output_collector = PulsarServerOutputCollector(job_directory, action_executor, was_cancelled)
         results_collector = ResultsCollector(output_collector, file_action_mapper, client_outputs, pulsar_outputs)
         collection_failure_exceptions = results_collector.collect()
         if collection_failure_exceptions:
@@ -62,11 +62,17 @@ def realized_dynamic_file_sources(job_directory):
 
 class PulsarServerOutputCollector:
 
-    def __init__(self, job_directory, action_executor):
+    def __init__(self, job_directory, action_executor, was_cancelled):
         self.job_directory = job_directory
         self.action_executor = action_executor
+        self.was_cancelled = was_cancelled
 
     def collect_output(self, results_collector, output_type, action, name):
+        def action_if_not_cancelled():
+            if self.was_cancelled():
+                log.info(f"Skipped output collection '{name}', job is cancelled")
+                return
+            action.write_from_path(pulsar_path)
         # Not using input path, this is because action knows it path
         # in this context.
         if action.staging_action_local:
@@ -79,7 +85,7 @@ def collect_output(self, results_collector, output_type, action, name):
 
         pulsar_path = self.job_directory.calculate_path(name, output_type)
         description = "staging out file {} via {}".format(pulsar_path, action)
-        self.action_executor.execute(lambda: action.write_from_path(pulsar_path), description)
+        self.action_executor.execute(action_if_not_cancelled, description)
 
 
 def __pulsar_outputs(job_directory):
diff --git a/pulsar/managers/staging/pre.py b/pulsar/managers/staging/pre.py
index 543e9cdf..5411237e 100644
--- a/pulsar/managers/staging/pre.py
+++ b/pulsar/managers/staging/pre.py
@@ -7,8 +7,11 @@
 log = logging.getLogger(__name__)
 
 
-def preprocess(job_directory, setup_actions, action_executor, object_store=None):
+def preprocess(job_directory, setup_actions, action_executor, was_cancelled, object_store=None):
     for setup_action in setup_actions:
+        if was_cancelled():
+            log.info("Exiting preprocessing, job is cancelled")
+            return
         name = setup_action["name"]
         input_type = setup_action["type"]
         action = from_dict(setup_action["action"])
diff --git a/pulsar/managers/stateful.py b/pulsar/managers/stateful.py
index 32b843d4..2f32ce1c 100644
--- a/pulsar/managers/stateful.py
+++ b/pulsar/managers/stateful.py
@@ -3,6 +3,7 @@
 import os
 import threading
 import time
+from functools import partial
 
 try:
     # If galaxy-lib or Galaxy 19.05 present.
@@ -103,6 +104,7 @@ def _launch_prepreprocessing_thread(self, job_id, launch_config):
         def do_preprocess():
             with self._handling_of_preprocessing_state(job_id, launch_config):
                 job_directory = self._proxied_manager.job_directory(job_id)
+                was_cancelled = partial(self._proxied_manager._was_cancelled, job_id)
                 staging_config = launch_config.get("remote_staging", {})
                 # TODO: swap out for a generic "job_extra_params"
                 if 'action_mapper' in staging_config and \
@@ -111,7 +113,7 @@ def do_preprocess():
                     for action in staging_config['setup']:
                         action['action'].update(ssh_key=staging_config['action_mapper']['ssh_key'])
                 setup_config = staging_config.get("setup", [])
-                preprocess(job_directory, setup_config, self.__preprocess_action_executor, object_store=self.object_store)
+                preprocess(job_directory, setup_config, self.__preprocess_action_executor, was_cancelled, object_store=self.object_store)
                 self.active_jobs.deactivate_job(job_id, active_status=ACTIVE_STATUS_PREPROCESSING)
 
         new_thread_for_job(self, "preprocess", job_id, do_preprocess, daemon=False)
@@ -121,6 +123,9 @@ def _handling_of_preprocessing_state(self, job_id, launch_config):
         job_directory = self._proxied_manager.job_directory(job_id)
         try:
             yield
+            if self._proxied_manager._was_cancelled(job_id):
+                log.info("Exiting job launch, job is cancelled")
+                return
             launch_kwds = {}
             if launch_config.get("dependencies_description"):
                 dependencies_description = DependenciesDescription.from_dict(launch_config["dependencies_description"])
@@ -219,8 +224,9 @@ def __handle_postprocessing(self, job_id):
         def do_postprocess():
             postprocess_success = False
             job_directory = self._proxied_manager.job_directory(job_id)
+            was_cancelled = partial(self._proxied_manager._was_cancelled, job_id)
             try:
-                postprocess_success = postprocess(job_directory, self.__postprocess_action_executor)
+                postprocess_success = postprocess(job_directory, self.__postprocess_action_executor, was_cancelled)
             except Exception:
                 log.exception("Failed to postprocess results for job id %s" % job_id)
             final_status = status.COMPLETE if postprocess_success else status.FAILED