Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

format_adt: refactor to manually-written formatter #116

Merged
merged 3 commits into from
Oct 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 118 additions & 88 deletions scripts/format_adt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# vim: noai:ts=2:sw=2:expandtab
# vim: ts=2:sw=2:expandtab:autoindent

"""
format_adt.py implements pretty-printing of BAP .adt files
Expand All @@ -10,13 +10,15 @@
which only supports parsing of a literal Python expression.
"""

import io
import re
import ast
import sys
import time
import pprint
import logging
import argparse
import dataclasses
import contextlib

# grep -E "'([A-Z][a-zA-Z]+)'" src/main/antlr4/BAP_ADT.g4 --only-matching --no-filename | sort | uniq | xargs printf "'%s', " | fold -w80 -s
Expand All @@ -31,121 +33,149 @@
'Var', 'XOR',
]
heads_joined = '|'.join(heads)
heads_joined_b = heads_joined.encode('ascii')

log = logging.getLogger()

notspace_re = re.compile(rb'\S')
head_re = re.compile(rb'[^\s(]+')
num_re = re.compile(rb'[_0-9xa-fA-F]+')
string_re = re.compile(rb'"(?:[^"\\]|\\.)*"')

@dataclasses.dataclass
class Context:
begin: int
closer: bytes
multiline: bool

flip = {
b'(': b')',
b')': b'(',
b'[': b']',
b']': b'[',
}

def pretty(outfile, data: bytes, spaces: int):
# stack of expression beginning parentheses and their start position
stack: list[Context] = []
i = 0
depth = 0
head = b''

indent = b' ' * spaces

i0 = i - 1
length = len(data)
while i < length:
assert i0 != i
i0 = i

c = bytes((data[i],))
if c.isspace():
while i < length and data[i] in b' \t\r\n':
i += 1
elif c.isdigit():
m = num_re.match(data, i)
assert m
outfile.write(m[0])
i = m.end(0)
elif c == b',':
outfile.write(b',')
i += 1
if stack[-1].multiline:
outfile.write(b'\n')
outfile.write(indent * depth)
else:
outfile.write(b' ')
elif c.isupper():
m = head_re.match(data, i)
assert m
i = m.end(0)
head = m[0]
outfile.write(head)
elif c in b'([':
outfile.write(c)
i += 1
islist = c == b'[' and ']' != chr(data[i])
multiline = islist or head in (b'Project', b'Def', b'Goto', b'Call', b'Sub', b'Blk', b'Arg')
if multiline:
depth += 1
outfile.write(b'\n')
outfile.write(indent * depth)

stack.append(Context(i, flip[c], multiline))
elif c in b')]':
outfile.write(c)
i += 1

s = stack.pop()
if c != s.closer:
raise ValueError(f"mismatched bracket: {flip[s.closer]} at byte {s.begin} closed by {c} at {i}.")
if s.multiline:
depth -= 1
if not stack:
outfile.write(b'\n')
elif c == b'"':
string = string_re.match(data, i)
if not string:
raise ValueError(f"unclosed string beginning at byte {i+1}.")
outfile.write(string[0])
i = string.end(0)
else:
sys.stderr.buffer.write(b'\npreceding text:\n')
sys.stderr.buffer.write(data[max(0,i-100):i])
raise ValueError(f"unsupported @ {i} = {c}")

def preprocess(data: str) -> str:
"""
Preprocesses BAP ADT intrinsics (like Program, Subs, ...) into tuple syntax.
For example, Program(1, 2, 3) becomes ("Program", 1, 2, 3).
"""
heads_re = re.compile(f'({heads_joined})[(]')

data = heads_re.sub(lambda x: '(' + repr(x[1]) + ', ', data)
return data

class DoubleQuoteStr(str):
def __repr__(self):
# TODO: maybe make more robust?
r = super().__repr__()
r = r[1:-1]
r = r.replace(r"\'", r"'")
r = r.replace(r'"', r'\"')
return '"' + r + '"'

