Skip to content

Commit

Permalink
smtbmc: Improvements for --incremental and .yw fixes
Browse files Browse the repository at this point in the history
This extends the experimental incremental JSON API to allow arbitrary
smtlib subexpressions, defining smtlib constants and to allow access of
signals by their .yw path.

It also fixes a bug during .yw writing where values would be re-emitted
in later cycles if they have no newer defined value and a potential
crash when using --track-assumes.
  • Loading branch information
jix committed May 7, 2024
1 parent 71f2540 commit a52088b
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 92 deletions.
203 changes: 126 additions & 77 deletions backends/smt2/smtbmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def help():
--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.
""" + so.helpmsg())

def usage():
Expand Down Expand Up @@ -670,18 +669,12 @@ def print_msg(msg):

ywfile_hierwitness_cache = None

def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
def ywfile_hierwitness():
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}

with open(inywfile, "r") as f:
inyw = ReadWitness(f)

if ywfile_hierwitness_cache is None:
ywfile_hierwitness_cache = smt.hierwitness(topmod, allregs=True, blackbox=True)
if ywfile_hierwitness_cache is None:
ywfile_hierwitness = smt.hierwitness(topmod, allregs=True, blackbox=True)

inits, seqs, clocks, mems = ywfile_hierwitness_cache
inits, seqs, clocks, mems = ywfile_hierwitness

smt_wires = defaultdict(list)
smt_mems = defaultdict(list)
Expand All @@ -692,91 +685,147 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
for mem in mems:
smt_mems[mem["path"]].append(mem)

addr_re = re.compile(r'\\\[[0-9]+\]$')
bits_re = re.compile(r'[01?]*$')
ywfile_hierwitness_cache = inits, seqs, clocks, mems, smt_wires, smt_mems

max_t = -1
return ywfile_hierwitness_cache

for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals:
bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")
def_bits_re = re.compile(r'([01]+)')

sig_end = sig.offset + len(bits)
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]
def smt_extract_mask(smt_expr, mask):
chunks = []
def_bits = ''

smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1
mask_index_order = mask[::-1]

offset = max(offset, 0)
for matched in def_bits_re.finditer(mask_index_order):
chunks.append(matched.span())
def_bits += matched[0]

end = width + offset
common_offset = max(sig.offset, offset)
common_end = min(sig_end, end)
if common_end <= common_offset:
continue
if not chunks:
return

smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire)
if len(chunks) == 1:
start, end = chunks[0]
if start == 0 and end == len(mask_index_order):
combined_chunks = smt_expr
else:
combined_chunks = '((_ extract %d %d) %s)' % (end - 1, start, smt_expr)
else:
combined_chunks = '(let ((x %s)) (concat %s))' % (smt_expr, ' '.join(
'((_ extract %d %d) x)' % (end - 1, start)
for start, end in reversed(chunks)
))

if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
return combined_chunks, ''.join(mask_index_order[start:end] for start, end in chunks)[::-1]

bit_slice = bits[len(bits) - (common_end - sig.offset):len(bits) - (common_offset - sig.offset)]
def smt_concat(exprs):
if not exprs:
return ""
if len(exprs) == 1:
return exprs[1]
return "(concat %s)" % ' '.join(exprs)

if bit_slice.count("?") == len(bit_slice):
continue
def ywfile_signal(sig, step, mask=None):
assert sig.width > 0

if smt_bool:
assert width == 1
smt_constr = "(= %s %s)" % (smt_expr, "true" if bit_slice == "1" else "false")
else:
if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()
sig_end = sig.offset + sig.width

smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
output = []

constr_assumes[t].append((inywfile, smt_constr))
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]

if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1

smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"])
offset = max(offset, 0)

if bv:
word_low = sig.memory_addr * width
word_high = word_low + width - 1
smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
else:
addr_width = (size - 1).bit_length()
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)
end = width + offset
common_offset = max(sig.offset, offset)
common_end = min(sig_end, end)
if common_end <= common_offset:
continue

if len(bits) < width:
slice_high = sig.offset + len(bits) - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)
smt_expr = smt.witness_net_expr(topmod, f"s{step}", wire)

bit_slice = bits
if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
else:
smt_expr = "(ite %s #b1 #b0)" % smt_expr

if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
output.append(((common_offset - sig.offset), (common_end - sig.offset), smt_expr))

smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr))
max_t = t
if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]

smt_expr = smt.net_expr(topmod, f"s{step}", mem["smtpath"])

if bv:
word_low = sig.memory_addr * width
word_high = word_low + width - 1
smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
else:
addr_width = (size - 1).bit_length()
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)

if sig.width < width:
slice_high = sig.offset + sig.width - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)

output.append((0, sig.width, smt_expr))

output.sort()

output = [chunk for chunk in output if chunk[0] != chunk[1]]

pos = 0

for start, end, smt_expr in output:
assert start == pos
pos = end

assert pos == sig.width

if len(output) == 1:
return output[0][-1]
return smt_concat(smt_expr for start, end, smt_expr in reversed(output))

def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}

with open(inywfile, "r") as f:
inyw = ReadWitness(f)

inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()

bits_re = re.compile(r'[01?]*$')
max_t = -1

for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals:
bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")

smt_expr = ywfile_signal(sig, map_steps.get(t, t))

smt_expr, bits = smt_extract_mask(smt_expr, bits)

smt_constr = "(= %s #b%s)" % (smt_expr, bits)
constr_assumes[t].append((inywfile, smt_constr))

max_t = t
return max_t

if inywfile is not None:
Expand Down Expand Up @@ -1367,11 +1416,11 @@ def write_yw_trace(steps, index, allregs=False, filename=None):

exprs.extend(smt.witness_net_expr(topmod, f"s{k}", sig) for sig in sigs)

all_sigs.append(sigs)
all_sigs.append((step_values, sigs))

bvs = iter(smt.get_list(exprs))

for sigs in all_sigs:
for (step_values, sigs) in all_sigs:
for sig in sigs:
value = smt.bv2bin(next(bvs))
step_values[sig["sig"]] = value
Expand Down
Loading

0 comments on commit a52088b

Please sign in to comment.