Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graceful interrupts #119

Merged
merged 13 commits into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cozy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,13 @@ def compare_with_lt(x, y):
else:
return 0

def never_stop():
"""Takes no arguments, always returns False.

This is the default "stop callback" value for procedures that require one.
"""
return False

class StopException(Exception):
"""
Used to indicate that a process should stop operation.
Expand Down
5 changes: 3 additions & 2 deletions cozy/cost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def __init__(self,
examples = (),
funcs = (),
freebies : [Exp] = [],
ops : [Op] = []):
ops : [Op] = [],
solver_args : dict = {}):
"""
assumptions : assumed to be true when comparing expressions
examples : initial examples (the right set of examples can speed up
Expand All @@ -129,7 +130,7 @@ def __init__(self,
ops : mutators which are used to determine how expensive it is
to maintain a state variable
"""
self.solver = ModelCachingSolver(vars=(), funcs=funcs, examples=examples, assumptions=assumptions)
self.solver = ModelCachingSolver(vars=(), funcs=funcs, examples=examples, assumptions=assumptions, **solver_args)
self.assumptions = assumptions
# self.examples = list(examples)
self.funcs = OrderedDict(funcs)
Expand Down
127 changes: 113 additions & 14 deletions cozy/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,58 @@
`multiprocessing` module.
"""

from multiprocessing import Process, Array, Queue
import os
import multiprocessing
from queue import Queue as PlainQueue, Empty, Full
import threading
import signal

from cozy.common import partition
from cozy.opts import Option

do_profiling = Option("profile", bool, False, description="Profile Cozy itself")

_interrupted = False
def _set_interrupt_flag(signal_number, stack_frame):
global _interrupted
# print("GOT INTERRUPTED")
# import traceback
# traceback.print_stack(stack_frame)
_interrupted = True

def install_graceful_sigint_handler():
"""Install a graceful handler for SIGINT.

The handler sets a flag to true when SIGINT happens and does nothing else.
Use `was_interrupted()` to check the flag.

Note: the installed handler is inherited by child processes. The
Job.stop_requested property checks the SIGINT flag in addition to its own
private flag, giving an additional cross-process way to stop a running job
gracefully.
"""
signal.signal(signal.SIGINT, _set_interrupt_flag)

def was_interrupted():
"""Determine if SIGINT was sent to this process.

Precisely, this procedure returns true if a SIGINT signal was ever
delivered to this process after the first time install_graceful_sigint_handler()
was called.
"""
return _interrupted

# This module uses the "spawn" method for multiprocessing interaction. This is
# a little bit of forward-compatibility. The "spawn" context is the default on
# Windows (always) and MacOS (in Python 3.8+). It was introduced in Python
# 3.4. The "spawn" context behaves a bit differently from the "fork" context
# used by default on Linux. In particular:
# - It is allegely a bit slower (but I haven't seen much difference).
# - More objects need to be pickled.
# - It is less likely to cause crashes due to MacOS's bad fork()
# implementation (https://bugs.python.org/issue33725).
multiprocessing_context = multiprocessing.get_context("spawn")

class Job(object):
"""An interruptable job.

Expand Down Expand Up @@ -65,12 +108,36 @@ def run(self):
way to retrieve the thrown exception, if any.
"""

# -------------------------------------------------------------------------
# Names for cell indexes in a Job's flags array. For efficiency, the flags
# are represented as a compact boolean array in shared memory. These
# constants give human-readable names to the elements of that array. All
# flags are initially false.

# The parent process sets this flag to true when it wants to request that
# a Job should stop gracefully.
_STOP_REQUESTED_FLAG = 0

# The Job sets this flag to true when it finishes, regardless of the
# outcome.
_DONE_FLAG = 1

# The Job sets this flag to true when it finishes without an exception.
_DONE_NORMALLY_FLAG = 2

# The Job sets this flag to true after it installs a handler for SIGINT.
# The parent should only send SIGINT if this flag is true. If this flag is
# false, then the handler MAY OR MAY NOT have been installed yet.
_SIGINT_HANDLER_INSTALLED_FLAG = 3

# The total number of flags.
_FLAG_COUNT = 4

# -------------------------------------------------------------------------

def __init__(self):
self._thread = Process(target=self._run, daemon=True)
self._flags = Array("b", [False] * 3)
# flags[0] - stop_requested?
# flags[1] - done?
# flags[2] - true iff completed with no exception
self._thread = multiprocessing_context.Process(target=self._run, daemon=True)
self._flags = multiprocessing_context.Array("b", [False] * Job._FLAG_COUNT)

def start(self):
"""Start the job by invoking its .run() method asynchronously."""
Expand All @@ -82,6 +149,8 @@ def run(self):

def _run(self):
"""Private helper that wraps .run() and sets various exit flags."""
install_graceful_sigint_handler()
self._flags[Job._SIGINT_HANDLER_INSTALLED_FLAG] = True
try:
if do_profiling.value:
import cProfile
Expand All @@ -91,39 +160,60 @@ def _run(self):
cProfile.runctx("self.run()", globals(), locals(), filename=filename)
else:
self.run()
self._flags[2] = True
self._flags[Job._DONE_NORMALLY_FLAG] = True
except Exception as e:
import traceback
traceback.print_exc()
self._flags[1] = True
finally:
self._flags[Job._DONE_FLAG] = True

@property
def stop_requested(self):
"""True if the job has been asked to stop.

The implementation of .run() should check this periodically and return
when it becomes True."""
return self._flags[0]
return was_interrupted() or self._flags[Job._STOP_REQUESTED_FLAG]

@property
def done(self):
"""True if the job has stopped."""
return self._flags[1] or (self._thread.exitcode is not None)
return self._flags[Job._DONE_FLAG] or (self._thread.exitcode is not None)

@property
def successful(self):
"""True if the job has stopped without throwing an uncaught exception."""
return self._flags[2]
return self._flags[Job._DONE_NORMALLY_FLAG]

def request_stop(self):
"""Request a graceful stop.

Causes this Job's .stop_requested property to become True.

Clients can call .join() to wait for the job to wrap up.

This method delivers a SIGINT to the job process, interrupting any
running Z3 solver call.
"""
print("requesting stop for {}".format(self))
self._flags[0] = True
self._flags[Job._STOP_REQUESTED_FLAG] = True

# Ah, there's a bit of danger here (time-of-check to time-of-use bug):
# (1) is_alive() returns true
# (2) the job process exits
# (3) its PID is reassigned to a different process
# (4) oops, we deliver SIGINT to the wrong process!
# Sadly, Python doesn't give us a way to access the actual underlying
# process handle, so (I think) this is the best we can do. In fact,
# the actual Python source code has the same bug:
# https://github.com/python/cpython/blob/3.8/Lib/multiprocessing/popen_fork.py#L50
if self._flags[Job._SIGINT_HANDLER_INSTALLED_FLAG] and self._thread.is_alive():
try:
os.kill(self._thread.pid, signal.SIGINT)
except ProcessLookupError:
# This can happen if the job finished in the background between
# the check and the kill call.
pass

def join(self, timeout=None):
"""Wait for the job to finish and clean up its resources.
Expand Down Expand Up @@ -220,15 +310,15 @@ class SafeQueue(object):
- This queue needs to be closed.
Proper usage example:
with SafeQueue() as q:
# spawn processes to insert items into q
# spawn processes to insert items into q.handle_for_subjobs()
# get items from q
# join spawned processes

[1]: https://docs.python.org/3/library/multiprocessing.html#pipes-and-queues
"""
def __init__(self, queue_to_wrap=None):
if queue_to_wrap is None:
queue_to_wrap = Queue()
queue_to_wrap = multiprocessing_context.Queue()
self.q = queue_to_wrap
self.sideq = PlainQueue()
self.stop_requested = False
Expand Down Expand Up @@ -273,3 +363,12 @@ def drain(self, block=False, timeout=None):
res.append(self.get(block=False))
except Empty:
return res
def handle_for_subjobs(self):
"""Obtain a handle that can be passed to a Job.

Due to the limitations of Python's multiprocessing module, a SafeQueue
cannot be passed as an argument to a Job. This method returns a Queue
object that can. The parent is still responsible for holding onto this
object and cleaning it up.
"""
return self.q
10 changes: 7 additions & 3 deletions cozy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import datetime
import pickle

from multiprocessing import Value

from cozy import parse
from cozy import codegen
from cozy import common
Expand All @@ -22,6 +20,7 @@
from cozy import synthesis
from cozy.structures import rewriting
from cozy import opts
from cozy import jobs

save_failed_codegen_inputs = opts.Option("save-failed-codegen-inputs", str, "/tmp/failed_codegen.py", metavar="PATH")
checkpoint_prefix = opts.Option("checkpoint-prefix", str, "")
Expand Down Expand Up @@ -60,7 +59,12 @@ def run():
args = parser.parse_args()
opts.read(args)

improve_count = Value('i', 0)
# Install a handler for SIGINT, the signal that is delivered when you
# Ctrl+C a process. This allows Cozy to exit cleanly when it is
# interrupted. If you need to stop Cozy forcibly, use SIGTERM or SIGKILL.
jobs.install_graceful_sigint_handler()

improve_count = jobs.multiprocessing_context.Value('i', 0)

if args.resume:
with common.open_maybe_stdin(args.file or "-", mode="rb") as f:
Expand Down
39 changes: 31 additions & 8 deletions cozy/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from datetime import datetime, timedelta
from functools import lru_cache
import threading
from typing import Callable

import z3

from cozy.target_syntax import *
from cozy.syntax_tools import BottomUpExplorer, pprint, free_vars, free_funcs, cse, all_exps, purify
from cozy.typecheck import is_collection, is_numeric
from cozy.common import declare_case, fresh_name, Visitor, FrozenDict, typechecked, extend, OrderedSet, make_random_access
from cozy.common import declare_case, fresh_name, Visitor, FrozenDict, typechecked, extend, OrderedSet, make_random_access, StopException, never_stop
from cozy import evaluation
from cozy.opts import Option
from cozy.structures import extension_handler
Expand Down Expand Up @@ -224,13 +225,22 @@ def grid(rows, cols):
return [[None for c in range(cols)] for r in range(rows)]

class ToZ3(Visitor):
def __init__(self, z3ctx, z3solver):
def __init__(self, z3ctx, z3solver, stop_callback : Callable[[], bool]):
"""Create a "ToZ3" object that can convert ASTs to Z3 queries.

:param z3ctx: A Z3 context object to use.
:param z3solver: A Z3 solver object to use.
:param stop_callback: A zero-argument function that will be checked
     periodically. The solver will raise a
StopException when the callback returns True.
"""
self.ctx = z3ctx
self.solver = z3solver
self.int_zero = z3.IntVal(0, self.ctx)
self.int_one = z3.IntVal(1, self.ctx)
self.true = z3.BoolVal(True, self.ctx)
self.false = z3.BoolVal(False, self.ctx)
self.stop_callback = stop_callback
assert to_bool(self.true) is True
assert to_bool(self.false) is False
assert to_bool(self.int_zero) is None
Expand Down Expand Up @@ -285,6 +295,9 @@ def gt(self, t, e1, e2, env, deep=False):
else:
raise NotImplementedError(t)
def eq(self, t, e1, e2, deep=False):
if self.stop_callback():
raise StopException("interrupted while encoding equality constraint")

if e1 is e2:
return self.true

Expand Down Expand Up @@ -831,6 +844,8 @@ def visit_AstRef(self, e, env):
def visit_bool(self, e, env):
return z3.BoolVal(e, self.ctx)
def visit(self, e, *args):
if self.stop_callback():
raise StopException("interrupted while encoding {}".format(pprint(e)))
try:
return super().visit(e, *args)
except KeyboardInterrupt:
Expand Down Expand Up @@ -1005,7 +1020,8 @@ def __init__(self,
model_callback = None,
logic : str = None,
timeout : float = None,
do_cse : bool = True):
do_cse : bool = True,
stop_callback : Callable[[], bool] = never_stop):

if collection_depth is None:
collection_depth = collection_depth_opt.value
Expand All @@ -1016,6 +1032,7 @@ def __init__(self,
self.collection_depth = collection_depth
self.validate_model = validate_model
self.model_callback = model_callback
self.stop_callback = stop_callback
self._env = OrderedDict()
self.stk = []
self.do_cse = do_cse
Expand All @@ -1026,7 +1043,7 @@ def __init__(self,
if timeout is not None:
solver.set("timeout", int(timeout * 1000))
solver.set("core.validate", validate_model)
visitor = ToZ3(ctx, solver)
visitor = ToZ3(ctx, solver, stop_callback=stop_callback)

self.visitor = visitor
self.z3_solver = solver
Expand Down Expand Up @@ -1155,6 +1172,11 @@ def reconstruct(model, value, type):
with task("invoke Z3"):
res = solver.check()
_tock(e, "solve")

if self.stop_callback():
solver.pop()
raise StopException("stop requested during Z3 solver call")

if res == z3.unsat:
solver.pop()
return None
Expand Down Expand Up @@ -1267,13 +1289,13 @@ class ModelCachingSolver(object):
calls can often be avoided using a counterexample found on a previous call.
"""

def __init__(self, vars : [EVar], funcs : { str : TFunc }, examples : [dict] = (), assumptions : Exp = ETRUE):
def __init__(self, vars : [EVar], funcs : { str : TFunc }, examples : [dict] = (), assumptions : Exp = ETRUE, **kwargs):
self.vars = list(vars)
self.funcs = OrderedDict(funcs)
self.calls = 0
self.hits = 0
self.examples = list(examples)
self.solver = IncrementalSolver(vars=vars, funcs=funcs)
self.solver = IncrementalSolver(vars=vars, funcs=funcs, **kwargs)
self.solver.add_assumption(assumptions)

def satisfy(self, e):
Expand All @@ -1295,8 +1317,9 @@ def valid(self, e):
return not self.satisfiable(ENot(e))

@lru_cache()
def solver_for_context(context : Context, assumptions : Exp = ETRUE):
def solver_for_context(context : Context, assumptions : Exp = ETRUE, **kwargs):
return ModelCachingSolver(
vars = [v for v, _ in context.vars()],
funcs = context.funcs(),
assumptions = assumptions)
assumptions = assumptions,
**kwargs)
Loading