Skip to content

Commit

Permalink
Merge pull request #4377 from jix/smtbmc-incremental-improvements
Browse files Browse the repository at this point in the history
smtbmc: Improvements for --incremental and .yw fixes
  • Loading branch information
mmicko authored May 7, 2024
2 parents 8735107 + a52088b commit c9d87d5
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 92 deletions.
203 changes: 126 additions & 77 deletions backends/smt2/smtbmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def help():
--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.
""" + so.helpmsg())

def usage():
Expand Down Expand Up @@ -670,18 +669,12 @@ def print_msg(msg):

ywfile_hierwitness_cache = None

def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
def ywfile_hierwitness():
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}

with open(inywfile, "r") as f:
inyw = ReadWitness(f)

if ywfile_hierwitness_cache is None:
ywfile_hierwitness_cache = smt.hierwitness(topmod, allregs=True, blackbox=True)
if ywfile_hierwitness_cache is None:
ywfile_hierwitness = smt.hierwitness(topmod, allregs=True, blackbox=True)

inits, seqs, clocks, mems = ywfile_hierwitness_cache
inits, seqs, clocks, mems = ywfile_hierwitness

smt_wires = defaultdict(list)
smt_mems = defaultdict(list)
Expand All @@ -692,91 +685,147 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
for mem in mems:
smt_mems[mem["path"]].append(mem)

addr_re = re.compile(r'\\\[[0-9]+\]$')
bits_re = re.compile(r'[01?]*$')
ywfile_hierwitness_cache = inits, seqs, clocks, mems, smt_wires, smt_mems

max_t = -1
return ywfile_hierwitness_cache

for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals:
bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")
def_bits_re = re.compile(r'([01]+)')

sig_end = sig.offset + len(bits)
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]
def smt_extract_mask(smt_expr, mask):
chunks = []
def_bits = ''

smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1
mask_index_order = mask[::-1]

offset = max(offset, 0)
for matched in def_bits_re.finditer(mask_index_order):
chunks.append(matched.span())
def_bits += matched[0]

end = width + offset
common_offset = max(sig.offset, offset)
common_end = min(sig_end, end)
if common_end <= common_offset:
continue
if not chunks:
return

smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire)
if len(chunks) == 1:
start, end = chunks[0]
if start == 0 and end == len(mask_index_order):
combined_chunks = smt_expr
else:
combined_chunks = '((_ extract %d %d) %s)' % (end - 1, start, smt_expr)
else:
combined_chunks = '(let ((x %s)) (concat %s))' % (smt_expr, ' '.join(
'((_ extract %d %d) x)' % (end - 1, start)
for start, end in reversed(chunks)
))

if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
return combined_chunks, ''.join(mask_index_order[start:end] for start, end in chunks)[::-1]

bit_slice = bits[len(bits) - (common_end - sig.offset):len(bits) - (common_offset - sig.offset)]
def smt_concat(exprs):
if not exprs:
return ""
if len(exprs) == 1:
return exprs[1]
return "(concat %s)" % ' '.join(exprs)

if bit_slice.count("?") == len(bit_slice):
continue
def ywfile_signal(sig, step, mask=None):
assert sig.width > 0

if smt_bool:
assert width == 1
smt_constr = "(= %s %s)" % (smt_expr, "true" if bit_slice == "1" else "false")
else:
if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()
sig_end = sig.offset + sig.width

smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
output = []

constr_assumes[t].append((inywfile, smt_constr))
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]

if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1

smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"])
offset = max(offset, 0)

if bv:
word_low = sig.memory_addr * width
word_high = word_low + width - 1
smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
else:
addr_width = (size - 1).bit_length()
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)
end = width + offset
common_offset = max(sig.offset, offset)
common_end = min(sig_end, end)
if common_end <= common_offset:
continue

if len(bits) < width:
slice_high = sig.offset + len(bits) - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)
smt_expr = smt.witness_net_expr(topmod, f"s{step}", wire)

bit_slice = bits
if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
else:
smt_expr = "(ite %s #b1 #b0)" % smt_expr

if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
output.append(((common_offset - sig.offset), (common_end - sig.offset), smt_expr))

smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr))
max_t = t
if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]

smt_expr = smt.net_expr(topmod, f"s{step}", mem["smtpath"])

if bv:
word_low = sig.memory_addr * width
word_high = word_low + width - 1
smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
else:
addr_width = (size - 1).bit_length()
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)

if sig.width < width:
slice_high = sig.offset + sig.width - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)

output.append((0, sig.width, smt_expr))

output.sort()

output = [chunk for chunk in output if chunk[0] != chunk[1]]

pos = 0

for start, end, smt_expr in output:
assert start == pos
pos = end

assert pos == sig.width

if len(output) == 1:
return output[0][-1]
return smt_concat(smt_expr for start, end, smt_expr in reversed(output))

def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}

with open(inywfile, "r") as f:
inyw = ReadWitness(f)

inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()

bits_re = re.compile(r'[01?]*$')
max_t = -1

for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals:
bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")

smt_expr = ywfile_signal(sig, map_steps.get(t, t))

smt_expr, bits = smt_extract_mask(smt_expr, bits)

smt_constr = "(= %s #b%s)" % (smt_expr, bits)
constr_assumes[t].append((inywfile, smt_constr))

max_t = t
return max_t

if inywfile is not None:
Expand Down Expand Up @@ -1367,11 +1416,11 @@ def write_yw_trace(steps, index, allregs=False, filename=None):

exprs.extend(smt.witness_net_expr(topmod, f"s{k}", sig) for sig in sigs)

all_sigs.append(sigs)
all_sigs.append((step_values, sigs))

bvs = iter(smt.get_list(exprs))

for sigs in all_sigs:
for (step_values, sigs) in all_sigs:
for sig in sigs:
value = smt.bv2bin(next(bvs))
step_values[sig["sig"]] = value
Expand Down
Loading

0 comments on commit c9d87d5

Please sign in to comment.