Skip to content

Commit

Permalink
Merge pull request #116 from UQ-PAC/format-adt
Browse files Browse the repository at this point in the history
format_adt: refactor to manually-written formatter
  • Loading branch information
l-kent authored Oct 19, 2023
2 parents fcd2f8e + c31b93f commit eaa98b7
Showing 1 changed file with 118 additions and 88 deletions.
206 changes: 118 additions & 88 deletions scripts/format_adt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
# vim: noai:ts=2:sw=2:expandtab
# vim: ts=2:sw=2:expandtab:autoindent

"""
format_adt.py implements pretty-printing of BAP .adt files
Expand All @@ -10,13 +10,15 @@
which only supports parsing of a literal Python expression.
"""

import io
import re
import ast
import sys
import time
import pprint
import logging
import argparse
import dataclasses
import contextlib

# grep -E "'([A-Z][a-zA-Z]+)'" src/main/antlr4/BAP_ADT.g4 --only-matching --no-filename | sort | uniq | xargs printf "'%s', " | fold -w80 -s
Expand All @@ -31,121 +33,149 @@
'Var', 'XOR',
]
heads_joined = '|'.join(heads)
heads_joined_b = heads_joined.encode('ascii')

log = logging.getLogger()

notspace_re = re.compile(rb'\S')
head_re = re.compile(rb'[^\s(]+')
num_re = re.compile(rb'[_0-9xa-fA-F]+')
string_re = re.compile(rb'"(?:[^"\\]|\\.)*"')

@dataclasses.dataclass
class Context:
begin: int
closer: bytes
multiline: bool

flip = {
b'(': b')',
b')': b'(',
b'[': b']',
b']': b'[',
}

def pretty(outfile, data: bytes, spaces: int):
# stack of expression beginning parentheses and their start position
stack: list[Context] = []
i = 0
depth = 0
head = b''

indent = b' ' * spaces

i0 = i - 1
length = len(data)
while i < length:
assert i0 != i
i0 = i

c = bytes((data[i],))
if c.isspace():
while i < length and data[i] in b' \t\r\n':
i += 1
elif c.isdigit():
m = num_re.match(data, i)
assert m
outfile.write(m[0])
i = m.end(0)
elif c == b',':
outfile.write(b',')
i += 1
if stack[-1].multiline:
outfile.write(b'\n')
outfile.write(indent * depth)
else:
outfile.write(b' ')
elif c.isupper():
m = head_re.match(data, i)
assert m
i = m.end(0)
head = m[0]
outfile.write(head)
elif c in b'([':
outfile.write(c)
i += 1
islist = c == b'[' and ']' != chr(data[i])
multiline = islist or head in (b'Project', b'Def', b'Goto', b'Call', b'Sub', b'Blk', b'Arg')
if multiline:
depth += 1
outfile.write(b'\n')
outfile.write(indent * depth)

stack.append(Context(i, flip[c], multiline))
elif c in b')]':
outfile.write(c)
i += 1

s = stack.pop()
if c != s.closer:
raise ValueError(f"mismatched bracket: {flip[s.closer]} at byte {s.begin} closed by {c} at {i}.")
if s.multiline:
depth -= 1
if not stack:
outfile.write(b'\n')
elif c == b'"':
string = string_re.match(data, i)
if not string:
raise ValueError(f"unclosed string beginning at byte {i+1}.")
outfile.write(string[0])
i = string.end(0)
else:
sys.stderr.buffer.write(b'\npreceding text:\n')
sys.stderr.buffer.write(data[max(0,i-100):i])
raise ValueError(f"unsupported @ {i} = {c}")

def preprocess(data: str) -> str:
"""
Preprocesses BAP ADT intrinsics (like Program, Subs, ...) into tuple syntax.
For example, Program(1, 2, 3) becomes ("Program", 1, 2, 3).
"""
heads_re = re.compile(f'({heads_joined})[(]')

data = heads_re.sub(lambda x: '(' + repr(x[1]) + ', ', data)
return data

class DoubleQuoteStr(str):
def __repr__(self):
# TODO: maybe make more robust?
r = super().__repr__()
r = r[1:-1]
r = r.replace(r"\'", r"'")
r = r.replace(r'"', r'\"')
return '"' + r + '"'

