diff --git a/scripts/format_adt.py b/scripts/format_adt.py index 638cf5df7..4d6028f57 100755 --- a/scripts/format_adt.py +++ b/scripts/format_adt.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# vim: noai:ts=2:sw=2:expandtab +# vim: ts=2:sw=2:expandtab:autoindent """ format_adt.py implements pretty-printing of BAP .adt files @@ -17,6 +17,7 @@ import pprint import logging import argparse +import dataclasses import contextlib # grep -E "'([A-Z][a-zA-Z]+)'" src/main/antlr4/BAP_ADT.g4 --only-matching --no-filename | sort | uniq | xargs printf "'%s', " | fold -w80 -s @@ -35,68 +36,89 @@ log = logging.getLogger() - -def preprocess(data: bytes) -> bytes: - """ - Preprocesses BAP ADT intrinsics (like Program, Subs, ...) into tuple syntax. - For example, Program(1, 2, 3) becomes ("Program", 1, 2, 3). - """ - heads_re = re.compile(b'(' + heads_joined_b + b')[(]') - - data = heads_re.sub(lambda x: b'("' + x[1] + b'", ', data) - return data - -class DoubleQuoteStr(str): - def __repr__(self): - # TODO: maybe make more robust? - r = super().__repr__() - r = r[1:-1] - r = r.replace(r"\'", r"'") - r = r.replace(r'"', r'\"') - return '"' + r + '"' - -class ByteStr(str): - def __repr__(self): - bytes = self.encode('utf-8') - return '"' + ''.join(f'\\x{b:02x}' for b in bytes) + '"' - -class UnderscoreInt(int): - def __repr__(self): - return f'{self:_}' - -Exp = tuple | list | str | bytes | int -def clean(data: Exp) -> Exp: - """ - Intermediate step before formatting to tweak pprint's formatting. - This ensures we match BAP as close as possible with double-quoted strings - and underscores in Tid. - """ - if isinstance(data, tuple): - if data[0] == 'Tid' and not isinstance(data[1], UnderscoreInt): - return clean((data[0], UnderscoreInt(data[1]), ) + data[2:]) - if data[0] == 'Section' and not isinstance(data[3], ByteStr): - return clean(data[:3] + (ByteStr(data[3]), ) + data[4:]) - if isinstance(data, str) and type(data) is str: - return DoubleQuoteStr(data) - if isinstance(data, (list, tuple)): - return type(data)(map(clean, data)) - return data - -def postprocess(data: bytes) -> bytes: - """ - Postprocesses the formatted Python expression to restore the BAP-style intrinsics. - """ - heads_re2 = re.compile(rb'\("(' + heads_joined_b + rb')",(\s|\))') - - def replacement(x: re.Match) -> bytes: - head = x[1] - endc = x[2] - if endc not in b')\n': - endc = b'' - return head + b'(' + endc - - data = heads_re2.sub(replacement, data) - return data +notspace_re = re.compile(rb'\S') +head_re = re.compile(rb'[^\s(]+') +num_re = re.compile(rb'[_0-9xa-fA-F]+') +string_re = re.compile(rb'"(?:[^"\\]|\\.)+"') + +@dataclasses.dataclass +class Elem: + begin: int + closer: bytes + multiline: bool + +flip = { + b'(': b')', + b')': b'(', + b'[': b']', + b']': b'[', +} + +def pretty(outfile, data: bytes, spaces: int): + # stack of expression beginning parentheses and their start position + stack: list[Elem] = [] + i = 0 + depth = 0 + head = b'' + + indent = b' ' * spaces + + i0 = i - 1 + while i < len(data): + assert i0 != i + i0 = i + + c = bytes((data[i],)) + if c.isspace(): + while data[i] in b' \t\r\n': + i += 1 + elif c.isdigit(): + m = num_re.match(data, i) + assert m + outfile.write(m[0]) + i = m.end(0) + elif c == b',': + outfile.write(b',') + i += 1 + if stack[-1].multiline: + outfile.write(b'\n') + outfile.write(indent * depth) + else: + outfile.write(b' ') + elif c.isupper(): + m = head_re.match(data, i) + assert m + i = m.end(0) + head = m[0] + outfile.write(head) + elif c in b'([': + outfile.write(c) + i += 1 + islist = c == b'[' and ']' != chr(data[i]) + multiline = islist or head in (b'Project', b'Def', b'Goto', b'Call', b'Sub', b'Blk', b'Arg') + if multiline: + depth += 1 + outfile.write(b'\n') + outfile.write(indent * depth) + + stack.append(Elem(i, flip[c], multiline)) + elif c in b')]': + outfile.write(c) + i += 1 + + s = stack.pop() + assert c == s.closer + if s.multiline: + depth -= 1 + elif c == b'"': + string = string_re.match(data, i) + assert string + outfile.write(string[0]) + i = string.end(0) + else: + sys.stderr.buffer.write(b'\npreceding text:\n') + sys.stderr.buffer.write(data[max(0,i-100):i]) + raise ValueError(f"unsupported @ {i} = {c}") @contextlib.contextmanager @@ -110,43 +132,21 @@ def measure_time(context: str): def main(args): infile = args.input outfile = args.output - width = args.width update = args.update + spaces = args.spaces with measure_time('read'): data = infile.read() + infile.close() log.debug(f' read {len(data):,} characters') - out = data - - with measure_time('preprocess'): - out = preprocess(out) - - with measure_time('parse'): - out = compile(out, infile.name, 'eval', ast.PyCF_ONLY_AST) - - with measure_time('eval'): - out = ast.literal_eval(out) - - with measure_time('preprint'): - out = clean(out) - with measure_time('pprint'): - out = pprint.pformat(out, indent=width, underscore_numbers=False) - out = out.encode('ascii') - - with measure_time('postprocess'): - out = postprocess(out) - - with measure_time('output'): - if update: - infile.close() - with open(infile.name, 'wb') as outfile: - outfile.write(out) - outfile.write(b'\n') - else: - outfile.write(out) - outfile.write(b'\n') - outfile.flush() + if update: + outfile = open(infile.name, 'wb') + + with measure_time('pretty'): + out = pretty(outfile, data, spaces) + + outfile.close() if __name__ == '__main__': logging.basicConfig(format='[%(asctime)s:%(name)s@%(filename)s:%(levelname)-7s]\t%(message)s') @@ -158,7 +158,7 @@ def main(args): excl.add_argument('output', nargs='?', type=argparse.FileType('wb'), default=sys.stdout.buffer, help="output file name (default: stdout)") - argp.add_argument('--width', '-w', default=1, type=int, + argp.add_argument('--spaces', '-s', default=1, type=int, help="indent size in spaces (default: 1)") excl.add_argument('--update', '-i', action='store_true',