From 19d2aded7b2201f3b988987d92104591d3ee61c1 Mon Sep 17 00:00:00 2001 From: jesko Date: Thu, 28 Nov 2024 16:32:47 +0100 Subject: [PATCH] emit offsets from pattern matchers --- refinery/units/pattern/__init__.py | 34 +++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/refinery/units/pattern/__init__.py b/refinery/units/pattern/__init__.py index f690522c2b..83d859e2a6 100644 --- a/refinery/units/pattern/__init__.py +++ b/refinery/units/pattern/__init__.py @@ -5,7 +5,7 @@ """ import re -from typing import Iterable, Optional, Callable, Union, ByteString, Dict +from typing import Iterable, Optional, Callable, Union, Tuple, ByteString, Dict from itertools import islice from hashlib import blake2b @@ -50,18 +50,21 @@ def matches(self, data: ByteString, pattern: Union[ByteString, re.Pattern]): if not isinstance(pattern, re.Pattern): pattern = re.compile(pattern) if self.args.ascii: - yield from pattern.finditer(data) + for match in pattern.finditer(data): + yield match.start(), match if self.args.utf16: for zm in re.finditer(BR'(.?)((?:.\0)+)', data, flags=re.DOTALL): a, b = zm.span(2) # Look one character further if there is evidence that this is UTF16-BE b += bool(zm[1] and data[a]) - yield from pattern.finditer(bytes(data[a:b:2])) + for match in pattern.finditer(bytes(data[a:b:2])): + start = a + match.start() * 2 + yield start, match - def _prefilter(self, matches: Iterable[re.Match]) -> Iterable[re.Match]: + def _prefilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterable[Tuple[int, re.Match[bytes]]]: barrier = set() taken = 0 - for match in matches: + for offset, match in matches: hit = memoryview(match[0]) if not hit or len(hit) != self.args.len or len(hit) < self.args.min or len(hit) > self.args.max: continue @@ -70,12 +73,12 @@ def _prefilter(self, matches: Iterable[re.Match]) -> Iterable[re.Match]: if uid in barrier: continue barrier.add(uid) - yield match + yield offset, match taken += 1 if not self.args.longest and taken >= self.args.take: break - def _postfilter(self, matches: Iterable[re.Match]) -> Iterable[re.Match]: + def _postfilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterable[Tuple[int, re.Match[bytes]]]: result = matches if self.args.longest and self.args.take and self.args.take is not INF: try: @@ -83,15 +86,15 @@ def _postfilter(self, matches: Iterable[re.Match]) -> Iterable[re.Match]: except TypeError: result = list(result) length = len(result) - indices = sorted(range(length), key=lambda k: len(result[k][0]), reverse=True) + indices = sorted(range(length), key=lambda k: len(result[k][1][0]), reverse=True) for k in sorted(islice(indices, abs(self.args.take))): yield result[k] elif self.args.longest: - yield from sorted(result, key=lambda m: m.end() - m.start(), reverse=True) + yield from sorted(result, key=lambda m: m[1].end() - m[1].start(), reverse=True) elif self.args.take: yield from islice(result, abs(self.args.take)) - def matchfilter(self, matches: Iterable[re.Match]) -> Iterable[re.Match]: + def matchfilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterable[Tuple[int, re.Match[bytes]]]: yield from self._postfilter(self._prefilter(matches)) def matches_filtered( @@ -106,20 +109,17 @@ def matches_filtered( dictionary mapping its position (start, end) in the input data to the filtered and transformed match that was found at this position. """ - def funcify(t): - def const(m): return t - return t if callable(t) else const - - transforms = [funcify(f) for f in transforms] or [lambda m: m[0]] + transforms = [(f if callable(f) else lambda _: f) for f in transforms] + transforms = transforms or [lambda m: m[0]] if self.args.stripspace: data = re.sub(BR'\s+', B'', data) - for k, match in enumerate(self.matchfilter(self.matches(memoryview(data), pattern))): + for k, (offset, match) in enumerate(self.matchfilter(self.matches(memoryview(data), pattern))): for transform in transforms: t = transform(match) if t is None: continue - t = self.labelled(t) + t = self.labelled(t, offset=offset) t.set_next_batch(k) yield t