diff --git a/scripts/format_adt.py b/scripts/format_adt.py index 7b18888b7..3a9da6a74 100755 --- a/scripts/format_adt.py +++ b/scripts/format_adt.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# vim: noai:ts=2:sw=2:expandtab +# vim: ts=2:sw=2:expandtab:autoindent """ format_adt.py implements pretty-printing of BAP .adt files @@ -10,6 +10,7 @@ which only supports parsing of a literal Python expression. """ +import io import re import ast import sys @@ -17,6 +18,7 @@ import pprint import logging import argparse +import dataclasses import contextlib # grep -E "'([A-Z][a-zA-Z]+)'" src/main/antlr4/BAP_ADT.g4 --only-matching --no-filename | sort | uniq | xargs printf "'%s', " | fold -w80 -s @@ -31,121 +33,149 @@ 'Var', 'XOR', ] heads_joined = '|'.join(heads) +heads_joined_b = heads_joined.encode('ascii') log = logging.getLogger() +notspace_re = re.compile(rb'\S') +head_re = re.compile(rb'[^\s(]+') +num_re = re.compile(rb'[_0-9xa-fA-F]+') +string_re = re.compile(rb'"(?:[^"\\]|\\.)*"') + +@dataclasses.dataclass +class Context: + begin: int + closer: bytes + multiline: bool + +flip = { + b'(': b')', + b')': b'(', + b'[': b']', + b']': b'[', +} + +def pretty(outfile, data: bytes, spaces: int): + # stack of expression beginning parentheses and their start position + stack: list[Context] = [] + i = 0 + depth = 0 + head = b'' + + indent = b' ' * spaces + + i0 = i - 1 + length = len(data) + while i < length: + assert i0 != i + i0 = i + + c = bytes((data[i],)) + if c.isspace(): + while i < length and data[i] in b' \t\r\n': + i += 1 + elif c.isdigit(): + m = num_re.match(data, i) + assert m + outfile.write(m[0]) + i = m.end(0) + elif c == b',': + outfile.write(b',') + i += 1 + if stack[-1].multiline: + outfile.write(b'\n') + outfile.write(indent * depth) + else: + outfile.write(b' ') + elif c.isupper(): + m = head_re.match(data, i) + assert m + i = m.end(0) + head = m[0] + outfile.write(head) + elif c in b'([': + outfile.write(c) + i += 1 + islist = c == b'[' and ']' != chr(data[i]) + multiline = islist or head in (b'Project', b'Def', b'Goto', b'Call', b'Sub', b'Blk', b'Arg') + if multiline: + depth += 1 + outfile.write(b'\n') + outfile.write(indent * depth) + + stack.append(Context(i, flip[c], multiline)) + elif c in b')]': + outfile.write(c) + i += 1 + + s = stack.pop() + if c != s.closer: + raise ValueError(f"mismatched bracket: {flip[s.closer]} at byte {s.begin} closed by {c} at {i}.") + if s.multiline: + depth -= 1 + if not stack: + outfile.write(b'\n') + elif c == b'"': + string = string_re.match(data, i) + if not string: + raise ValueError(f"unclosed string beginning at byte {i+1}.") + outfile.write(string[0]) + i = string.end(0) + else: + sys.stderr.buffer.write(b'\npreceding text:\n') + sys.stderr.buffer.write(data[max(0,i-100):i]) + raise ValueError(f"unsupported @ {i} = {c}") -def preprocess(data: str) -> str: - """ - Preprocesses BAP ADT intrinsics (like Program, Subs, ...) into tuple syntax. - For example, Program(1, 2, 3) becomes ("Program", 1, 2, 3). - """ - heads_re = re.compile(f'({heads_joined})[(]') - - data = heads_re.sub(lambda x: '(' + repr(x[1]) + ', ', data) - return data - -class DoubleQuoteStr(str): - def __repr__(self): - # TODO: maybe make more robust? - r = super().__repr__() - r = r[1:-1] - r = r.replace(r"\'", r"'") - r = r.replace(r'"', r'\"') - return '"' + r + '"' - -class UnderscoreInt(int): - def __repr__(self): - return f'{self:_}' - -Exp = tuple | list | str | int -def clean(data: Exp) -> Exp: - """ - Intermediate step before formatting to tweak pprint's formatting. - This ensures we match BAP as close as possible with double-quoted strings - and underscores in Tid. - """ - if isinstance(data, tuple) and data[0] == 'Tid' and not isinstance(data[1], UnderscoreInt): - return clean((data[0], UnderscoreInt(data[1]), ) + data[2:]) - if isinstance(data, str): - return DoubleQuoteStr(data) - if isinstance(data, (list, tuple)): - return data.__class__(map(clean, data)) - return data - -def postprocess(data: str) -> str: - """ - Postprocesses the formatted Python expression to restore the BAP-style intrinsics. - """ - heads_re2 = re.compile(f'\\("({heads_joined})",(\\s|\\))') - - def replacement(x: re.Match) -> str: - head = x[1] - endc = x[2] - if endc not in ')\n': - endc = '' - return head + '(' + endc - - data = heads_re2.sub(replacement, data) - return data + if stack: + closers = ''.join(chr(x.closer[0]) for x in reversed(stack)) + log.warning(f"unclosed brackets. expected: '{closers}'. malformed adt?") @contextlib.contextmanager def measure_time(context: str): - log.info(f'starting {context}') + log.info(f'starting {context}', stacklevel=3) start = time.perf_counter() yield lambda: time.perf_counter() - start - log.debug(f'... done in {time.perf_counter() - start:.3f} seconds') + log.debug(f'... done in {time.perf_counter() - start:.3f} seconds', stacklevel=3) def main(args): infile = args.input outfile = args.output - width = args.width update = args.update + spaces = args.spaces with measure_time('read'): data = infile.read() - log.debug(f' read {len(data):,} characters') + infile.close() + log.debug(f' read {len(data):,} bytes') - out = data - - with measure_time('preprocess'): - out = preprocess(out) - - with measure_time('parse'): - out = ast.literal_eval(out) - - with measure_time('clean'): - out = clean(out) - with measure_time('pprint'): - out = pprint.pformat(out, indent=width, underscore_numbers=False) - - with measure_time('postprocess'): - out = postprocess(out) - - with measure_time('output'): - if update: - infile.close() - with open(infile.name, 'w') as outfile: - outfile.write(out) - outfile.write('\n') - else: - outfile.write(out) - outfile.write('\n') - outfile.flush() + if len(data) > 5000000: + log.warning(f'large input of {len(data):,} bytes. formatting may be slow.') + + outbuf = None + if update: + outfile = outbuf = io.BytesIO() + + with measure_time('pretty + write' if not update else 'pretty'): + pretty(outfile, data, spaces) + + if update: + with measure_time('write'), open(infile.name, 'wb') as outfile: + assert outbuf + outfile.write(outbuf.getbuffer()) if __name__ == '__main__': - logging.basicConfig(format='[%(asctime)s:%(name)s@%(filename)s:%(levelname)-7s]\t%(message)s') + logging.basicConfig(format='[%(asctime)s %(module)s:%(lineno)-3d %(levelname)-7s] %(message)s') argp = argparse.ArgumentParser(description="pretty formats BAP ADT files.") - argp.add_argument('input', nargs='?', type=argparse.FileType('r'), default=sys.stdin, + argp.add_argument('input', nargs='?', type=argparse.FileType('rb'), default=sys.stdin.buffer, help="input .adt file (default: stdin)") excl = argp.add_mutually_exclusive_group() - excl.add_argument('output', nargs='?', type=argparse.FileType('w'), default=sys.stdout, + excl.add_argument('output', nargs='?', type=argparse.FileType('wb'), default=sys.stdout.buffer, help="output file name (default: stdout)") - argp.add_argument('--width', '-w', default=1, type=int, + argp.add_argument('--spaces', '-s', default=1, type=int, help="indent size in spaces (default: 1)") excl.add_argument('--update', '-i', action='store_true',