From a52088b6af5c4467835de173d83a00645546eddf Mon Sep 17 00:00:00 2001 From: Jannis Harder Date: Tue, 7 May 2024 17:57:37 +0200 Subject: [PATCH] smtbmc: Improvements for --incremental and .yw fixes This extends the experimental incremental JSON API to allow arbitrary smtlib subexpressions, defining smtlib constants and to allow access of signals by their .yw path. It also fixes a bug during .yw writing where values would be re-emitted in later cycles if they have no newer defined value and a potential crash when using --track-assumes. --- backends/smt2/smtbmc.py | 203 +++++++++++++++++----------- backends/smt2/smtbmc_incremental.py | 147 ++++++++++++++++++-- backends/smt2/smtio.py | 21 ++- 3 files changed, 279 insertions(+), 92 deletions(-) diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index e6b4088dbd7..995a714c926 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -199,7 +199,6 @@ def help(): --minimize-assumes when using --track-assumes, solve for a minimal set of sufficient assumptions. - """ + so.helpmsg()) def usage(): @@ -670,18 +669,12 @@ def print_msg(msg): ywfile_hierwitness_cache = None -def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False): +def ywfile_hierwitness(): global ywfile_hierwitness_cache - if map_steps is None: - map_steps = {} - - with open(inywfile, "r") as f: - inyw = ReadWitness(f) - - if ywfile_hierwitness_cache is None: - ywfile_hierwitness_cache = smt.hierwitness(topmod, allregs=True, blackbox=True) + if ywfile_hierwitness_cache is None: + ywfile_hierwitness = smt.hierwitness(topmod, allregs=True, blackbox=True) - inits, seqs, clocks, mems = ywfile_hierwitness_cache + inits, seqs, clocks, mems = ywfile_hierwitness smt_wires = defaultdict(list) smt_mems = defaultdict(list) @@ -692,91 +685,147 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False): for mem in mems: smt_mems[mem["path"]].append(mem) - addr_re = re.compile(r'\\\[[0-9]+\]$') - bits_re = re.compile(r'[01?]*$') + ywfile_hierwitness_cache = inits, seqs, clocks, mems, smt_wires, smt_mems - max_t = -1 + return ywfile_hierwitness_cache - for t, step in inyw.steps(): - present_signals, missing = step.present_signals(inyw.sigmap) - for sig in present_signals: - bits = step[sig] - if skip_x: - bits = bits.replace('x', '?') - if not bits_re.match(bits): - raise ValueError("unsupported bit value in Yosys witness file") +def_bits_re = re.compile(r'([01]+)') - sig_end = sig.offset + len(bits) - if sig.path in smt_wires: - for wire in smt_wires[sig.path]: - width, offset = wire["width"], wire["offset"] +def smt_extract_mask(smt_expr, mask): + chunks = [] + def_bits = '' - smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1 + mask_index_order = mask[::-1] - offset = max(offset, 0) + for matched in def_bits_re.finditer(mask_index_order): + chunks.append(matched.span()) + def_bits += matched[0] - end = width + offset - common_offset = max(sig.offset, offset) - common_end = min(sig_end, end) - if common_end <= common_offset: - continue + if not chunks: + return - smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire) + if len(chunks) == 1: + start, end = chunks[0] + if start == 0 and end == len(mask_index_order): + combined_chunks = smt_expr + else: + combined_chunks = '((_ extract %d %d) %s)' % (end - 1, start, smt_expr) + else: + combined_chunks = '(let ((x %s)) (concat %s))' % (smt_expr, ' '.join( + '((_ extract %d %d) x)' % (end - 1, start) + for start, end in reversed(chunks) + )) - if not smt_bool: - slice_high = common_end - offset - 1 - slice_low = common_offset - offset - smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr) + return combined_chunks, ''.join(mask_index_order[start:end] for start, end in chunks)[::-1] - bit_slice = bits[len(bits) - (common_end - sig.offset):len(bits) - (common_offset - sig.offset)] +def smt_concat(exprs): + if not exprs: + return "" + if len(exprs) == 1: + return exprs[1] + return "(concat %s)" % ' '.join(exprs) - if bit_slice.count("?") == len(bit_slice): - continue +def ywfile_signal(sig, step, mask=None): + assert sig.width > 0 - if smt_bool: - assert width == 1 - smt_constr = "(= %s %s)" % (smt_expr, "true" if bit_slice == "1" else "false") - else: - if "?" in bit_slice: - mask = bit_slice.replace("0", "1").replace("?", "0") - bit_slice = bit_slice.replace("?", "0") - smt_expr = "(bvand %s #b%s)" % (smt_expr, mask) + inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness() + sig_end = sig.offset + sig.width - smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice) + output = [] - constr_assumes[t].append((inywfile, smt_constr)) + if sig.path in smt_wires: + for wire in smt_wires[sig.path]: + width, offset = wire["width"], wire["offset"] - if sig.memory_path: - if sig.memory_path in smt_mems: - for mem in smt_mems[sig.memory_path]: - width, size, bv = mem["width"], mem["size"], mem["statebv"] + smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1 - smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"]) + offset = max(offset, 0) - if bv: - word_low = sig.memory_addr * width - word_high = word_low + width - 1 - smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr) - else: - addr_width = (size - 1).bit_length() - addr_bits = f"{sig.memory_addr:0{addr_width}b}" - smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits) + end = width + offset + common_offset = max(sig.offset, offset) + common_end = min(sig_end, end) + if common_end <= common_offset: + continue - if len(bits) < width: - slice_high = sig.offset + len(bits) - 1 - smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr) + smt_expr = smt.witness_net_expr(topmod, f"s{step}", wire) - bit_slice = bits + if not smt_bool: + slice_high = common_end - offset - 1 + slice_low = common_offset - offset + smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr) + else: + smt_expr = "(ite %s #b1 #b0)" % smt_expr - if "?" in bit_slice: - mask = bit_slice.replace("0", "1").replace("?", "0") - bit_slice = bit_slice.replace("?", "0") - smt_expr = "(bvand %s #b%s)" % (smt_expr, mask) + output.append(((common_offset - sig.offset), (common_end - sig.offset), smt_expr)) - smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice) - constr_assumes[t].append((inywfile, smt_constr)) - max_t = t + if sig.memory_path: + if sig.memory_path in smt_mems: + for mem in smt_mems[sig.memory_path]: + width, size, bv = mem["width"], mem["size"], mem["statebv"] + + smt_expr = smt.net_expr(topmod, f"s{step}", mem["smtpath"]) + + if bv: + word_low = sig.memory_addr * width + word_high = word_low + width - 1 + smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr) + else: + addr_width = (size - 1).bit_length() + addr_bits = f"{sig.memory_addr:0{addr_width}b}" + smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits) + + if sig.width < width: + slice_high = sig.offset + sig.width - 1 + smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr) + output.append((0, sig.width, smt_expr)) + + output.sort() + + output = [chunk for chunk in output if chunk[0] != chunk[1]] + + pos = 0 + + for start, end, smt_expr in output: + assert start == pos + pos = end + + assert pos == sig.width + + if len(output) == 1: + return output[0][-1] + return smt_concat(smt_expr for start, end, smt_expr in reversed(output)) + +def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False): + global ywfile_hierwitness_cache + if map_steps is None: + map_steps = {} + + with open(inywfile, "r") as f: + inyw = ReadWitness(f) + + inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness() + + bits_re = re.compile(r'[01?]*$') + max_t = -1 + + for t, step in inyw.steps(): + present_signals, missing = step.present_signals(inyw.sigmap) + for sig in present_signals: + bits = step[sig] + if skip_x: + bits = bits.replace('x', '?') + if not bits_re.match(bits): + raise ValueError("unsupported bit value in Yosys witness file") + + smt_expr = ywfile_signal(sig, map_steps.get(t, t)) + + smt_expr, bits = smt_extract_mask(smt_expr, bits) + + smt_constr = "(= %s #b%s)" % (smt_expr, bits) + constr_assumes[t].append((inywfile, smt_constr)) + + max_t = t return max_t if inywfile is not None: @@ -1367,11 +1416,11 @@ def write_yw_trace(steps, index, allregs=False, filename=None): exprs.extend(smt.witness_net_expr(topmod, f"s{k}", sig) for sig in sigs) - all_sigs.append(sigs) + all_sigs.append((step_values, sigs)) bvs = iter(smt.get_list(exprs)) - for sigs in all_sigs: + for (step_values, sigs) in all_sigs: for sig in sigs: value = smt.bv2bin(next(bvs)) step_values[sig["sig"]] = value diff --git a/backends/smt2/smtbmc_incremental.py b/backends/smt2/smtbmc_incremental.py index f43e878f31c..0bd280b4a48 100644 --- a/backends/smt2/smtbmc_incremental.py +++ b/backends/smt2/smtbmc_incremental.py @@ -1,7 +1,7 @@ from collections import defaultdict import json import typing -from functools import partial +import ywio if typing.TYPE_CHECKING: import smtbmc @@ -34,6 +34,7 @@ def __init__(self): self._witness_index = None self._yw_constraints = {} + self._define_sorts = {} def setup(self): generic_assert_map = smtbmc.get_assert_map( @@ -175,11 +176,7 @@ def expr_andor(self, expr, smt_out): if len(expr) == 1: smt_out.push({"and": "true", "or": "false"}[expr[0]]) elif len(expr) == 2: - arg_sort = self.expr(expr[1], smt_out) - if arg_sort != "Bool": - raise InteractiveError( - f"arguments of {json.dumps(expr[0])} must have sort Bool" - ) + self.expr(expr[1], smt_out, required_sort="Bool") else: sep = f"({expr[0]} " for arg in expr[1:]: @@ -189,7 +186,51 @@ def expr_andor(self, expr, smt_out): smt_out.append(")") return "Bool" + def expr_bv_binop(self, expr, smt_out): + self.expr_arg_len(expr, 2) + + smt_out.append(f"({expr[0]} ") + arg_sort = self.expr(expr[1], smt_out, required_sort=("BitVec", None)) + smt_out.append(" ") + self.expr(expr[2], smt_out, required_sort=arg_sort) + smt_out.append(")") + return arg_sort + + def expr_extract(self, expr, smt_out): + self.expr_arg_len(expr, 3) + + hi = expr[1] + lo = expr[2] + + smt_out.append(f"((_ extract {hi} {lo}) ") + + arg_sort = self.expr(expr[3], smt_out, required_sort=("BitVec", None)) + smt_out.append(")") + + if not (isinstance(hi, int) and 0 <= hi < arg_sort[1]): + raise InteractiveError( + f"high bit index must be 0 <= index < {arg_sort[1]}, is {hi!r}" + ) + if not (isinstance(lo, int) and 0 <= lo <= hi): + raise InteractiveError( + f"low bit index must be 0 <= index < {hi}, is {lo!r}" + ) + + return "BitVec", hi - lo + 1 + + def expr_bv(self, expr, smt_out): + self.expr_arg_len(expr, 1) + + arg = expr[1] + if not isinstance(arg, str) or arg.count("0") + arg.count("1") != len(arg): + raise InteractiveError("bv argument must contain only 0 or 1 bits") + + smt_out.append("#b" + arg) + + return "BitVec", len(arg) + def expr_yw(self, expr, smt_out): + self.expr_arg_len(expr, 1, 2) if len(expr) == 2: name = None step = expr[1] @@ -219,6 +260,40 @@ def expr_yw(self, expr, smt_out): return "Bool" + def expr_yw_sig(self, expr, smt_out): + self.expr_arg_len(expr, 3, 4) + + step = expr[1] + path = expr[2] + offset = expr[3] + width = expr[4] if len(expr) == 5 else 1 + + if not isinstance(offset, int) or offset < 0: + raise InteractiveError( + f"offset must be a non-negative integer, got {json.dumps(offset)}" + ) + + if not isinstance(width, int) or width <= 0: + raise InteractiveError( + f"width must be a positive integer, got {json.dumps(width)}" + ) + + if not isinstance(path, list) or not all(isinstance(s, str) for s in path): + raise InteractiveError( + f"path must be a string list, got {json.dumps(path)}" + ) + + if step not in self.state_set: + raise InteractiveError(f"step {step} not declared") + + smt_expr = smtbmc.ywfile_signal( + ywio.WitnessSig(path=path, offset=offset, width=width), step + ) + + smt_out.append(smt_expr) + + return "BitVec", width + def expr_smtlib(self, expr, smt_out): self.expr_arg_len(expr, 2) @@ -231,10 +306,15 @@ def expr_smtlib(self, expr, smt_out): f"got {json.dumps(smtlib_expr)}" ) - if not isinstance(sort, str): - raise InteractiveError( - f"raw SMT-LIB sort has to be a string, got {json.dumps(sort)}" - ) + if ( + isinstance(sort, list) + and len(sort) == 2 + and sort[0] == "BitVec" + and (sort[1] is None or isinstance(sort[1], int)) + ): + sort = tuple(sort) + elif not isinstance(sort, str): + raise InteractiveError(f"unsupported raw SMT-LIB sort {json.dumps(sort)}") smt_out.append(smtlib_expr) return sort @@ -258,6 +338,14 @@ def expr_label(self, expr, smt_out): return sort + def expr_def(self, expr, smt_out): + self.expr_arg_len(expr, 1) + sort = self._define_sorts.get(expr[1]) + if sort is None: + raise InteractiveError(f"unknown definition {json.dumps(expr)}") + smt_out.append(expr[1]) + return sort + expr_handlers = { "step": expr_step, "cell": expr_cell, @@ -270,8 +358,15 @@ def expr_label(self, expr, smt_out): "not": expr_not, "and": expr_andor, "or": expr_andor, + "bv": expr_bv, + "bvand": expr_bv_binop, + "bvor": expr_bv_binop, + "bvxor": expr_bv_binop, + "extract": expr_extract, + "def": expr_def, "=": expr_eq, "yw": expr_yw, + "yw_sig": expr_yw_sig, "smtlib": expr_smtlib, "!": expr_label, } @@ -305,10 +400,13 @@ def expr(self, expr, smt_out, required_sort=None): raise InteractiveError(f"unknown expression {json.dumps(expr[0])}") def expr_smt(self, expr, required_sort): + return self.expr_smt_and_sort(expr, required_sort)[0] + + def expr_smt_and_sort(self, expr, required_sort=None): smt_out = [] - self.expr(expr, smt_out, required_sort=required_sort) + output_sort = self.expr(expr, smt_out, required_sort=required_sort) out = "".join(smt_out) - return out + return out, output_sort def cmd_new_step(self, cmd): step = self.arg_step(cmd, declare=True) @@ -338,7 +436,6 @@ def cmd_update_assumptions(self, cmd): expr = cmd.get("expr") key = cmd.get("key") - key = mkkey(key) result = smtbmc.smt.smt2_assumptions.pop(key, None) @@ -348,7 +445,7 @@ def cmd_update_assumptions(self, cmd): return result def cmd_get_unsat_assumptions(self, cmd): - return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get('minimize'))) + return smtbmc.smt.get_unsat_assumptions(minimize=bool(cmd.get("minimize"))) def cmd_push(self, cmd): smtbmc.smt_push() @@ -370,6 +467,27 @@ def cmd_smtlib(self, cmd): if response: return smtbmc.smt.read() + def cmd_define(self, cmd): + expr = cmd.get("expr") + if expr is None: + raise InteractiveError("'define' copmmand requires 'expr' parameter") + + expr, sort = self.expr_smt_and_sort(expr) + + if isinstance(sort, tuple) and sort[0] == "module": + raise InteractiveError("'define' does not support module sorts") + + define_name = f"|inc def {len(self._define_sorts)}|" + + self._define_sorts[define_name] = sort + + if isinstance(sort, tuple): + sort = f"(_ {' '.join(map(str, sort))})" + + smtbmc.smt.write(f"(define-const {define_name} {sort} {expr})") + + return {"name": define_name} + def cmd_design_hierwitness(self, cmd=None): allregs = (cmd is None) or bool(cmd.get("allreges", False)) if self._cached_hierwitness[allregs] is not None: @@ -451,6 +569,7 @@ def cmd_ping(self, cmd): "pop": cmd_pop, "check": cmd_check, "smtlib": cmd_smtlib, + "define": cmd_define, "design_hierwitness": cmd_design_hierwitness, "write_yw_trace": cmd_write_yw_trace, "read_yw_trace": cmd_read_yw_trace, diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index e32f43c60a0..5fc3ab5a424 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -160,6 +160,7 @@ def __init__(self, opts=None): self.noincr = opts.noincr self.info_stmts = opts.info_stmts self.nocomments = opts.nocomments + self.smt2_options.update(opts.smt2_options) else: self.solver = "yices" @@ -959,6 +960,8 @@ def bv2int(self, v): return int(self.bv2bin(v), 2) def get_raw_unsat_assumptions(self): + if not self.smt2_assumptions: + return [] self.write("(get-unsat-assumptions)") exprs = set(self.unparse(part) for part in self.parse(self.read())) unsat_assumptions = [] @@ -973,6 +976,10 @@ def get_raw_unsat_assumptions(self): def get_unsat_assumptions(self, minimize=False): if not minimize: return self.get_raw_unsat_assumptions() + orig_assumptions = self.smt2_assumptions + + self.smt2_assumptions = dict(orig_assumptions) + required_assumptions = {} while True: @@ -998,6 +1005,7 @@ def get_unsat_assumptions(self, minimize=False): required_assumptions[candidate_key] = candidate_assume if candidate_assumptions is not None: + self.smt2_assumptions = orig_assumptions return list(required_assumptions) def get(self, expr): @@ -1146,7 +1154,7 @@ def wait(self): class SmtOpts: def __init__(self): self.shortopts = "s:S:v" - self.longopts = ["unroll", "noincr", "noprogress", "timeout=", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments"] + self.longopts = ["unroll", "noincr", "noprogress", "timeout=", "dump-smt2=", "logic=", "dummy=", "info=", "nocomments", "smt2-option="] self.solver = "yices" self.solver_opts = list() self.debug_print = False @@ -1159,6 +1167,7 @@ def __init__(self): self.logic = None self.info_stmts = list() self.nocomments = False + self.smt2_options = {} def handle(self, o, a): if o == "-s": @@ -1185,6 +1194,13 @@ def handle(self, o, a): self.info_stmts.append(a) elif o == "--nocomments": self.nocomments = True + elif o == "--smt2-option": + args = a.split('=', 1) + if len(args) != 2: + print("--smt2-option expects an