class UnderscoreInt(int):
def __repr__(self):
return f'{self:_}'

Exp = tuple | list | str | int
def clean(data: Exp) -> Exp:
"""
Intermediate step before formatting to tweak pprint's formatting.
This ensures we match BAP as close as possible with double-quoted strings
and underscores in Tid.
"""
if isinstance(data, tuple) and data[0] == 'Tid' and not isinstance(data[1], UnderscoreInt):
return clean((data[0], UnderscoreInt(data[1]), ) + data[2:])
if isinstance(data, str):
return DoubleQuoteStr(data)
if isinstance(data, (list, tuple)):
return data.__class__(map(clean, data))
return data

def postprocess(data: str) -> str:
"""
Postprocesses the formatted Python expression to restore the BAP-style intrinsics.
"""
heads_re2 = re.compile(f'\\("({heads_joined})",(\\s|\\))')

def replacement(x: re.Match) -> str:
head = x[1]
endc = x[2]
if endc not in ')\n':
endc = ''
return head + '(' + endc

data = heads_re2.sub(replacement, data)
return data
if stack:
closers = ''.join(chr(x.closer[0]) for x in reversed(stack))
log.warning(f"unclosed brackets. expected: '{closers}'. malformed adt?")


@contextlib.contextmanager
def measure_time(context: str):
log.info(f'starting {context}')
log.info(f'starting {context}', stacklevel=3)
start = time.perf_counter()
yield lambda: time.perf_counter() - start
log.debug(f'... done in {time.perf_counter() - start:.3f} seconds')
log.debug(f'... done in {time.perf_counter() - start:.3f} seconds', stacklevel=3)


def main(args):
infile = args.input
outfile = args.output
width = args.width
update = args.update
spaces = args.spaces

with measure_time('read'):
data = infile.read()
log.debug(f' read {len(data):,} characters')
infile.close()
log.debug(f' read {len(data):,} bytes')

out = data

with measure_time('preprocess'):
out = preprocess(out)

with measure_time('parse'):
out = ast.literal_eval(out)

with measure_time('clean'):
out = clean(out)
with measure_time('pprint'):
out = pprint.pformat(out, indent=width, underscore_numbers=False)

with measure_time('postprocess'):
out = postprocess(out)

with measure_time('output'):
if update:
infile.close()
with open(infile.name, 'w') as outfile:
outfile.write(out)
outfile.write('\n')
else:
outfile.write(out)
outfile.write('\n')
outfile.flush()
if len(data) > 5000000:
log.warning(f'large input of {len(data):,} bytes. formatting may be slow.')

outbuf = None
if update:
outfile = outbuf = io.BytesIO()

with measure_time('pretty + write' if not update else 'pretty'):
pretty(outfile, data, spaces)

if update:
with measure_time('write'), open(infile.name, 'wb') as outfile:
assert outbuf
outfile.write(outbuf.getbuffer())

if __name__ == '__main__':
logging.basicConfig(format='[%(asctime)s:%(name)s@%(filename)s:%(levelname)-7s]\t%(message)s')
logging.basicConfig(format='[%(asctime)s %(module)s:%(lineno)-3d %(levelname)-7s] %(message)s')

argp = argparse.ArgumentParser(description="pretty formats BAP ADT files.")
argp.add_argument('input', nargs='?', type=argparse.FileType('r'), default=sys.stdin,
argp.add_argument('input', nargs='?', type=argparse.FileType('rb'), default=sys.stdin.buffer,
help="input .adt file (default: stdin)")
excl = argp.add_mutually_exclusive_group()
excl.add_argument('output', nargs='?', type=argparse.FileType('w'), default=sys.stdout,
excl.add_argument('output', nargs='?', type=argparse.FileType('wb'), default=sys.stdout.buffer,
help="output file name (default: stdout)")

argp.add_argument('--width', '-w', default=1, type=int,
argp.add_argument('--spaces', '-s', default=1, type=int,
help="indent size in spaces (default: 1)")

excl.add_argument('--update', '-i', action='store_true',
Expand Down

0 comments on commit eaa98b7

Please sign in to comment.