Skip to content

Commit

Permalink
Replace CancellationToken with shared variable
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejka committed Nov 27, 2024
1 parent 1900d74 commit cb20442
Showing 1 changed file with 42 additions and 52 deletions.
94 changes: 42 additions & 52 deletions scripts/data/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,31 +30,23 @@
current_weight = 0
weight_lock = threading.Condition()
job_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)


class CancellationToken:
def __init__(self):
self._is_cancelled = threading.Event()

def cancel(self):
self._is_cancelled.set()

def is_cancelled(self):
return self._is_cancelled.is_set()

shutdown_requested = False

class ShutdownRequested(Exception):
"""Raised when shutdown is requested during process execution"""

pass


def run(cmd, timeout=None, cancellation_token=None):
def run(cmd, timeout=None):
"""
Run a subprocess with proper cancellation handling
Run a subprocess with proper shutdown handling
"""
if cancellation_token and cancellation_token.is_cancelled():
raise ShutdownRequested("Cancellation requested before process start")

global shutdown_requested

if shutdown_requested:
raise ShutdownRequested("Shutdown requested before process start")

process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
Expand All @@ -63,14 +55,14 @@ def run(cmd, timeout=None, cancellation_token=None):
try:
stdout, stderr = process.communicate(timeout=timeout)

if cancellation_token and cancellation_token.is_cancelled():
if shutdown_requested:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
raise ShutdownRequested("Cancellation requested during process execution")
raise ShutdownRequested("Shutdown requested during process execution")

return stdout, stderr, process.returncode

Expand Down Expand Up @@ -111,7 +103,7 @@ def __str__(self):

# Generator function to create jobs
def job_generator(
start, blocks, step, mode, strategy, execute_scripts, cancellation_token=None
start, blocks, step, mode, strategy, execute_scripts
):
BASE_DIR.mkdir(exist_ok=True)
end = start + blocks
Expand All @@ -123,8 +115,6 @@ def job_generator(
)

for height in height_range:
if cancellation_token and cancellation_token.is_cancelled():
break
try:
batch_file = BASE_DIR / f"{mode}_{height}_{step}.json"

Expand All @@ -142,7 +132,7 @@ def job_generator(
logger.error(f"Error while generating data for: {height}:\n{e}")


def process_batch(job, cancellation_token=None):
def process_batch(job):
arguments_file = job.batch_file.as_posix().replace(".json", "-arguments.json")

with open(arguments_file, "w") as af:
Expand All @@ -159,8 +149,7 @@ def process_batch(job, cancellation_token=None):
"main",
"--arguments-file",
str(arguments_file),
],
cancellation_token=cancellation_token,
]
)

