Skip to content

Commit

Permalink
fix: Add SIGTERM handling (#297)
Browse files Browse the repository at this point in the history
Co-authored-by: Maciej Kamiński <[email protected]>
  • Loading branch information
raizo07 and maciejka authored Nov 27, 2024
1 parent 4a68be6 commit 9025ce4
Showing 1 changed file with 119 additions and 32 deletions.
151 changes: 119 additions & 32 deletions scripts/data/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import random
import signal
from generate_data import generate_data
from format_args import format_args
import logging
from logging.handlers import TimedRotatingFileHandler

logger = logging.getLogger(__name__)
Expand All @@ -30,6 +30,51 @@
current_weight = 0
weight_lock = threading.Condition()
job_queue = queue.Queue(maxsize=QUEUE_MAX_SIZE)
shutdown_requested = False


class ShutdownRequested(Exception):
"""Raised when shutdown is requested during process execution"""

pass


def run(cmd, timeout=None):
"""
Run a subprocess with proper shutdown handling
"""

global shutdown_requested

if shutdown_requested:
raise ShutdownRequested("Shutdown requested before process start")

process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)

try:
stdout, stderr = process.communicate(timeout=timeout)

if shutdown_requested:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
raise ShutdownRequested("Shutdown requested during process execution")

return stdout, stderr, process.returncode

except subprocess.TimeoutExpired:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
raise


# Function to calculate weight of a block
Expand Down Expand Up @@ -86,14 +131,13 @@ def job_generator(start, blocks, step, mode, strategy, execute_scripts):
logger.error(f"Error while generating data for: {height}:\n{e}")


# Function to process a batch
def process_batch(job):
arguments_file = job.batch_file.as_posix().replace(".json", "-arguments.json")

with open(arguments_file, "w") as af:
af.write(str(format_args(job.batch_file, job.execute_scripts, False)))

result = subprocess.run(
stdout, stderr, returncode = run(
[
"scarb",
"cairo-run",
Expand All @@ -104,20 +148,13 @@ def process_batch(job):
"main",
"--arguments-file",
str(arguments_file),
],
capture_output=True,
text=True,
]
)

