Skip to content

Commit

Permalink
fix crash with trailing whitespace, add more error checking.
Browse files Browse the repository at this point in the history
  • Loading branch information
katrinafyi committed Oct 18, 2023
1 parent f6543dc commit c31b93f
Showing 1 changed file with 33 additions and 16 deletions.
49 changes: 33 additions & 16 deletions scripts/format_adt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
which only supports parsing of a literal Python expression.
"""

import io
import re
import ast
import sys
Expand Down Expand Up @@ -39,10 +40,10 @@
notspace_re = re.compile(rb'\S')
head_re = re.compile(rb'[^\s(]+')
num_re = re.compile(rb'[_0-9xa-fA-F]+')
string_re = re.compile(rb'"(?:[^"\\]|\\.)+"')
string_re = re.compile(rb'"(?:[^"\\]|\\.)*"')

@dataclasses.dataclass
class Elem:
class Context:
begin: int
closer: bytes
multiline: bool
Expand All @@ -56,21 +57,22 @@ class Elem:

def pretty(outfile, data: bytes, spaces: int):
# stack of expression beginning parentheses and their start position
stack: list[Elem] = []
stack: list[Context] = []
i = 0
depth = 0
head = b''

indent = b' ' * spaces

i0 = i - 1
while i < len(data):
length = len(data)
while i < length:
assert i0 != i
i0 = i

c = bytes((data[i],))
if c.isspace():
while data[i] in b' \t\r\n':
while i < length and data[i] in b' \t\r\n':
i += 1
elif c.isdigit():
m = num_re.match(data, i)
Expand Down Expand Up @@ -101,32 +103,40 @@ def pretty(outfile, data: bytes, spaces: int):
outfile.write(b'\n')
outfile.write(indent * depth)

stack.append(Elem(i, flip[c], multiline))
stack.append(Context(i, flip[c], multiline))
elif c in b')]':
outfile.write(c)
i += 1

s = stack.pop()
assert c == s.closer
if c != s.closer:
raise ValueError(f"mismatched bracket: {flip[s.closer]} at byte {s.begin} closed by {c} at {i}.")
if s.multiline:
depth -= 1
if not stack:
outfile.write(b'\n')
elif c == b'"':
string = string_re.match(data, i)
assert string
if not string:
raise ValueError(f"unclosed string beginning at byte {i+1}.")
outfile.write(string[0])
i = string.end(0)
else:
sys.stderr.buffer.write(b'\npreceding text:\n')
sys.stderr.buffer.write(data[max(0,i-100):i])
raise ValueError(f"unsupported @ {i} = {c}")

if stack:
closers = ''.join(chr(x.closer[0]) for x in reversed(stack))
log.warning(f"unclosed brackets. expected: '{closers}'. malformed adt?")


@contextlib.contextmanager
def measure_time(context: str):
log.info(f'starting {context}')
log.info(f'starting {context}', stacklevel=3)
start = time.perf_counter()
yield lambda: time.perf_counter() - start
log.debug(f'... done in {time.perf_counter() - start:.3f} seconds')
log.debug(f'... done in {time.perf_counter() - start:.3f} seconds', stacklevel=3)


def main(args):
Expand All @@ -138,18 +148,25 @@ def main(args):
with measure_time('read'):
data = infile.read()
infile.close()
log.debug(f' read {len(data):,} characters')
log.debug(f' read {len(data):,} bytes')

if len(data) > 5000000:
log.warning(f'large input of {len(data):,} bytes. formatting may be slow.')

outbuf = None
if update:
outfile = open(infile.name, 'wb')
outfile = outbuf = io.BytesIO()

with measure_time('pretty'):
out = pretty(outfile, data, spaces)
with measure_time('pretty + write' if not update else 'pretty'):
pretty(outfile, data, spaces)

outfile.close()
if update:
with measure_time('write'), open(infile.name, 'wb') as outfile:
assert outbuf
outfile.write(outbuf.getbuffer())

if __name__ == '__main__':
logging.basicConfig(format='[%(asctime)s:%(name)s@%(filename)s:%(levelname)-7s]\t%(message)s')
logging.basicConfig(format='[%(asctime)s %(module)s:%(lineno)-3d %(levelname)-7s] %(message)s')

argp = argparse.ArgumentParser(description="pretty formats BAP ADT files.")
argp.add_argument('input', nargs='?', type=argparse.FileType('rb'), default=sys.stdin.buffer,
Expand Down

0 comments on commit c31b93f

Please sign in to comment.