diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index 02e15a1b502..dd3f9ac48dc 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -17,7 +17,7 @@ # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. # -import os, sys, getopt, re, bisect +import os, sys, getopt, re, bisect, json ##yosys-sys-path## from smtio import SmtIo, SmtOpts, MkVcd from ywio import ReadWitness, WriteWitness, WitnessValues @@ -56,6 +56,7 @@ keep_going = False check_witness = False detect_loops = False +incremental = None so = SmtOpts() @@ -185,6 +186,9 @@ def help(): check if states are unique in temporal induction counter examples (this feature is experimental and incomplete) + --incremental + run in incremental mode (experimental) + """ + so.helpmsg()) def usage(): @@ -196,7 +200,7 @@ def usage(): opts, args = getopt.getopt(sys.argv[1:], so.shortopts + "t:higcm:", so.longopts + ["help", "final-only", "assume-skipped=", "smtc=", "cex=", "aig=", "aig-noheader", "yw=", "btorwit=", "presat", "dump-vcd=", "dump-yw=", "dump-vlogtb=", "vlogtb-top=", "dump-smtc=", "dump-all", "noinfo", "append=", - "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops"]) + "smtc-init", "smtc-top=", "noinit", "binary", "keep-going", "check-witness", "detect-loops", "incremental"]) except: usage() @@ -282,6 +286,9 @@ def usage(): check_witness = True elif o == "--detect-loops": detect_loops = True + elif o == "--incremental": + from smtbmc_incremental import Incremental + incremental = Incremental() elif so.handle(o, a): pass else: @@ -290,7 +297,7 @@ def usage(): if len(args) != 1: usage() -if sum([tempind, gentrace, covermode]) > 1: +if sum([tempind, gentrace, covermode, incremental is not None]) > 1: usage() constr_final_start = None @@ -444,8 +451,10 @@ def replace_netref(match): smt.produce_models = False def print_msg(msg): - print("%s %s" % (smt.timestamp(), msg)) - sys.stdout.flush() + if incremental: + incremental.print_msg(msg) + else: + print("%s %s" % (smt.timestamp(), msg), flush=True) print_msg("Solver: %s" % (so.solver)) @@ -640,10 +649,9 @@ def print_msg(msg): num_steps = max(num_steps, step+2) step += 1 -if inywfile is not None: - if not got_topt: - skip_steps = 0 - num_steps = 0 +def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False): + if map_steps is None: + map_steps = {} with open(inywfile, "r") as f: inyw = ReadWitness(f) @@ -662,10 +670,14 @@ def print_msg(msg): addr_re = re.compile(r'\\\[[0-9]+\]$') bits_re = re.compile(r'[01?]*$') + max_t = -1 + for t, step in inyw.steps(): present_signals, missing = step.present_signals(inyw.sigmap) for sig in present_signals: bits = step[sig] + if skip_x: + bits = bits.replace('x', '?') if not bits_re.match(bits): raise ValueError("unsupported bit value in Yosys witness file") @@ -684,7 +696,7 @@ def print_msg(msg): if common_end <= common_offset: continue - smt_expr = smt.witness_net_expr(topmod, f"s{t}", wire) + smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire) if not smt_bool: slice_high = common_end - offset - 1 @@ -714,7 +726,7 @@ def print_msg(msg): for mem in smt_mems[sig.memory_path]: width, size, bv = mem["width"], mem["size"], mem["statebv"] - smt_expr = smt.net_expr(topmod, f"s{t}", mem["smtpath"]) + smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"]) if bv: word_low = sig.memory_addr * width @@ -738,11 +750,21 @@ def print_msg(msg): smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice) constr_assumes[t].append((inywfile, smt_constr)) + max_t = t - if not got_topt: - if not check_witness: - skip_steps = max(skip_steps, t) - num_steps = max(num_steps, t+1) + return max_t + +if inywfile is not None: + if not got_topt: + skip_steps = 0 + num_steps = 0 + + max_t = ywfile_constraints(inywfile, constr_assumes) + + if not got_topt: + if not check_witness: + skip_steps = max(skip_steps, max_t) + num_steps = max(num_steps, max_t+1) if btorwitfile is not None: with open(btorwitfile, "r") as f: @@ -841,7 +863,7 @@ def print_msg(msg): skip_steps = step num_steps = step+1 -def collect_mem_trace_data(steps_start, steps_stop, vcd=None): +def collect_mem_trace_data(steps, vcd=None): mem_trace_data = dict() for mempath in sorted(smt.hiermems(topmod)): @@ -849,16 +871,16 @@ def collect_mem_trace_data(steps_start, steps_stop, vcd=None): expr_id = list() expr_list = list() - for i in range(steps_start, steps_stop): + for seq, i in enumerate(steps): for j in range(rports): - expr_id.append(('R', i-steps_start, j, 'A')) - expr_id.append(('R', i-steps_start, j, 'D')) + expr_id.append(('R', seq, j, 'A')) + expr_id.append(('R', seq, j, 'D')) expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dA" % j)) expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "R%dD" % j)) for j in range(wports): - expr_id.append(('W', i-steps_start, j, 'A')) - expr_id.append(('W', i-steps_start, j, 'D')) - expr_id.append(('W', i-steps_start, j, 'M')) + expr_id.append(('W', seq, j, 'A')) + expr_id.append(('W', seq, j, 'D')) + expr_id.append(('W', seq, j, 'M')) expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dA" % j)) expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dD" % j)) expr_list.append(smt.mem_expr(topmod, "s%d" % i, mempath, "W%dM" % j)) @@ -943,14 +965,14 @@ def collect_mem_trace_data(steps_start, steps_stop, vcd=None): netpath[-1] += "<%0*x>" % ((len(addr)+3) // 4, int_addr) vcd.add_net([topmod] + netpath, width) - for i in range(steps_start, steps_stop): + for seq, i in enumerate(steps): if i not in mem_trace_data: mem_trace_data[i] = list() - mem_trace_data[i].append((netpath, int_addr, "".join(tdata[i-steps_start]))) + mem_trace_data[i].append((netpath, int_addr, "".join(tdata[seq]))) return mem_trace_data -def write_vcd_trace(steps_start, steps_stop, index): +def write_vcd_trace(steps, index, seq_time=False): filename = vcdfile.replace("%", index) print_msg("Writing trace to VCD file: %s" % (filename)) @@ -971,10 +993,10 @@ def write_vcd_trace(steps_start, steps_stop, index): vcd.add_clock([topmod] + netpath, edge) path_list.append(netpath) - mem_trace_data = collect_mem_trace_data(steps_start, steps_stop, vcd) + mem_trace_data = collect_mem_trace_data(steps, vcd) - for i in range(steps_start, steps_stop): - vcd.set_time(i) + for seq, i in enumerate(steps): + vcd.set_time(seq if seq_time else i) value_list = smt.get_net_bin_list(topmod, path_list, "s%d" % i) for path, value in zip(path_list, value_list): vcd.set_net([topmod] + path, value) @@ -982,7 +1004,14 @@ def write_vcd_trace(steps_start, steps_stop, index): for path, addr, value in mem_trace_data[i]: vcd.set_net([topmod] + path, value) - vcd.set_time(steps_stop) + if seq_time: + end_time = len(steps) + elif steps: + end_time = steps[-1] + 1 + else: + end_time = 0 + + vcd.set_time(end_time) def detect_state_loop(steps_start, steps_stop): print_msg(f"Checking for loops in found induction counter example") @@ -1027,7 +1056,7 @@ def escape_identifier(identifier): -def write_vlogtb_trace(steps_start, steps_stop, index): +def write_vlogtb_trace(steps, index): filename = vlogtbfile.replace("%", index) print_msg("Writing trace to Verilog testbench: %s" % (filename)) @@ -1092,7 +1121,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index): print(" initial begin", file=f) regs = sorted(smt.hiernets(vlogtb_topmod, regs_only=True)) - regvals = smt.get_net_bin_list(vlogtb_topmod, regs, vlogtb_state.replace("@@step_idx@@", str(steps_start))) + regvals = smt.get_net_bin_list(vlogtb_topmod, regs, vlogtb_state.replace("@@step_idx@@", str(steps[0]))) print("`ifndef VERILATOR", file=f) print(" #1;", file=f) @@ -1107,7 +1136,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index): anyconsts = sorted(smt.hieranyconsts(vlogtb_topmod)) for info in anyconsts: if info[3] is not None: - modstate = smt.net_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(steps_start)), info[0]) + modstate = smt.net_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(steps[0])), info[0]) value = smt.bv2bin(smt.get("(|%s| %s)" % (info[1], modstate))) print(" UUT.%s = %d'b%s;" % (".".join(escape_identifier(info[0] + [info[3]])), len(value), value), file=f); @@ -1117,7 +1146,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index): addr_expr_list = list() data_expr_list = list() - for i in range(steps_start, steps_stop): + for i in steps: for j in range(rports): addr_expr_list.append(smt.mem_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(i)), mempath, "R%dA" % j)) data_expr_list.append(smt.mem_expr(vlogtb_topmod, vlogtb_state.replace("@@step_idx@@", str(i)), mempath, "R%dD" % j)) @@ -1138,7 +1167,7 @@ def write_vlogtb_trace(steps_start, steps_stop, index): print("", file=f) anyseqs = sorted(smt.hieranyseqs(vlogtb_topmod)) - for i in range(steps_start, steps_stop): + for i in steps: pi_names = [[name] for name, _ in primary_inputs if name not in clock_inputs] pi_values = smt.get_net_bin_list(vlogtb_topmod, pi_names, vlogtb_state.replace("@@step_idx@@", str(i))) @@ -1170,14 +1199,14 @@ def write_vlogtb_trace(steps_start, steps_stop, index): print(" end", file=f) print(" always @(posedge clock) begin", file=f) - print(" genclock <= cycle < %d;" % (steps_stop-1), file=f) + print(" genclock <= cycle < %d;" % (steps[-1]), file=f) print(" cycle <= cycle + 1;", file=f) print(" end", file=f) print("endmodule", file=f) -def write_constr_trace(steps_start, steps_stop, index): +def write_constr_trace(steps, index): filename = outconstr.replace("%", index) print_msg("Writing trace to constraints file: %s" % (filename)) @@ -1194,7 +1223,7 @@ def write_constr_trace(steps_start, steps_stop, index): constr_prefix = smtctop[1] + "." if smtcinit: - steps_start = steps_stop - 1 + steps = [steps[-1]] with open(filename, "w") as f: primary_inputs = list() @@ -1203,13 +1232,13 @@ def write_constr_trace(steps_start, steps_stop, index): width = smt.modinfo[constr_topmod].wsize[name] primary_inputs.append((name, width)) - if steps_start == 0 or smtcinit: + if steps[0] == 0 or smtcinit: print("initial", file=f) else: - print("state %d" % steps_start, file=f) + print("state %d" % steps[0], file=f) regnames = sorted(smt.hiernets(constr_topmod, regs_only=True)) - regvals = smt.get_net_list(constr_topmod, regnames, constr_state.replace("@@step_idx@@", str(steps_start))) + regvals = smt.get_net_list(constr_topmod, regnames, constr_state.replace("@@step_idx@@", str(steps[0]))) for name, val in zip(regnames, regvals): print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f) @@ -1220,7 +1249,7 @@ def write_constr_trace(steps_start, steps_stop, index): addr_expr_list = list() data_expr_list = list() - for i in range(steps_start, steps_stop): + for i in steps: for j in range(rports): addr_expr_list.append(smt.mem_expr(constr_topmod, constr_state.replace("@@step_idx@@", str(i)), mempath, "R%dA" % j)) data_expr_list.append(smt.mem_expr(constr_topmod, constr_state.replace("@@step_idx@@", str(i)), mempath, "R%dD" % j)) @@ -1236,7 +1265,7 @@ def write_constr_trace(steps_start, steps_stop, index): for addr, data in addr_data.items(): print("assume (= (select [%s%s] %s) %s)" % (constr_prefix, ".".join(mempath), addr, data), file=f) - for k in range(steps_start, steps_stop): + for k in steps: if not smtcinit: print("", file=f) print("state %d" % k, file=f) @@ -1247,11 +1276,14 @@ def write_constr_trace(steps_start, steps_stop, index): for name, val in zip(pi_names, pi_values): print("assume (= [%s%s] %s)" % (constr_prefix, ".".join(name), val), file=f) -def write_yw_trace(steps_start, steps_stop, index, allregs=False): - filename = outywfile.replace("%", index) - print_msg("Writing trace to Yosys witness file: %s" % (filename)) +def write_yw_trace(steps, index, allregs=False, filename=None): + if filename is None: + if outywfile is None: + return + filename = outywfile.replace("%", index) + print_msg("Writing trace to Yosys witness file: %s" % (filename)) - mem_trace_data = collect_mem_trace_data(steps_start, steps_stop) + mem_trace_data = collect_mem_trace_data(steps) with open(filename, "w") as f: inits, seqs, clocks, mems = smt.hierwitness(topmod, allregs) @@ -1295,10 +1327,10 @@ def write_yw_trace(steps_start, steps_stop, index, allregs=False): sig = yw.add_sig(word_path, overlap_start, overlap_end - overlap_start, True) mem_init_values.append((sig, overlap_bits.replace("x", "?"))) - for k in range(steps_start, steps_stop): + for i, k in enumerate(steps): step_values = WitnessValues() - if k == steps_start: + if not i: for sig, value in mem_init_values: step_values[sig] = value sigs = inits + seqs @@ -1314,17 +1346,24 @@ def write_yw_trace(steps_start, steps_stop, index, allregs=False): def write_trace(steps_start, steps_stop, index, allregs=False): + if steps_stop is None: + steps = steps_start + seq_time = True + else: + steps = list(range(steps_start, steps_stop)) + seq_time = False + if vcdfile is not None: - write_vcd_trace(steps_start, steps_stop, index) + write_vcd_trace(steps, index, seq_time=seq_time) if vlogtbfile is not None: - write_vlogtb_trace(steps_start, steps_stop, index) + write_vlogtb_trace(steps, index) if outconstr is not None: - write_constr_trace(steps_start, steps_stop, index) + write_constr_trace(steps, index) if outywfile is not None: - write_yw_trace(steps_start, steps_stop, index, allregs) + write_yw_trace(steps, index, allregs) def print_failed_asserts_worker(mod, state, path, extrainfo, infomap, infokey=()): @@ -1596,7 +1635,11 @@ def smt_check_sat(expected=["sat", "unsat"]): smt_forall_assert() return smt.check_sat(expected=expected) -if tempind: + +if incremental: + incremental.mainloop() + +elif tempind: retstatus = "FAILED" skip_counter = step_size for step in range(num_steps, -1, -1): @@ -1954,5 +1997,6 @@ def smt_check_sat(expected=["sat", "unsat"]): smt.write("(exit)") smt.wait() -print_msg("Status: %s" % retstatus) -sys.exit(0 if retstatus == "PASSED" else 1) +if not incremental: + print_msg("Status: %s" % retstatus) + sys.exit(0 if retstatus == "PASSED" else 1) diff --git a/backends/smt2/smtbmc_incremental.py b/backends/smt2/smtbmc_incremental.py new file mode 100644 index 00000000000..2be4fb6799f --- /dev/null +++ b/backends/smt2/smtbmc_incremental.py @@ -0,0 +1,389 @@ +from collections import defaultdict +import json +import typing +from functools import partial + +if typing.TYPE_CHECKING: + import smtbmc +else: + import sys + + smtbmc = sys.modules["__main__"] + + +class InteractiveError(Exception): + pass + + +class Incremental: + def __init__(self): + self.traceidx = 0 + + self.state_set = set() + self.map_cache = {} + + self._cached_hierwitness = {} + self._witness_index = None + + self._yw_constraints = {} + + def setup(self): + generic_assert_map = smtbmc.get_assert_map( + smtbmc.topmod, "state", smtbmc.topmod + ) + self.inv_generic_assert_map = { + tuple(data[1:]): key for key, data in generic_assert_map.items() + } + assert len(self.inv_generic_assert_map) == len(generic_assert_map) + + def print_json(self, **kwargs): + print(json.dumps(kwargs), flush=True) + + def print_msg(self, msg): + self.print_json(msg=msg) + + def get_cached_assert(self, step, name): + try: + assert_map = self.map_cache[step] + except KeyError: + assert_map = self.map_cache[step] = smtbmc.get_assert_map( + smtbmc.topmod, f"s{step}", smtbmc.topmod + ) + return assert_map[self.inv_generic_assert_map[name]][0] + + def arg_step(self, cmd, declare=False, name="step", optional=False): + step = cmd.get(name, None) + if step is None and optional: + return None + if not isinstance(step, int): + if optional: + raise InteractiveError(f"{name} must be an integer") + else: + raise InteractiveError(f"integer {name} argument required") + if declare and step in self.state_set: + raise InteractiveError(f"step {step} already declared") + if not declare and step not in self.state_set: + raise InteractiveError(f"step {step} not declared") + return step + + def expr_arg_len(self, expr, min_len, max_len=-1): + if max_len == -1: + max_len = min_len + arg_len = len(expr) - 1 + + if min_len is not None and arg_len < min_len: + if min_len == max_len: + raise ( + f"{json.dumps(expr[0])} expression must have " + f"{min_len} argument{'s' if min_len != 1 else ''}" + ) + else: + raise ( + f"{json.dumps(expr[0])} expression must have at least " + f"{min_len} argument{'s' if min_len != 1 else ''}" + ) + if max_len is not None and arg_len > max_len: + raise ( + f"{json.dumps(expr[0])} expression can have at most " + f"{min_len} argument{'s' if max_len != 1 else ''}" + ) + + def expr_step(self, expr, smt_out): + self.expr_arg_len(expr, 1) + step = expr[1] + if step not in self.state_set: + raise InteractiveError(f"step {step} not declared") + smt_out.append(f"s{step}") + return "module", smtbmc.topmod + + def expr_mod_constraint(self, expr, smt_out): + self.expr_arg_len(expr, 1) + position = len(smt_out) + smt_out.append(None) + arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None]) + module = arg_sort[1] + suffix = expr[0][3:] + smt_out[position] = f"(|{module}{suffix}| " + smt_out.append(")") + return "Bool" + + def expr_mod_constraint2(self, expr, smt_out): + self.expr_arg_len(expr, 2) + + position = len(smt_out) + smt_out.append(None) + arg_sort = self.expr(expr[1], smt_out, required_sort=["module", None]) + smt_out.append(" ") + self.expr(expr[2], smt_out, required_sort=arg_sort) + module = arg_sort[1] + suffix = expr[0][3:] + smt_out[position] = f"(|{module}{suffix}| " + smt_out.append(")") + return "Bool" + + def expr_not(self, expr, smt_out): + self.expr_arg_len(expr, 1) + + smt_out.append("(not ") + self.expr(expr[1], smt_out, required_sort="Bool") + smt_out.append(")") + return "Bool" + + def expr_eq(self, expr, smt_out): + self.expr_arg_len(expr, 2) + + smt_out.append("(= ") + arg_sort = self.expr(expr[1], smt_out) + if ( + smtbmc.smt.unroll + and isinstance(arg_sort, (list, tuple)) + and arg_sort[0] == "module" + ): + raise InteractiveError("state equality not supported in unroll mode") + + smt_out.append(" ") + self.expr(expr[2], smt_out, required_sort=arg_sort) + smt_out.append(")") + return "Bool" + + def expr_andor(self, expr, smt_out): + if len(expr) == 1: + smt_out.push({"and": "true", "or": "false"}[expr[0]]) + elif len(expr) == 2: + arg_sort = self.expr(expr[1], smt_out) + if arg_sort != "Bool": + raise InteractiveError( + f"arguments of {json.dumps(expr[0])} must have sort Bool" + ) + else: + sep = f"({expr[0]} " + for arg in expr[1:]: + smt_out.append(sep) + sep = " " + self.expr(arg, smt_out, required_sort="Bool") + smt_out.append(")") + return "Bool" + + def expr_yw(self, expr, smt_out): + if len(expr) == 2: + name = None + step = expr[1] + elif len(expr) == 3: + name = expr[1] + step = expr[2] + + if step not in self.state_set: + raise InteractiveError(f"step {step} not declared") + + if name not in self._yw_constraints: + raise InteractiveError(f"no yw file loaded as name {name!r}") + + constraints = self._yw_constraints[name].get(step, []) + + if len(constraints) == 0: + smt_out.append("true") + elif len(constraints) == 1: + smt_out.append(constraints[0]) + else: + sep = "(and " + for constraint in constraints: + smt_out.append(sep) + sep = " " + smt_out.append(constraint) + smt_out.append(")") + + return "Bool" + + def expr_label(self, expr, smt_out): + if len(expr) != 3: + raise InteractiveError(f'expected ["!", label, sub_expr], got {expr!r}') + label = expr[1] + subexpr = expr[2] + + if not isinstance(label, str): + raise InteractiveError(f"expression label has to be a string") + + smt_out.append("(! ") + smt_out.appedd(label) + smt_out.append(" ") + + sort = self.expr(subexpr, smt_out) + + smt_out.append(")") + + return sort + + expr_handlers = { + "step": expr_step, + "mod_h": expr_mod_constraint, + "mod_is": expr_mod_constraint, + "mod_i": expr_mod_constraint, + "mod_a": expr_mod_constraint, + "mod_u": expr_mod_constraint, + "mod_t": expr_mod_constraint2, + "not": expr_not, + "and": expr_andor, + "or": expr_andor, + "=": expr_eq, + "yw": expr_yw, + "!": expr_label, + } + + def expr(self, expr, smt_out, required_sort=None): + if not isinstance(expr, (list, tuple)) or not expr: + raise InteractiveError( + f"expression must be a non-empty JSON array, found: {json.dumps(expr)}" + ) + name = expr[0] + + handler = self.expr_handlers.get(name) + if handler: + sort = handler(self, expr, smt_out) + + if required_sort is not None: + if isinstance(required_sort, (list, tuple)): + if ( + not isinstance(sort, (list, tuple)) + or len(sort) != len(required_sort) + or any( + r is not None and r != s + for r, s in zip(required_sort, sort) + ) + ): + raise InteractiveError( + f"required sort {json.dumps(required_sort)} found sort {json.dumps(sort)}" + ) + return sort + raise InteractiveError(f"unknown expression {json.dumps(expr[0])}") + + def expr_smt(self, expr, required_sort): + smt_out = [] + self.expr(expr, smt_out, required_sort=required_sort) + out = "".join(smt_out) + return out + + def cmd_new_step(self, cmd): + step = self.arg_step(cmd, declare=True) + self.state_set.add(step) + smtbmc.smt_state(step) + + def cmd_assert(self, cmd): + name = cmd.get("cmd") + + assert_fn = { + "assert_antecedent": smtbmc.smt_assert_antecedent, + "assert_consequent": smtbmc.smt_assert_consequent, + "assert": smtbmc.smt_assert, + }[name] + + assert_fn(self.expr_smt(cmd.get("expr"), "Bool")) + + def cmd_push(self, cmd): + smtbmc.smt_push() + + def cmd_pop(self, cmd): + smtbmc.smt_pop() + + def cmd_check(self, cmd): + return smtbmc.smt_check_sat() + + def cmd_design_hierwitness(self, cmd=None): + allregs = (cmd is None) or bool(cmd.get("allreges", False)) + if self._cached_hierwitness[allregs] is not None: + return self._cached_hierwitness[allregs] + inits, seqs, clocks, mems = smtbmc.smt.hierwitness(smtbmc.topmod, allregs) + self._cached_hierwitness[allregs] = result = dict( + inits=inits, seqs=seqs, clocks=clocks, mems=mems + ) + return result + + def cmd_write_yw_trace(self, cmd): + steps = cmd.get("steps") + allregs = bool(cmd.get("allregs", False)) + + if steps is None: + steps = sorted(self.state_set) + + path = cmd.get("path") + + smtbmc.write_yw_trace(steps, self.traceidx, allregs=allregs, filename=path) + + if path is None: + self.traceidx += 1 + + def cmd_read_yw_trace(self, cmd): + steps = cmd.get("steps") + path = cmd.get("path") + name = cmd.get("name") + skip_x = cmd.get("skip_x", False) + if path is None: + raise InteractiveError("path required") + + constraints = defaultdict(list) + + if steps is None: + steps = sorted(self.state_set) + + map_steps = {i: int(j) for i, j in enumerate(steps)} + + smtbmc.ywfile_constraints(path, constraints, map_steps=map_steps, skip_x=skip_x) + + self._yw_constraints[name] = { + map_steps.get(i, i): [smtexpr for cexfile, smtexpr in constraint_list] + for i, constraint_list in constraints.items() + } + + def cmd_ping(self, cmd): + return cmd + + cmd_handlers = { + "new_step": cmd_new_step, + "assert": cmd_assert, + "assert_antecedent": cmd_assert, + "assert_consequent": cmd_assert, + "push": cmd_push, + "pop": cmd_pop, + "check": cmd_check, + "design_hierwitness": cmd_design_hierwitness, + "write_yw_trace": cmd_write_yw_trace, + "read_yw_trace": cmd_read_yw_trace, + "ping": cmd_ping, + } + + def handle_command(self, cmd): + if not isinstance(cmd, dict) or "cmd" not in cmd: + raise InteractiveError('object with "cmd" key required') + + name = cmd.get("cmd", None) + + handler = self.cmd_handlers.get(name) + if handler: + return handler(self, cmd) + else: + raise InteractiveError(f"unknown command: {name}") + + def mainloop(self): + self.setup() + while True: + try: + cmd = input().strip() + if not cmd or cmd.startswith("#") or cmd.startswith("//"): + continue + try: + cmd = json.loads(cmd) + except json.decoder.JSONDecodeError as e: + self.print_json(err=f"invalid JSON: {e}") + continue + except EOFError: + break + + try: + result = self.handle_command(cmd) + except InteractiveError as e: + self.print_json(err=str(e)) + continue + except Exception as e: + self.print_json(err=f"internal error: {e}") + raise + else: + self.print_json(ok=result) diff --git a/backends/smt2/witness.py b/backends/smt2/witness.py index 0977f4532d5..a39500c2dc1 100644 --- a/backends/smt2/witness.py +++ b/backends/smt2/witness.py @@ -33,10 +33,14 @@ def cli(): Display a Yosys witness trace in a human readable format. """) @click.argument("input", type=click.File("r")) -def display(input): +@click.option("--skip-x", help="Treat x bits as unassigned.", is_flag=True) +def display(input, skip_x): click.echo(f"Reading Yosys witness trace {input.name!r}...") inyw = ReadWitness(input) + if skip_x: + inyw.skip_x() + def output(): yield click.style("*** RTLIL bit-order below may differ from source level declarations ***", fg="red") @@ -91,7 +95,11 @@ def stats(input): @click.option("--append", "-p", type=int, multiple=True, help="Number of steps (+ve or -ve) to append to end of input trace. " +"Can be defined multiple times, following the same order as input traces. ") -def yw2yw(inputs, output, append): +@click.option("--skip-x", help="Leave input x bits unassigned.", is_flag=True) +def yw2yw(inputs, output, append, skip_x): + if len(inputs) == 0: + raise click.ClickException(f"no inputs specified") + outyw = WriteWitness(output, "yosys-witness yw2yw") join_inputs = len(inputs) > 1 inyws = {} @@ -129,12 +137,12 @@ def yw2yw(inputs, output, append): click.echo(f"Copying yosys witness trace from {input.name!r} to {output.name!r}...") if first_witness: - outyw.step(init_values) + outyw.step(init_values, skip_x=skip_x) else: - outyw.step(inyw.first_step()) + outyw.step(inyw.first_step(), skip_x=skip_x) for t, values in inyw.steps(1): - outyw.step(values) + outyw.step(values, skip_x=skip_x) click.echo(f" copied {t + 1} time steps.") first_witness = False @@ -174,7 +182,8 @@ def __init__(self, mapfile): @click.argument("input", type=click.File("r")) @click.argument("mapfile", type=click.File("r")) @click.argument("output", type=click.File("w")) -def aiw2yw(input, mapfile, output): +@click.option("--skip-x", help="Leave input x bits unassigned.", is_flag=True) +def aiw2yw(input, mapfile, output, skip_x): input_name = input.name click.echo(f"Converting AIGER witness trace {input_name!r} to Yosys witness trace {output.name!r}...") click.echo(f"Using Yosys witness AIGER map file {mapfile.name!r}") @@ -245,7 +254,7 @@ def aiw2yw(input, mapfile, output): values[bit] = v - outyw.step(values) + outyw.step(values, skip_x=skip_x) outyw.end_trace() diff --git a/backends/smt2/ywio.py b/backends/smt2/ywio.py index 4e95f8c33d2..023a2d351b3 100644 --- a/backends/smt2/ywio.py +++ b/backends/smt2/ywio.py @@ -351,11 +351,14 @@ def write_header(self): self.out.name("steps") self.out.begin_array() - def step(self, values): + def step(self, values, skip_x=False): if not self.header_written: self.write_header() - self.out.value({"bits": values.pack(self.sigmap)}) + packed = values.pack(self.sigmap) + if skip_x: + packed = packed.replace('x', '?') + self.out.value({"bits": packed}) self.t += 1 @@ -390,6 +393,9 @@ def __init__(self, f): self.bits = [step["bits"] for step in data["steps"]] + def skip_x(self): + self.bits = [step.replace('x', '?') for step in self.bits] + def init_step(self): return self.step(0)