Skip to content

Commit

Permalink
fixes type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
huettenhain committed Nov 28, 2024
1 parent 19d2ade commit f833931
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions refinery/units/pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
"""
Pattern matching based extraction and substitution units.
"""
from __future__ import annotations

import re

from typing import Iterable, Optional, Callable, Union, Tuple, ByteString, Dict
from typing import Iterable, Optional, Callable, Union, Tuple, ByteString, Dict, TYPE_CHECKING
from itertools import islice
from hashlib import blake2b

from refinery.lib.types import INF, AST, BufferOrStr
from refinery.lib.argformats import regexp
from refinery.units import Arg, Unit

if TYPE_CHECKING:
MT = Tuple[int, re.Match[bytes]]


class PatternExtractorBase(Unit, abstract=True):

Expand Down Expand Up @@ -61,7 +66,7 @@ def matches(self, data: ByteString, pattern: Union[ByteString, re.Pattern]):
start = a + match.start() * 2
yield start, match

def _prefilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterable[Tuple[int, re.Match[bytes]]]:
def _prefilter(self, matches: Iterable[MT]) -> Iterable[MT]:
barrier = set()
taken = 0
for offset, match in matches:
Expand All @@ -78,7 +83,7 @@ def _prefilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterable
if not self.args.longest and taken >= self.args.take:
break

def _postfilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterable[Tuple[int, re.Match[bytes]]]:
def _postfilter(self, matches: Iterable[MT]) -> Iterable[MT]:
result = matches
if self.args.longest and self.args.take and self.args.take is not INF:
try:
Expand All @@ -94,7 +99,7 @@ def _postfilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterabl
elif self.args.take:
yield from islice(result, abs(self.args.take))

def matchfilter(self, matches: Iterable[Tuple[int, re.Match[bytes]]]) -> Iterable[Tuple[int, re.Match[bytes]]]:
def matchfilter(self, matches: Iterable[MT]) -> Iterable[MT]:
yield from self._postfilter(self._prefilter(matches))

def matches_filtered(
Expand Down

0 comments on commit f833931

Please sign in to comment.