diff --git a/backend/staticfiles/synthesis/main.py b/backend/staticfiles/synthesis/main.py index 52587889..47e75263 100644 --- a/backend/staticfiles/synthesis/main.py +++ b/backend/staticfiles/synthesis/main.py @@ -1,22 +1,29 @@ import importlib import json +import functools import multiprocessing import random import string from multiprocessing import shared_memory, Lock from time import sleep +import signal +import sys from lib.inputs import Inputs from lib.outputs import Outputs from lib.parameters import Parameters from lib.utils import Synchronise - BLOCK_DIRECTORY = 'modules' FUNCTION_NAME = 'main' -def clean_shared_memory(names): +def clean_shared_memory(signum, frame, names, processes): + # End all processes + for process in processes: + process.terminate() + process.join() + all_names = list(names.keys()) all_names.extend([name + "_dim" for name in names]) all_names.extend([name + "_shape" for name in names]) @@ -38,6 +45,10 @@ def clean_shared_memory(names): except ValueError: pass + # Exit the program + print("Exiting program.") + sys.exit(0) + def main(): """ @@ -116,6 +127,10 @@ def main(): multiprocessing.Process(target=method, args=(inputs, outputs, parameters, Synchronise(1 / (freq if freq != 0 else 30)))) ) + # Register handler for Ctrl+C + param_func = functools.partial(clean_shared_memory, names=all_wires, processes=processes) + signal.signal(signal.SIGINT, param_func) + for process in processes: process.start() @@ -123,10 +138,7 @@ def main(): while True: sleep(10) except KeyboardInterrupt: - for process in processes: - process.terminate() - process.join() - clean_shared_memory(all_wires) + pass if __name__ == "__main__":