if (
Expand Down Expand Up @@ -205,30 +194,29 @@ def process_batch(job, cancellation_token=None):
logger.warning(f"{job}: no gas info found")

# Producer function: Generates data and adds jobs to the queue
def job_producer(job_gen, cancellation_token=None):
def job_producer(job_gen):
global current_weight

global shutdown_requested

try:
for job, weight in job_gen:
if cancellation_token and cancellation_token.is_cancelled():
break
# if shutdown_requested:
# break

# Wait until there is enough weight capacity to add the new block
with weight_lock:
logger.debug(
f"Adding job: {job}, current total weight: {current_weight}..."
)
while not (
cancellation_token and cancellation_token.is_cancelled()
) and (
while not shutdown_requested and (
(current_weight + weight > MAX_WEIGHT_LIMIT)
and current_weight != 0
or job_queue.full()
):
logger.debug("Producer is waiting for weight to be released.")
weight_lock.wait(timeout=1.0)

if cancellation_token and cancellation_token.is_cancelled():
if shutdown_requested:
break

if (current_weight + weight > MAX_WEIGHT_LIMIT) and current_weight == 0:
Expand All @@ -246,7 +234,7 @@ def job_producer(job_gen, cancellation_token=None):
finally:
logger.debug("Producer is exiting...")
# Signal end of jobs to consumers
if not(cancellation_token and cancellation_token.is_cancelled()):
if not shutdown_requested:
for _ in range(THREAD_POOL_SIZE):
logger.warning(f"Producer is putting None into the queue..., full: {job_queue.full()}")
job_queue.put(None, block=False)
Expand All @@ -258,15 +246,16 @@ def job_producer(job_gen, cancellation_token=None):


# Consumer function: Processes blocks from the queue
def job_consumer(process_job, cancellation_token=None):
def job_consumer(process_job):
global current_weight
global shutdown_requested

while not (cancellation_token and cancellation_token.is_cancelled()):
while not shutdown_requested:
try:
logger.debug(
f"Consumer is waiting for a job. Queue length: {job_queue.qsize()}"
)
# Get a job from the queue with timeout to check cancellation
# Get a job from the queue with timeout to check shutdown
try:
work_to_do = job_queue.get(timeout=1.0)
except queue.Empty:
Expand All @@ -279,7 +268,7 @@ def job_consumer(process_job, cancellation_token=None):

(job, weight) = work_to_do

if cancellation_token and cancellation_token.is_cancelled():
if shutdown_requested:
with weight_lock:
current_weight -= weight
weight_lock.notify_all()
Expand All @@ -289,9 +278,9 @@ def job_consumer(process_job, cancellation_token=None):
# Process the block
try:
logger.debug(f"Executing job: {job}...")
process_job(job, cancellation_token)
process_job(job)
except ShutdownRequested:
logger.debug(f"Cancellation requested while processing {job}")
logger.debug(f"Shutdown requested while processing {job}")
return
except subprocess.TimeoutExpired:
logger.warning(f"Timeout while terminating subprocess for {job}")
Expand All @@ -309,21 +298,22 @@ def job_consumer(process_job, cancellation_token=None):
job_queue.task_done()

except Exception as e:
if not (cancellation_token and cancellation_token.is_cancelled()):
# TODO: is this necessary?
if not shutdown_requested:
logger.error("Error in the consumer: %s", e)
break

logger.debug("Job consumer done.")

def main(start, blocks, step, mode, strategy, execute_scripts):
# Create a centralized cancellation mechanism
cancellation_token = CancellationToken()
global shutdown_requested

# Set up signal handlers to use the cancellation token
# Set up signal handlers
def signal_handler(signum, frame):
global shutdown_requested
signal_name = signal.Signals(signum).name
logger.info(f"Received signal {signal_name}. Initiating graceful shutdown...")
cancellation_token.cancel()
shutdown_requested = True

signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
Expand All @@ -346,37 +336,37 @@ def signal_handler(signum, frame):

# Create the job generator
job_gen = job_generator(
start, blocks, step, mode, strategy, execute_scripts, cancellation_token
start, blocks, step, mode, strategy, execute_scripts
)

# Start the job producer thread
producer_thread = threading.Thread(
target=job_producer, args=(job_gen, cancellation_token)
target=job_producer, args=(job_gen, )
)
producer_thread.start()

# Start the consumer threads using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=THREAD_POOL_SIZE) as executor:
futures = [
executor.submit(job_consumer, process_batch, cancellation_token)
executor.submit(job_consumer, process_batch)
for _ in range(THREAD_POOL_SIZE)
]

# Wait for producer to finish or cancellation
# Wait for producer to finish or shutdown
producer_thread.join()

# Wait for all items in the queue to be processed or cancellation
# Wait for all items in the queue to be processed or shutdown
while (
not (cancellation_token and cancellation_token.is_cancelled())
not shutdown_requested
and not job_queue.empty()
):
try:
job_queue.join()
break
except KeyboardInterrupt:
cancellation_token.cancel()
shutdown_requested = True

if cancellation_token.is_cancelled():
if shutdown_requested:
logger.info("Shutdown complete.")
else:
logger.info("All jobs have been processed.")
Expand Down

0 comments on commit cb20442

Please sign in to comment.