Skip to content

Commit

Permalink
Merge pull request #165 from koenvo/feature/accept-pathlib
Browse files Browse the repository at this point in the history
Refactor open_as_file to also accept Path inputs
  • Loading branch information
koenvo authored Jan 16, 2023
2 parents 099d5c4 + 7435cfc commit 1be58c1
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 81 deletions.
1 change: 1 addition & 0 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ class Metadata:
flags: DatasetFlag
provider: Provider
coordinate_system: CoordinateSystem
attributes: Optional[Dict] = field(default_factory=dict, compare=False)


class DatasetType(Enum):
Expand Down
90 changes: 38 additions & 52 deletions kloppy/infra/io/adapters/http.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,55 @@
from typing import BinaryIO
import base64

from kloppy.config import get_config
from kloppy.exceptions import AdapterError, InputNotFoundError
from .adapter import Adapter

try:
from js import XMLHttpRequest

class HTTPAdapter(Adapter):
def supports(self, url: str) -> bool:
return url.startswith("http://") or url.startswith("https://")
RUNS_IN_BROWSER = True
except ImportError:
RUNS_IN_BROWSER = False

def read_to_stream(self, url: str, output: BinaryIO):
basic_authentication = get_config("adapters.http.basic_authentication")

def check_requests_patch():
if RUNS_IN_BROWSER:
try:
from js import XMLHttpRequest

_RUNS_IN_BROWSER = True
import pyodide_http
except ImportError:
try:
import requests
except ImportError:
raise AdapterError(
"Seems like you don't have requests installed. Please"
" install it using: pip install requests"
)

_RUNS_IN_BROWSER = False

if _RUNS_IN_BROWSER:
xhr = XMLHttpRequest.new()
xhr.responseType = "arraybuffer"
if basic_authentication:
authentication = base64.b64encode(
basic_authentication.join(":")
)
xhr.setRequestHeader(
"Authorization",
f"Basic {authentication}",
)
raise AdapterError(
"Seems like you don't have `pyodide-http` installed, which is required to make http requests "
"work in the browser. Please install it using: pip install pyodide-http"
)

xhr.open("GET", url, False)
xhr.send(None)
pyodide_http.patch_all()

# Borrowed from 'raise_for_status'
http_error_msg = ""
if 400 <= xhr.status < 500:
http_error_msg = f"{xhr.status} Client Error: url: {url}"

elif 500 <= xhr.status < 600:
http_error_msg = f"{xhr.status} Server Error: url: {url}"

if http_error_msg:
raise AdapterError(http_error_msg)
class HTTPAdapter(Adapter):
def supports(self, url: str) -> bool:
return url.startswith("http://") or url.startswith("https://")

output.write(xhr.response.to_py().tobytes())
else:
auth = None
if basic_authentication:
auth = requests.auth.HTTPBasicAuth(*basic_authentication)
def read_to_stream(self, url: str, output: BinaryIO):
check_requests_patch()

with requests.get(url, stream=True, auth=auth) as r:
if r.status_code == 404:
raise InputNotFoundError(f"Could not find {url}")
basic_authentication = get_config("adapters.http.basic_authentication")

r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
output.write(chunk)
try:
import requests
except ImportError:
raise AdapterError(
"Seems like you don't have `requests` installed. Please"
" install it using: pip install requests"
)

auth = None
if basic_authentication:
auth = requests.auth.HTTPBasicAuth(*basic_authentication)

with requests.get(url, stream=True, auth=auth) as r:
if r.status_code == 404:
raise InputNotFoundError(f"Could not find {url}")

r.raise_for_status()
for chunk in r.iter_content(chunk_size=8192):
output.write(chunk)
12 changes: 7 additions & 5 deletions kloppy/infra/serializers/tracking/metrica_epts/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Dict, Union, Set
from dataclasses import dataclass, field
from typing import List, Dict, Union

from kloppy.domain import Team, Player, Metadata

Expand Down Expand Up @@ -160,6 +160,8 @@ def to_regex(self, **kwargs) -> str:

