Skip to content

Commit

Permalink
smtbmc: Add --track-assumes and --minimize-assumes options
Browse files Browse the repository at this point in the history
The --track-assumes option makes smtbmc keep track of which assumptions
were used by the solver when reaching an unsat case and to output that
set of assumptions. This is particularly useful to debug PREUNSAT
failures.

The --minimize-assumes option can be used in addition to --track-assumes
which will cause smtbmc to spend additional solving effort to produce a
minimal set of assumptions that are sufficient to cause the unsat
result.
  • Loading branch information
jix committed Mar 7, 2024
1 parent e9cd6ca commit a718238
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 24 deletions.
87 changes: 78 additions & 9 deletions backends/smt2/smtbmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
check_witness = False
detect_loops = False
incremental = None
track_assumes = False
minimize_assumes = False
so = SmtOpts()


Expand Down Expand Up @@ -189,6 +191,15 @@ def help():
--incremental
run in incremental mode (experimental)
--track-assumes
track individual assumptions and report a subset of used
assumptions that are sufficient for the reported outcome. This
can be used to debug PREUNSAT failures as well as to find a
smaller set of sufficient assumptions.
--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.
""" + so.helpmsg())

def usage():
Expand All @@ -200,7 +211,8 @@ def usage():
opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts +
["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat",
"dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"])
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental",
"track-assumes", "minimize-assumes"])
except:
usage()

Expand Down Expand Up @@ -289,6 +301,10 @@ def usage():
elif o == "--incremental":
from smtbmc_incremental import Incremental
incremental = Incremental()
elif o == "--track-assumes":
track_assumes = True
elif o == "--minimize-assumes":
minimize_assumes = True
elif so.handle(o, a):
pass
else:
Expand Down Expand Up @@ -447,6 +463,9 @@ def replace_netref(match):

smt = SmtIo(opts=so)

if track_assumes:
smt.smt2_options[':produce-unsat-assumptions'] = 'true'

if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None:
smt.produce_models = False

Expand Down Expand Up @@ -1497,6 +1516,44 @@ def get_active_assert_map(step, active):

return assert_map

assume_enables = {}

def declare_assume_enables():
def recurse(mod, path, key_base=()):
for expr, desc in smt.modinfo[mod].assumes.items():
enable = f"|assume_enable {len(assume_enables)}|"
smt.smt2_assumptions[(expr, key_base)] = enable
smt.write(f"(declare-const {enable} Bool)")
assume_enables[(expr, key_base)] = (enable, path, desc)

for cell, submod in smt.modinfo[mod].cells.items():
recurse(submod, f"{path}.{cell}", (mod, cell, key_base))

recurse(topmod, topmod)

if track_assumes:
declare_assume_enables()

def smt_assert_design_assumes(step):
if not track_assumes:
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
return

if not assume_enables:
return

def expr_for_assume(assume_key, base=None):
expr, key_base = assume_key
expr_prefix = f"(|{expr}| "
expr_suffix = ")"
while key_base:
mod, cell, key_base = key_base
expr_prefix += f"(|{mod}_h {cell}| "
expr_suffix += ")"
return f"{expr_prefix} s{step}{expr_suffix}"

for assume_key, (enable, path, desc) in assume_enables.items():
smt_assert_consequent(f"(=> {enable} {expr_for_assume(assume_key)})")

states = list()
asserts_antecedent_cache = [list()]
Expand Down Expand Up @@ -1651,6 +1708,13 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_forall_assert()
return smt.check_sat(expected=expected)

def report_tracked_assumptions(msg):
if track_assumes:
print_msg(msg)
for key in smt.get_unsat_assumptions(minimize=minimize_assumes):
enable, path, descr = assume_enables[key]
print_msg(f" In {path}: {descr}")


if incremental:
incremental.mainloop()
Expand All @@ -1664,7 +1728,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
break

smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))
Expand Down Expand Up @@ -1707,6 +1771,7 @@ def smt_check_sat(expected=["sat", "unsat"]):

else:
print_msg("Temporal induction successful.")
report_tracked_assumptions("Used assumptions:")
retstatus = "PASSED"
break

Expand All @@ -1732,7 +1797,7 @@ def smt_check_sat(expected=["sat", "unsat"]):

while step < num_steps:
smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))

Expand All @@ -1753,6 +1818,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_assert("(distinct (covers_%d s%d) #b%s)" % (coveridx, step, "0" * len(cover_desc)))

if smt_check_sat() == "unsat":
report_tracked_assumptions("Used assumptions:")
smt_pop()
break

Expand All @@ -1761,13 +1827,14 @@ def smt_check_sat(expected=["sat", "unsat"]):
print_msg("Appending additional step %d." % i)
smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i))
smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat":
print("%s Cannot appended steps without violating assumptions!" % smt.timestamp())
report_tracked_assumptions("Conflicting assumptions:")
found_failed_assert = True
retstatus = "FAILED"
break
Expand Down Expand Up @@ -1823,7 +1890,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
retstatus = "PASSED"
while step < num_steps:
smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))

Expand Down Expand Up @@ -1853,7 +1920,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
if step+i < num_steps:
smt_state(step+i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step+i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step+i))
smt_assert_design_assumes(step + i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step+i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, step+i-1, step+i))
smt_assert_consequent(get_constr_expr(constr_assumes, step+i))
Expand All @@ -1867,7 +1934,8 @@ def smt_check_sat(expected=["sat", "unsat"]):
print_msg("Checking assumptions in steps %d to %d.." % (step, last_check_step))

if smt_check_sat() == "unsat":
print("%s Assumptions are unsatisfiable!" % smt.timestamp())
print_msg("Assumptions are unsatisfiable!")
report_tracked_assumptions("Conficting assumptions:")
retstatus = "PREUNSAT"
break

Expand Down Expand Up @@ -1920,13 +1988,14 @@ def smt_check_sat(expected=["sat", "unsat"]):
print_msg("Appending additional step %d." % i)
smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i))
smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat":
print("%s Cannot append steps without violating assumptions!" % smt.timestamp())
print_msg("Cannot append steps without violating assumptions!")
report_tracked_assumptions("Conflicting assumptions:")
retstatus = "FAILED"
break
print_anyconsts(step)
Expand Down
95 changes: 83 additions & 12 deletions backends/smt2/smtbmc_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ class InteractiveError(Exception):
pass


def mkkey(data):
if isinstance(data, list):
return tuple(map(mkkey, data))
elif isinstance(data, dict):
raise InteractiveError(f"JSON objects found in assumption key: {key!r}")
return data


class Incremental:
def __init__(self):
self.traceidx = 0
Expand Down Expand Up @@ -73,17 +81,17 @@ def expr_arg_len(self, expr, min_len, max_len=-1):

if min_len is not None and arg_len < min_len:
if min_len == max_len:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression must have "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
else:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression must have at least "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
if max_len is not None and arg_len > max_len:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression can have at most "
f"{min_len} argument{'s' if max_len != 1 else ''}"
)
Expand All @@ -96,14 +104,31 @@ def expr_step(self, expr, smt_out):
smt_out.append(f"s{step}")
return "module", smtbmc.topmod

def expr_mod_constraint(self, expr, smt_out):
self.expr_arg_len(expr, 1)
def expr_cell(self, expr, smt_out):
self.expr_arg_len(expr, 2)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None])
arg_sort = self.expr(expr[2], smt_out, required_sort=["module", None])
smt_out.append(")")
module = arg_sort[1]
cell = expr[1]
submod = smtbmc.smt.modinfo[module].cells.get(cell)
if submod is None:
raise InteractiveError(f"module {module!r} has no cell {cell!r}")
smt_out[position] = f"(|{module}_h {cell}| "
return ("module", submod)

def expr_mod_constraint(self, expr, smt_out):
suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| "
self.expr_arg_len(expr, 1, 2 if suffix in ["_a", "_u", "_c"] else 1)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[-1], smt_out, required_sort=["module", None])
module = arg_sort[1]
if len(expr) == 3:
smt_out[position] = f"(|{module}{suffix} {expr[1]}| "
else:
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")")
return "Bool"

Expand Down Expand Up @@ -223,20 +248,19 @@ def expr_label(self, expr, smt_out):
subexpr = expr[2]

if not isinstance(label, str):
raise InteractiveError(f"expression label has to be a string")
raise InteractiveError("expression label has to be a string")

smt_out.append("(! ")
smt_out.appedd(label)
smt_out.append(" ")

sort = self.expr(subexpr, smt_out)

smt_out.append(" :named ")
smt_out.append(label)
smt_out.append(")")

return sort

expr_handlers = {
"step": expr_step,
"cell": expr_cell,
"mod_h": expr_mod_constraint,
"mod_is": expr_mod_constraint,
"mod_i": expr_mod_constraint,
Expand Down Expand Up @@ -302,6 +326,30 @@ def cmd_assert(self, cmd):

assert_fn(self.expr_smt(cmd.get("expr"), "Bool"))

def cmd_assert_design_assumes(self, cmd):
step = self.arg_step(cmd)
smtbmc.smt_assert_design_assumes(step)

def cmd_get_design_assume(self, cmd):
key = mkkey(cmd.get("key"))
return smtbmc.assume_enables.get(key)

def cmd_update_assumptions(self, cmd):
expr = cmd.get("expr")
key = cmd.get("key")


key = mkkey(key)

result = smtbmc.smt.smt2_assumptions.pop(key, None)
if expr is not None:
expr = self.expr_smt(expr, "Bool")
smtbmc.smt.smt2_assumptions[key] = expr
return result

def cmd_get_unsat_assumptions(self, cmd):
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize')))

def cmd_push(self, cmd):
smtbmc.smt_push()

Expand All @@ -313,11 +361,14 @@ def cmd_check(self, cmd):

def cmd_smtlib(self, cmd):
command = cmd.get("command")
response = cmd.get("response", False)
if not isinstance(command, str):
raise InteractiveError(
f"raw SMT-LIB command must be a string, found {json.dumps(command)}"
)
smtbmc.smt.write(command)
if response:
return smtbmc.smt.read()

def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False))
Expand Down Expand Up @@ -369,6 +420,21 @@ def cmd_read_yw_trace(self, cmd):

return dict(last_step=last_step)

def cmd_modinfo(self, cmd):
fields = cmd.get("fields", [])

mod = cmd.get("mod")
if mod is None:
mod = smtbmc.topmod
modinfo = smtbmc.smt.modinfo.get(mod)
if modinfo is None:
return None

result = dict(name=mod)
for field in fields:
result[field] = getattr(modinfo, field, None)
return result

def cmd_ping(self, cmd):
return cmd

Expand All @@ -377,13 +443,18 @@ def cmd_ping(self, cmd):
"assert": cmd_assert,
"assert_antecedent": cmd_assert,
"assert_consequent": cmd_assert,
"assert_design_assumes": cmd_assert_design_assumes,
"get_design_assume": cmd_get_design_assume,
"update_assumptions": cmd_update_assumptions,
"get_unsat_assumptions": cmd_get_unsat_assumptions,
"push": cmd_push,
"pop": cmd_pop,
"check": cmd_check,
"smtlib": cmd_smtlib,
"design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace,
"modinfo": cmd_modinfo,
"ping": cmd_ping,
}

Expand Down
Loading

0 comments on commit a718238

Please sign in to comment.