Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
john-kurkowski committed Oct 24, 2020
1 parent 638476f commit 79f7937
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 65 deletions.
1 change: 0 additions & 1 deletion tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Test the caching functionality"""
import pytest

from tldextract.cache import DiskCache


Expand Down
31 changes: 13 additions & 18 deletions tldextract/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
try:
FileNotFoundError
except NameError:

class FileNotFoundError(Exception):
pass


LOG = logging.getLogger(__name__)


Expand All @@ -22,7 +24,7 @@ class DiskCache(object):

def __init__(self, cache_dir, lock_timeout=20):
self.enabled = bool(cache_dir)
self.cache_dir = os.path.expanduser(str(cache_dir) or '')
self.cache_dir = os.path.expanduser(str(cache_dir) or "")
self.lock_timeout = lock_timeout
# using a unique extension provides some safety that an incorrectly set cache_dir
# combined with a call to `.clear()` wont wipe someones hard drive
Expand All @@ -40,11 +42,7 @@ def get(self, namespace, key):
with open(cache_filepath) as cache_file:
return json.load(cache_file)
except (OSError, ValueError) as exc:
LOG.error(
"error reading TLD cache file %s: %s",
cache_filepath,
exc
)
LOG.error("error reading TLD cache file %s: %s", cache_filepath, exc)
raise KeyError("namespace: " + namespace + " key: " + repr(key))

def set(self, namespace, key, value):
Expand All @@ -54,7 +52,7 @@ def set(self, namespace, key, value):
cache_filepath = self._key_to_cachefile_path(namespace, key)

try:
with open(cache_filepath, 'w') as cache_file:
with open(cache_filepath, "w") as cache_file:
json.dump(value, cache_file)
except OSError as ioe:
LOG.warning(
Expand All @@ -75,7 +73,9 @@ def clear(self):
"""Clear the disk cache"""
for root, _, files in os.walk(self.cache_dir):
for filename in files:
if filename.endswith(self.file_ext) or filename.endswith(self.file_ext + ".lock"):
if filename.endswith(self.file_ext) or filename.endswith(
self.file_ext + ".lock"
):
try:
os.unlink(os.path.join(root, filename))
except FileNotFoundError:
Expand All @@ -102,8 +102,7 @@ def run_and_cache(self, func, namespace, kwargs, hashed_argnames):

key_args = {k: v for k, v in kwargs.items() if k in hashed_argnames}
cache_filepath = self._key_to_cachefile_path(namespace, key_args)
lock_path = cache_filepath + '.lock'
# print(lock_path)
lock_path = cache_filepath + ".lock"
with FileLock(lock_path, timeout=self.lock_timeout):
try:
result = self.get(namespace=namespace, key=key_args)
Expand All @@ -118,12 +117,8 @@ def cached_fetch_url(self, session, url, timeout):
return self.run_and_cache(
func=_fetch_url,
namespace="urls",
kwargs={
"session": session,
"url": url,
"timeout": timeout
},
hashed_argnames=["url"]
kwargs={"session": session, "url": url, "timeout": timeout},
hashed_argnames=["url"],
)


Expand All @@ -134,7 +129,7 @@ def _fetch_url(session, url, timeout):
text = response.text

if not isinstance(text, str):
text = str(text, 'utf-8')
text = str(text, "utf-8")

return text

Expand All @@ -144,7 +139,7 @@ def _make_cache_key(inputs):
try:
key = md5(key).hexdigest()
except TypeError:
key = md5(key.encode('utf8')).hexdigest()
key = md5(key.encode("utf8")).hexdigest()
return key


Expand Down
105 changes: 60 additions & 45 deletions tldextract/tldextract.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,27 +57,25 @@
import idna

from .cache import DiskCache
from .remote import IP_RE
from .remote import SCHEME_RE
from .remote import looks_like_ip
from .remote import IP_RE, SCHEME_RE, looks_like_ip
from .suffix_list import get_suffix_lists

LOG = logging.getLogger("tldextract")

CACHE_DIR_DEFAULT = os.path.join(os.path.dirname(__file__), '.suffix_cache/')
CACHE_DIR_DEFAULT = os.path.join(os.path.dirname(__file__), ".suffix_cache/")
CACHE_DIR = os.path.expanduser(os.environ.get("TLDEXTRACT_CACHE", CACHE_DIR_DEFAULT))
CACHE_TIMEOUT = os.environ.get('TLDEXTRACT_CACHE_TIMEOUT')
CACHE_TIMEOUT = os.environ.get("TLDEXTRACT_CACHE_TIMEOUT")

PUBLIC_SUFFIX_LIST_URLS = (
'https://publicsuffix.org/list/public_suffix_list.dat',
'https://raw.githubusercontent.com/publicsuffix/list/master/public_suffix_list.dat',
"https://publicsuffix.org/list/public_suffix_list.dat",
"https://raw.githubusercontent.com/publicsuffix/list/master/public_suffix_list.dat",
)

CACHE = DiskCache(cache_dir=CACHE_DIR)


class ExtractResult(collections.namedtuple('ExtractResult', 'subdomain domain suffix')):
'''namedtuple of a URL's subdomain, domain, and suffix.'''
class ExtractResult(collections.namedtuple("ExtractResult", "subdomain domain suffix")):
"""namedtuple of a URL's subdomain, domain, and suffix."""