@dataclass
class EPTSMetadata(Metadata):
player_channels: List[PlayerChannel]
data_format_specifications: List[DataFormatSpecification]
sensors: List[Sensor]
player_channels: List[PlayerChannel] = field(default_factory=list)
data_format_specifications: List[DataFormatSpecification] = field(
default_factory=list
)
sensors: List[Sensor] = field(default_factory=list)
51 changes: 27 additions & 24 deletions kloppy/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os
import urllib.parse
from dataclasses import dataclass, replace
from typing import Union, IO, BinaryIO, Tuple, Optional
from pathlib import PurePath
from typing import Union, IO, BinaryIO, Tuple

from io import BytesIO

Expand All @@ -19,7 +20,7 @@

@dataclass(frozen=True)
class Source:
data: str
data: "FileLike"
optional: bool = False
skip_if_missing: bool = False

Expand All @@ -30,7 +31,7 @@ def create(cls, input_: "FileLike", **kwargs):
return Source(data=input_, **kwargs)


FileLike = Union[str, bytes, IO[bytes], Source]
FileLike = Union[str, PurePath, bytes, IO[bytes], Source]


def get_local_cache_stream(url: str, cache_dir: str) -> Tuple[BinaryIO, bool]:
Expand Down Expand Up @@ -58,49 +59,51 @@ def dummy_context_mgr():


def open_as_file(input_: FileLike) -> IO:
if isinstance(input_, str) or isinstance(input_, Source):
if isinstance(input_, str):
input_ = Source(input_)

if isinstance(input_, Source):
if input_.data is None and input_.optional:
# This saves us some additional code in every vendor specific code
return dummy_context_mgr()

if "{" in input_.data or "<" in input_.data:
return BytesIO(input_.data.encode("utf8"))
try:
return open_as_file(input_.data)
except InputNotFoundError:
if input_.skip_if_missing:
logging.info(f"Input {input_.data} not found. Skipping")
return dummy_context_mgr()
raise
elif isinstance(input_, str) or isinstance(input_, PurePath):
if isinstance(input_, PurePath):
input_ = str(input_)
is_path = True
else:
is_path = False

if not is_path and ("{" in input_ or "<" in input_):
return BytesIO(input_.encode("utf8"))
else:
adapter = get_adapter(input_.data)
adapter = get_adapter(input_)
if adapter:
cache_dir = get_config("cache")
if cache_dir:
stream, local_cache_file = get_local_cache_stream(
input_.data, cache_dir
input_, cache_dir
)
else:
stream = BytesIO()
local_cache_file = None

if not local_cache_file:
logger.info(f"Retrieving {input_.data}")
try:
adapter.read_to_stream(input_.data, stream)
except InputNotFoundError:
if input_.skip_if_missing:
logging.info(
f"Input {input_.data} not found. Skipping"
)
return dummy_context_mgr()
raise

logger.info(f"Retrieving {input_}")
adapter.read_to_stream(input_, stream)
logger.info("Retrieval complete")
else:
logger.info(f"Using local cached file {local_cache_file}")
stream.seek(0)
else:
if not os.path.exists(input_.data):
if not os.path.exists(input_):
raise InputNotFoundError(f"File {input_} does not exist")

stream = _open(input_.data, "rb")
stream = _open(input_, "rb")
return stream
elif isinstance(input_, bytes):
return BytesIO(input_)
Expand Down
11 changes: 11 additions & 0 deletions kloppy/tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from pathlib import Path

from pandas import DataFrame
from pandas.testing import assert_frame_equal
Expand All @@ -25,6 +26,7 @@
)

from kloppy import opta, tracab, statsbomb
from kloppy.io import open_as_file


class TestHelpers:
Expand Down Expand Up @@ -326,3 +328,12 @@ def test_to_pandas_additional_columns(self):
)

assert_frame_equal(data_frame, expected_data_frame, check_like=True)


class TestOpenAsFile:
def test_path(self):
path = Path(__file__).parent / "files/tracab_meta.xml"
with open_as_file(path) as fp:
data = fp.read()

assert len(data) == os.path.getsize(path)

0 comments on commit 1be58c1

Please sign in to comment.