Skip to content

Commit

Permalink
Merge pull request #4268 from jix/smtbmc-track-assumes
Browse files Browse the repository at this point in the history
smtbmc: Add --track-assumes and --minimize-assumes options
  • Loading branch information
nakengelhardt authored Mar 11, 2024
2 parents e4f11eb + 42122e2 commit 0909c2e
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 24 deletions.
87 changes: 78 additions & 9 deletions backends/smt2/smtbmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
check_witness = False
detect_loops = False
incremental = None
track_assumes = False
minimize_assumes = False
so = SmtOpts()


Expand Down Expand Up @@ -189,6 +191,15 @@ def help():
--incremental
run in incremental mode (experimental)
--track-assumes
track individual assumptions and report a subset of used
assumptions that are sufficient for the reported outcome. This
can be used to debug PREUNSAT failures as well as to find a
smaller set of sufficient assumptions.
--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.
""" + so.helpmsg())

def usage():
Expand All @@ -200,7 +211,8 @@ def usage():
opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts +
["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat",
"dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"])
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental",
"track-assumes", "minimize-assumes"])
except:
usage()

Expand Down Expand Up @@ -289,6 +301,10 @@ def usage():
elif o == "--incremental":
from smtbmc_incremental import Incremental
incremental = Incremental()
elif o == "--track-assumes":
track_assumes = True
elif o == "--minimize-assumes":
minimize_assumes = True
elif so.handle(o, a):
pass
else:
Expand Down Expand Up @@ -447,6 +463,9 @@ def replace_netref(match):

smt = SmtIo(opts=so)

if track_assumes:
smt.smt2_options[':produce-unsat-assumptions'] = 'true'

if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None:
smt.produce_models = False

Expand Down Expand Up @@ -1497,6 +1516,44 @@ def get_active_assert_map(step, active):

return assert_map

assume_enables = {}

def declare_assume_enables():
def recurse(mod, path, key_base=()):
for expr, desc in smt.modinfo[mod].assumes.items():
enable = f"|assume_enable {len(assume_enables)}|"
smt.smt2_assumptions[(expr, key_base)] = enable
smt.write(f"(declare-const {enable} Bool)")
assume_enables[(expr, key_base)] = (enable, path, desc)

for cell, submod in smt.modinfo[mod].cells.items():
recurse(submod, f"{path}.{cell}", (mod, cell, key_base))

recurse(topmod, topmod)

if track_assumes:
declare_assume_enables()

def smt_assert_design_assumes(step):
if not track_assumes:
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
return

if not assume_enables:
return

def expr_for_assume(assume_key, base=None):
expr, key_base = assume_key
expr_prefix = f"(|{expr}| "
expr_suffix = ")"
while key_base:
mod, cell, key_base = key_base
expr_prefix += f"(|{mod}_h {cell}| "
expr_suffix += ")"
return f"{expr_prefix} s{step}{expr_suffix}"

for assume_key, (enable, path, desc) in assume_enables.items():
smt_assert_consequent(f"(=> {enable} {expr_for_assume(assume_key)})")

states = list()
asserts_antecedent_cache = [list()]
Expand Down Expand Up @@ -1651,6 +1708,13 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_forall_assert()
return smt.check_sat(expected=expected)

def report_tracked_assumptions(msg):
if track_assumes:
print_msg(msg)
for key in smt.get_unsat_assumptions(minimize=minimize_assumes):
enable, path, descr = assume_enables[key]
print_msg(f" In {path}: {descr}")


if incremental:
incremental.mainloop()
Expand All @@ -1664,7 +1728,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
break

smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))
Expand Down Expand Up @@ -1707,6 +1771,7 @@ def smt_check_sat(expected=["sat", "unsat"]):

else:
print_msg("Temporal induction successful.")
report_tracked_assumptions("Used assumptions:")
retstatus = "PASSED"
break

Expand All @@ -1732,7 +1797,7 @@ def smt_check_sat(expected=["sat", "unsat"]):

while step < num_steps:
smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))

Expand All @@ -1753,6 +1818,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_assert("(distinct (covers_%d s%d) #b%s)" % (coveridx, step, "0" * len(cover_desc)))

if smt_check_sat() == "unsat":
report_tracked_assumptions("Used assumptions:")
smt_pop()
break

Expand All @@ -1761,13 +1827,14 @@ def smt_check_sat(expected=["sat", "unsat"]):
print_msg("Appending additional step %d." % i)
smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i))
smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat":
print("%s Cannot appended steps without violating assumptions!" % smt.timestamp())
report_tracked_assumptions("Conflicting assumptions:")
found_failed_assert = True
retstatus = "FAILED"
break
Expand Down Expand Up @@ -1823,7 +1890,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
retstatus = "PASSED"
while step < num_steps:
smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))

Expand Down Expand Up @@ -1853,7 +1920,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
if step+i < num_steps:
smt_state(step+i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step+i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step+i))
smt_assert_design_assumes(step + i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step+i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, step+i-1, step+i))
smt_assert_consequent(get_constr_expr(constr_assumes, step+i))
Expand All @@ -1867,7 +1934,8 @@ def smt_check_sat(expected=["sat", "unsat"]):
print_msg("Checking assumptions in steps %d to %d.." % (step, last_check_step))

if smt_check_sat() == "unsat":
print("%s Assumptions are unsatisfiable!" % smt.timestamp())
print_msg("Assumptions are unsatisfiable!")
report_tracked_assumptions("Conficting assumptions:")
retstatus = "PREUNSAT"
break

Expand Down Expand Up @@ -1920,13 +1988,14 @@ def smt_check_sat(expected=["sat", "unsat"]):
print_msg("Appending additional step %d." % i)
smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i))
smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat":
print("%s Cannot append steps without violating assumptions!" % smt.timestamp())
print_msg("Cannot append steps without violating assumptions!")
report_tracked_assumptions("Conflicting assumptions:")
retstatus = "FAILED"
break
print_anyconsts(step)
Expand Down
95 changes: 83 additions & 12 deletions backends/smt2/smtbmc_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ class InteractiveError(Exception):
pass


def mkkey(data):
if isinstance(data, list):
return tuple(map(mkkey, data))
elif isinstance(data, dict):
raise InteractiveError(f"JSON objects found in assumption key: {data!r}")
return data


class Incremental:
def __init__(self):
self.traceidx = 0
Expand Down Expand Up @@ -73,17 +81,17 @@ def expr_arg_len(self, expr, min_len, max_len=-1):

if min_len is not None and arg_len < min_len:
if min_len == max_len:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression must have "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
else:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression must have at least "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
if max_len is not None and arg_len > max_len:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression can have at most "
f"{min_len} argument{'s' if max_len != 1 else ''}"
)
Expand All @@ -96,14 +104,31 @@ def expr_step(self, expr, smt_out):
smt_out.append(f"s{step}")
return "module", smtbmc.topmod

def expr_mod_constraint(self, expr, smt_out):
self.expr_arg_len(expr, 1)
def expr_cell(self, expr, smt_out):
self.expr_arg_len(expr, 2)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None])
arg_sort = self.expr(expr[2], smt_out, required_sort=["module", None])
smt_out.append(")")
module = arg_sort[1]
cell = expr[1]
submod = smtbmc.smt.modinfo[module].cells.get(cell)
if submod is None:
raise InteractiveError(f"module {module!r} has no cell {cell!r}")
smt_out[position] = f"(|{module}_h {cell}| "
return ("module", submod)

def expr_mod_constraint(self, expr, smt_out):
suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| "
self.expr_arg_len(expr, 1, 2 if suffix in ["_a", "_u", "_c"] else 1)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[-1], smt_out, required_sort=["module", None])
module = arg_sort[1]
if len(expr) == 3:
smt_out[position] = f"(|{module}{suffix} {expr[1]}| "
else:
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")")
return "Bool"

Expand Down Expand Up @@ -223,20 +248,19 @@ def expr_label(self, expr, smt_out):
subexpr = expr[2]

if not isinstance(label, str):
raise InteractiveError(f"expression label has to be a string")
raise InteractiveError("expression label has to be a string")

smt_out.append("(! ")
smt_out.appedd(label)
smt_out.append(" ")

sort = self.expr(subexpr, smt_out)

smt_out.append(" :named ")
smt_out.append(label)
smt_out.append(")")

return sort

expr_handlers = {
"step": expr_step,
"cell": expr_cell,
"mod_h": expr_mod_constraint,
"mod_is": expr_mod_constraint,
"mod_i": expr_mod_constraint,
Expand Down Expand Up @@ -302,6 +326,30 @@ def cmd_assert(self, cmd):

assert_fn(self.expr_smt(cmd.get("expr"), "Bool"))

def cmd_assert_design_assumes(self, cmd):
step = self.arg_step(cmd)
smtbmc.smt_assert_design_assumes(step)

def cmd_get_design_assume(self, cmd):
key = mkkey(cmd.get("key"))
return smtbmc.assume_enables.get(key)

def cmd_update_assumptions(self, cmd):
expr = cmd.get("expr")
key = cmd.get("key")


key = mkkey(key)

result = smtbmc.smt.smt2_assumptions.pop(key, None)
if expr is not None:
expr = self.expr_smt(expr, "Bool")
smtbmc.smt.smt2_assumptions[key] = expr
return result

def cmd_get_unsat_assumptions(self, cmd):
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize')))

def cmd_push(self, cmd):
smtbmc.smt_push()

Expand All @@ -313,11 +361,14 @@ def cmd_check(self, cmd):

def cmd_smtlib(self, cmd):
command = cmd.get("command")
response = cmd.get("response", False)
if not isinstance(command, str):
raise InteractiveError(
f"raw SMT-LIB command must be a string, found {json.dumps(command)}"
)
smtbmc.smt.write(command)
if response:
return smtbmc.smt.read()

def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False))
Expand Down Expand Up @@ -369,6 +420,21 @@ def cmd_read_yw_trace(self, cmd):

return dict(last_step=last_step)

def cmd_modinfo(self, cmd):
fields = cmd.get("fields", [])

mod = cmd.get("mod")
if mod is None:
mod = smtbmc.topmod
modinfo = smtbmc.smt.modinfo.get(mod)
if modinfo is None:
return None

result = dict(name=mod)
for field in fields:
result[field] = getattr(modinfo, field, None)
return result

def cmd_ping(self, cmd):
return cmd

Expand All @@ -377,13 +443,18 @@ def cmd_ping(self, cmd):
"assert": cmd_assert,
"assert_antecedent": cmd_assert,
"assert_consequent": cmd_assert,
"assert_design_assumes": cmd_assert_design_assumes,
"get_design_assume": cmd_get_design_assume,
"update_assumptions": cmd_update_assumptions,
"get_unsat_assumptions": cmd_get_unsat_assumptions,
"push": cmd_push,
"pop": cmd_pop,
"check": cmd_check,
"smtlib": cmd_smtlib,
"design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace,
"modinfo": cmd_modinfo,
"ping": cmd_ping,
}

Expand Down
Loading

0 comments on commit 0909c2e

Please sign in to comment.