diff --git a/tests/test_matcher.py b/tests/test_matcher.py index f0bbaec..56fdc7e 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -74,6 +74,7 @@ def test_matcher_add_remove_get(): matcher.add_or_update(1, patterns) assert matcher.match("http://example.com") == 1 assert matcher.get(1) is patterns + assert list(matcher.match_universal()) == [] patterns_3 = Patterns(["example.com/articles"]) matcher.add_or_update(3, patterns_3) @@ -93,12 +94,14 @@ def test_matcher_add_remove_get(): assert matcher.match("http://example.com") == 2 assert matcher.match("http://example.com/products") == 1 assert matcher.get(2) is univ_patterns + assert list(matcher.match_universal()) == [2] # Removing a universal pattern matcher.remove(2) assert matcher.match("http://example.com") is None assert matcher.match("http://example.com/products") == 1 assert matcher.get(2) is None + assert list(matcher.match_universal()) == [] # Removing regular patterns matcher.remove(3) @@ -161,3 +164,26 @@ def test_match_all(): assert list(matcher.match_all("http://example.com/products")) == [1] assert list(matcher.match_all("http://foo.example.com/products")) == [2, 1] assert list(matcher.match_all("http://bar.example.com/products")) == [3, 4, 1] + + +def test_match_all_include_universal(): + matcher = URLMatcher() + matcher.add_or_update(1, Patterns(include=["example.com"])) + matcher.add_or_update(2, Patterns(include=[])) + matcher.add_or_update(3, Patterns(include=["foo.example.com"])) + matcher.add_or_update(4, Patterns(include=[""])) + assert list(matcher.match_all("http://example.com")) == [1, 4, 2] + assert list(matcher.match_all("http://example.com", include_universal=False)) == [1] + assert list(matcher.match_all("http://foo.example.com")) == [3, 1, 4, 2] + assert list(matcher.match_all("http://foo.example.com", include_universal=False)) == [3, 1] + assert list(matcher.match_all("http://example.net")) == [4, 2] + assert list(matcher.match_all("http://example.net", include_universal=False)) == [] + + +def test_match_universal(): + matcher = URLMatcher() + matcher.add_or_update(1, Patterns(include=["example.com"])) + matcher.add_or_update(2, Patterns(include=[])) + matcher.add_or_update(3, Patterns(include=["foo.example.com"])) + matcher.add_or_update(4, Patterns(include=[""])) + assert list(matcher.match_universal()) == [4, 2] diff --git a/url_matcher/matcher.py b/url_matcher/matcher.py index 99a3544..2c8c269 100644 --- a/url_matcher/matcher.py +++ b/url_matcher/matcher.py @@ -107,6 +107,7 @@ def __init__(self, data: Union[Mapping[Any, Patterns], Iterable[Tuple[Any, Patte initialize the object from """ self.matchers_by_domain: Dict[str, List[PatternsMatcher]] = {} + self.matchers_universal: List[PatternsMatcher] = [] self.patterns: Dict[Any, Patterns] = {} if data: @@ -151,15 +152,21 @@ def remove(self, identifier: Any): def get(self, identifier: Any) -> Optional[Patterns]: return self.patterns.get(identifier) - def match(self, url: str) -> Optional[Any]: - return next(self.match_all(url), None) + def match(self, url: str, *, include_universal=True) -> Optional[Any]: + return next(self.match_all(url, include_universal=include_universal), None) - def match_all(self, url: str) -> Iterator[Any]: + def match_all(self, url: str, *, include_universal=True) -> Iterator[Any]: domain = get_domain(url) - for matcher in chain(self.matchers_by_domain.get(domain) or [], self.matchers_by_domain.get("") or []): + matchers: Iterable[PatternsMatcher] = self.matchers_by_domain.get(domain) or [] + if include_universal: + matchers = chain(matchers, self.matchers_universal) + for matcher in matchers: if matcher.match(url): yield matcher.identifier + def match_universal(self) -> Iterator[Any]: + return (m.identifier for m in self.matchers_universal) + def _sort_domain(self, domain: str): """ Sort all the rules within a domain so that the matching can be done in sequence: @@ -179,6 +186,7 @@ def sort_key(matcher: PatternsMatcher) -> Tuple: return (matcher.patterns.priority, sorted_includes, matcher.identifier) self.matchers_by_domain[domain].sort(key=sort_key, reverse=True) + self.matchers_universal.sort(key=sort_key, reverse=True) def _del_matcher(self, domain: str, identifier: Any): matchers = self.matchers_by_domain[domain] @@ -188,10 +196,16 @@ def _del_matcher(self, domain: str, identifier: Any): break if not matchers: del self.matchers_by_domain[domain] + for idx in range(len(self.matchers_universal)): + if self.matchers_universal[idx].identifier == identifier: + del self.matchers_universal[idx] + break def _add_matcher(self, domain: str, matcher: PatternsMatcher): # FIXME: This can be made much more efficient if we insert the data directly in order instead of resorting. # The bisect module could be used for this purpose. # I'm leaving it for the future as insertion time is not critical. self.matchers_by_domain.setdefault(domain, []).append(matcher) + if domain == "": + self.matchers_universal.append(matcher) self._sort_domain(domain)