diff --git a/tests/test_matcher.py b/tests/test_matcher.py index 11f1f65..f0bbaec 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -116,6 +116,17 @@ def test_matcher_add_remove_get(): with pytest.raises(IncludePatternsWithoutDomainError): matcher.add_or_update(1, Patterns(["/no_domain_pattern"])) + # Patterns with the same domain shouldn't produce multiple matchers + patterns = Patterns(["example.com/products", "example.com/brands"]) + matcher.add_or_update(1, patterns) + assert len(matcher.matchers_by_domain) == 1 + assert len(matcher.matchers_by_domain["example.com"]) == 1 + assert len(matcher.patterns) == 1 + assert matcher.match("http://example.com") is None + assert matcher.match("http://example.com/products") == 1 + assert matcher.match("http://example.com/brands") == 1 + assert len(list(matcher.match_all("http://example.com/products"))) == 1 + def test_dedupe_unique_patterns(): diff --git a/url_matcher/matcher.py b/url_matcher/matcher.py index 5022f1d..99a3544 100644 --- a/url_matcher/matcher.py +++ b/url_matcher/matcher.py @@ -4,18 +4,7 @@ from dataclasses import dataclass, field from itertools import chain -from typing import ( - Any, - Dict, - Iterable, - Iterator, - List, - Mapping, - Optional, - Set, - Tuple, - Union, -) +from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Tuple, Union from url_matcher.patterns import PatternMatcher, get_pattern_domain, hierarchical_str from url_matcher.util import get_domain @@ -46,7 +35,8 @@ def __init__(self, include: List[str], exclude: Optional[List[str]] = None, prio def get_domains(self) -> List[str]: domains = [get_pattern_domain(pattern) for pattern in self.include] - return [domain for domain in domains if domain] + # remove duplicate domains preserving the order + return list(dict.fromkeys(domain for domain in domains if domain)) def get_includes_without_domain(self) -> List[str]: return [pattern for pattern in self.include if get_pattern_domain(pattern) is None] @@ -166,12 +156,9 @@ def match(self, url: str) -> Optional[Any]: def match_all(self, url: str) -> Iterator[Any]: domain = get_domain(url) - unique: Set[Any] = set() for matcher in chain(self.matchers_by_domain.get(domain) or [], self.matchers_by_domain.get("") or []): if matcher.match(url): - if matcher.identifier not in unique: - unique.add(matcher.identifier) - yield matcher.identifier + yield matcher.identifier def _sort_domain(self, domain: str): """