if (
result.returncode != 0
or "FAIL" in result.stdout
or "error" in result.stdout
or "panicked" in result.stdout
):
error = result.stdout or result.stderr
if result.returncode == -9:
match = re.search(r"gas_spent=(\d+)", result.stdout)
if returncode != 0 or "FAIL" in stdout or "error" in stdout or "panicked" in stdout:
error = stdout or stderr
if returncode == -9:
match = re.search(r"gas_spent=(\d+)", stdout)
gas_info = (
f", gas spent: {int(match.group(1))}"
if match
Expand All @@ -142,31 +179,38 @@ def process_batch(job):
logger.error(f"{job} error: {message}")
logger.debug(f"Full error while processing: {job}:\n{error}")
else:
match = re.search(r"gas_spent=(\d+)", result.stdout)
match = re.search(r"gas_spent=(\d+)", stdout)
gas_info = f"gas spent: {int(match.group(1))}" if match else "no gas info found"
logger.info(f"{job} done, {gas_info}")
if not match:
logger.warning(f"{job}: not gas info found")
logger.warning(f"{job}: no gas info found")


# Producer function: Generates data and adds jobs to the queue
def job_producer(job_gen):
global current_weight
global shutdown_requested

try:
for job, weight in job_gen:
# if shutdown_requested:
# break

# Wait until there is enough weight capacity to add the new block
with weight_lock:
logger.debug(
f"Adding job: {job}, current total weight: {current_weight}..."
)
while (
while not shutdown_requested and (
(current_weight + weight > MAX_WEIGHT_LIMIT)
and current_weight != 0
or job_queue.full()
):
logger.debug("Producer is waiting for weight to be released.")
weight_lock.wait() # Wait for the condition to be met
weight_lock.wait(timeout=1.0)

if shutdown_requested:
break

if (current_weight + weight > MAX_WEIGHT_LIMIT) and current_weight == 0:
logger.warning(f"{job} over the weight limit: {MAX_WEIGHT_LIMIT}")
Expand All @@ -182,8 +226,13 @@ def job_producer(job_gen):
weight_lock.notify_all()
finally:
logger.debug("Producer is exiting...")
for _ in range(THREAD_POOL_SIZE):
job_queue.put(None)
# Signal end of jobs to consumers
if not shutdown_requested:
for _ in range(THREAD_POOL_SIZE):
logger.warning(
f"Producer is putting None into the queue..., full: {job_queue.full()}"
)
job_queue.put(None, block=False)

with weight_lock:
weight_lock.notify_all()
Expand All @@ -194,14 +243,18 @@ def job_producer(job_gen):
# Consumer function: Processes blocks from the queue
def job_consumer(process_job):
global current_weight
global shutdown_requested

while True:
while not shutdown_requested:
try:
logger.debug(
f"Consumer is waiting for a job. Queue length: {job_queue.qsize()}"
)
# Get a job from the queue
work_to_do = job_queue.get(block=True)
# Get a job from the queue with timeout to check shutdown
try:
work_to_do = job_queue.get(timeout=1.0)
except queue.Empty:
continue

if work_to_do is None:
logger.debug("No more work to do, consumer is exiting.")
Expand All @@ -210,10 +263,22 @@ def job_consumer(process_job):

(job, weight) = work_to_do

if shutdown_requested:
with weight_lock:
current_weight -= weight
weight_lock.notify_all()
job_queue.task_done()
break

# Process the block
try:
logger.debug(f"Executing job: {job}...")
process_job(job)
except ShutdownRequested:
logger.debug(f"Shutdown requested while processing {job}")
return
except subprocess.TimeoutExpired:
logger.warning(f"Timeout while terminating subprocess for {job}")
except Exception as e:
logger.error(f"Error while processing job: {job}:\n{e}")

Expand All @@ -228,11 +293,26 @@ def job_consumer(process_job):
job_queue.task_done()

except Exception as e:
logger.error("Error in the consumer: %s", e)
# TODO: is this necessary?
if not shutdown_requested:
logger.error("Error in the consumer: %s", e)
break

logger.debug("Job consumer done.")


def main(start, blocks, step, mode, strategy, execute_scripts):
global shutdown_requested

# Set up signal handlers
def signal_handler(signum, frame):
global shutdown_requested
signal_name = signal.Signals(signum).name
logger.info(f"Received signal {signal_name}. Initiating graceful shutdown...")
shutdown_requested = True

signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)

logger.info(
"Starting client, initial height: %d, blocks: %d, step: %d, mode: %s, strategy: %s, execute_scripts: %s",
Expand Down Expand Up @@ -264,17 +344,24 @@ def main(start, blocks, step, mode, strategy, execute_scripts):
for _ in range(THREAD_POOL_SIZE)
]

# Wait for producer to finish
producer_thread.join()
# Wait for producer to finish or shutdown
producer_thread.join()

# Wait for all items in the queue to be processed
job_queue.join()
# Wait for all items in the queue to be processed or shutdown
while not shutdown_requested and not job_queue.empty():
try:
job_queue.join()
break
except KeyboardInterrupt:
shutdown_requested = True

logger.info("All jobs have been processed.")
if shutdown_requested:
logger.info("Shutdown complete.")
else:
logger.info("All jobs have been processed.")


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="Run client script")
parser.add_argument("--start", type=int, required=True, help="Start block height")
parser.add_argument(
Expand Down Expand Up @@ -313,7 +400,7 @@ def main(start, blocks, step, mode, strategy, execute_scripts):

MAX_WEIGHT_LIMIT = args.maxweight

# file_handler = logging.FileHandler("client.log")
# Logging setup
file_handler = TimedRotatingFileHandler(
filename="client.log",
when="midnight",
Expand Down

0 comments on commit 9025ce4

Please sign in to comment.