Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

smtbmc: Improvements for --incremental and .yw fixes #4377

Merged
merged 1 commit into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 126 additions & 77 deletions backends/smt2/smtbmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def help():

--minimize-assumes
when using --track-assumes, solve for a minimal set of sufficient assumptions.

""" + so.helpmsg())

def usage():
Expand Down Expand Up @@ -670,18 +669,12 @@ def print_msg(msg):

ywfile_hierwitness_cache = None

def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
def ywfile_hierwitness():
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}

with open(inywfile, "r") as f:
inyw = ReadWitness(f)

if ywfile_hierwitness_cache is None:
ywfile_hierwitness_cache = smt.hierwitness(topmod, allregs=True, blackbox=True)
if ywfile_hierwitness_cache is None:
ywfile_hierwitness = smt.hierwitness(topmod, allregs=True, blackbox=True)

inits, seqs, clocks, mems = ywfile_hierwitness_cache
inits, seqs, clocks, mems = ywfile_hierwitness

smt_wires = defaultdict(list)
smt_mems = defaultdict(list)
Expand All @@ -692,91 +685,147 @@ def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
for mem in mems:
smt_mems[mem["path"]].append(mem)

addr_re = re.compile(r'\\\[[0-9]+\]$')
bits_re = re.compile(r'[01?]*$')
ywfile_hierwitness_cache = inits, seqs, clocks, mems, smt_wires, smt_mems

max_t = -1
return ywfile_hierwitness_cache

for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals:
bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")
def_bits_re = re.compile(r'([01]+)')

sig_end = sig.offset + len(bits)
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]
def smt_extract_mask(smt_expr, mask):
chunks = []
def_bits = ''

smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1
mask_index_order = mask[::-1]

offset = max(offset, 0)
for matched in def_bits_re.finditer(mask_index_order):
chunks.append(matched.span())
def_bits += matched[0]

end = width + offset
common_offset = max(sig.offset, offset)
common_end = min(sig_end, end)
if common_end <= common_offset:
continue
if not chunks:
return

smt_expr = smt.witness_net_expr(topmod, f"s{map_steps.get(t, t)}", wire)
if len(chunks) == 1:
start, end = chunks[0]
if start == 0 and end == len(mask_index_order):
combined_chunks = smt_expr
else:
combined_chunks = '((_ extract %d %d) %s)' % (end - 1, start, smt_expr)
else:
combined_chunks = '(let ((x %s)) (concat %s))' % (smt_expr, ' '.join(
'((_ extract %d %d) x)' % (end - 1, start)
for start, end in reversed(chunks)
))

if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
return combined_chunks, ''.join(mask_index_order[start:end] for start, end in chunks)[::-1]

bit_slice = bits[len(bits) - (common_end - sig.offset):len(bits) - (common_offset - sig.offset)]
def smt_concat(exprs):
if not exprs:
return ""
if len(exprs) == 1:
return exprs[1]
return "(concat %s)" % ' '.join(exprs)

if bit_slice.count("?") == len(bit_slice):
continue
def ywfile_signal(sig, step, mask=None):
assert sig.width > 0

if smt_bool:
assert width == 1
smt_constr = "(= %s %s)" % (smt_expr, "true" if bit_slice == "1" else "false")
else:
if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()
sig_end = sig.offset + sig.width

smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
output = []

constr_assumes[t].append((inywfile, smt_constr))
if sig.path in smt_wires:
for wire in smt_wires[sig.path]:
width, offset = wire["width"], wire["offset"]

if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]
smt_bool = smt.net_width(topmod, wire["smtpath"]) == 1

smt_expr = smt.net_expr(topmod, f"s{map_steps.get(t, t)}", mem["smtpath"])
offset = max(offset, 0)

if bv:
word_low = sig.memory_addr * width
word_high = word_low + width - 1
smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
else:
addr_width = (size - 1).bit_length()
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)
end = width + offset
common_offset = max(sig.offset, offset)
common_end = min(sig_end, end)
if common_end <= common_offset:
continue

if len(bits) < width:
slice_high = sig.offset + len(bits) - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)
smt_expr = smt.witness_net_expr(topmod, f"s{step}", wire)

bit_slice = bits
if not smt_bool:
slice_high = common_end - offset - 1
slice_low = common_offset - offset
smt_expr = "((_ extract %d %d) %s)" % (slice_high, slice_low, smt_expr)
else:
smt_expr = "(ite %s #b1 #b0)" % smt_expr

if "?" in bit_slice:
mask = bit_slice.replace("0", "1").replace("?", "0")
bit_slice = bit_slice.replace("?", "0")
smt_expr = "(bvand %s #b%s)" % (smt_expr, mask)
output.append(((common_offset - sig.offset), (common_end - sig.offset), smt_expr))

smt_constr = "(= %s #b%s)" % (smt_expr, bit_slice)
constr_assumes[t].append((inywfile, smt_constr))
max_t = t
if sig.memory_path:
if sig.memory_path in smt_mems:
for mem in smt_mems[sig.memory_path]:
width, size, bv = mem["width"], mem["size"], mem["statebv"]

smt_expr = smt.net_expr(topmod, f"s{step}", mem["smtpath"])

if bv:
word_low = sig.memory_addr * width
word_high = word_low + width - 1
smt_expr = "((_ extract %d %d) %s)" % (word_high, word_low, smt_expr)
else:
addr_width = (size - 1).bit_length()
addr_bits = f"{sig.memory_addr:0{addr_width}b}"
smt_expr = "(select %s #b%s )" % (smt_expr, addr_bits)

if sig.width < width:
slice_high = sig.offset + sig.width - 1
smt_expr = "((_ extract %d %d) %s)" % (slice_high, sig.offset, smt_expr)

output.append((0, sig.width, smt_expr))

output.sort()

output = [chunk for chunk in output if chunk[0] != chunk[1]]

pos = 0

for start, end, smt_expr in output:
assert start == pos
pos = end

assert pos == sig.width

if len(output) == 1:
return output[0][-1]
return smt_concat(smt_expr for start, end, smt_expr in reversed(output))

def ywfile_constraints(inywfile, constr_assumes, map_steps=None, skip_x=False):
global ywfile_hierwitness_cache
if map_steps is None:
map_steps = {}

with open(inywfile, "r") as f:
inyw = ReadWitness(f)

inits, seqs, clocks, mems, smt_wires, smt_mems = ywfile_hierwitness()

bits_re = re.compile(r'[01?]*$')
max_t = -1

for t, step in inyw.steps():
present_signals, missing = step.present_signals(inyw.sigmap)
for sig in present_signals:
bits = step[sig]
if skip_x:
bits = bits.replace('x', '?')
if not bits_re.match(bits):
raise ValueError("unsupported bit value in Yosys witness file")

smt_expr = ywfile_signal(sig, map_steps.get(t, t))

smt_expr, bits = smt_extract_mask(smt_expr, bits)

smt_constr = "(= %s #b%s)" % (smt_expr, bits)
constr_assumes[t].append((inywfile, smt_constr))

max_t = t
return max_t

if inywfile is not None:
Expand Down Expand Up @@ -1367,11 +1416,11 @@ def write_yw_trace(steps, index, allregs=False, filename=None):

exprs.extend(smt.witness_net_expr(topmod, f"s{k}", sig) for sig in sigs)

all_sigs.append(sigs)
all_sigs.append((step_values, sigs))

bvs = iter(smt.get_list(exprs))

for sigs in all_sigs:
for (step_values, sigs) in all_sigs:
for sig in sigs:
value = smt.bv2bin(next(bvs))
step_values[sig["sig"]] = value
Expand Down
Loading
Loading