Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

smtbmc: Add --track-assumes and --minimize-assumes options #4268

Merged
merged 1 commit into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 78 additions & 9 deletions backends/smt2/smtbmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@
check_witness = False
detect_loops = False
incremental = None
track_assumes = False
minimize_assumes = False
so = SmtOpts()


Expand Down Expand Up @@ -189,6 +191,15 @@ def help():
--incremental
run in incremental mode (experimental)

--track-assumes
track individual assumptions and report a subset of used
assumptions that are sufficient for the reported outcome. This
can be used to debug PREUNSAT failures as well as to find a
smaller set of sufficient assumptions.

--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.

""" + so.helpmsg())

def usage():
Expand All @@ -200,7 +211,8 @@ def usage():
opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts +
["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat",
"dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=",
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"])
"smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental",
"track-assumes", "minimize-assumes"])
except:
usage()

Expand Down Expand Up @@ -289,6 +301,10 @@ def usage():
elif o == "--incremental":
from smtbmc_incremental import Incremental
incremental = Incremental()
elif o == "--track-assumes":
track_assumes = True
elif o == "--minimize-assumes":
minimize_assumes = True
elif so.handle(o, a):
pass
else:
Expand Down Expand Up @@ -447,6 +463,9 @@ def replace_netref(match):

smt = SmtIo(opts=so)

if track_assumes:
smt.smt2_options[':produce-unsat-assumptions'] = 'true'

if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None:
smt.produce_models = False

Expand Down Expand Up @@ -1497,6 +1516,44 @@ def get_active_assert_map(step, active):

return assert_map

assume_enables = {}

def declare_assume_enables():
def recurse(mod, path, key_base=()):
for expr, desc in smt.modinfo[mod].assumes.items():
enable = f"|assume_enable {len(assume_enables)}|"
smt.smt2_assumptions[(expr, key_base)] = enable
smt.write(f"(declare-const {enable} Bool)")
assume_enables[(expr, key_base)] = (enable, path, desc)

for cell, submod in smt.modinfo[mod].cells.items():
recurse(submod, f"{path}.{cell}", (mod, cell, key_base))

recurse(topmod, topmod)

if track_assumes:
declare_assume_enables()

def smt_assert_design_assumes(step):
if not track_assumes:
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
return

if not assume_enables:
return

def expr_for_assume(assume_key, base=None):
expr, key_base = assume_key
expr_prefix = f"(|{expr}| "
expr_suffix = ")"
while key_base:
mod, cell, key_base = key_base
expr_prefix += f"(|{mod}_h {cell}| "
expr_suffix += ")"
return f"{expr_prefix} s{step}{expr_suffix}"

for assume_key, (enable, path, desc) in assume_enables.items():
smt_assert_consequent(f"(=> {enable} {expr_for_assume(assume_key)})")

states = list()
asserts_antecedent_cache = [list()]
Expand Down Expand Up @@ -1651,6 +1708,13 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_forall_assert()
return smt.check_sat(expected=expected)

def report_tracked_assumptions(msg):
if track_assumes:
print_msg(msg)
for key in smt.get_unsat_assumptions(minimize=minimize_assumes):
enable, path, descr = assume_enables[key]
print_msg(f" In {path}: {descr}")


if incremental:
incremental.mainloop()
Expand All @@ -1664,7 +1728,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
break

smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))
Expand Down Expand Up @@ -1707,6 +1771,7 @@ def smt_check_sat(expected=["sat", "unsat"]):

else:
print_msg("Temporal induction successful.")
report_tracked_assumptions("Used assumptions:")
retstatus = "PASSED"
break

Expand All @@ -1732,7 +1797,7 @@ def smt_check_sat(expected=["sat", "unsat"]):

while step < num_steps:
smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))

Expand All @@ -1753,6 +1818,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
smt_assert("(distinct (covers_%d s%d) #b%s)" % (coveridx, step, "0" * len(cover_desc)))

if smt_check_sat() == "unsat":
report_tracked_assumptions("Used assumptions:")
smt_pop()
break

Expand All @@ -1761,13 +1827,14 @@ def smt_check_sat(expected=["sat", "unsat"]):
print_msg("Appending additional step %d." % i)
smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i))
smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat":
print("%s Cannot appended steps without violating assumptions!" % smt.timestamp())
report_tracked_assumptions("Conflicting assumptions:")
found_failed_assert = True
retstatus = "FAILED"
break
Expand Down Expand Up @@ -1823,7 +1890,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
retstatus = "PASSED"
while step < num_steps:
smt_state(step)
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step))
smt_assert_design_assumes(step)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step))
smt_assert_consequent(get_constr_expr(constr_assumes, step))

Expand Down Expand Up @@ -1853,7 +1920,7 @@ def smt_check_sat(expected=["sat", "unsat"]):
if step+i < num_steps:
smt_state(step+i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step+i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, step+i))
smt_assert_design_assumes(step + i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step+i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, step+i-1, step+i))
smt_assert_consequent(get_constr_expr(constr_assumes, step+i))
Expand All @@ -1867,7 +1934,8 @@ def smt_check_sat(expected=["sat", "unsat"]):
print_msg("Checking assumptions in steps %d to %d.." % (step, last_check_step))

if smt_check_sat() == "unsat":
print("%s Assumptions are unsatisfiable!" % smt.timestamp())
print_msg("Assumptions are unsatisfiable!")
report_tracked_assumptions("Conficting assumptions:")
retstatus = "PREUNSAT"
break

Expand Down Expand Up @@ -1920,13 +1988,14 @@ def smt_check_sat(expected=["sat", "unsat"]):
print_msg("Appending additional step %d." % i)
smt_state(i)
smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i))
smt_assert_consequent("(|%s_u| s%d)" % (topmod, i))
smt_assert_design_assumes(i)
smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i))
smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i))
smt_assert_consequent(get_constr_expr(constr_assumes, i))
print_msg("Re-solving with appended steps..")
if smt_check_sat() == "unsat":
print("%s Cannot append steps without violating assumptions!" % smt.timestamp())
print_msg("Cannot append steps without violating assumptions!")
report_tracked_assumptions("Conflicting assumptions:")
retstatus = "FAILED"
break
print_anyconsts(step)
Expand Down
95 changes: 83 additions & 12 deletions backends/smt2/smtbmc_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ class InteractiveError(Exception):
pass


def mkkey(data):
if isinstance(data, list):
return tuple(map(mkkey, data))
elif isinstance(data, dict):
raise InteractiveError(f"JSON objects found in assumption key: {data!r}")
return data


class Incremental:
def __init__(self):
self.traceidx = 0
Expand Down Expand Up @@ -73,17 +81,17 @@ def expr_arg_len(self, expr, min_len, max_len=-1):

if min_len is not None and arg_len < min_len:
if min_len == max_len:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression must have "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
else:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression must have at least "
f"{min_len} argument{'s' if min_len != 1 else ''}"
)
if max_len is not None and arg_len > max_len:
raise (
raise InteractiveError(
f"{json.dumps(expr[0])} expression can have at most "
f"{min_len} argument{'s' if max_len != 1 else ''}"
)
Expand All @@ -96,14 +104,31 @@ def expr_step(self, expr, smt_out):
smt_out.append(f"s{step}")
return "module", smtbmc.topmod

def expr_mod_constraint(self, expr, smt_out):
self.expr_arg_len(expr, 1)
def expr_cell(self, expr, smt_out):
self.expr_arg_len(expr, 2)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None])
arg_sort = self.expr(expr[2], smt_out, required_sort=["module", None])
smt_out.append(")")
module = arg_sort[1]
cell = expr[1]
submod = smtbmc.smt.modinfo[module].cells.get(cell)
if submod is None:
raise InteractiveError(f"module {module!r} has no cell {cell!r}")
smt_out[position] = f"(|{module}_h {cell}| "
return ("module", submod)

def expr_mod_constraint(self, expr, smt_out):
suffix = expr[0][3:]
smt_out[position] = f"(|{module}{suffix}| "
self.expr_arg_len(expr, 1, 2 if suffix in ["_a", "_u", "_c"] else 1)
position = len(smt_out)
smt_out.append(None)
arg_sort = self.expr(expr[-1], smt_out, required_sort=["module", None])
module = arg_sort[1]
if len(expr) == 3:
smt_out[position] = f"(|{module}{suffix} {expr[1]}| "
else:
smt_out[position] = f"(|{module}{suffix}| "
smt_out.append(")")
return "Bool"

Expand Down Expand Up @@ -223,20 +248,19 @@ def expr_label(self, expr, smt_out):
subexpr = expr[2]

if not isinstance(label, str):
raise InteractiveError(f"expression label has to be a string")
raise InteractiveError("expression label has to be a string")

smt_out.append("(! ")
smt_out.appedd(label)
smt_out.append(" ")

sort = self.expr(subexpr, smt_out)

smt_out.append(" :named ")
smt_out.append(label)
smt_out.append(")")

return sort

expr_handlers = {
"step": expr_step,
"cell": expr_cell,
"mod_h": expr_mod_constraint,
"mod_is": expr_mod_constraint,
"mod_i": expr_mod_constraint,
Expand Down Expand Up @@ -302,6 +326,30 @@ def cmd_assert(self, cmd):

assert_fn(self.expr_smt(cmd.get("expr"), "Bool"))

def cmd_assert_design_assumes(self, cmd):
step = self.arg_step(cmd)
smtbmc.smt_assert_design_assumes(step)

def cmd_get_design_assume(self, cmd):
key = mkkey(cmd.get("key"))
return smtbmc.assume_enables.get(key)

def cmd_update_assumptions(self, cmd):
expr = cmd.get("expr")
key = cmd.get("key")


key = mkkey(key)

result = smtbmc.smt.smt2_assumptions.pop(key, None)
if expr is not None:
expr = self.expr_smt(expr, "Bool")
smtbmc.smt.smt2_assumptions[key] = expr
return result

def cmd_get_unsat_assumptions(self, cmd):
return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize')))

def cmd_push(self, cmd):
smtbmc.smt_push()

Expand All @@ -313,11 +361,14 @@ def cmd_check(self, cmd):

def cmd_smtlib(self, cmd):
command = cmd.get("command")
response = cmd.get("response", False)
if not isinstance(command, str):
raise InteractiveError(
f"raw SMT-LIB command must be a string, found {json.dumps(command)}"
)
smtbmc.smt.write(command)
if response:
return smtbmc.smt.read()

def cmd_design_hierwitness(self, cmd=None):
allregs = (cmd is None) or bool(cmd.get("allreges", False))
Expand Down Expand Up @@ -369,6 +420,21 @@ def cmd_read_yw_trace(self, cmd):

return dict(last_step=last_step)

def cmd_modinfo(self, cmd):
fields = cmd.get("fields", [])

mod = cmd.get("mod")
if mod is None:
mod = smtbmc.topmod
modinfo = smtbmc.smt.modinfo.get(mod)
if modinfo is None:
return None

result = dict(name=mod)
for field in fields:
result[field] = getattr(modinfo, field, None)
return result

def cmd_ping(self, cmd):
return cmd

Expand All @@ -377,13 +443,18 @@ def cmd_ping(self, cmd):
"assert": cmd_assert,
"assert_antecedent": cmd_assert,
"assert_consequent": cmd_assert,
"assert_design_assumes": cmd_assert_design_assumes,
"get_design_assume": cmd_get_design_assume,
"update_assumptions": cmd_update_assumptions,
"get_unsat_assumptions": cmd_get_unsat_assumptions,
"push": cmd_push,
"pop": cmd_pop,
"check": cmd_check,
"smtlib": cmd_smtlib,
"design_hierwitness": cmd_design_hierwitness,
"write_yw_trace": cmd_write_yw_trace,
"read_yw_trace": cmd_read_yw_trace,
"modinfo": cmd_modinfo,
"ping": cmd_ping,
}

Expand Down
Loading
Loading