Skip to content

Commit

Permalink
Merge pull request #109 from UQ-PAC/format-adt
Browse files Browse the repository at this point in the history
add .adt pretty-printer in format_adt.py
  • Loading branch information
l-kent authored Oct 17, 2023
2 parents a3dbcae + 8c33194 commit 8c199a0
Showing 1 changed file with 174 additions and 0 deletions.
174 changes: 174 additions & 0 deletions scripts/format_adt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
#!/usr/bin/env python3
# vim: noai:ts=2:sw=2:expandtab

"""
format_adt.py implements pretty-printing of BAP .adt files
by translating the ADT into Python syntax, then parsing and
formatting the python.
Although this eval()s, it is made safe by using ast.literal_eval
which only supports parsing of a literal Python expression.
"""

import re
import ast
import sys
import time
import pprint
import logging
import argparse
import contextlib

# grep -E "'([A-Z][a-zA-Z]+)'" src/main/antlr4/BAP_ADT.g4 --only-matching --no-filename | sort | uniq | xargs printf "'%s', " | fold -w80 -s
heads = [
'AND', 'Annotation', 'Arg', 'Args', 'ARSHIFT', 'Attr', 'Attrs', 'BigEndian',
'Blk', 'Blks', 'Both', 'Call', 'Concat', 'Def', 'Defs', 'Direct', 'DIVIDE',
'EQ', 'Extract', 'Goto', 'HIGH', 'Imm', 'In', 'Indirect', 'Int', 'Jmp', 'Jmps',
'LE', 'LittleEndian', 'Load', 'LOW', 'LSHIFT', 'LT', 'Mem', 'Memmap', 'MINUS',
'MOD', 'NEG', 'NEQ', 'NOT', 'OR', 'Out', 'Phi', 'Phis', 'PLUS', 'Program',
'Project', 'Region', 'RSHIFT', 'SDIVIDE', 'Section', 'Sections', 'SIGNED',
'SLE', 'SLT', 'SMOD', 'Store', 'Sub', 'Subs', 'Tid', 'TIMES', 'UNSIGNED',
'Var', 'XOR',
]
heads_joined = '|'.join(heads)

log = logging.getLogger()


def preprocess(data: str) -> str:
"""
Preprocesses BAP ADT intrinsics (like Program, Subs, ...) into tuple syntax.
For example, Program(1, 2, 3) becomes ("Program", 1, 2, 3).
"""
heads_re = re.compile(f'({heads_joined})[(]')

data = heads_re.sub(lambda x: '(' + repr(x[1]) + ', ', data)
return data

class DoubleQuoteStr(str):
def __repr__(self):
# TODO: maybe make more robust?
r = super().__repr__()
r = r[1:-1]
r = r.replace(r"\'", r"'")
r = r.replace(r'"', r'\"')
return '"' + r + '"'

class UnderscoreInt(int):
def __repr__(self):
return f'{self:_}'

Exp = tuple | list | str | int
def clean(data: Exp) -> Exp:
"""
Intermediate step before formatting to tweak pprint's formatting.
This ensures we match BAP as close as possible with double-quoted strings
and underscores in Tid.
"""
if isinstance(data, tuple) and data[0] == 'Tid' and not isinstance(data[1], UnderscoreInt):
return clean((data[0], UnderscoreInt(data[1]), ) + data[2:])
if isinstance(data, str):
return DoubleQuoteStr(data)
if isinstance(data, (list, tuple)):
return data.__class__(map(clean, data))
return data

def postprocess(data: str) -> str:
"""
Postprocesses the formatted Python expression to restore the BAP-style intrinsics.
"""
heads_re2 = re.compile(f'\\("({heads_joined})",(\\s|\\))')

def replacement(x: re.Match) -> str:
head = x[1]
endc = x[2]
if endc not in ')\n':
endc = ''
return head + '(' + endc

data = heads_re2.sub(replacement, data)
return data


@contextlib.contextmanager
def measure_time(context: str):
log.info(f'starting {context}')
start = time.perf_counter()
yield lambda: time.perf_counter() - start
log.debug(f'... done in {time.perf_counter() - start:.3f} seconds')


def main(args):
infile = args.input
outfile = args.output
width = args.width
update = args.update

with measure_time('read'):
data = infile.read()
log.debug(f' read {len(data):,} characters')

out = data

with measure_time('preprocess'):
out = preprocess(out)

with measure_time('parse'):
out = ast.literal_eval(out)

with measure_time('clean'):
out = clean(out)
with measure_time('pprint'):
out = pprint.pformat(out, indent=width, underscore_numbers=False)

with measure_time('postprocess'):
out = postprocess(out)

with measure_time('output'):
if update:
infile.close()
with open(infile.name, 'w') as outfile:
outfile.write(out)
outfile.write('\n')
else:
outfile.write(out)
outfile.write('\n')
outfile.flush()

if __name__ == '__main__':
logging.basicConfig(format='[%(asctime)s:%(name)s@%(filename)s:%(levelname)-7s]\t%(message)s')

argp = argparse.ArgumentParser(description="pretty formats BAP ADT files.")
argp.add_argument('input', nargs='?', type=argparse.FileType('r'), default=sys.stdin,
help="input .adt file (default: stdin)")
excl = argp.add_mutually_exclusive_group()
excl.add_argument('output', nargs='?', type=argparse.FileType('w'), default=sys.stdout,
help="output file name (default: stdout)")

argp.add_argument('--width', '-w', default=1, type=int,
help="indent size in spaces (default: 1)")

excl.add_argument('--update', '-i', action='store_true',
help="write output back to the input file (default: false)")

argp.add_argument('--verbose', '-v', action='count', default=0,
help="print logging output to stderr (default: no, repeatable)")

args = argp.parse_args()

if args.input is sys.stdin and args.update:
argp.error('argument --update/-i: not allowed with stdin input')

if args.verbose == 0:
level = logging.WARN
elif args.verbose == 1:
level = logging.INFO
else:
level = logging.DEBUG
log.setLevel(level)

log.debug(str(args))

with measure_time('format_adt.main'):
main(args)

0 comments on commit 8c199a0

Please sign in to comment.