class UnderscoreInt(int):
def __repr__(self):
return f'{self:_}'

Exp = tuple | list | str | int
def clean(data: Exp) -> Exp:
"""
Intermediate step before formatting to tweak pprint's formatting.
This ensures we match BAP as close as possible with double-quoted strings
and underscores in Tid.
"""
if isinstance(data, tuple) and data[0] == 'Tid' and not isinstance(data[1], UnderscoreInt):
return clean((data[0], UnderscoreInt(data[1]), ) + data[2:])
if isinstance(data, str):
return DoubleQuoteStr(data)
if isinstance(data, (list, tuple)):
return data.__class__(map(clean, data))
return data

def postprocess(data: str) -> str:
"""
Postprocesses the formatted Python expression to restore the BAP-style intrinsics.
"""
heads_re2 = re.compile(f'\\("({heads_joined})",(\\s|\\))')

def replacement(x: re.Match) -> str:
head = x[1]
endc = x[2]
if endc not in ')\n':
endc = ''
return head + '(' + endc

data = heads_re2.sub(replacement, data)
return data
if stack:
closers = ''.join(chr(x.closer[0]) for x in reversed(stack))
log.warning(f"unclosed brackets. expected: '{closers}'. malformed adt?")


@contextlib.contextmanager
def measure_time(context: str):
log.info(f'starting {context}')
log.info(f'starting {context}', stacklevel=3)
start = time.perf_counter()
yield lambda: time.perf_counter() - start
log.debug(f'... done in {time.perf_counter() - start:.3f} seconds')
log.debug(f'... done in {time.perf_counter() - start:.3f} seconds', stacklevel=3)


def main(args):
infile = args.input
outfile = args.output
width = args.width
update = args.update
spaces = args.spaces

with measure_time('read'):
data = infile.read()
log.debug(f' read {len(data):,} characters')
infile.close()
log.debug(f' read {len(data):,} bytes')

out = data

with measure_time('preprocess'):
out = preprocess(out)

with measure_time('parse'):
out = ast.literal_eval(out)

with measure_time('clean'):
out = clean(out)
with measure_time('pprint'):
out = pprint.pformat(out, indent=width, underscore_numbers=False)

with measure_time('postprocess'):
out = postprocess(out)

with measure_time('output'):
if update:
infile.close()
with open(infile.name, 'w') as outfile:
outfile.write(out)
outfile.write('\n')
else:
outfile.write(out)
outfile.write('\n')
outfile.flush()
if len(data) > 5000000:
log.warning(f'large input of {len(data):,} bytes. formatting may be slow.')

outbuf = None
if update:
outfile = outbuf = io.BytesIO()

with measure_time('pretty + write' if not update else 'pretty'):
pretty(outfile, data, spaces)

if update:
with measure_time('write'), open(infile.name, 'wb') as outfile:
assert outbuf
outfile.write(outbuf.getbuffer())

if __name__ == '__main__':
logging.basicConfig(format='[%(asctime)s:%(name)s@%(filename)s:%(levelname)-7s]\t%(message)s')
logging.basicConfig(format='[%(asctime)s %(module)s:%(lineno)-3d %(levelname)-7s] %(message)s')

argp = argparse.ArgumentParser(description="pretty formats BAP ADT files.")
argp.add_argument('input', nargs='?', type=argparse.FileType('r'), default=sys.stdin,
argp.add_argument('input', nargs='?', type=argparse.FileType('rb'), default=sys.stdin.buffer,
help="input .adt file (default: stdin)")
excl = argp.add_mutually_exclusive_group()
excl.add_argument('output', nargs='?', type=argparse.FileType('w'), default=sys.stdout,
excl.add_argument('output', nargs='?', type=argparse.FileType('wb'), default=sys.stdout.buffer,
help="output file name (default: stdout)")

argp.add_argument('--width', '-w', default=1, type=int,
argp.add_argument('--spaces', '-s', default=1, type=int,
help="indent size in spaces (default: 1)")

excl.add_argument('--update', '-i', action='store_true',
Expand Down