Skip to content

Commit

Permalink
Merge pull request #16 from zytedata/skip-domainless
Browse files Browse the repository at this point in the history
Add a flag for skipping universal patterns.
  • Loading branch information
wRAR authored Apr 12, 2024
2 parents a99793e + b7107a8 commit c39f11b
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
26 changes: 26 additions & 0 deletions tests/test_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_matcher_add_remove_get():
matcher.add_or_update(1, patterns)
assert matcher.match("http://example.com") == 1
assert matcher.get(1) is patterns
assert list(matcher.match_universal()) == []

patterns_3 = Patterns(["example.com/articles"])
matcher.add_or_update(3, patterns_3)
Expand All @@ -93,12 +94,14 @@ def test_matcher_add_remove_get():
assert matcher.match("http://example.com") == 2
assert matcher.match("http://example.com/products") == 1
assert matcher.get(2) is univ_patterns
assert list(matcher.match_universal()) == [2]

# Removing a universal pattern
matcher.remove(2)
assert matcher.match("http://example.com") is None
assert matcher.match("http://example.com/products") == 1
assert matcher.get(2) is None
assert list(matcher.match_universal()) == []

# Removing regular patterns
matcher.remove(3)
Expand Down Expand Up @@ -161,3 +164,26 @@ def test_match_all():
assert list(matcher.match_all("http://example.com/products")) == [1]
assert list(matcher.match_all("http://foo.example.com/products")) == [2, 1]
assert list(matcher.match_all("http://bar.example.com/products")) == [3, 4, 1]


def test_match_all_include_universal():
matcher = URLMatcher()
matcher.add_or_update(1, Patterns(include=["example.com"]))
matcher.add_or_update(2, Patterns(include=[]))
matcher.add_or_update(3, Patterns(include=["foo.example.com"]))
matcher.add_or_update(4, Patterns(include=[""]))
assert list(matcher.match_all("http://example.com")) == [1, 4, 2]
assert list(matcher.match_all("http://example.com", include_universal=False)) == [1]
assert list(matcher.match_all("http://foo.example.com")) == [3, 1, 4, 2]
assert list(matcher.match_all("http://foo.example.com", include_universal=False)) == [3, 1]
assert list(matcher.match_all("http://example.net")) == [4, 2]
assert list(matcher.match_all("http://example.net", include_universal=False)) == []


def test_match_universal():
matcher = URLMatcher()
matcher.add_or_update(1, Patterns(include=["example.com"]))
matcher.add_or_update(2, Patterns(include=[]))
matcher.add_or_update(3, Patterns(include=["foo.example.com"]))
matcher.add_or_update(4, Patterns(include=[""]))
assert list(matcher.match_universal()) == [4, 2]
22 changes: 18 additions & 4 deletions url_matcher/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self, data: Union[Mapping[Any, Patterns], Iterable[Tuple[Any, Patte
initialize the object from
"""
self.matchers_by_domain: Dict[str, List[PatternsMatcher]] = {}
self.matchers_universal: List[PatternsMatcher] = []
self.patterns: Dict[Any, Patterns] = {}

if data:
Expand Down Expand Up @@ -151,15 +152,21 @@ def remove(self, identifier: Any):
def get(self, identifier: Any) -> Optional[Patterns]:
return self.patterns.get(identifier)

def match(self, url: str) -> Optional[Any]:
return next(self.match_all(url), None)
def match(self, url: str, *, include_universal=True) -> Optional[Any]:
return next(self.match_all(url, include_universal=include_universal), None)

def match_all(self, url: str) -> Iterator[Any]:
def match_all(self, url: str, *, include_universal=True) -> Iterator[Any]:
domain = get_domain(url)
for matcher in chain(self.matchers_by_domain.get(domain) or [], self.matchers_by_domain.get("") or []):
matchers: Iterable[PatternsMatcher] = self.matchers_by_domain.get(domain) or []
if include_universal:
matchers = chain(matchers, self.matchers_universal)
for matcher in matchers:
if matcher.match(url):
yield matcher.identifier

def match_universal(self) -> Iterator[Any]:
return (m.identifier for m in self.matchers_universal)

def _sort_domain(self, domain: str):
"""
Sort all the rules within a domain so that the matching can be done in sequence:
Expand All @@ -179,6 +186,7 @@ def sort_key(matcher: PatternsMatcher) -> Tuple:
return (matcher.patterns.priority, sorted_includes, matcher.identifier)

self.matchers_by_domain[domain].sort(key=sort_key, reverse=True)
self.matchers_universal.sort(key=sort_key, reverse=True)

def _del_matcher(self, domain: str, identifier: Any):
matchers = self.matchers_by_domain[domain]
Expand All @@ -188,10 +196,16 @@ def _del_matcher(self, domain: str, identifier: Any):
break
if not matchers:
del self.matchers_by_domain[domain]
for idx in range(len(self.matchers_universal)):
if self.matchers_universal[idx].identifier == identifier:
del self.matchers_universal[idx]
break

def _add_matcher(self, domain: str, matcher: PatternsMatcher):
# FIXME: This can be made much more efficient if we insert the data directly in order instead of resorting.
# The bisect module could be used for this purpose.
# I'm leaving it for the future as insertion time is not critical.
self.matchers_by_domain.setdefault(domain, []).append(matcher)
if domain == "":
self.matchers_universal.append(matcher)
self._sort_domain(domain)

0 comments on commit c39f11b

Please sign in to comment.