From c31b93f027ccec3ba03eab0306f055b6f61f91df Mon Sep 17 00:00:00 2001 From: rina Date: Wed, 18 Oct 2023 20:07:54 +1000 Subject: [PATCH] fix crash with trailing whitespace, add more error checking. --- scripts/format_adt.py | 49 +++++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/scripts/format_adt.py b/scripts/format_adt.py index 4d6028f57..3a9da6a74 100755 --- a/scripts/format_adt.py +++ b/scripts/format_adt.py @@ -10,6 +10,7 @@ which only supports parsing of a literal Python expression. """ +import io import re import ast import sys @@ -39,10 +40,10 @@ notspace_re = re.compile(rb'\S') head_re = re.compile(rb'[^\s(]+') num_re = re.compile(rb'[_0-9xa-fA-F]+') -string_re = re.compile(rb'"(?:[^"\\]|\\.)+"') +string_re = re.compile(rb'"(?:[^"\\]|\\.)*"') @dataclasses.dataclass -class Elem: +class Context: begin: int closer: bytes multiline: bool @@ -56,7 +57,7 @@ class Elem: def pretty(outfile, data: bytes, spaces: int): # stack of expression beginning parentheses and their start position - stack: list[Elem] = [] + stack: list[Context] = [] i = 0 depth = 0 head = b'' @@ -64,13 +65,14 @@ def pretty(outfile, data: bytes, spaces: int): indent = b' ' * spaces i0 = i - 1 - while i < len(data): + length = len(data) + while i < length: assert i0 != i i0 = i c = bytes((data[i],)) if c.isspace(): - while data[i] in b' \t\r\n': + while i < length and data[i] in b' \t\r\n': i += 1 elif c.isdigit(): m = num_re.match(data, i) @@ -101,18 +103,22 @@ def pretty(outfile, data: bytes, spaces: int): outfile.write(b'\n') outfile.write(indent * depth) - stack.append(Elem(i, flip[c], multiline)) + stack.append(Context(i, flip[c], multiline)) elif c in b')]': outfile.write(c) i += 1 s = stack.pop() - assert c == s.closer + if c != s.closer: + raise ValueError(f"mismatched bracket: {flip[s.closer]} at byte {s.begin} closed by {c} at {i}.") if s.multiline: depth -= 1 + if not stack: + outfile.write(b'\n') elif c == b'"': string = string_re.match(data, i) - assert string + if not string: + raise ValueError(f"unclosed string beginning at byte {i+1}.") outfile.write(string[0]) i = string.end(0) else: @@ -120,13 +126,17 @@ def pretty(outfile, data: bytes, spaces: int): sys.stderr.buffer.write(data[max(0,i-100):i]) raise ValueError(f"unsupported @ {i} = {c}") + if stack: + closers = ''.join(chr(x.closer[0]) for x in reversed(stack)) + log.warning(f"unclosed brackets. expected: '{closers}'. malformed adt?") + @contextlib.contextmanager def measure_time(context: str): - log.info(f'starting {context}') + log.info(f'starting {context}', stacklevel=3) start = time.perf_counter() yield lambda: time.perf_counter() - start - log.debug(f'... done in {time.perf_counter() - start:.3f} seconds') + log.debug(f'... done in {time.perf_counter() - start:.3f} seconds', stacklevel=3) def main(args): @@ -138,18 +148,25 @@ def main(args): with measure_time('read'): data = infile.read() infile.close() - log.debug(f' read {len(data):,} characters') + log.debug(f' read {len(data):,} bytes') + if len(data) > 5000000: + log.warning(f'large input of {len(data):,} bytes. formatting may be slow.') + + outbuf = None if update: - outfile = open(infile.name, 'wb') + outfile = outbuf = io.BytesIO() - with measure_time('pretty'): - out = pretty(outfile, data, spaces) + with measure_time('pretty + write' if not update else 'pretty'): + pretty(outfile, data, spaces) - outfile.close() + if update: + with measure_time('write'), open(infile.name, 'wb') as outfile: + assert outbuf + outfile.write(outbuf.getbuffer()) if __name__ == '__main__': - logging.basicConfig(format='[%(asctime)s:%(name)s@%(filename)s:%(levelname)-7s]\t%(message)s') + logging.basicConfig(format='[%(asctime)s %(module)s:%(lineno)-3d %(levelname)-7s] %(message)s') argp = argparse.ArgumentParser(description="pretty formats BAP ADT files.") argp.add_argument('input', nargs='?', type=argparse.FileType('rb'), default=sys.stdin.buffer,