Skip to content

Commit

Permalink
emit offsets from pattern matchers
Browse files Browse the repository at this point in the history
  • Loading branch information
huettenhain committed Nov 28, 2024
1 parent 325bf84 commit 19d2ade
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions refinery/units/pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import re

from typing import Iterable, Optional, Callable, Union, ByteString, Dict
from typing import Iterable, Optional, Callable, Union, Tuple, ByteString, Dict
from itertools import islice
from hashlib import blake2b

Expand Down Expand Up @@ -50,18 +50,21 @@ def matches(self, data: ByteString, pattern: Union[ByteString, re.Pattern]):
if not isinstance(pattern, re.Pattern):
pattern = re.compile(pattern)
if self.args.ascii:
yield from pattern.finditer(data)
for match in pattern.finditer(data):
yield match.start(), match
if self.args.utf16:
for zm in re.finditer(BR'(.?)((?:.\0)+)', data, flags=re.DOTALL):
a, b = zm.span(2)
# Look one character further if there is evidence that this is UTF16-BE
b += bool(zm[1] and data[a])
yield from pattern.finditer(bytes(data[a:b:2]))
for match in pattern.finditer(bytes(data[a:b:2])):
start = a + match.start() * 2
yield start, match

def _prefilter(self, matches: Iterable[re.Match]) -> Iterable[re.Match]:
def _prefilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterable[Tuple[int, re.Match[bytes]]]:
barrier = set()
taken = 0
for match in matches:
for offset, match in matches:
hit = memoryview(match[0])
if not hit or len(hit) != self.args.len or len(hit) < self.args.min or len(hit) > self.args.max:
continue
Expand All @@ -70,28 +73,28 @@ def _prefilter(self, matches: Iterable[re.Match]) -> Iterable[re.Match]:
if uid in barrier:
continue
barrier.add(uid)
yield match
yield offset, match
taken += 1
if not self.args.longest and taken >= self.args.take:
break

def _postfilter(self, matches: Iterable[re.Match]) -> Iterable[re.Match]:
def _postfilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterable[Tuple[int, re.Match[bytes]]]:
result = matches
if self.args.longest and self.args.take and self.args.take is not INF:
try:
length = len(result)
except TypeError:
result = list(result)
length = len(result)
indices = sorted(range(length), key=lambda k: len(result[k][0]), reverse=True)
indices = sorted(range(length), key=lambda k: len(result[k][1][0]), reverse=True)
for k in sorted(islice(indices, abs(self.args.take))):
yield result[k]
elif self.args.longest:
yield from sorted(result, key=lambda m: m.end() - m.start(), reverse=True)
yield from sorted(result, key=lambda m: m[1].end() - m[1].start(), reverse=True)
elif self.args.take:
yield from islice(result, abs(self.args.take))

def matchfilter(self, matches: Iterable[re.Match]) -> Iterable[re.Match]:
def matchfilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterable[Tuple[int, re.Match[bytes]]]:
yield from self._postfilter(self._prefilter(matches))

def matches_filtered(
Expand All @@ -106,20 +109,17 @@ def matches_filtered(
dictionary mapping its position (start, end) in the input data to the
filtered and transformed match that was found at this position.
"""
def funcify(t):
def const(m): return t
return t if callable(t) else const

transforms = [funcify(f) for f in transforms] or [lambda m: m[0]]
transforms = [(f if callable(f) else lambda _: f) for f in transforms]
transforms = transforms or [lambda m: m[0]]

if self.args.stripspace:
data = re.sub(BR'\s+', B'', data)
for k, match in enumerate(self.matchfilter(self.matches(memoryview(data), pattern))):
for k, (offset, match) in enumerate(self.matchfilter(self.matches(memoryview(data), pattern))):
for transform in transforms:
t = transform(match)
if t is None:
continue
t = self.labelled(t)
t = self.labelled(t, offset=offset)
t.set_next_batch(k)
yield t

Expand Down

0 comments on commit 19d2ade

Please sign in to comment.