# Necessary for __dict__ member to get populated in Python 3+
__slots__ = ()
Expand All @@ -93,8 +91,8 @@ def registered_domain(self):
''
"""
if self.domain and self.suffix:
return self.domain + '.' + self.suffix
return ''
return self.domain + "." + self.suffix
return ""

@property
def fqdn(self):
Expand All @@ -108,8 +106,8 @@ def fqdn(self):
"""
if self.domain and self.suffix:
# self is the namedtuple (subdomain domain suffix)
return '.'.join(i for i in self if i)
return ''
return ".".join(i for i in self if i)
return ""

@property
def ipv4(self):
Expand All @@ -125,17 +123,23 @@ def ipv4(self):
"""
if not (self.suffix or self.subdomain) and IP_RE.match(self.domain):
return self.domain
return ''
return ""


class TLDExtract(object):
'''A callable for extracting, subdomain, domain, and suffix components from
a URL.'''
"""A callable for extracting, subdomain, domain, and suffix components from
a URL."""

# TODO: Agreed with Pylint: too-many-arguments
def __init__(self, cache_dir=CACHE_DIR, suffix_list_urls=PUBLIC_SUFFIX_LIST_URLS, # pylint: disable=too-many-arguments
fallback_to_snapshot=True, include_psl_private_domains=False, extra_suffixes=(),
cache_fetch_timeout=CACHE_TIMEOUT):
def __init__(
self,
cache_dir=CACHE_DIR,
suffix_list_urls=PUBLIC_SUFFIX_LIST_URLS, # pylint: disable=too-many-arguments
fallback_to_snapshot=True,
include_psl_private_domains=False,
extra_suffixes=(),
cache_fetch_timeout=CACHE_TIMEOUT,
):
"""
Constructs a callable for extracting subdomain, domain, and suffix
components from a URL.
Expand Down Expand Up @@ -178,13 +182,17 @@ def __init__(self, cache_dir=CACHE_DIR, suffix_list_urls=PUBLIC_SUFFIX_LIST_URLS
and read timeouts
"""
suffix_list_urls = suffix_list_urls or ()
self.suffix_list_urls = tuple(url.strip() for url in suffix_list_urls if url.strip())
self.suffix_list_urls = tuple(
url.strip() for url in suffix_list_urls if url.strip()
)

self.fallback_to_snapshot = fallback_to_snapshot
if not (self.suffix_list_urls or cache_dir or self.fallback_to_snapshot):
raise ValueError("The arguments you have provided disable all ways for tldextract "
"to obtain data. Please provide a suffix list data, a cache_dir, "
"or set `fallback_to_snapshot` to `True`.")
raise ValueError(
"The arguments you have provided disable all ways for tldextract "
"to obtain data. Please provide a suffix list data, a cache_dir, "
"or set `fallback_to_snapshot` to `True`."
)

self.include_psl_private_domains = include_psl_private_domains
self.extra_suffixes = extra_suffixes
Expand All @@ -207,28 +215,29 @@ def __call__(self, url, include_psl_private_domains=None):
ExtractResult(subdomain='forums', domain='bbc', suffix='co.uk')
"""

netloc = SCHEME_RE.sub("", url) \
.partition("/")[0] \
.partition("?")[0] \
.partition("#")[0] \
.split("@")[-1] \
.partition(":")[0] \
.strip() \
netloc = (
SCHEME_RE.sub("", url)
.partition("/")[0]
.partition("?")[0]
.partition("#")[0]
.split("@")[-1]
.partition(":")[0]
.strip()
.rstrip(".")
)

labels = netloc.split(".")

translations = [_decode_punycode(label) for label in labels]
suffix_index = self._get_tld_extractor().suffix_index(
translations,
include_psl_private_domains=include_psl_private_domains
translations, include_psl_private_domains=include_psl_private_domains
)

suffix = ".".join(labels[suffix_index:])
if not suffix and netloc and looks_like_ip(netloc):
return ExtractResult('', netloc, '')
return ExtractResult("", netloc, "")

subdomain = ".".join(labels[:suffix_index - 1]) if suffix_index else ""
subdomain = ".".join(labels[: suffix_index - 1]) if suffix_index else ""
domain = labels[suffix_index - 1] if suffix_index else ""
return ExtractResult(subdomain, domain, suffix)

Expand All @@ -248,14 +257,14 @@ def tlds(self):
return list(self._get_tld_extractor().tlds())

def _get_tld_extractor(self):
'''Get or compute this object's TLDExtractor. Looks up the TLDExtractor
"""Get or compute this object's TLDExtractor. Looks up the TLDExtractor
in roughly the following order, based on the settings passed to
__init__:
1. Memoized on `self`
2. Local system _cache file
3. Remote PSL, over HTTP
4. Bundled PSL snapshot file'''
4. Bundled PSL snapshot file"""
# pylint: disable=no-else-return

if self._extractor:
Expand All @@ -265,7 +274,7 @@ def _get_tld_extractor(self):
cache=self._cache,
urls=self.suffix_list_urls,
cache_fetch_timeout=self.cache_fetch_timeout,
fallback_to_snapshot=self.fallback_to_snapshot
fallback_to_snapshot=self.fallback_to_snapshot,
)

if not any([public_tlds, private_tlds, self.extra_suffixes]):
Expand All @@ -275,7 +284,7 @@ def _get_tld_extractor(self):
public_tlds=public_tlds,
private_tlds=private_tlds,
extra_tlds=list(self.extra_suffixes),
include_psl_private_domains=self.include_psl_private_domains
include_psl_private_domains=self.include_psl_private_domains,
)
return self._extractor

Expand All @@ -298,7 +307,9 @@ class _PublicSuffixListTLDExtractor:
lookups.
"""

def __init__(self, public_tlds, private_tlds, extra_tlds, include_psl_private_domains=False):
def __init__(
self, public_tlds, private_tlds, extra_tlds, include_psl_private_domains=False
):
# set the default value
self.include_psl_private_domains = include_psl_private_domains
self.public_tlds = public_tlds
Expand All @@ -310,7 +321,11 @@ def tlds(self, include_psl_private_domains=None):
if include_psl_private_domains is None:
include_psl_private_domains = self.include_psl_private_domains

return self.tlds_incl_private if include_psl_private_domains else self.tlds_excl_private
return (
self.tlds_incl_private
if include_psl_private_domains
else self.tlds_excl_private
)

def suffix_index(self, lower_spl, include_psl_private_domains=None):
"""Returns the index of the first suffix label.
Expand All @@ -319,15 +334,15 @@ def suffix_index(self, lower_spl, include_psl_private_domains=None):
tlds = self.tlds(include_psl_private_domains)
length = len(lower_spl)
for i in range(length):
maybe_tld = '.'.join(lower_spl[i:])
exception_tld = '!' + maybe_tld
maybe_tld = ".".join(lower_spl[i:])
exception_tld = "!" + maybe_tld
if exception_tld in tlds:
return i + 1

if maybe_tld in tlds:
return i

wildcard_tld = '*.' + '.'.join(lower_spl[i + 1:])
wildcard_tld = "*." + ".".join(lower_spl[i + 1 :])
if wildcard_tld in tlds:
return i

Expand All @@ -336,10 +351,10 @@ def suffix_index(self, lower_spl, include_psl_private_domains=None):

def _decode_punycode(label):
lowered = label.lower()
looks_like_puny = lowered.startswith('xn--')
looks_like_puny = lowered.startswith("xn--")
if looks_like_puny:
try:
return idna.decode(label.encode('ascii')).lower()
return idna.decode(label.encode("ascii")).lower()
except (UnicodeError, IndexError):
pass
return lowered
3 changes: 2 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ deps = pycodestyle
commands = pycodestyle tldextract tests {posargs}

[pycodestyle]
# E203 - whitespace before; disagrees with PEP8 https://github.com/psf/black/issues/354#issuecomment-397684838
# E501 - line too long
ignore = E501
ignore = E203,E501

0 comments on commit 79f7937

Please sign in to comment.