diff --git a/cozy/common.py b/cozy/common.py index b05ff399..54c84c1a 100644 --- a/cozy/common.py +++ b/cozy/common.py @@ -553,6 +553,13 @@ def compare_with_lt(x, y): else: return 0 +def never_stop(): + """Takes no arguments, always returns False. + + This is the default "stop callback" value for procedures that require one. + """ + return False + class StopException(Exception): """ Used to indicate that a process should stop operation. diff --git a/cozy/cost_model.py b/cozy/cost_model.py index 82d2e15e..9b68ed1b 100644 --- a/cozy/cost_model.py +++ b/cozy/cost_model.py @@ -118,7 +118,8 @@ def __init__(self, examples = (), funcs = (), freebies : [Exp] = [], - ops : [Op] = []): + ops : [Op] = [], + solver_args : dict = {}): """ assumptions : assumed to be true when comparing expressions examples : initial examples (the right set of examples can speed up @@ -129,7 +130,7 @@ def __init__(self, ops : mutators which are used to determine how expensive it is to maintain a state variable """ - self.solver = ModelCachingSolver(vars=(), funcs=funcs, examples=examples, assumptions=assumptions) + self.solver = ModelCachingSolver(vars=(), funcs=funcs, examples=examples, assumptions=assumptions, **solver_args) self.assumptions = assumptions # self.examples = list(examples) self.funcs = OrderedDict(funcs) diff --git a/cozy/jobs.py b/cozy/jobs.py index 731748e8..af5094e7 100644 --- a/cozy/jobs.py +++ b/cozy/jobs.py @@ -17,15 +17,58 @@ `multiprocessing` module. """ -from multiprocessing import Process, Array, Queue +import os +import multiprocessing from queue import Queue as PlainQueue, Empty, Full import threading +import signal from cozy.common import partition from cozy.opts import Option do_profiling = Option("profile", bool, False, description="Profile Cozy itself") +_interrupted = False +def _set_interrupt_flag(signal_number, stack_frame): + global _interrupted + # print("GOT INTERRUPTED") + # import traceback + # traceback.print_stack(stack_frame) + _interrupted = True + +def install_graceful_sigint_handler(): + """Install a graceful handler for SIGINT. + + The handler sets a flag to true when SIGINT happens and does nothing else. + Use `was_interrupted()` to check the flag. + + Note: the installed handler is inherited by child processes. The + Job.stop_requested property checks the SIGINT flag in addition to its own + private flag, giving an additional cross-process way to stop a running job + gracefully. + """ + signal.signal(signal.SIGINT, _set_interrupt_flag) + +def was_interrupted(): + """Determine if SIGINT was sent to this process. + + Precisely, this procedure returns true if a SIGINT signal was ever + delivered to this process after the first time install_graceful_sigint_handler() + was called. + """ + return _interrupted + +# This module uses the "spawn" method for multiprocessing interaction. This is +# a little bit of forward-compatibility. The "spawn" context is the default on +# Windows (always) and MacOS (in Python 3.8+). It was introduced in Python +# 3.4. The "spawn" context behaves a bit differently from the "fork" context +# used by default on Linux. In particular: +# - It is allegely a bit slower (but I haven't seen much difference). +# - More objects need to be pickled. +# - It is less likely to cause crashes due to MacOS's bad fork() +# implementation (https://bugs.python.org/issue33725). +multiprocessing_context = multiprocessing.get_context("spawn") + class Job(object): """An interruptable job. @@ -65,12 +108,36 @@ def run(self): way to retrieve the thrown exception, if any. """ + # ------------------------------------------------------------------------- + # Names for cell indexes in a Job's flags array. For efficiency, the flags + # are represented as a compact boolean array in shared memory. These + # constants give human-readable names to the elements of that array. All + # flags are initially false. + + # The parent process sets this flag to true when it wants to request that + # a Job should stop gracefully. + _STOP_REQUESTED_FLAG = 0 + + # The Job sets this flag to true when it finishes, regardless of the + # outcome. + _DONE_FLAG = 1 + + # The Job sets this flag to true when it finishes without an exception. + _DONE_NORMALLY_FLAG = 2 + + # The Job sets this flag to true after it installs a handler for SIGINT. + # The parent should only send SIGINT if this flag is true. If this flag is + # false, then the handler MAY OR MAY NOT have been installed yet. + _SIGINT_HANDLER_INSTALLED_FLAG = 3 + + # The total number of flags. + _FLAG_COUNT = 4 + + # ------------------------------------------------------------------------- + def __init__(self): - self._thread = Process(target=self._run, daemon=True) - self._flags = Array("b", [False] * 3) - # flags[0] - stop_requested? - # flags[1] - done? - # flags[2] - true iff completed with no exception + self._thread = multiprocessing_context.Process(target=self._run, daemon=True) + self._flags = multiprocessing_context.Array("b", [False] * Job._FLAG_COUNT) def start(self): """Start the job by invoking its .run() method asynchronously.""" @@ -82,6 +149,8 @@ def run(self): def _run(self): """Private helper that wraps .run() and sets various exit flags.""" + install_graceful_sigint_handler() + self._flags[Job._SIGINT_HANDLER_INSTALLED_FLAG] = True try: if do_profiling.value: import cProfile @@ -91,11 +160,12 @@ def _run(self): cProfile.runctx("self.run()", globals(), locals(), filename=filename) else: self.run() - self._flags[2] = True + self._flags[Job._DONE_NORMALLY_FLAG] = True except Exception as e: import traceback traceback.print_exc() - self._flags[1] = True + finally: + self._flags[Job._DONE_FLAG] = True @property def stop_requested(self): @@ -103,17 +173,17 @@ def stop_requested(self): The implementation of .run() should check this periodically and return when it becomes True.""" - return self._flags[0] + return was_interrupted() or self._flags[Job._STOP_REQUESTED_FLAG] @property def done(self): """True if the job has stopped.""" - return self._flags[1] or (self._thread.exitcode is not None) + return self._flags[Job._DONE_FLAG] or (self._thread.exitcode is not None) @property def successful(self): """True if the job has stopped without throwing an uncaught exception.""" - return self._flags[2] + return self._flags[Job._DONE_NORMALLY_FLAG] def request_stop(self): """Request a graceful stop. @@ -121,9 +191,29 @@ def request_stop(self): Causes this Job's .stop_requested property to become True. Clients can call .join() to wait for the job to wrap up. + + This method delivers a SIGINT to the job process, interrupting any + running Z3 solver call. """ print("requesting stop for {}".format(self)) - self._flags[0] = True + self._flags[Job._STOP_REQUESTED_FLAG] = True + + # Ah, there's a bit of danger here (time-of-check to time-of-use bug): + # (1) is_alive() returns true + # (2) the job process exits + # (3) its PID is reassigned to a different process + # (4) oops, we deliver SIGINT to the wrong process! + # Sadly, Python doesn't give us a way to access the actual underlying + # process handle, so (I think) this is the best we can do. In fact, + # the actual Python source code has the same bug: + # https://github.com/python/cpython/blob/3.8/Lib/multiprocessing/popen_fork.py#L50 + if self._flags[Job._SIGINT_HANDLER_INSTALLED_FLAG] and self._thread.is_alive(): + try: + os.kill(self._thread.pid, signal.SIGINT) + except ProcessLookupError: + # This can happen if the job finished in the background between + # the check and the kill call. + pass def join(self, timeout=None): """Wait for the job to finish and clean up its resources. @@ -220,7 +310,7 @@ class SafeQueue(object): - This queue needs to be closed. Proper usage example: with SafeQueue() as q: - # spawn processes to insert items into q + # spawn processes to insert items into q.handle_for_subjobs() # get items from q # join spawned processes @@ -228,7 +318,7 @@ class SafeQueue(object): """ def __init__(self, queue_to_wrap=None): if queue_to_wrap is None: - queue_to_wrap = Queue() + queue_to_wrap = multiprocessing_context.Queue() self.q = queue_to_wrap self.sideq = PlainQueue() self.stop_requested = False @@ -273,3 +363,12 @@ def drain(self, block=False, timeout=None): res.append(self.get(block=False)) except Empty: return res + def handle_for_subjobs(self): + """Obtain a handle that can be passed to a Job. + + Due to the limitations of Python's multiprocessing module, a SafeQueue + cannot be passed as an argument to a Job. This method returns a Queue + object that can. The parent is still responsible for holding onto this + object and cleaning it up. + """ + return self.q diff --git a/cozy/main.py b/cozy/main.py index 8983d81c..cb7ed8a4 100755 --- a/cozy/main.py +++ b/cozy/main.py @@ -10,8 +10,6 @@ import datetime import pickle -from multiprocessing import Value - from cozy import parse from cozy import codegen from cozy import common @@ -22,6 +20,7 @@ from cozy import synthesis from cozy.structures import rewriting from cozy import opts +from cozy import jobs save_failed_codegen_inputs = opts.Option("save-failed-codegen-inputs", str, "/tmp/failed_codegen.py", metavar="PATH") checkpoint_prefix = opts.Option("checkpoint-prefix", str, "") @@ -60,7 +59,12 @@ def run(): args = parser.parse_args() opts.read(args) - improve_count = Value('i', 0) + # Install a handler for SIGINT, the signal that is delivered when you + # Ctrl+C a process. This allows Cozy to exit cleanly when it is + # interrupted. If you need to stop Cozy forcibly, use SIGTERM or SIGKILL. + jobs.install_graceful_sigint_handler() + + improve_count = jobs.multiprocessing_context.Value('i', 0) if args.resume: with common.open_maybe_stdin(args.file or "-", mode="rb") as f: diff --git a/cozy/solver.py b/cozy/solver.py index 8f9842c3..53d8c6d4 100644 --- a/cozy/solver.py +++ b/cozy/solver.py @@ -12,13 +12,14 @@ from datetime import datetime, timedelta from functools import lru_cache import threading +from typing import Callable import z3 from cozy.target_syntax import * from cozy.syntax_tools import BottomUpExplorer, pprint, free_vars, free_funcs, cse, all_exps, purify from cozy.typecheck import is_collection, is_numeric -from cozy.common import declare_case, fresh_name, Visitor, FrozenDict, typechecked, extend, OrderedSet, make_random_access +from cozy.common import declare_case, fresh_name, Visitor, FrozenDict, typechecked, extend, OrderedSet, make_random_access, StopException, never_stop from cozy import evaluation from cozy.opts import Option from cozy.structures import extension_handler @@ -224,13 +225,22 @@ def grid(rows, cols): return [[None for c in range(cols)] for r in range(rows)] class ToZ3(Visitor): - def __init__(self, z3ctx, z3solver): + def __init__(self, z3ctx, z3solver, stop_callback : Callable[[], bool]): + """Create a "ToZ3" object that can convert ASTs to Z3 queries. + + :param z3ctx: A Z3 context object to use. + :param z3solver: A Z3 solver object to use. + :param stop_callback: A zero-argument function that will be checked +      periodically. The solver will raise a + StopException when the callback returns True. + """ self.ctx = z3ctx self.solver = z3solver self.int_zero = z3.IntVal(0, self.ctx) self.int_one = z3.IntVal(1, self.ctx) self.true = z3.BoolVal(True, self.ctx) self.false = z3.BoolVal(False, self.ctx) + self.stop_callback = stop_callback assert to_bool(self.true) is True assert to_bool(self.false) is False assert to_bool(self.int_zero) is None @@ -285,6 +295,9 @@ def gt(self, t, e1, e2, env, deep=False): else: raise NotImplementedError(t) def eq(self, t, e1, e2, deep=False): + if self.stop_callback(): + raise StopException("interrupted while encoding equality constraint") + if e1 is e2: return self.true @@ -831,6 +844,8 @@ def visit_AstRef(self, e, env): def visit_bool(self, e, env): return z3.BoolVal(e, self.ctx) def visit(self, e, *args): + if self.stop_callback(): + raise StopException("interrupted while encoding {}".format(pprint(e))) try: return super().visit(e, *args) except KeyboardInterrupt: @@ -1005,7 +1020,8 @@ def __init__(self, model_callback = None, logic : str = None, timeout : float = None, - do_cse : bool = True): + do_cse : bool = True, + stop_callback : Callable[[], bool] = never_stop): if collection_depth is None: collection_depth = collection_depth_opt.value @@ -1016,6 +1032,7 @@ def __init__(self, self.collection_depth = collection_depth self.validate_model = validate_model self.model_callback = model_callback + self.stop_callback = stop_callback self._env = OrderedDict() self.stk = [] self.do_cse = do_cse @@ -1026,7 +1043,7 @@ def __init__(self, if timeout is not None: solver.set("timeout", int(timeout * 1000)) solver.set("core.validate", validate_model) - visitor = ToZ3(ctx, solver) + visitor = ToZ3(ctx, solver, stop_callback=stop_callback) self.visitor = visitor self.z3_solver = solver @@ -1155,6 +1172,11 @@ def reconstruct(model, value, type): with task("invoke Z3"): res = solver.check() _tock(e, "solve") + + if self.stop_callback(): + solver.pop() + raise StopException("stop requested during Z3 solver call") + if res == z3.unsat: solver.pop() return None @@ -1267,13 +1289,13 @@ class ModelCachingSolver(object): calls can often be avoided using a counterexample found on a previous call. """ - def __init__(self, vars : [EVar], funcs : { str : TFunc }, examples : [dict] = (), assumptions : Exp = ETRUE): + def __init__(self, vars : [EVar], funcs : { str : TFunc }, examples : [dict] = (), assumptions : Exp = ETRUE, **kwargs): self.vars = list(vars) self.funcs = OrderedDict(funcs) self.calls = 0 self.hits = 0 self.examples = list(examples) - self.solver = IncrementalSolver(vars=vars, funcs=funcs) + self.solver = IncrementalSolver(vars=vars, funcs=funcs, **kwargs) self.solver.add_assumption(assumptions) def satisfy(self, e): @@ -1295,8 +1317,9 @@ def valid(self, e): return not self.satisfiable(ENot(e)) @lru_cache() -def solver_for_context(context : Context, assumptions : Exp = ETRUE): +def solver_for_context(context : Context, assumptions : Exp = ETRUE, **kwargs): return ModelCachingSolver( vars = [v for v, _ in context.vars()], funcs = context.funcs(), - assumptions = assumptions) + assumptions = assumptions, + **kwargs) diff --git a/cozy/synthesis/core.py b/cozy/synthesis/core.py index 94d99b45..d6fd9a5f 100644 --- a/cozy/synthesis/core.py +++ b/cozy/synthesis/core.py @@ -31,7 +31,7 @@ from cozy.typecheck import is_collection, is_scalar from cozy.syntax_tools import subst, pprint, free_vars, fresh_var, alpha_equivalent, strip_EStateVar, freshen_binders, wrap_naked_statevars, break_conj, inline_lets from cozy.wf import exp_wf -from cozy.common import No, unique, OrderedSet, StopException +from cozy.common import No, unique, OrderedSet, StopException, never_stop from cozy.solver import valid, solver_for_context, ModelCachingSolver from cozy.evaluation import construct_value from cozy.cost_model import CostModel, Order, LINEAR_TIME_UOPS @@ -100,10 +100,6 @@ allow_random_assignment_heuristic = Option("allow-random-assignment-heuristic", bool, True, description="Use a random assignment heuristic instead of solver to solve sat/unsat problem") -def never_stop(): - """Takes no arguments, always returns False.""" - return False - def improve( target : Exp, context : Context, @@ -197,7 +193,10 @@ def improve( vars = list(v for (v, p) in context.vars()) funcs = context.funcs() - solver = solver_for_context(context, assumptions=assumptions) + solver = solver_for_context( + context, + assumptions=assumptions, + stop_callback=stop_callback) if not solver.satisfiable(ETRUE): print("assumptions are unsat; this query will never be called") diff --git a/cozy/synthesis/high_level_interface.py b/cozy/synthesis/high_level_interface.py index bc28d1c2..fb5ce747 100644 --- a/cozy/synthesis/high_level_interface.py +++ b/cozy/synthesis/high_level_interface.py @@ -42,7 +42,7 @@ def __init__(self, assumptions : [Exp], q : Query, context : Context, - k, + solutions_q, hints : [Exp] = [], freebies : [Exp] = [], ops : [Op] = [], @@ -57,45 +57,54 @@ def __init__(self, self.hints = hints self.freebies = freebies self.ops = ops - self.k = k + self.solutions_q = solutions_q self.improve_count = improve_count def __str__(self): return "ImproveQueryJob[{}]".format(self.q.name) def run(self): - print("STARTING IMPROVEMENT JOB {}".format(self.q.name)) os.makedirs(log_dir.value, exist_ok=True) with open(os.path.join(log_dir.value, "{}.log".format(self.q.name)), "w", buffering=LINE_BUFFER_MODE) as f: + original_stdout = sys.stdout sys.stdout = f + print("STARTING IMPROVEMENT JOB {}".format(self.q.name)) print(pprint(self.q)) if nice_children.value: os.nice(20) - cost_model = CostModel( - funcs=self.context.funcs(), - assumptions=EAll(self.assumptions), - freebies=self.freebies, - ops=self.ops) + stop_callback = lambda: self.stop_requested try: + cost_model = CostModel( + funcs=self.context.funcs(), + assumptions=EAll(self.assumptions), + freebies=self.freebies, + ops=self.ops, + solver_args={"stop_callback": stop_callback}) + for expr in itertools.chain((self.q.ret,), core.improve( target=self.q.ret, assumptions=EAll(self.assumptions), context=self.context, hints=self.hints, - stop_callback=lambda: self.stop_requested, + stop_callback=stop_callback, cost_model=cost_model, ops=self.ops, improve_count=self.improve_count)): new_rep, new_ret = unpack_representation(expr) - self.k(new_rep, new_ret) + self.solutions_q.put((self.q, new_rep, new_ret)) print("PROVED OPTIMALITY FOR {}".format(self.q.name)) except core.StopException: print("stopping synthesis of {}".format(self.q.name)) return + # Restore the original stdout handle. Python multiprocessing does + # some stream flushing as the process exits, and if we leave stdout + # unchanged then it will refer to a closed file when that happens. + sys.stdout = original_stdout + def improve_implementation( impl : Implementation, timeout : datetime.timedelta = datetime.timedelta(seconds=60), @@ -145,12 +154,13 @@ def reconcile_jobs(): for q in impl.query_specs: if q.name not in job_query_names: states_maintained_by_q = impl.states_maintained_by(q) + print("STARTING IMPROVEMENT JOB {}".format(q.name)) new.append(ImproveQueryJob( impl.abstract_state, list(impl.spec.assumptions) + list(q.assumptions), q, context=impl.context_for_method(q), - k=(lambda q: lambda new_rep, new_ret: solutions_q.put((q, new_rep, new_ret)))(q), + solutions_q=solutions_q.handle_for_subjobs(), hints=[EStateVar(c).with_type(c.type) for c in impl.concretization_functions.values()], freebies=[e for (v, e) in impl.concretization_functions.items() if EVar(v) in states_maintained_by_q], ops=impl.op_specs, @@ -172,7 +182,7 @@ def reconcile_jobs(): # wait for results timeout = Timeout(timeout) done = False - while not done and not timeout.is_timed_out(): + while not done and not timeout.is_timed_out() and not jobs.was_interrupted(): for j in improvement_jobs: if j.done: if j.successful: