From a718238cbd3af662ec0a2fcf3b29664dcd3f171e Mon Sep 17 00:00:00 2001 From: Jannis Harder Date: Thu, 7 Mar 2024 13:27:03 +0100 Subject: [PATCH] smtbmc: Add --track-assumes and --minimize-assumes options The --track-assumes option makes smtbmc keep track of which assumptions were used by the solver when reaching an unsat case and to output that set of assumptions. This is particularly useful to debug PREUNSAT failures. The --minimize-assumes option can be used in addition to --track-assumes which will cause smtbmc to spend additional solving effort to produce a minimal set of assumptions that are sufficient to cause the unsat result. --- backends/smt2/smtbmc.py | 87 +++++++++++++++++++++++--- backends/smt2/smtbmc_incremental.py | 95 +++++++++++++++++++++++++---- backends/smt2/smtio.py | 61 +++++++++++++++++- 3 files changed, 219 insertions(+), 24 deletions(-) diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index cc47bc3762a..e6b4088dbd7 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -57,6 +57,8 @@ check_witness = False detect_loops = False incremental = None +track_assumes = False +minimize_assumes = False so = SmtOpts() @@ -189,6 +191,15 @@ def help(): --incremental run in incremental mode (experimental) + --track-assumes + track individual assumptions and report a subset of used + assumptions that are sufficient for the reported outcome. This + can be used to debug PREUNSAT failures as well as to find a + smaller set of sufficient assumptions. + + --minimize-assumes + when using --track-assumes, solve for a minimal set of sufficient assumptions. + """ + so.helpmsg()) def usage(): @@ -200,7 +211,8 @@ def usage(): opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts + ["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat", "dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=", - "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"]) + "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental", + "track-assumes", "minimize-assumes"]) except: usage() @@ -289,6 +301,10 @@ def usage(): elif o == "--incremental": from smtbmc_incremental import Incremental incremental = Incremental() + elif o == "--track-assumes": + track_assumes = True + elif o == "--minimize-assumes": + minimize_assumes = True elif so.handle(o, a): pass else: @@ -447,6 +463,9 @@ def replace_netref(match): smt = SmtIo(opts=so) +if track_assumes: + smt.smt2_options[':produce-unsat-assumptions'] = 'true' + if noinfo and vcdfile is None and vlogtbfile is None and outconstr is None: smt.produce_models = False @@ -1497,6 +1516,44 @@ def get_active_assert_map(step, active): return assert_map +assume_enables = {} + +def declare_assume_enables(): + def recurse(mod, path, key_base=()): + for expr, desc in smt.modinfo[mod].assumes.items(): + enable = f"|assume_enable {len(assume_enables)}|" + smt.smt2_assumptions[(expr, key_base)] = enable + smt.write(f"(declare-const {enable} Bool)") + assume_enables[(expr, key_base)] = (enable, path, desc) + + for cell, submod in smt.modinfo[mod].cells.items(): + recurse(submod, f"{path}.{cell}", (mod, cell, key_base)) + + recurse(topmod, topmod) + +if track_assumes: + declare_assume_enables() + +def smt_assert_design_assumes(step): + if not track_assumes: + smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) + return + + if not assume_enables: + return + + def expr_for_assume(assume_key, base=None): + expr, key_base = assume_key + expr_prefix = f"(|{expr}| " + expr_suffix = ")" + while key_base: + mod, cell, key_base = key_base + expr_prefix += f"(|{mod}_h {cell}| " + expr_suffix += ")" + return f"{expr_prefix} s{step}{expr_suffix}" + + for assume_key, (enable, path, desc) in assume_enables.items(): + smt_assert_consequent(f"(=> {enable} {expr_for_assume(assume_key)})") states = list() asserts_antecedent_cache = [list()] @@ -1651,6 +1708,13 @@ def smt_check_sat(expected=["sat", "unsat"]): smt_forall_assert() return smt.check_sat(expected=expected) +def report_tracked_assumptions(msg): + if track_assumes: + print_msg(msg) + for key in smt.get_unsat_assumptions(minimize=minimize_assumes): + enable, path, descr = assume_enables[key] + print_msg(f" In {path}: {descr}") + if incremental: incremental.mainloop() @@ -1664,7 +1728,7 @@ def smt_check_sat(expected=["sat", "unsat"]): break smt_state(step) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) + smt_assert_design_assumes(step) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step)) @@ -1707,6 +1771,7 @@ def smt_check_sat(expected=["sat", "unsat"]): else: print_msg("Temporal induction successful.") + report_tracked_assumptions("Used assumptions:") retstatus = "PASSED" break @@ -1732,7 +1797,7 @@ def smt_check_sat(expected=["sat", "unsat"]): while step < num_steps: smt_state(step) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) + smt_assert_design_assumes(step) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step)) @@ -1753,6 +1818,7 @@ def smt_check_sat(expected=["sat", "unsat"]): smt_assert("(distinct (covers_%d s%d) #b%s)" % (coveridx, step, "0" * len(cover_desc))) if smt_check_sat() == "unsat": + report_tracked_assumptions("Used assumptions:") smt_pop() break @@ -1761,13 +1827,14 @@ def smt_check_sat(expected=["sat", "unsat"]): print_msg("Appending additional step %d." % i) smt_state(i) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i)) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, i)) + smt_assert_design_assumes(i) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i)) smt_assert_consequent(get_constr_expr(constr_assumes, i)) print_msg("Re-solving with appended steps..") if smt_check_sat() == "unsat": print("%s Cannot appended steps without violating assumptions!" % smt.timestamp()) + report_tracked_assumptions("Conflicting assumptions:") found_failed_assert = True retstatus = "FAILED" break @@ -1823,7 +1890,7 @@ def smt_check_sat(expected=["sat", "unsat"]): retstatus = "PASSED" while step < num_steps: smt_state(step) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, step)) + smt_assert_design_assumes(step) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step)) smt_assert_consequent(get_constr_expr(constr_assumes, step)) @@ -1853,7 +1920,7 @@ def smt_check_sat(expected=["sat", "unsat"]): if step+i < num_steps: smt_state(step+i) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, step+i)) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, step+i)) + smt_assert_design_assumes(step + i) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, step+i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, step+i-1, step+i)) smt_assert_consequent(get_constr_expr(constr_assumes, step+i)) @@ -1867,7 +1934,8 @@ def smt_check_sat(expected=["sat", "unsat"]): print_msg("Checking assumptions in steps %d to %d.." % (step, last_check_step)) if smt_check_sat() == "unsat": - print("%s Assumptions are unsatisfiable!" % smt.timestamp()) + print_msg("Assumptions are unsatisfiable!") + report_tracked_assumptions("Conficting assumptions:") retstatus = "PREUNSAT" break @@ -1920,13 +1988,14 @@ def smt_check_sat(expected=["sat", "unsat"]): print_msg("Appending additional step %d." % i) smt_state(i) smt_assert_antecedent("(not (|%s_is| s%d))" % (topmod, i)) - smt_assert_consequent("(|%s_u| s%d)" % (topmod, i)) + smt_assert_design_assumes(i) smt_assert_antecedent("(|%s_h| s%d)" % (topmod, i)) smt_assert_antecedent("(|%s_t| s%d s%d)" % (topmod, i-1, i)) smt_assert_consequent(get_constr_expr(constr_assumes, i)) print_msg("Re-solving with appended steps..") if smt_check_sat() == "unsat": - print("%s Cannot append steps without violating assumptions!" % smt.timestamp()) + print_msg("Cannot append steps without violating assumptions!") + report_tracked_assumptions("Conflicting assumptions:") retstatus = "FAILED" break print_anyconsts(step) diff --git a/backends/smt2/smtbmc_incremental.py b/backends/smt2/smtbmc_incremental.py index 1a2a4570312..a1f793c9a41 100644 --- a/backends/smt2/smtbmc_incremental.py +++ b/backends/smt2/smtbmc_incremental.py @@ -15,6 +15,14 @@ class InteractiveError(Exception): pass +def mkkey(data): + if isinstance(data, list): + return tuple(map(mkkey, data)) + elif isinstance(data, dict): + raise InteractiveError(f"JSON objects found in assumption key: {key!r}") + return data + + class Incremental: def __init__(self): self.traceidx = 0 @@ -73,17 +81,17 @@ def expr_arg_len(self, expr, min_len, max_len=-1): if min_len is not None and arg_len < min_len: if min_len == max_len: - raise ( + raise InteractiveError( f"{json.dumps(expr[0])} expression must have " f"{min_len} argument{'s' if min_len != 1 else ''}" ) else: - raise ( + raise InteractiveError( f"{json.dumps(expr[0])} expression must have at least " f"{min_len} argument{'s' if min_len != 1 else ''}" ) if max_len is not None and arg_len > max_len: - raise ( + raise InteractiveError( f"{json.dumps(expr[0])} expression can have at most " f"{min_len} argument{'s' if max_len != 1 else ''}" ) @@ -96,14 +104,31 @@ def expr_step(self, expr, smt_out): smt_out.append(f"s{step}") return "module", smtbmc.topmod - def expr_mod_constraint(self, expr, smt_out): - self.expr_arg_len(expr, 1) + def expr_cell(self, expr, smt_out): + self.expr_arg_len(expr, 2) position = len(smt_out) smt_out.append(None) - arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None]) + arg_sort = self.expr(expr[2], smt_out, required_sort=["module", None]) + smt_out.append(")") module = arg_sort[1] + cell = expr[1] + submod = smtbmc.smt.modinfo[module].cells.get(cell) + if submod is None: + raise InteractiveError(f"module {module!r} has no cell {cell!r}") + smt_out[position] = f"(|{module}_h {cell}| " + return ("module", submod) + + def expr_mod_constraint(self, expr, smt_out): suffix = expr[0][3:] - smt_out[position] = f"(|{module}{suffix}| " + self.expr_arg_len(expr, 1, 2 if suffix in ["_a", "_u", "_c"] else 1) + position = len(smt_out) + smt_out.append(None) + arg_sort = self.expr(expr[-1], smt_out, required_sort=["module", None]) + module = arg_sort[1] + if len(expr) == 3: + smt_out[position] = f"(|{module}{suffix} {expr[1]}| " + else: + smt_out[position] = f"(|{module}{suffix}| " smt_out.append(")") return "Bool" @@ -223,20 +248,19 @@ def expr_label(self, expr, smt_out): subexpr = expr[2] if not isinstance(label, str): - raise InteractiveError(f"expression label has to be a string") + raise InteractiveError("expression label has to be a string") smt_out.append("(! ") - smt_out.appedd(label) - smt_out.append(" ") - sort = self.expr(subexpr, smt_out) - + smt_out.append(" :named ") + smt_out.append(label) smt_out.append(")") return sort expr_handlers = { "step": expr_step, + "cell": expr_cell, "mod_h": expr_mod_constraint, "mod_is": expr_mod_constraint, "mod_i": expr_mod_constraint, @@ -302,6 +326,30 @@ def cmd_assert(self, cmd): assert_fn(self.expr_smt(cmd.get("expr"), "Bool")) + def cmd_assert_design_assumes(self, cmd): + step = self.arg_step(cmd) + smtbmc.smt_assert_design_assumes(step) + + def cmd_get_design_assume(self, cmd): + key = mkkey(cmd.get("key")) + return smtbmc.assume_enables.get(key) + + def cmd_update_assumptions(self, cmd): + expr = cmd.get("expr") + key = cmd.get("key") + + + key = mkkey(key) + + result = smtbmc.smt.smt2_assumptions.pop(key, None) + if expr is not None: + expr = self.expr_smt(expr, "Bool") + smtbmc.smt.smt2_assumptions[key] = expr + return result + + def cmd_get_unsat_assumptions(self, cmd): + return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize'))) + def cmd_push(self, cmd): smtbmc.smt_push() @@ -313,11 +361,14 @@ def cmd_check(self, cmd): def cmd_smtlib(self, cmd): command = cmd.get("command") + response = cmd.get("response", False) if not isinstance(command, str): raise InteractiveError( f"raw SMT-LIB command must be a string, found {json.dumps(command)}" ) smtbmc.smt.write(command) + if response: + return smtbmc.smt.read() def cmd_design_hierwitness(self, cmd=None): allregs = (cmd is None) or bool(cmd.get("allreges", False)) @@ -369,6 +420,21 @@ def cmd_read_yw_trace(self, cmd): return dict(last_step=last_step) + def cmd_modinfo(self, cmd): + fields = cmd.get("fields", []) + + mod = cmd.get("mod") + if mod is None: + mod = smtbmc.topmod + modinfo = smtbmc.smt.modinfo.get(mod) + if modinfo is None: + return None + + result = dict(name=mod) + for field in fields: + result[field] = getattr(modinfo, field, None) + return result + def cmd_ping(self, cmd): return cmd @@ -377,6 +443,10 @@ def cmd_ping(self, cmd): "assert": cmd_assert, "assert_antecedent": cmd_assert, "assert_consequent": cmd_assert, + "assert_design_assumes": cmd_assert_design_assumes, + "get_design_assume": cmd_get_design_assume, + "update_assumptions": cmd_update_assumptions, + "get_unsat_assumptions": cmd_get_unsat_assumptions, "push": cmd_push, "pop": cmd_pop, "check": cmd_check, @@ -384,6 +454,7 @@ def cmd_ping(self, cmd): "design_hierwitness": cmd_design_hierwitness, "write_yw_trace": cmd_write_yw_trace, "read_yw_trace": cmd_read_yw_trace, + "modinfo": cmd_modinfo, "ping": cmd_ping, } diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index c904aea9531..e32f43c60a0 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -114,6 +114,7 @@ def __init__(self): self.clocks = dict() self.cells = dict() self.asserts = dict() + self.assumes = dict() self.covers = dict() self.maximize = set() self.minimize = set() @@ -141,6 +142,7 @@ def __init__(self, opts=None): self.recheck = False self.smt2cache = [list()] self.smt2_options = dict() + self.smt2_assumptions = dict() self.p = None self.p_index = solvers_index solvers_index += 1 @@ -602,6 +604,12 @@ def info(self, stmt): else: self.modinfo[self.curmod].covers["%s_c %s" % (self.curmod, fields[2])] = fields[3] + if fields[1] == "yosys-smt2-assume": + if len(fields) > 4: + self.modinfo[self.curmod].assumes["%s_u %s" % (self.curmod, fields[2])] = f'{fields[4]} ({fields[3]})' + else: + self.modinfo[self.curmod].assumes["%s_u %s" % (self.curmod, fields[2])] = fields[3] + if fields[1] == "yosys-smt2-maximize": self.modinfo[self.curmod].maximize.add(fields[2]) @@ -785,8 +793,13 @@ def read(self): return stmt def check_sat(self, expected=["sat", "unsat", "unknown", "timeout", "interrupted"]): + if self.smt2_assumptions: + assume_exprs = " ".join(self.smt2_assumptions.values()) + check_stmt = f"(check-sat-assuming ({assume_exprs}))" + else: + check_stmt = "(check-sat)" if self.debug_print: - print("> (check-sat)") + print(f"> {check_stmt}") if self.debug_file and not self.nocomments: print("; running check-sat..", file=self.debug_file) self.debug_file.flush() @@ -800,7 +813,7 @@ def check_sat(self, expected=["sat", "unsat", "unknown", "timeout", "interrupted for cache_stmt in cache_ctx: self.p_write(cache_stmt + "\n", False) - self.p_write("(check-sat)\n", True) + self.p_write(f"{check_stmt}\n", True) if self.timeinfo: i = 0 @@ -868,7 +881,7 @@ def check_sat(self, expected=["sat", "unsat", "unknown", "timeout", "interrupted if self.debug_file: print("(set-info :status %s)" % result, file=self.debug_file) - print("(check-sat)", file=self.debug_file) + print(check_stmt, file=self.debug_file) self.debug_file.flush() if result not in expected: @@ -945,6 +958,48 @@ def bv2bin(self, v): def bv2int(self, v): return int(self.bv2bin(v), 2) + def get_raw_unsat_assumptions(self): + self.write("(get-unsat-assumptions)") + exprs = set(self.unparse(part) for part in self.parse(self.read())) + unsat_assumptions = [] + for key, value in self.smt2_assumptions.items(): + # normalize expression + value = self.unparse(self.parse(value)) + if value in exprs: + exprs.remove(value) + unsat_assumptions.append(key) + return unsat_assumptions + + def get_unsat_assumptions(self, minimize=False): + if not minimize: + return self.get_raw_unsat_assumptions() + required_assumptions = {} + + while True: + candidate_assumptions = {} + for key in self.get_raw_unsat_assumptions(): + if key not in required_assumptions: + candidate_assumptions[key] = self.smt2_assumptions[key] + + while candidate_assumptions: + + candidate_key, candidate_assume = candidate_assumptions.popitem() + + self.smt2_assumptions = {} + for key, assume in candidate_assumptions.items(): + self.smt2_assumptions[key] = assume + for key, assume in required_assumptions.items(): + self.smt2_assumptions[key] = assume + result = self.check_sat() + + if result == 'unsat': + candidate_assumptions = None + else: + required_assumptions[candidate_key] = candidate_assume + + if candidate_assumptions is not None: + return list(required_assumptions) + def get(self, expr): self.write("(get-value (%s))" % (expr)) return self.parse(self.read())[0][1]