From cb20442200b380d676231b7701ae05b7e7853e11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Kami=C5=84ski?= Date: Wed, 27 Nov 2024 16:02:35 +0000 Subject: [PATCH] Replace CancellationToken with shared variable --- scripts/data/client.py | 94 +++++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 52 deletions(-) diff --git a/scripts/data/client.py b/scripts/data/client.py index b40e5832..468a68dd 100755 --- a/scripts/data/client.py +++ b/scripts/data/client.py @@ -30,18 +30,7 @@ current_weight = 0 weight_lock = threading.Condition() job_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE) - - -class CancellationToken: - def __init__(self): - self._is_cancelled = threading.Event() - - def cancel(self): - self._is_cancelled.set() - - def is_cancelled(self): - return self._is_cancelled.is_set() - +shutdown_requested = False class ShutdownRequested(Exception): """Raised when shutdown is requested during process execution""" @@ -49,12 +38,15 @@ class ShutdownRequested(Exception): pass -def run(cmd, timeout=None, cancellation_token=None): +def run(cmd, timeout=None): """ - Run a subprocess with proper cancellation handling + Run a subprocess with proper shutdown handling """ - if cancellation_token and cancellation_token.is_cancelled(): - raise ShutdownRequested("Cancellation requested before process start") + + global shutdown_requested + + if shutdown_requested: + raise ShutdownRequested("Shutdown requested before process start") process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True @@ -63,14 +55,14 @@ def run(cmd, timeout=None, cancellation_token=None): try: stdout, stderr = process.communicate(timeout=timeout) - if cancellation_token and cancellation_token.is_cancelled(): + if shutdown_requested: process.terminate() try: process.wait(timeout=5) except subprocess.TimeoutExpired: process.kill() process.wait() - raise ShutdownRequested("Cancellation requested during process execution") + raise ShutdownRequested("Shutdown requested during process execution") return stdout, stderr, process.returncode @@ -111,7 +103,7 @@ def __str__(self): # Generator function to create jobs def job_generator( - start, blocks, step, mode, strategy, execute_scripts, cancellation_token=None + start, blocks, step, mode, strategy, execute_scripts ): BASE_DIR.mkdir(exist_ok=True) end = start + blocks @@ -123,8 +115,6 @@ def job_generator( ) for height in height_range: - if cancellation_token and cancellation_token.is_cancelled(): - break try: batch_file = BASE_DIR / f"{mode}_{height}_{step}.json" @@ -142,7 +132,7 @@ def job_generator( logger.error(f"Error while generating data for: {height}:\n{e}") -def process_batch(job, cancellation_token=None): +def process_batch(job): arguments_file = job.batch_file.as_posix().replace(".json", "-arguments.json") with open(arguments_file, "w") as af: @@ -159,8 +149,7 @@ def process_batch(job, cancellation_token=None): "main", "--arguments-file", str(arguments_file), - ], - cancellation_token=cancellation_token, + ] ) if ( @@ -205,22 +194,21 @@ def process_batch(job, cancellation_token=None): logger.warning(f"{job}: no gas info found") # Producer function: Generates data and adds jobs to the queue -def job_producer(job_gen, cancellation_token=None): +def job_producer(job_gen): global current_weight - + global shutdown_requested + try: for job, weight in job_gen: - if cancellation_token and cancellation_token.is_cancelled(): - break + # if shutdown_requested: + # break # Wait until there is enough weight capacity to add the new block with weight_lock: logger.debug( f"Adding job: {job}, current total weight: {current_weight}..." ) - while not ( - cancellation_token and cancellation_token.is_cancelled() - ) and ( + while not shutdown_requested and ( (current_weight + weight > MAX_WEIGHT_LIMIT) and current_weight != 0 or job_queue.full() @@ -228,7 +216,7 @@ def job_producer(job_gen, cancellation_token=None): logger.debug("Producer is waiting for weight to be released.") weight_lock.wait(timeout=1.0) - if cancellation_token and cancellation_token.is_cancelled(): + if shutdown_requested: break if (current_weight + weight > MAX_WEIGHT_LIMIT) and current_weight == 0: @@ -246,7 +234,7 @@ def job_producer(job_gen, cancellation_token=None): finally: logger.debug("Producer is exiting...") # Signal end of jobs to consumers - if not(cancellation_token and cancellation_token.is_cancelled()): + if not shutdown_requested: for _ in range(THREAD_POOL_SIZE): logger.warning(f"Producer is putting None into the queue..., full: {job_queue.full()}") job_queue.put(None, block=False) @@ -258,15 +246,16 @@ def job_producer(job_gen, cancellation_token=None): # Consumer function: Processes blocks from the queue -def job_consumer(process_job, cancellation_token=None): +def job_consumer(process_job): global current_weight + global shutdown_requested - while not (cancellation_token and cancellation_token.is_cancelled()): + while not shutdown_requested: try: logger.debug( f"Consumer is waiting for a job. Queue length: {job_queue.qsize()}" ) - # Get a job from the queue with timeout to check cancellation + # Get a job from the queue with timeout to check shutdown try: work_to_do = job_queue.get(timeout=1.0) except queue.Empty: @@ -279,7 +268,7 @@ def job_consumer(process_job, cancellation_token=None): (job, weight) = work_to_do - if cancellation_token and cancellation_token.is_cancelled(): + if shutdown_requested: with weight_lock: current_weight -= weight weight_lock.notify_all() @@ -289,9 +278,9 @@ def job_consumer(process_job, cancellation_token=None): # Process the block try: logger.debug(f"Executing job: {job}...") - process_job(job, cancellation_token) + process_job(job) except ShutdownRequested: - logger.debug(f"Cancellation requested while processing {job}") + logger.debug(f"Shutdown requested while processing {job}") return except subprocess.TimeoutExpired: logger.warning(f"Timeout while terminating subprocess for {job}") @@ -309,21 +298,22 @@ def job_consumer(process_job, cancellation_token=None): job_queue.task_done() except Exception as e: - if not (cancellation_token and cancellation_token.is_cancelled()): + # TODO: is this necessary? + if not shutdown_requested: logger.error("Error in the consumer: %s", e) break logger.debug("Job consumer done.") def main(start, blocks, step, mode, strategy, execute_scripts): - # Create a centralized cancellation mechanism - cancellation_token = CancellationToken() + global shutdown_requested - # Set up signal handlers to use the cancellation token + # Set up signal handlers def signal_handler(signum, frame): + global shutdown_requested signal_name = signal.Signals(signum).name logger.info(f"Received signal {signal_name}. Initiating graceful shutdown...") - cancellation_token.cancel() + shutdown_requested = True signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) @@ -346,37 +336,37 @@ def signal_handler(signum, frame): # Create the job generator job_gen = job_generator( - start, blocks, step, mode, strategy, execute_scripts, cancellation_token + start, blocks, step, mode, strategy, execute_scripts ) # Start the job producer thread producer_thread = threading.Thread( - target=job_producer, args=(job_gen, cancellation_token) + target=job_producer, args=(job_gen, ) ) producer_thread.start() # Start the consumer threads using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=THREAD_POOL_SIZE) as executor: futures = [ - executor.submit(job_consumer, process_batch, cancellation_token) + executor.submit(job_consumer, process_batch) for _ in range(THREAD_POOL_SIZE) ] - # Wait for producer to finish or cancellation + # Wait for producer to finish or shutdown producer_thread.join() - # Wait for all items in the queue to be processed or cancellation + # Wait for all items in the queue to be processed or shutdown while ( - not (cancellation_token and cancellation_token.is_cancelled()) + not shutdown_requested and not job_queue.empty() ): try: job_queue.join() break except KeyboardInterrupt: - cancellation_token.cancel() + shutdown_requested = True - if cancellation_token.is_cancelled(): + if shutdown_requested: logger.info("Shutdown complete.") else: logger.info("All jobs have been processed.")