Skip to content

Commit

Permalink
format_adt: handle memory bytes correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
katrinafyi committed Oct 17, 2023
1 parent 8c199a0 commit b82783d
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions scripts/format_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,19 @@
'Var', 'XOR',
]
heads_joined = '|'.join(heads)
heads_joined_b = heads_joined.encode('ascii')

log = logging.getLogger()


def preprocess(data: str) -> str:
def preprocess(data: bytes) -> bytes:
"""
Preprocesses BAP ADT intrinsics (like Program, Subs, ...) into tuple syntax.
For example, Program(1, 2, 3) becomes ("Program", 1, 2, 3).
"""
heads_re = re.compile(f'({heads_joined})[(]')
heads_re = re.compile(b'(' + heads_joined_b + b')[(]')

data = heads_re.sub(lambda x: '(' + repr(x[1]) + ', ', data)
data = heads_re.sub(lambda x: b'("' + x[1] + b'", ', data)
return data

class DoubleQuoteStr(str):
Expand All @@ -54,37 +55,45 @@ def __repr__(self):
r = r.replace(r'"', r'\"')
return '"' + r + '"'

class ByteStr(str):
def __repr__(self):
bytes = self.encode('utf-8')
return '"' + ''.join(f'\\x{b:02x}' for b in bytes) + '"'

class UnderscoreInt(int):
def __repr__(self):
return f'{self:_}'

Exp = tuple | list | str | int
Exp = tuple | list | str | bytes | int
def clean(data: Exp) -> Exp:
"""
Intermediate step before formatting to tweak pprint's formatting.
This ensures we match BAP as close as possible with double-quoted strings
and underscores in Tid.
"""
if isinstance(data, tuple) and data[0] == 'Tid' and not isinstance(data[1], UnderscoreInt):
return clean((data[0], UnderscoreInt(data[1]), ) + data[2:])
if isinstance(data, str):
if isinstance(data, tuple):
if data[0] == 'Tid' and not isinstance(data[1], UnderscoreInt):
return clean((data[0], UnderscoreInt(data[1]), ) + data[2:])
if data[0] == 'Section' and not isinstance(data[3], ByteStr):
return clean(data[:3] + (ByteStr(data[3]), ) + data[4:])
if isinstance(data, str) and type(data) is str:
return DoubleQuoteStr(data)
if isinstance(data, (list, tuple)):
return data.__class__(map(clean, data))
return type(data)(map(clean, data))
return data

def postprocess(data: str) -> str:
def postprocess(data: bytes) -> bytes:
"""
Postprocesses the formatted Python expression to restore the BAP-style intrinsics.
"""
heads_re2 = re.compile(f'\\("({heads_joined})",(\\s|\\))')
heads_re2 = re.compile(rb'\("(' + heads_joined_b + rb')",(\s|\))')

def replacement(x: re.Match) -> str:
def replacement(x: re.Match) -> bytes:
head = x[1]
endc = x[2]
if endc not in ')\n':
endc = ''
return head + '(' + endc
if endc not in b')\n':
endc = b''
return head + b'(' + endc

data = heads_re2.sub(replacement, data)
return data
Expand Down Expand Up @@ -114,35 +123,39 @@ def main(args):
out = preprocess(out)

with measure_time('parse'):
out = compile(out, infile.name, 'eval', ast.PyCF_ONLY_AST)

with measure_time('eval'):
out = ast.literal_eval(out)

with measure_time('clean'):
with measure_time('preprint'):
out = clean(out)
with measure_time('pprint'):
out = pprint.pformat(out, indent=width, underscore_numbers=False)
out = out.encode('ascii')

with measure_time('postprocess'):
out = postprocess(out)

with measure_time('output'):
if update:
infile.close()
with open(infile.name, 'w') as outfile:
with open(infile.name, 'wb') as outfile:
outfile.write(out)
outfile.write('\n')
outfile.write(b'\n')
else:
outfile.write(out)
outfile.write('\n')
outfile.write(b'\n')
outfile.flush()

if __name__ == '__main__':
logging.basicConfig(format='[%(asctime)s:%(name)s@%(filename)s:%(levelname)-7s]\t%(message)s')

argp = argparse.ArgumentParser(description="pretty formats BAP ADT files.")
argp.add_argument('input', nargs='?', type=argparse.FileType('r'), default=sys.stdin,
argp.add_argument('input', nargs='?', type=argparse.FileType('rb'), default=sys.stdin.buffer,
help="input .adt file (default: stdin)")
excl = argp.add_mutually_exclusive_group()
excl.add_argument('output', nargs='?', type=argparse.FileType('w'), default=sys.stdout,
excl.add_argument('output', nargs='?', type=argparse.FileType('wb'), default=sys.stdout.buffer,
help="output file name (default: stdout)")

argp.add_argument('--width', '-w', default=1, type=int,
Expand Down

0 comments on commit b82783d

Please sign in to comment.