From 0b2439b539414d70b8484aac596d37ccb1ed9cfe Mon Sep 17 00:00:00 2001 From: Pieter Robberechts Date: Tue, 17 Dec 2024 20:30:59 +0100 Subject: [PATCH] chore: remove CRLF line endings --- kloppy/_providers/tracab.py | 122 +-- .../serializers/tracking/tracab/helpers.py | 480 ++++----- .../serializers/tracking/tracab/tracab_dat.py | 452 ++++----- .../tracking/tracab/tracab_json.py | 432 ++++---- kloppy/io.py | 938 +++++++++--------- kloppy/tests/test_tracab.py | 796 +++++++-------- 6 files changed, 1610 insertions(+), 1610 deletions(-) diff --git a/kloppy/_providers/tracab.py b/kloppy/_providers/tracab.py index 7c21f603..5afdc392 100644 --- a/kloppy/_providers/tracab.py +++ b/kloppy/_providers/tracab.py @@ -1,61 +1,61 @@ -from typing import Optional, Union, Type - - -from kloppy.domain import TrackingDataset -from kloppy.infra.serializers.tracking.tracab.tracab_dat import ( - TRACABDatDeserializer, -) -from kloppy.infra.serializers.tracking.tracab.tracab_json import ( - TRACABJSONDeserializer, - TRACABInputs, -) -from kloppy.io import FileLike, open_as_file, get_file_extension - - -def load( - meta_data: FileLike, - raw_data: FileLike, - sample_rate: Optional[float] = None, - limit: Optional[int] = None, - coordinates: Optional[str] = None, - only_alive: Optional[bool] = True, - file_format: Optional[str] = None, -) -> TrackingDataset: - if file_format == "dat": - deserializer_class = TRACABDatDeserializer - elif file_format == "json": - deserializer_class = TRACABJSONDeserializer - else: - deserializer_class = identify_deserializer(raw_data) - - deserializer = deserializer_class( - sample_rate=sample_rate, - limit=limit, - coordinate_system=coordinates, - only_alive=only_alive, - meta_data_extension=get_file_extension(meta_data), - ) - with open_as_file(meta_data) as meta_data_fp, open_as_file( - raw_data - ) as raw_data_fp: - return deserializer.deserialize( - inputs=TRACABInputs(meta_data=meta_data_fp, raw_data=raw_data_fp) - ) - - -def identify_deserializer( - raw_data: FileLike, -) -> Union[Type[TRACABDatDeserializer], Type[TRACABJSONDeserializer]]: - - raw_data_extension = get_file_extension(raw_data) - - if raw_data_extension == ".dat": - deserializer = TRACABDatDeserializer - elif raw_data_extension == ".json": - deserializer = TRACABJSONDeserializer - else: - raise ValueError( - "Tracab file format could not be recognized, please specify" - ) - - return deserializer +from typing import Optional, Union, Type + + +from kloppy.domain import TrackingDataset +from kloppy.infra.serializers.tracking.tracab.tracab_dat import ( + TRACABDatDeserializer, +) +from kloppy.infra.serializers.tracking.tracab.tracab_json import ( + TRACABJSONDeserializer, + TRACABInputs, +) +from kloppy.io import FileLike, open_as_file, get_file_extension + + +def load( + meta_data: FileLike, + raw_data: FileLike, + sample_rate: Optional[float] = None, + limit: Optional[int] = None, + coordinates: Optional[str] = None, + only_alive: Optional[bool] = True, + file_format: Optional[str] = None, +) -> TrackingDataset: + if file_format == "dat": + deserializer_class = TRACABDatDeserializer + elif file_format == "json": + deserializer_class = TRACABJSONDeserializer + else: + deserializer_class = identify_deserializer(raw_data) + + deserializer = deserializer_class( + sample_rate=sample_rate, + limit=limit, + coordinate_system=coordinates, + only_alive=only_alive, + meta_data_extension=get_file_extension(meta_data), + ) + with open_as_file(meta_data) as meta_data_fp, open_as_file( + raw_data + ) as raw_data_fp: + return deserializer.deserialize( + inputs=TRACABInputs(meta_data=meta_data_fp, raw_data=raw_data_fp) + ) + + +def identify_deserializer( + raw_data: FileLike, +) -> Union[Type[TRACABDatDeserializer], Type[TRACABJSONDeserializer]]: + + raw_data_extension = get_file_extension(raw_data) + + if raw_data_extension == ".dat": + deserializer = TRACABDatDeserializer + elif raw_data_extension == ".json": + deserializer = TRACABJSONDeserializer + else: + raise ValueError( + "Tracab file format could not be recognized, please specify" + ) + + return deserializer diff --git a/kloppy/infra/serializers/tracking/tracab/helpers.py b/kloppy/infra/serializers/tracking/tracab/helpers.py index 7503b203..9e2ca919 100644 --- a/kloppy/infra/serializers/tracking/tracab/helpers.py +++ b/kloppy/infra/serializers/tracking/tracab/helpers.py @@ -1,240 +1,240 @@ -import logging -import warnings -import json -import html -from datetime import timedelta, timezone -from typing import Dict -from dateutil.parser import parse - -from lxml import objectify - -from kloppy.domain import ( - Team, - Period, - Orientation, - Ground, - Player, -) -from kloppy.domain.models import PositionType -from kloppy.exceptions import DeserializationError - -from kloppy.utils import performance_logging -from .common import position_types_mapping - - -logger = logging.getLogger(__name__) - - -def load_meta_data_json(meta_data): - meta_data = json.load(meta_data) - - def __create_team(team_data, ground): - team = Team( - team_id=str(team_data["TeamID"]), - name=html.unescape(team_data["ShortName"]), - ground=ground, - ) - - team.players = [ - Player( - player_id=str(player["PlayerID"]), - team=team, - first_name=html.unescape(player["FirstName"]), - last_name=html.unescape(player["LastName"]), - name=html.unescape( - player["FirstName"] + " " + player["LastName"] - ), - jersey_no=int(player["JerseyNo"]), - starting=True if player["StartingPosition"] != "S" else False, - starting_position=position_types_mapping.get( - player["StartingPosition"], PositionType.Unknown - ), - ) - for player in team_data["Players"] - ] - - return team - - with performance_logging("Loading metadata", logger=logger): - frame_rate = meta_data["FrameRate"] - pitch_size_width = meta_data["PitchShortSide"] / 100 - pitch_size_length = meta_data["PitchLongSide"] / 100 - - periods = [] - for period_id in (1, 2, 3, 4): - period_start_frame = meta_data[f"Phase{period_id}StartFrame"] - period_end_frame = meta_data[f"Phase{period_id}EndFrame"] - if period_start_frame != 0 or period_end_frame != 0: - periods.append( - Period( - id=period_id, - start_timestamp=timedelta( - seconds=period_start_frame / frame_rate - ), - end_timestamp=timedelta( - seconds=period_end_frame / frame_rate - ), - ) - ) - - home_team = __create_team(meta_data["HomeTeam"], Ground.HOME) - away_team = __create_team(meta_data["AwayTeam"], Ground.AWAY) - teams = [home_team, away_team] - - date = meta_data.get("Kickoff", None) - if date is not None: - date = parse(date).astimezone(timezone.utc) - game_id = meta_data.get("GameID", None) - - return ( - pitch_size_length, - pitch_size_width, - teams, - periods, - frame_rate, - date, - game_id, - ) - - -def load_meta_data_xml(meta_data): - def __create_team( - team_data, ground, start_frame_id, id_suffix="Id", player_item="Player" - ): - team = Team( - team_id=str(team_data[f"Team{id_suffix}"]), - name=html.unescape(team_data["ShortName"]), - ground=ground, - ) - - team.players = [ - Player( - player_id=str(player[f"Player{id_suffix}"]), - team=team, - first_name=html.unescape(player["FirstName"]), - last_name=html.unescape(player["LastName"]), - name=html.unescape( - player["FirstName"] + " " + player["LastName"] - ), - jersey_no=int(player["JerseyNo"]), - starting=player["StartFrameCount"] == start_frame_id, - starting_position=position_types_mapping.get( - player.get("StartingPosition"), PositionType.Unknown - ), - ) - for player in team_data["Players"][player_item] - ] - - return team - - with performance_logging("Loading metadata", logger=logger): - meta_data = objectify.fromstring(meta_data.read()) - - periods = [] - - if hasattr(meta_data, "match"): - id_suffix = "Id" - player_item = "Player" - - match = meta_data.match - frame_rate = int(match.attrib["iFrameRateFps"]) - pitch_size_width = float( - match.attrib["fPitchXSizeMeters"].replace(",", ".") - ) - pitch_size_height = float( - match.attrib["fPitchYSizeMeters"].replace(",", ".") - ) - date = parse(meta_data.match.attrib["dtDate"]).replace( - tzinfo=timezone.utc - ) - game_id = meta_data.match.attrib["iId"] - - for period in match.iterchildren(tag="period"): - start_frame_id = int(period.attrib["iStartFrame"]) - end_frame_id = int(period.attrib["iEndFrame"]) - if start_frame_id != 0 or end_frame_id != 0: - periods.append( - Period( - id=int(period.attrib["iId"]), - start_timestamp=timedelta( - seconds=start_frame_id / frame_rate - ), - end_timestamp=timedelta( - seconds=end_frame_id / frame_rate - ), - ) - ) - elif hasattr(meta_data, "Phase1StartFrame"): - date = parse(str(meta_data["Kickoff"])) - game_id = str(meta_data["GameID"]) - id_suffix = "ID" - player_item = "item" - - frame_rate = int(meta_data["FrameRate"]) - pitch_size_width = float(meta_data["PitchLongSide"]) / 100 - pitch_size_height = float(meta_data["PitchShortSide"]) / 100 - for i in [1, 2, 3, 4, 5]: - start_frame_id = int(meta_data[f"Phase{i}StartFrame"]) - end_frame_id = int(meta_data[f"Phase{i}EndFrame"]) - if start_frame_id != 0 or end_frame_id != 0: - periods.append( - Period( - id=i, - start_timestamp=timedelta( - seconds=start_frame_id / frame_rate - ), - end_timestamp=timedelta( - seconds=end_frame_id / frame_rate - ), - ) - ) - - orientation = ( - Orientation.HOME_AWAY - if bool(meta_data["Phase1HomeGKLeft"]) - else Orientation.AWAY_HOME - ) - else: - raise NotImplementedError( - """This 'meta_data' format is currently not supported...""" - ) - - if hasattr(meta_data, "HomeTeam") and hasattr(meta_data, "AwayTeam"): - home_team = __create_team( - meta_data["HomeTeam"], - Ground.HOME, - start_frame_id=start_frame_id, - id_suffix=id_suffix, - player_item=player_item, - ) - away_team = __create_team( - meta_data["AwayTeam"], - Ground.AWAY, - start_frame_id=start_frame_id, - id_suffix=id_suffix, - player_item=player_item, - ) - else: - home_team = Team(team_id="home", name="home", ground=Ground.HOME) - away_team = Team(team_id="away", name="away", ground=Ground.AWAY) - teams = [home_team, away_team] - return ( - pitch_size_height, - pitch_size_width, - teams, - periods, - frame_rate, - date, - game_id, - ) - - -def load_meta_data(meta_data_extension, meta_data): - if meta_data_extension == ".xml": - return load_meta_data_xml(meta_data) - elif meta_data_extension == ".json": - return load_meta_data_json(meta_data) - else: - raise ValueError( - "Tracab meta data file format could not be recognized, it should be either .xml or .json" - ) +import logging +import warnings +import json +import html +from datetime import timedelta, timezone +from typing import Dict +from dateutil.parser import parse + +from lxml import objectify + +from kloppy.domain import ( + Team, + Period, + Orientation, + Ground, + Player, +) +from kloppy.domain.models import PositionType +from kloppy.exceptions import DeserializationError + +from kloppy.utils import performance_logging +from .common import position_types_mapping + + +logger = logging.getLogger(__name__) + + +def load_meta_data_json(meta_data): + meta_data = json.load(meta_data) + + def __create_team(team_data, ground): + team = Team( + team_id=str(team_data["TeamID"]), + name=html.unescape(team_data["ShortName"]), + ground=ground, + ) + + team.players = [ + Player( + player_id=str(player["PlayerID"]), + team=team, + first_name=html.unescape(player["FirstName"]), + last_name=html.unescape(player["LastName"]), + name=html.unescape( + player["FirstName"] + " " + player["LastName"] + ), + jersey_no=int(player["JerseyNo"]), + starting=True if player["StartingPosition"] != "S" else False, + starting_position=position_types_mapping.get( + player["StartingPosition"], PositionType.Unknown + ), + ) + for player in team_data["Players"] + ] + + return team + + with performance_logging("Loading metadata", logger=logger): + frame_rate = meta_data["FrameRate"] + pitch_size_width = meta_data["PitchShortSide"] / 100 + pitch_size_length = meta_data["PitchLongSide"] / 100 + + periods = [] + for period_id in (1, 2, 3, 4): + period_start_frame = meta_data[f"Phase{period_id}StartFrame"] + period_end_frame = meta_data[f"Phase{period_id}EndFrame"] + if period_start_frame != 0 or period_end_frame != 0: + periods.append( + Period( + id=period_id, + start_timestamp=timedelta( + seconds=period_start_frame / frame_rate + ), + end_timestamp=timedelta( + seconds=period_end_frame / frame_rate + ), + ) + ) + + home_team = __create_team(meta_data["HomeTeam"], Ground.HOME) + away_team = __create_team(meta_data["AwayTeam"], Ground.AWAY) + teams = [home_team, away_team] + + date = meta_data.get("Kickoff", None) + if date is not None: + date = parse(date).astimezone(timezone.utc) + game_id = meta_data.get("GameID", None) + + return ( + pitch_size_length, + pitch_size_width, + teams, + periods, + frame_rate, + date, + game_id, + ) + + +def load_meta_data_xml(meta_data): + def __create_team( + team_data, ground, start_frame_id, id_suffix="Id", player_item="Player" + ): + team = Team( + team_id=str(team_data[f"Team{id_suffix}"]), + name=html.unescape(team_data["ShortName"]), + ground=ground, + ) + + team.players = [ + Player( + player_id=str(player[f"Player{id_suffix}"]), + team=team, + first_name=html.unescape(player["FirstName"]), + last_name=html.unescape(player["LastName"]), + name=html.unescape( + player["FirstName"] + " " + player["LastName"] + ), + jersey_no=int(player["JerseyNo"]), + starting=player["StartFrameCount"] == start_frame_id, + starting_position=position_types_mapping.get( + player.get("StartingPosition"), PositionType.Unknown + ), + ) + for player in team_data["Players"][player_item] + ] + + return team + + with performance_logging("Loading metadata", logger=logger): + meta_data = objectify.fromstring(meta_data.read()) + + periods = [] + + if hasattr(meta_data, "match"): + id_suffix = "Id" + player_item = "Player" + + match = meta_data.match + frame_rate = int(match.attrib["iFrameRateFps"]) + pitch_size_width = float( + match.attrib["fPitchXSizeMeters"].replace(",", ".") + ) + pitch_size_height = float( + match.attrib["fPitchYSizeMeters"].replace(",", ".") + ) + date = parse(meta_data.match.attrib["dtDate"]).replace( + tzinfo=timezone.utc + ) + game_id = meta_data.match.attrib["iId"] + + for period in match.iterchildren(tag="period"): + start_frame_id = int(period.attrib["iStartFrame"]) + end_frame_id = int(period.attrib["iEndFrame"]) + if start_frame_id != 0 or end_frame_id != 0: + periods.append( + Period( + id=int(period.attrib["iId"]), + start_timestamp=timedelta( + seconds=start_frame_id / frame_rate + ), + end_timestamp=timedelta( + seconds=end_frame_id / frame_rate + ), + ) + ) + elif hasattr(meta_data, "Phase1StartFrame"): + date = parse(str(meta_data["Kickoff"])) + game_id = str(meta_data["GameID"]) + id_suffix = "ID" + player_item = "item" + + frame_rate = int(meta_data["FrameRate"]) + pitch_size_width = float(meta_data["PitchLongSide"]) / 100 + pitch_size_height = float(meta_data["PitchShortSide"]) / 100 + for i in [1, 2, 3, 4, 5]: + start_frame_id = int(meta_data[f"Phase{i}StartFrame"]) + end_frame_id = int(meta_data[f"Phase{i}EndFrame"]) + if start_frame_id != 0 or end_frame_id != 0: + periods.append( + Period( + id=i, + start_timestamp=timedelta( + seconds=start_frame_id / frame_rate + ), + end_timestamp=timedelta( + seconds=end_frame_id / frame_rate + ), + ) + ) + + orientation = ( + Orientation.HOME_AWAY + if bool(meta_data["Phase1HomeGKLeft"]) + else Orientation.AWAY_HOME + ) + else: + raise NotImplementedError( + """This 'meta_data' format is currently not supported...""" + ) + + if hasattr(meta_data, "HomeTeam") and hasattr(meta_data, "AwayTeam"): + home_team = __create_team( + meta_data["HomeTeam"], + Ground.HOME, + start_frame_id=start_frame_id, + id_suffix=id_suffix, + player_item=player_item, + ) + away_team = __create_team( + meta_data["AwayTeam"], + Ground.AWAY, + start_frame_id=start_frame_id, + id_suffix=id_suffix, + player_item=player_item, + ) + else: + home_team = Team(team_id="home", name="home", ground=Ground.HOME) + away_team = Team(team_id="away", name="away", ground=Ground.AWAY) + teams = [home_team, away_team] + return ( + pitch_size_height, + pitch_size_width, + teams, + periods, + frame_rate, + date, + game_id, + ) + + +def load_meta_data(meta_data_extension, meta_data): + if meta_data_extension == ".xml": + return load_meta_data_xml(meta_data) + elif meta_data_extension == ".json": + return load_meta_data_json(meta_data) + else: + raise ValueError( + "Tracab meta data file format could not be recognized, it should be either .xml or .json" + ) diff --git a/kloppy/infra/serializers/tracking/tracab/tracab_dat.py b/kloppy/infra/serializers/tracking/tracab/tracab_dat.py index 2652861c..74548570 100644 --- a/kloppy/infra/serializers/tracking/tracab/tracab_dat.py +++ b/kloppy/infra/serializers/tracking/tracab/tracab_dat.py @@ -1,226 +1,226 @@ -import logging -from datetime import timedelta, timezone -import warnings -from typing import Dict, Optional, Union, Literal -import html -from dateutil.parser import parse - -from lxml import objectify - -from kloppy.domain import ( - TrackingDataset, - DatasetFlag, - AttackingDirection, - Frame, - Point, - Point3D, - Team, - BallState, - Period, - Orientation, - attacking_direction_from_frame, - Metadata, - Ground, - Player, - Provider, - PlayerData, - PositionType, -) -from kloppy.exceptions import DeserializationError - -from kloppy.utils import Readable, performance_logging - -from .common import TRACABInputs, position_types_mapping -from .helpers import load_meta_data -from ..deserializer import TrackingDataDeserializer - -logger = logging.getLogger(__name__) - - -class TRACABDatDeserializer(TrackingDataDeserializer[TRACABInputs]): - def __init__( - self, - limit: Optional[int] = None, - sample_rate: Optional[float] = None, - coordinate_system: Optional[Union[str, Provider]] = None, - only_alive: Optional[bool] = True, - meta_data_extension: Literal[".xml", ".json"] = None, - ): - super().__init__(limit, sample_rate, coordinate_system) - self.only_alive = only_alive - self.meta_data_extension = meta_data_extension - - @property - def provider(self) -> Provider: - return Provider.TRACAB - - @classmethod - def _frame_from_line(cls, teams, period, line, frame_rate): - line = str(line) - frame_id, players, ball = line.strip().split(":")[:3] - - players_data = {} - - for player_data in players.split(";")[:-1]: - team_id, target_id, jersey_no, x, y, speed = player_data.split(",") - team_id = int(team_id) - - if team_id == 1: - team = teams[0] - elif team_id == 0: - team = teams[1] - elif team_id in (-1, 3, 4): - continue - else: - raise DeserializationError( - f"Unknown Player Team ID: {team_id}" - ) - - player = team.get_player_by_jersey_number(jersey_no) - - if not player: - player = Player( - player_id=f"{team.ground}_{jersey_no}", - team=team, - jersey_no=int(jersey_no), - ) - team.players.append(player) - - players_data[player] = PlayerData( - coordinates=Point(float(x), float(y)), speed=float(speed) - ) - - ( - ball_x, - ball_y, - ball_z, - ball_speed, - ball_owning_team, - ball_state, - ) = ball.rstrip(";").split(",")[:6] - - frame_id = int(frame_id) - - if ball_owning_team == "H": - ball_owning_team = teams[0] - elif ball_owning_team == "A": - ball_owning_team = teams[1] - else: - raise DeserializationError( - f"Unknown ball owning team: {ball_owning_team}" - ) - - if ball_state == "Alive": - ball_state = BallState.ALIVE - elif ball_state == "Dead": - ball_state = BallState.DEAD - else: - raise DeserializationError(f"Unknown ball state: {ball_state}") - - return Frame( - frame_id=frame_id, - timestamp=timedelta(seconds=frame_id / frame_rate) - - period.start_timestamp, - ball_coordinates=Point3D( - float(ball_x), float(ball_y), float(ball_z) - ), - ball_state=ball_state, - ball_owning_team=ball_owning_team, - players_data=players_data, - period=period, - other_data={}, - ) - - @staticmethod - def __validate_inputs(inputs: Dict[str, Readable]): - if "metadata" not in inputs: - raise ValueError("Please specify a value for 'metadata'") - if "raw_data" not in inputs: - raise ValueError("Please specify a value for 'raw_data'") - - def deserialize(self, inputs: TRACABInputs) -> TrackingDataset: - ( - pitch_size_height, - pitch_size_width, - teams, - periods, - frame_rate, - date, - game_id, - ) = load_meta_data(self.meta_data_extension, inputs.meta_data) - - orientation = None - - with performance_logging("Loading data", logger=logger): - transformer = self.get_transformer( - pitch_length=pitch_size_width, pitch_width=pitch_size_height - ) - - def _iter(): - n = 0 - sample = 1.0 / self.sample_rate - - for line_ in inputs.raw_data.readlines(): - line_ = line_.strip().decode("ascii") - if not line_: - continue - - frame_id = int(line_[:10].split(":", 1)[0]) - if self.only_alive and not line_.endswith("Alive;:"): - continue - - for period_ in periods: - if ( - period_.start_timestamp - <= timedelta(seconds=frame_id / frame_rate) - <= period_.end_timestamp - ): - if n % sample == 0: - yield period_, line_ - n += 1 - - frames = [] - for n, (period, line) in enumerate(_iter()): - frame = self._frame_from_line(teams, period, line, frame_rate) - - frame = transformer.transform_frame(frame) - frames.append(frame) - - if self.limit and n >= self.limit: - break - - if not orientation: - try: - first_frame = next( - frame for frame in frames if frame.period.id == 1 - ) - orientation = ( - Orientation.HOME_AWAY - if attacking_direction_from_frame(first_frame) - == AttackingDirection.LTR - else Orientation.AWAY_HOME - ) - except StopIteration: - warnings.warn( - "Could not determine orientation of dataset, defaulting to NOT_SET" - ) - orientation = Orientation.NOT_SET - - metadata = Metadata( - teams=teams, - periods=periods, - pitch_dimensions=transformer.get_to_coordinate_system().pitch_dimensions, - score=None, - frame_rate=frame_rate, - orientation=orientation, - provider=Provider.TRACAB, - flags=DatasetFlag.BALL_OWNING_TEAM | DatasetFlag.BALL_STATE, - coordinate_system=transformer.get_to_coordinate_system(), - date=date, - game_id=game_id, - ) - - return TrackingDataset( - records=frames, - metadata=metadata, - ) +import logging +from datetime import timedelta, timezone +import warnings +from typing import Dict, Optional, Union, Literal +import html +from dateutil.parser import parse + +from lxml import objectify + +from kloppy.domain import ( + TrackingDataset, + DatasetFlag, + AttackingDirection, + Frame, + Point, + Point3D, + Team, + BallState, + Period, + Orientation, + attacking_direction_from_frame, + Metadata, + Ground, + Player, + Provider, + PlayerData, + PositionType, +) +from kloppy.exceptions import DeserializationError + +from kloppy.utils import Readable, performance_logging + +from .common import TRACABInputs, position_types_mapping +from .helpers import load_meta_data +from ..deserializer import TrackingDataDeserializer + +logger = logging.getLogger(__name__) + + +class TRACABDatDeserializer(TrackingDataDeserializer[TRACABInputs]): + def __init__( + self, + limit: Optional[int] = None, + sample_rate: Optional[float] = None, + coordinate_system: Optional[Union[str, Provider]] = None, + only_alive: Optional[bool] = True, + meta_data_extension: Literal[".xml", ".json"] = None, + ): + super().__init__(limit, sample_rate, coordinate_system) + self.only_alive = only_alive + self.meta_data_extension = meta_data_extension + + @property + def provider(self) -> Provider: + return Provider.TRACAB + + @classmethod + def _frame_from_line(cls, teams, period, line, frame_rate): + line = str(line) + frame_id, players, ball = line.strip().split(":")[:3] + + players_data = {} + + for player_data in players.split(";")[:-1]: + team_id, target_id, jersey_no, x, y, speed = player_data.split(",") + team_id = int(team_id) + + if team_id == 1: + team = teams[0] + elif team_id == 0: + team = teams[1] + elif team_id in (-1, 3, 4): + continue + else: + raise DeserializationError( + f"Unknown Player Team ID: {team_id}" + ) + + player = team.get_player_by_jersey_number(jersey_no) + + if not player: + player = Player( + player_id=f"{team.ground}_{jersey_no}", + team=team, + jersey_no=int(jersey_no), + ) + team.players.append(player) + + players_data[player] = PlayerData( + coordinates=Point(float(x), float(y)), speed=float(speed) + ) + + ( + ball_x, + ball_y, + ball_z, + ball_speed, + ball_owning_team, + ball_state, + ) = ball.rstrip(";").split(",")[:6] + + frame_id = int(frame_id) + + if ball_owning_team == "H": + ball_owning_team = teams[0] + elif ball_owning_team == "A": + ball_owning_team = teams[1] + else: + raise DeserializationError( + f"Unknown ball owning team: {ball_owning_team}" + ) + + if ball_state == "Alive": + ball_state = BallState.ALIVE + elif ball_state == "Dead": + ball_state = BallState.DEAD + else: + raise DeserializationError(f"Unknown ball state: {ball_state}") + + return Frame( + frame_id=frame_id, + timestamp=timedelta(seconds=frame_id / frame_rate) + - period.start_timestamp, + ball_coordinates=Point3D( + float(ball_x), float(ball_y), float(ball_z) + ), + ball_state=ball_state, + ball_owning_team=ball_owning_team, + players_data=players_data, + period=period, + other_data={}, + ) + + @staticmethod + def __validate_inputs(inputs: Dict[str, Readable]): + if "metadata" not in inputs: + raise ValueError("Please specify a value for 'metadata'") + if "raw_data" not in inputs: + raise ValueError("Please specify a value for 'raw_data'") + + def deserialize(self, inputs: TRACABInputs) -> TrackingDataset: + ( + pitch_size_height, + pitch_size_width, + teams, + periods, + frame_rate, + date, + game_id, + ) = load_meta_data(self.meta_data_extension, inputs.meta_data) + + orientation = None + + with performance_logging("Loading data", logger=logger): + transformer = self.get_transformer( + pitch_length=pitch_size_width, pitch_width=pitch_size_height + ) + + def _iter(): + n = 0 + sample = 1.0 / self.sample_rate + + for line_ in inputs.raw_data.readlines(): + line_ = line_.strip().decode("ascii") + if not line_: + continue + + frame_id = int(line_[:10].split(":", 1)[0]) + if self.only_alive and not line_.endswith("Alive;:"): + continue + + for period_ in periods: + if ( + period_.start_timestamp + <= timedelta(seconds=frame_id / frame_rate) + <= period_.end_timestamp + ): + if n % sample == 0: + yield period_, line_ + n += 1 + + frames = [] + for n, (period, line) in enumerate(_iter()): + frame = self._frame_from_line(teams, period, line, frame_rate) + + frame = transformer.transform_frame(frame) + frames.append(frame) + + if self.limit and n >= self.limit: + break + + if not orientation: + try: + first_frame = next( + frame for frame in frames if frame.period.id == 1 + ) + orientation = ( + Orientation.HOME_AWAY + if attacking_direction_from_frame(first_frame) + == AttackingDirection.LTR + else Orientation.AWAY_HOME + ) + except StopIteration: + warnings.warn( + "Could not determine orientation of dataset, defaulting to NOT_SET" + ) + orientation = Orientation.NOT_SET + + metadata = Metadata( + teams=teams, + periods=periods, + pitch_dimensions=transformer.get_to_coordinate_system().pitch_dimensions, + score=None, + frame_rate=frame_rate, + orientation=orientation, + provider=Provider.TRACAB, + flags=DatasetFlag.BALL_OWNING_TEAM | DatasetFlag.BALL_STATE, + coordinate_system=transformer.get_to_coordinate_system(), + date=date, + game_id=game_id, + ) + + return TrackingDataset( + records=frames, + metadata=metadata, + ) diff --git a/kloppy/infra/serializers/tracking/tracab/tracab_json.py b/kloppy/infra/serializers/tracking/tracab/tracab_json.py index 5a05d72e..6413330e 100644 --- a/kloppy/infra/serializers/tracking/tracab/tracab_json.py +++ b/kloppy/infra/serializers/tracking/tracab/tracab_json.py @@ -1,216 +1,216 @@ -import logging -import warnings -import json -import html -from datetime import timedelta -from typing import Dict, Optional, Union, Literal - -from kloppy.domain import ( - TrackingDataset, - DatasetFlag, - AttackingDirection, - Frame, - Point, - Point3D, - Team, - BallState, - Period, - Orientation, - Metadata, - Ground, - Player, - Provider, - PlayerData, - attacking_direction_from_frame, -) -from kloppy.domain.models import PositionType -from kloppy.exceptions import DeserializationError - -from kloppy.utils import Readable, performance_logging - -from .common import TRACABInputs, position_types_mapping -from .helpers import load_meta_data -from ..deserializer import TrackingDataDeserializer - -logger = logging.getLogger(__name__) - - -class TRACABJSONDeserializer(TrackingDataDeserializer[TRACABInputs]): - def __init__( - self, - limit: Optional[int] = None, - sample_rate: Optional[float] = None, - coordinate_system: Optional[Union[str, Provider]] = None, - only_alive: Optional[bool] = True, - meta_data_extension: Literal[".xml", ".json"] = None, - ): - super().__init__(limit, sample_rate, coordinate_system) - self.only_alive = only_alive - self.meta_data_extension = meta_data_extension - - @property - def provider(self) -> Provider: - return Provider.TRACAB - - @classmethod - def _create_frame(cls, teams, period, raw_frame, frame_rate): - frame_id = raw_frame["FrameCount"] - raw_players_data = raw_frame["PlayerPositions"] - raw_ball_position = raw_frame["BallPosition"][0] - - players_data = {} - for player_data in raw_players_data: - if player_data["Team"] == 1: - team = teams[0] - elif player_data["Team"] == 0: - team = teams[1] - elif player_data["Team"] in (-1, 3, 4): - continue - else: - raise DeserializationError( - f"Unknown Player Team ID: {player_data['Team']}" - ) - - jersey_no = player_data["JerseyNumber"] - x = player_data["X"] - y = player_data["Y"] - speed = player_data["Speed"] - - player = team.get_player_by_jersey_number(jersey_no) - if player: - players_data[player] = PlayerData( - coordinates=Point(x, y), speed=speed - ) - else: - # continue - raise DeserializationError( - f"Player not found for player jersey no {jersey_no} of team: {team.name}" - ) - - ball_x = raw_ball_position["X"] - ball_y = raw_ball_position["Y"] - ball_z = raw_ball_position["Z"] - ball_speed = raw_ball_position["Speed"] - if raw_ball_position["BallOwningTeam"] == "H": - ball_owning_team = teams[0] - elif raw_ball_position["BallOwningTeam"] == "A": - ball_owning_team = teams[1] - else: - raise DeserializationError( - f"Unknown ball owning team: {raw_ball_position['BallOwningTeam']}" - ) - if raw_ball_position["BallStatus"] == "Alive": - ball_state = BallState.ALIVE - elif raw_ball_position["BallStatus"] == "Dead": - ball_state = BallState.DEAD - else: - raise DeserializationError( - f"Unknown ball state: {raw_ball_position['BallStatus']}" - ) - - return Frame( - frame_id=frame_id, - timestamp=timedelta(seconds=frame_id / frame_rate) - - period.start_timestamp, - ball_coordinates=Point3D(ball_x, ball_y, ball_z), - ball_state=ball_state, - ball_owning_team=ball_owning_team, - ball_speed=ball_speed, - players_data=players_data, - period=period, - other_data={}, - ) - - @staticmethod - def __validate_inputs(inputs: Dict[str, Readable]): - if "metadata" not in inputs: - raise ValueError("Please specify a value for 'metadata'") - if "raw_data" not in inputs: - raise ValueError("Please specify a value for 'raw_data'") - - def deserialize(self, inputs: TRACABInputs) -> TrackingDataset: - ( - pitch_size_length, - pitch_size_width, - teams, - periods, - frame_rate, - date, - game_id, - ) = load_meta_data(self.meta_data_extension, inputs.meta_data) - raw_data = json.load(inputs.raw_data) - - transformer = self.get_transformer( - pitch_length=pitch_size_length, pitch_width=pitch_size_width - ) - - with performance_logging("Loading data", logger=logger): - raw_data = raw_data["FrameData"] - - def _iter(): - n = 0 - sample = 1.0 / self.sample_rate - - for frame in raw_data: - if ( - self.only_alive - and frame["BallPosition"][0]["BallStatus"] == "Dead" - ): - continue - - frame_id = frame["FrameCount"] - for _period in periods: - if ( - _period.start_timestamp - <= timedelta(seconds=frame_id / frame_rate) - <= _period.end_timestamp - ): - if n % sample == 0: - yield _period, frame - n += 1 - - frames = [] - for n, (_period, _frame) in enumerate(_iter()): - frame = self._create_frame(teams, _period, _frame, frame_rate) - - frame = transformer.transform_frame(frame) - - frames.append(frame) - - if self.limit and n >= self.limit: - break - - try: - first_frame = next( - frame for frame in frames if frame.period.id == 1 - ) - orientation = ( - Orientation.HOME_AWAY - if attacking_direction_from_frame(first_frame) - == AttackingDirection.LTR - else Orientation.AWAY_HOME - ) - except StopIteration: - warnings.warn( - "Could not determine orientation of dataset, defaulting to NOT_SET" - ) - orientation = Orientation.NOT_SET - - metadata = Metadata( - teams=teams, - periods=periods, - pitch_dimensions=transformer.get_to_coordinate_system().pitch_dimensions, - score=None, - frame_rate=frame_rate, - orientation=orientation, - provider=Provider.TRACAB, - flags=DatasetFlag.BALL_OWNING_TEAM | DatasetFlag.BALL_STATE, - coordinate_system=transformer.get_to_coordinate_system(), - date=date, - game_id=game_id, - ) - - return TrackingDataset( - records=frames, - metadata=metadata, - ) +import logging +import warnings +import json +import html +from datetime import timedelta +from typing import Dict, Optional, Union, Literal + +from kloppy.domain import ( + TrackingDataset, + DatasetFlag, + AttackingDirection, + Frame, + Point, + Point3D, + Team, + BallState, + Period, + Orientation, + Metadata, + Ground, + Player, + Provider, + PlayerData, + attacking_direction_from_frame, +) +from kloppy.domain.models import PositionType +from kloppy.exceptions import DeserializationError + +from kloppy.utils import Readable, performance_logging + +from .common import TRACABInputs, position_types_mapping +from .helpers import load_meta_data +from ..deserializer import TrackingDataDeserializer + +logger = logging.getLogger(__name__) + + +class TRACABJSONDeserializer(TrackingDataDeserializer[TRACABInputs]): + def __init__( + self, + limit: Optional[int] = None, + sample_rate: Optional[float] = None, + coordinate_system: Optional[Union[str, Provider]] = None, + only_alive: Optional[bool] = True, + meta_data_extension: Literal[".xml", ".json"] = None, + ): + super().__init__(limit, sample_rate, coordinate_system) + self.only_alive = only_alive + self.meta_data_extension = meta_data_extension + + @property + def provider(self) -> Provider: + return Provider.TRACAB + + @classmethod + def _create_frame(cls, teams, period, raw_frame, frame_rate): + frame_id = raw_frame["FrameCount"] + raw_players_data = raw_frame["PlayerPositions"] + raw_ball_position = raw_frame["BallPosition"][0] + + players_data = {} + for player_data in raw_players_data: + if player_data["Team"] == 1: + team = teams[0] + elif player_data["Team"] == 0: + team = teams[1] + elif player_data["Team"] in (-1, 3, 4): + continue + else: + raise DeserializationError( + f"Unknown Player Team ID: {player_data['Team']}" + ) + + jersey_no = player_data["JerseyNumber"] + x = player_data["X"] + y = player_data["Y"] + speed = player_data["Speed"] + + player = team.get_player_by_jersey_number(jersey_no) + if player: + players_data[player] = PlayerData( + coordinates=Point(x, y), speed=speed + ) + else: + # continue + raise DeserializationError( + f"Player not found for player jersey no {jersey_no} of team: {team.name}" + ) + + ball_x = raw_ball_position["X"] + ball_y = raw_ball_position["Y"] + ball_z = raw_ball_position["Z"] + ball_speed = raw_ball_position["Speed"] + if raw_ball_position["BallOwningTeam"] == "H": + ball_owning_team = teams[0] + elif raw_ball_position["BallOwningTeam"] == "A": + ball_owning_team = teams[1] + else: + raise DeserializationError( + f"Unknown ball owning team: {raw_ball_position['BallOwningTeam']}" + ) + if raw_ball_position["BallStatus"] == "Alive": + ball_state = BallState.ALIVE + elif raw_ball_position["BallStatus"] == "Dead": + ball_state = BallState.DEAD + else: + raise DeserializationError( + f"Unknown ball state: {raw_ball_position['BallStatus']}" + ) + + return Frame( + frame_id=frame_id, + timestamp=timedelta(seconds=frame_id / frame_rate) + - period.start_timestamp, + ball_coordinates=Point3D(ball_x, ball_y, ball_z), + ball_state=ball_state, + ball_owning_team=ball_owning_team, + ball_speed=ball_speed, + players_data=players_data, + period=period, + other_data={}, + ) + + @staticmethod + def __validate_inputs(inputs: Dict[str, Readable]): + if "metadata" not in inputs: + raise ValueError("Please specify a value for 'metadata'") + if "raw_data" not in inputs: + raise ValueError("Please specify a value for 'raw_data'") + + def deserialize(self, inputs: TRACABInputs) -> TrackingDataset: + ( + pitch_size_length, + pitch_size_width, + teams, + periods, + frame_rate, + date, + game_id, + ) = load_meta_data(self.meta_data_extension, inputs.meta_data) + raw_data = json.load(inputs.raw_data) + + transformer = self.get_transformer( + pitch_length=pitch_size_length, pitch_width=pitch_size_width + ) + + with performance_logging("Loading data", logger=logger): + raw_data = raw_data["FrameData"] + + def _iter(): + n = 0 + sample = 1.0 / self.sample_rate + + for frame in raw_data: + if ( + self.only_alive + and frame["BallPosition"][0]["BallStatus"] == "Dead" + ): + continue + + frame_id = frame["FrameCount"] + for _period in periods: + if ( + _period.start_timestamp + <= timedelta(seconds=frame_id / frame_rate) + <= _period.end_timestamp + ): + if n % sample == 0: + yield _period, frame + n += 1 + + frames = [] + for n, (_period, _frame) in enumerate(_iter()): + frame = self._create_frame(teams, _period, _frame, frame_rate) + + frame = transformer.transform_frame(frame) + + frames.append(frame) + + if self.limit and n >= self.limit: + break + + try: + first_frame = next( + frame for frame in frames if frame.period.id == 1 + ) + orientation = ( + Orientation.HOME_AWAY + if attacking_direction_from_frame(first_frame) + == AttackingDirection.LTR + else Orientation.AWAY_HOME + ) + except StopIteration: + warnings.warn( + "Could not determine orientation of dataset, defaulting to NOT_SET" + ) + orientation = Orientation.NOT_SET + + metadata = Metadata( + teams=teams, + periods=periods, + pitch_dimensions=transformer.get_to_coordinate_system().pitch_dimensions, + score=None, + frame_rate=frame_rate, + orientation=orientation, + provider=Provider.TRACAB, + flags=DatasetFlag.BALL_OWNING_TEAM | DatasetFlag.BALL_STATE, + coordinate_system=transformer.get_to_coordinate_system(), + date=date, + game_id=game_id, + ) + + return TrackingDataset( + records=frames, + metadata=metadata, + ) diff --git a/kloppy/io.py b/kloppy/io.py index 7834bddb..0726f99c 100644 --- a/kloppy/io.py +++ b/kloppy/io.py @@ -1,469 +1,469 @@ -"""I/O utilities for reading raw data.""" - -import bz2 -import contextlib -import gzip -import logging -import lzma -import os -import urllib.parse -from dataclasses import dataclass, replace -from io import BufferedWriter, BytesIO, TextIOWrapper -from typing import ( - IO, - BinaryIO, - ContextManager, - Generator, - Optional, - Tuple, - Union, -) - -from kloppy.config import get_config -from kloppy.exceptions import InputNotFoundError -from kloppy.infra.io.adapters import get_adapter - -logger = logging.getLogger(__name__) - -DEFAULT_GZIP_COMPRESSION = 1 -DEFAULT_BZ2_COMPRESSION = 9 -DEFAULT_XZ_COMPRESSION = 6 - - -FilePath = Union[str, bytes, os.PathLike] -FileOrPath = Union[FilePath, IO] - - -@dataclass(frozen=True) -class Source: - """A wrapper around a file-like object to enable optional inputs. - - Args: - data (FileLike): The file-like object. - optional (bool): Whether the file is optional. Defaults to False. - skip_if_missing (bool): Whether to skip the file if it is missing. Defaults to False. - - Example: - >>> open_as_file(Source.create("example.csv", optional=True)) - """ - - data: FileOrPath - optional: bool = False - skip_if_missing: bool = False - - @classmethod - def create(cls, input_: "FileLike", **kwargs): - if isinstance(input_, Source): - return replace(input_, **kwargs) - return Source(data=input_, **kwargs) - - -FileLike = Union[FileOrPath, Source] - - -def _file_or_path_to_binary_stream( - file_or_path: FileOrPath, binary_mode: str -) -> Tuple[BinaryIO, bool]: - """ - Converts a file path or a file-like object to a binary stream. - - Args: - file_or_path: The file path or file-like object to convert. - binary_mode: The binary mode to open the file in. Must be one of 'rb', 'wb', or 'ab'. - - Returns: - A tuple containing the binary stream and a boolean indicating whether - a new file was opened (True) or an existing file-like object was used (False). - """ - assert binary_mode in ("rb", "wb", "ab") - - if isinstance(file_or_path, (str, bytes)) or hasattr( - file_or_path, "__fspath__" - ): - # If file_or_path is a path-like object, open it and return the binary stream - return open(os.fspath(file_or_path), binary_mode), True # type: ignore - - if isinstance(file_or_path, TextIOWrapper): - # If file_or_path is a TextIOWrapper, return its underlying binary buffer - return file_or_path.buffer, False - - if hasattr(file_or_path, "readinto") or hasattr(file_or_path, "write"): - # If file_or_path is a file-like object, return it as is - return file_or_path, False # type: ignore - - raise TypeError( - f"Unsupported type for {file_or_path}, " - f"{file_or_path.__class__.__name__}." - ) - - -def _detect_format_from_content(file_or_path: FileOrPath) -> Optional[str]: - """ - Attempts to detect file format from the content by reading the first - 6 bytes. Returns None if no format could be detected. - """ - fileobj, closefd = _file_or_path_to_binary_stream(file_or_path, "rb") - try: - if not fileobj.readable(): - return None - if hasattr(fileobj, "peek"): - bs = fileobj.peek(6) - elif hasattr(fileobj, "seekable") and fileobj.seekable(): - current_pos = fileobj.tell() - bs = fileobj.read(6) - fileobj.seek(current_pos) - else: - return None - - if bs[:2] == b"\x1f\x8b": - # https://tools.ietf.org/html/rfc1952#page-6 - return "gz" - elif bs[:3] == b"\x42\x5a\x68": - # https://en.wikipedia.org/wiki/List_of_file_signatures - return "bz2" - elif bs[:6] == b"\xfd\x37\x7a\x58\x5a\x00": - # https://tukaani.org/xz/xz-file-format.txt - return "xz" - return None - finally: - if closefd: - fileobj.close() - - -def _detect_format_from_extension(filename: FilePath) -> Optional[str]: - """ - Attempt to detect file format from the filename extension. - Return None if no format could be detected. - """ - extensions = ("bz2", "xz", "gz") - - if isinstance(filename, bytes): - for ext in extensions: - if filename.endswith(b"." + ext.encode()): - return ext - - if isinstance(filename, str): - for ext in extensions: - if filename.endswith("." + ext): - return ext - - if hasattr(filename, "name"): - return _detect_format_from_extension(filename.name) - - return None - - -def _filepath_from_path_or_filelike(file_or_path: FileOrPath) -> str: - try: - return os.fspath(file_or_path) # type: ignore - except TypeError: - pass - - if hasattr(file_or_path, "name"): - name = file_or_path.name - if isinstance(name, str): - return name - elif isinstance(name, bytes): - return name.decode() - - return "" - - -def _open( - filename: FileOrPath, - mode: str = "rb", - compresslevel: Optional[int] = None, - format: Optional[str] = None, -) -> BinaryIO: - """ - A replacement for the "open" function that can also read and write - compressed files transparently. The supported compression formats are gzip, - bzip2 and xz. Filename can be a string, a Path or a file object. - - When writing, the file format is chosen based on the file name extension: - - .gz uses gzip compression - - .bz2 uses bzip2 compression - - .xz uses xz/lzma compression - - otherwise, no compression is used - - When reading, if a file name extension is available, the format is detected - using it, but if not, the format is detected from the contents. - - mode can be: 'rb', 'ab', or 'wb'. - - compresslevel is the compression level for writing to gzip, and xz. - This parameter is ignored for the other compression formats. - If set to None, a default depending on the format is used: - gzip: 6, xz: 6. - - format overrides the autodetection of input and output formats. This can be - useful when compressed output needs to be written to a file without an - extension. Possible values are "gz", "xz", "bz2" and "raw". In case of - "raw", no compression is used. - """ - if mode not in ("rb", "wb", "ab"): - raise ValueError("Mode '{}' not supported".format(mode)) - filepath = _filepath_from_path_or_filelike(filename) - - if format not in (None, "gz", "xz", "bz2", "raw"): - raise ValueError( - f"Format not supported: {format}. Choose one of: 'gz', 'xz', 'bz2'" - ) - - if format == "raw": - detected_format = None - else: - detected_format = format or _detect_format_from_extension(filepath) - if detected_format is None and "r" in mode: - detected_format = _detect_format_from_content(filename) - - if detected_format == "gz": - opened_file = _open_gz(filename, mode, compresslevel) - elif detected_format == "xz": - opened_file = _open_xz(filename, mode, compresslevel) - elif detected_format == "bz2": - opened_file = _open_bz2(filename, mode, compresslevel) - else: - opened_file, _ = _file_or_path_to_binary_stream(filename, mode) - - return opened_file - - -def _open_bz2( - filename: FileOrPath, - mode: str, - compresslevel: Optional[int] = None, -) -> BinaryIO: - assert mode in ("rb", "ab", "wb") - if compresslevel is None: - compresslevel = DEFAULT_BZ2_COMPRESSION - - if "r" in mode: - return bz2.open(filename, mode) # type: ignore - return BufferedWriter(bz2.open(filename, mode, compresslevel)) # type: ignore - - -def _open_xz( - filename: FileOrPath, - mode: str, - compresslevel: Optional[int] = None, -) -> BinaryIO: - assert mode in ("rb", "ab", "wb") - if compresslevel is None: - compresslevel = DEFAULT_XZ_COMPRESSION - - if "r" in mode: - return lzma.open(filename, mode) # type: ignore - return BufferedWriter(lzma.open(filename, mode, preset=compresslevel)) # type: ignore - - -def _open_gz( - filename: FileOrPath, - mode: str, - compresslevel: Optional[int] = None, -) -> BinaryIO: - assert mode in ("rb", "ab", "wb") - if compresslevel is None: - compresslevel = DEFAULT_GZIP_COMPRESSION - - if "r" in mode: - return gzip.open(filename, mode) # type: ignore - return BufferedWriter(gzip.open(filename, mode, compresslevel=compresslevel)) # type: ignore - - -def get_file_extension(file_or_path: FileLike) -> str: - """Determine the file extension of the given file-like object. - - If the file has compression extensions such as '.gz', '.xz', or '.bz2', - they will be stripped before determining the extension. - - Args: - file_or_path (FileLike): The file-like object whose extension needs to be determined. - - Returns: - str: The file extension, including the dot ('.') if present. - - Raises: - Exception: If the extension cannot be determined. - - Example: - >>> get_file_extension("example.xml.gz") - '.xml' - >>> get_file_extension(Path("example.txt")) - '.txt' - >>> get_file_extension(Source(data="example.csv")) - '.csv' - """ - if isinstance(file_or_path, (str, bytes)) or hasattr( - file_or_path, "__fspath__" - ): - path = os.fspath(file_or_path) # type: ignore - for ext in [".gz", ".xz", ".bz2"]: - if path.endswith(ext): - path = path[: -len(ext)] - return os.path.splitext(path)[1] - - if isinstance(file_or_path, Source): - return get_file_extension(file_or_path.data) - - raise TypeError( - f"Could not determine extension for input type: {type(file_or_path)}" - ) - - -def get_local_cache_stream( - url: str, cache_dir: str, mode: str = "rb", format: Optional[str] = None -) -> Tuple[BinaryIO, Union[bool, str]]: - """Get a stream to the local cache file for the given URL. - - Compressed files are read transparently. The supported compression formats - are gzip, bzip2 and xz. - - Args: - url (str): The URL to cache. - cache_dir (str): The directory where the cache file will be stored. - mode (str): The mode in which to open the cache file. Must be one of - 'rb', 'wb', or 'ab'. Defaults to 'ab'. - format (str): Overrides the autodetection of input and output formats. - Possible values are "gz", "xz", "bz2" and "raw". In case of "raw", - no compression is used.. - - Returns: - Tuple[BinaryIO, bool | str]: A tuple containing a binary stream to the - local cache file and the path to the cache file if it already - exists and is non-empty, otherwise False. - - Note: - - If the specified cache directory does not exist, it will be created. - - If the cache file does not exist, it will be created and will be - named after the URL. - - Example: - >>> stream, exists = get_local_cache_stream("https://example.com/data", "./cache") - """ - assert mode in ("rb", "wb", "ab") - - # Ensure the cache directory exists - if not os.path.exists(cache_dir): - os.makedirs(cache_dir) - - # Generate the local filename based on the URL - filename = urllib.parse.quote_plus(url) - local_filename = f"{cache_dir}/{filename}" - - # Ensure the file exists by opening it in append-binary mode, creating it if necessary - file_exists_and_non_empty = ( - os.path.exists(local_filename) and os.path.getsize(local_filename) > 0 - ) - file = _open(local_filename, mode, format=format) - - return file, file_exists_and_non_empty - - -@contextlib.contextmanager -def dummy_context_mgr() -> Generator[None, None, None]: - yield - - -def open_as_file(input_: FileLike) -> ContextManager[Optional[BinaryIO]]: - """Open a byte stream to the given input object. - - The following input types are supported: - - A string or `pathlib.Path` object representing a local file path. - - A string representing a URL. It should start with 'http://' or - 'https://'. - - A string representing a path to a file in a Amazon S3 cloud storage - bucket. It should start with 's3://'. - - A xml or json string containing the data. The string should contain - a '{' or '<' character. Otherwise, it will be treated as a file path. - - A bytes object containing the data. - - A buffered binary stream that inherits from `io.BufferedIOBase`. - - A [Source](`kloppy.io.Source`) object that wraps any of the above - input types. - - Args: - input_ (FileLike): The input object to be opened. - - Returns: - BinaryIO: A binary stream to the input object. - - Raises: - InputNotFoundError: If the input file is not found. - TypeError: If the input type is not supported. - - Example: - >>> with open_as_file("example.txt") as f: - ... contents = f.read() - - Note: - To support reading data from other sources, see the - [Adapter](`kloppy.io.adapters.Adapter`) class. - - If the given file path or URL ends with '.gz', '.xz', or '.bz2', the - file will be decompressed before being read. - """ - if isinstance(input_, Source): - if input_.data is None and input_.optional: - # This saves us some additional code in every vendor specific code - return dummy_context_mgr() - - try: - return open_as_file(input_.data) - except InputNotFoundError: - if input_.skip_if_missing: - logging.info(f"Input {input_.data} not found. Skipping") - return dummy_context_mgr() - raise - - if isinstance(input_, str) and ("{" in input_ or "<" in input_): - # If input_ is a JSON or XML string, return it as a binary stream - return BytesIO(input_.encode("utf8")) - - if isinstance(input_, bytes): - # If input_ is a bytes object, return it as a binary stream - return BytesIO(input_) - - if isinstance(input_, str) or hasattr(input_, "__fspath__"): - # If input_ is a path-like object, open it and return the binary stream - uri = _filepath_from_path_or_filelike(input_) - - adapter = get_adapter(uri) - if adapter: - cache_dir = get_config("cache") - assert cache_dir is None or isinstance(cache_dir, str) - if cache_dir: - stream, local_cache_file = get_local_cache_stream( - uri, cache_dir, "ab", format="raw" - ) - else: - stream, local_cache_file = BytesIO(), None - - if not local_cache_file: - logger.info(f"Retrieving {uri}") - adapter.read_to_stream(uri, stream) - logger.info("Retrieval complete") - else: - logger.info(f"Using local cached file {local_cache_file}") - - if cache_dir: - stream.close() - stream, _ = get_local_cache_stream(uri, cache_dir, "rb") - else: - stream.seek(0) - - else: - if not os.path.exists(uri): - raise InputNotFoundError(f"File {uri} does not exist") - - stream = _open(uri, "rb") - return stream - - if isinstance(input_, TextIOWrapper): - # If file_or_path is a TextIOWrapper, return its underlying binary buffer - return input_.buffer - - if hasattr(input_, "readinto"): - # If file_or_path is a file-like object, return it as is - return _open(input_) # type: ignore - - raise TypeError(f"Unsupported input type: {type(input_)}") +"""I/O utilities for reading raw data.""" + +import bz2 +import contextlib +import gzip +import logging +import lzma +import os +import urllib.parse +from dataclasses import dataclass, replace +from io import BufferedWriter, BytesIO, TextIOWrapper +from typing import ( + IO, + BinaryIO, + ContextManager, + Generator, + Optional, + Tuple, + Union, +) + +from kloppy.config import get_config +from kloppy.exceptions import InputNotFoundError +from kloppy.infra.io.adapters import get_adapter + +logger = logging.getLogger(__name__) + +DEFAULT_GZIP_COMPRESSION = 1 +DEFAULT_BZ2_COMPRESSION = 9 +DEFAULT_XZ_COMPRESSION = 6 + + +FilePath = Union[str, bytes, os.PathLike] +FileOrPath = Union[FilePath, IO] + + +@dataclass(frozen=True) +class Source: + """A wrapper around a file-like object to enable optional inputs. + + Args: + data (FileLike): The file-like object. + optional (bool): Whether the file is optional. Defaults to False. + skip_if_missing (bool): Whether to skip the file if it is missing. Defaults to False. + + Example: + >>> open_as_file(Source.create("example.csv", optional=True)) + """ + + data: FileOrPath + optional: bool = False + skip_if_missing: bool = False + + @classmethod + def create(cls, input_: "FileLike", **kwargs): + if isinstance(input_, Source): + return replace(input_, **kwargs) + return Source(data=input_, **kwargs) + + +FileLike = Union[FileOrPath, Source] + + +def _file_or_path_to_binary_stream( + file_or_path: FileOrPath, binary_mode: str +) -> Tuple[BinaryIO, bool]: + """ + Converts a file path or a file-like object to a binary stream. + + Args: + file_or_path: The file path or file-like object to convert. + binary_mode: The binary mode to open the file in. Must be one of 'rb', 'wb', or 'ab'. + + Returns: + A tuple containing the binary stream and a boolean indicating whether + a new file was opened (True) or an existing file-like object was used (False). + """ + assert binary_mode in ("rb", "wb", "ab") + + if isinstance(file_or_path, (str, bytes)) or hasattr( + file_or_path, "__fspath__" + ): + # If file_or_path is a path-like object, open it and return the binary stream + return open(os.fspath(file_or_path), binary_mode), True # type: ignore + + if isinstance(file_or_path, TextIOWrapper): + # If file_or_path is a TextIOWrapper, return its underlying binary buffer + return file_or_path.buffer, False + + if hasattr(file_or_path, "readinto") or hasattr(file_or_path, "write"): + # If file_or_path is a file-like object, return it as is + return file_or_path, False # type: ignore + + raise TypeError( + f"Unsupported type for {file_or_path}, " + f"{file_or_path.__class__.__name__}." + ) + + +def _detect_format_from_content(file_or_path: FileOrPath) -> Optional[str]: + """ + Attempts to detect file format from the content by reading the first + 6 bytes. Returns None if no format could be detected. + """ + fileobj, closefd = _file_or_path_to_binary_stream(file_or_path, "rb") + try: + if not fileobj.readable(): + return None + if hasattr(fileobj, "peek"): + bs = fileobj.peek(6) + elif hasattr(fileobj, "seekable") and fileobj.seekable(): + current_pos = fileobj.tell() + bs = fileobj.read(6) + fileobj.seek(current_pos) + else: + return None + + if bs[:2] == b"\x1f\x8b": + # https://tools.ietf.org/html/rfc1952#page-6 + return "gz" + elif bs[:3] == b"\x42\x5a\x68": + # https://en.wikipedia.org/wiki/List_of_file_signatures + return "bz2" + elif bs[:6] == b"\xfd\x37\x7a\x58\x5a\x00": + # https://tukaani.org/xz/xz-file-format.txt + return "xz" + return None + finally: + if closefd: + fileobj.close() + + +def _detect_format_from_extension(filename: FilePath) -> Optional[str]: + """ + Attempt to detect file format from the filename extension. + Return None if no format could be detected. + """ + extensions = ("bz2", "xz", "gz") + + if isinstance(filename, bytes): + for ext in extensions: + if filename.endswith(b"." + ext.encode()): + return ext + + if isinstance(filename, str): + for ext in extensions: + if filename.endswith("." + ext): + return ext + + if hasattr(filename, "name"): + return _detect_format_from_extension(filename.name) + + return None + + +def _filepath_from_path_or_filelike(file_or_path: FileOrPath) -> str: + try: + return os.fspath(file_or_path) # type: ignore + except TypeError: + pass + + if hasattr(file_or_path, "name"): + name = file_or_path.name + if isinstance(name, str): + return name + elif isinstance(name, bytes): + return name.decode() + + return "" + + +def _open( + filename: FileOrPath, + mode: str = "rb", + compresslevel: Optional[int] = None, + format: Optional[str] = None, +) -> BinaryIO: + """ + A replacement for the "open" function that can also read and write + compressed files transparently. The supported compression formats are gzip, + bzip2 and xz. Filename can be a string, a Path or a file object. + + When writing, the file format is chosen based on the file name extension: + - .gz uses gzip compression + - .bz2 uses bzip2 compression + - .xz uses xz/lzma compression + - otherwise, no compression is used + + When reading, if a file name extension is available, the format is detected + using it, but if not, the format is detected from the contents. + + mode can be: 'rb', 'ab', or 'wb'. + + compresslevel is the compression level for writing to gzip, and xz. + This parameter is ignored for the other compression formats. + If set to None, a default depending on the format is used: + gzip: 6, xz: 6. + + format overrides the autodetection of input and output formats. This can be + useful when compressed output needs to be written to a file without an + extension. Possible values are "gz", "xz", "bz2" and "raw". In case of + "raw", no compression is used. + """ + if mode not in ("rb", "wb", "ab"): + raise ValueError("Mode '{}' not supported".format(mode)) + filepath = _filepath_from_path_or_filelike(filename) + + if format not in (None, "gz", "xz", "bz2", "raw"): + raise ValueError( + f"Format not supported: {format}. Choose one of: 'gz', 'xz', 'bz2'" + ) + + if format == "raw": + detected_format = None + else: + detected_format = format or _detect_format_from_extension(filepath) + if detected_format is None and "r" in mode: + detected_format = _detect_format_from_content(filename) + + if detected_format == "gz": + opened_file = _open_gz(filename, mode, compresslevel) + elif detected_format == "xz": + opened_file = _open_xz(filename, mode, compresslevel) + elif detected_format == "bz2": + opened_file = _open_bz2(filename, mode, compresslevel) + else: + opened_file, _ = _file_or_path_to_binary_stream(filename, mode) + + return opened_file + + +def _open_bz2( + filename: FileOrPath, + mode: str, + compresslevel: Optional[int] = None, +) -> BinaryIO: + assert mode in ("rb", "ab", "wb") + if compresslevel is None: + compresslevel = DEFAULT_BZ2_COMPRESSION + + if "r" in mode: + return bz2.open(filename, mode) # type: ignore + return BufferedWriter(bz2.open(filename, mode, compresslevel)) # type: ignore + + +def _open_xz( + filename: FileOrPath, + mode: str, + compresslevel: Optional[int] = None, +) -> BinaryIO: + assert mode in ("rb", "ab", "wb") + if compresslevel is None: + compresslevel = DEFAULT_XZ_COMPRESSION + + if "r" in mode: + return lzma.open(filename, mode) # type: ignore + return BufferedWriter(lzma.open(filename, mode, preset=compresslevel)) # type: ignore + + +def _open_gz( + filename: FileOrPath, + mode: str, + compresslevel: Optional[int] = None, +) -> BinaryIO: + assert mode in ("rb", "ab", "wb") + if compresslevel is None: + compresslevel = DEFAULT_GZIP_COMPRESSION + + if "r" in mode: + return gzip.open(filename, mode) # type: ignore + return BufferedWriter(gzip.open(filename, mode, compresslevel=compresslevel)) # type: ignore + + +def get_file_extension(file_or_path: FileLike) -> str: + """Determine the file extension of the given file-like object. + + If the file has compression extensions such as '.gz', '.xz', or '.bz2', + they will be stripped before determining the extension. + + Args: + file_or_path (FileLike): The file-like object whose extension needs to be determined. + + Returns: + str: The file extension, including the dot ('.') if present. + + Raises: + Exception: If the extension cannot be determined. + + Example: + >>> get_file_extension("example.xml.gz") + '.xml' + >>> get_file_extension(Path("example.txt")) + '.txt' + >>> get_file_extension(Source(data="example.csv")) + '.csv' + """ + if isinstance(file_or_path, (str, bytes)) or hasattr( + file_or_path, "__fspath__" + ): + path = os.fspath(file_or_path) # type: ignore + for ext in [".gz", ".xz", ".bz2"]: + if path.endswith(ext): + path = path[: -len(ext)] + return os.path.splitext(path)[1] + + if isinstance(file_or_path, Source): + return get_file_extension(file_or_path.data) + + raise TypeError( + f"Could not determine extension for input type: {type(file_or_path)}" + ) + + +def get_local_cache_stream( + url: str, cache_dir: str, mode: str = "rb", format: Optional[str] = None +) -> Tuple[BinaryIO, Union[bool, str]]: + """Get a stream to the local cache file for the given URL. + + Compressed files are read transparently. The supported compression formats + are gzip, bzip2 and xz. + + Args: + url (str): The URL to cache. + cache_dir (str): The directory where the cache file will be stored. + mode (str): The mode in which to open the cache file. Must be one of + 'rb', 'wb', or 'ab'. Defaults to 'ab'. + format (str): Overrides the autodetection of input and output formats. + Possible values are "gz", "xz", "bz2" and "raw". In case of "raw", + no compression is used.. + + Returns: + Tuple[BinaryIO, bool | str]: A tuple containing a binary stream to the + local cache file and the path to the cache file if it already + exists and is non-empty, otherwise False. + + Note: + - If the specified cache directory does not exist, it will be created. + - If the cache file does not exist, it will be created and will be + named after the URL. + + Example: + >>> stream, exists = get_local_cache_stream("https://example.com/data", "./cache") + """ + assert mode in ("rb", "wb", "ab") + + # Ensure the cache directory exists + if not os.path.exists(cache_dir): + os.makedirs(cache_dir) + + # Generate the local filename based on the URL + filename = urllib.parse.quote_plus(url) + local_filename = f"{cache_dir}/{filename}" + + # Ensure the file exists by opening it in append-binary mode, creating it if necessary + file_exists_and_non_empty = ( + os.path.exists(local_filename) and os.path.getsize(local_filename) > 0 + ) + file = _open(local_filename, mode, format=format) + + return file, file_exists_and_non_empty + + +@contextlib.contextmanager +def dummy_context_mgr() -> Generator[None, None, None]: + yield + + +def open_as_file(input_: FileLike) -> ContextManager[Optional[BinaryIO]]: + """Open a byte stream to the given input object. + + The following input types are supported: + - A string or `pathlib.Path` object representing a local file path. + - A string representing a URL. It should start with 'http://' or + 'https://'. + - A string representing a path to a file in a Amazon S3 cloud storage + bucket. It should start with 's3://'. + - A xml or json string containing the data. The string should contain + a '{' or '<' character. Otherwise, it will be treated as a file path. + - A bytes object containing the data. + - A buffered binary stream that inherits from `io.BufferedIOBase`. + - A [Source](`kloppy.io.Source`) object that wraps any of the above + input types. + + Args: + input_ (FileLike): The input object to be opened. + + Returns: + BinaryIO: A binary stream to the input object. + + Raises: + InputNotFoundError: If the input file is not found. + TypeError: If the input type is not supported. + + Example: + >>> with open_as_file("example.txt") as f: + ... contents = f.read() + + Note: + To support reading data from other sources, see the + [Adapter](`kloppy.io.adapters.Adapter`) class. + + If the given file path or URL ends with '.gz', '.xz', or '.bz2', the + file will be decompressed before being read. + """ + if isinstance(input_, Source): + if input_.data is None and input_.optional: + # This saves us some additional code in every vendor specific code + return dummy_context_mgr() + + try: + return open_as_file(input_.data) + except InputNotFoundError: + if input_.skip_if_missing: + logging.info(f"Input {input_.data} not found. Skipping") + return dummy_context_mgr() + raise + + if isinstance(input_, str) and ("{" in input_ or "<" in input_): + # If input_ is a JSON or XML string, return it as a binary stream + return BytesIO(input_.encode("utf8")) + + if isinstance(input_, bytes): + # If input_ is a bytes object, return it as a binary stream + return BytesIO(input_) + + if isinstance(input_, str) or hasattr(input_, "__fspath__"): + # If input_ is a path-like object, open it and return the binary stream + uri = _filepath_from_path_or_filelike(input_) + + adapter = get_adapter(uri) + if adapter: + cache_dir = get_config("cache") + assert cache_dir is None or isinstance(cache_dir, str) + if cache_dir: + stream, local_cache_file = get_local_cache_stream( + uri, cache_dir, "ab", format="raw" + ) + else: + stream, local_cache_file = BytesIO(), None + + if not local_cache_file: + logger.info(f"Retrieving {uri}") + adapter.read_to_stream(uri, stream) + logger.info("Retrieval complete") + else: + logger.info(f"Using local cached file {local_cache_file}") + + if cache_dir: + stream.close() + stream, _ = get_local_cache_stream(uri, cache_dir, "rb") + else: + stream.seek(0) + + else: + if not os.path.exists(uri): + raise InputNotFoundError(f"File {uri} does not exist") + + stream = _open(uri, "rb") + return stream + + if isinstance(input_, TextIOWrapper): + # If file_or_path is a TextIOWrapper, return its underlying binary buffer + return input_.buffer + + if hasattr(input_, "readinto"): + # If file_or_path is a file-like object, return it as is + return _open(input_) # type: ignore + + raise TypeError(f"Unsupported input type: {type(input_)}") diff --git a/kloppy/tests/test_tracab.py b/kloppy/tests/test_tracab.py index d67f485e..2fe5b425 100644 --- a/kloppy/tests/test_tracab.py +++ b/kloppy/tests/test_tracab.py @@ -1,398 +1,398 @@ -from pathlib import Path -from datetime import datetime, timedelta, timezone - -import pytest - -from kloppy._providers.tracab import ( - identify_deserializer, - TRACABJSONDeserializer, - TRACABDatDeserializer, -) -from kloppy.domain import ( - Period, - AttackingDirection, - Orientation, - Provider, - Point, - Point3D, - BallState, - Team, - Ground, - DatasetType, -) - -from kloppy import tracab - - -@pytest.fixture(scope="session") -def json_meta_data(base_dir: Path) -> Path: - return base_dir / "files" / "tracab_meta.json" - - -@pytest.fixture(scope="session") -def json_raw_data(base_dir: Path) -> Path: - return base_dir / "files" / "tracab_raw.json" - - -@pytest.fixture(scope="session") -def xml_meta_data(base_dir: Path) -> Path: - return base_dir / "files" / "tracab_meta.xml" - - -@pytest.fixture(scope="session") -def xml_meta2_data(base_dir: Path) -> Path: - return base_dir / "files" / "tracab_meta_2.xml" - - -@pytest.fixture(scope="session") -def xml_meta3_data(base_dir: Path) -> Path: - return base_dir / "files" / "tracab_meta_3.xml" - - -@pytest.fixture(scope="session") -def xml_meta4_data(base_dir: Path) -> Path: - return base_dir / "files" / "tracab_meta_4.xml" - - -@pytest.fixture(scope="session") -def dat_raw_data(base_dir: Path) -> Path: - return base_dir / "files" / "tracab_raw.dat" - - -def test_correct_auto_recognize_deserialization( - json_raw_data: Path, - dat_raw_data: Path, -): - tracab_json_deserializer = identify_deserializer(raw_data=json_raw_data) - assert tracab_json_deserializer == TRACABJSONDeserializer - tracab_dat_deserializer = identify_deserializer(raw_data=dat_raw_data) - assert tracab_dat_deserializer == TRACABDatDeserializer - - -def meta_tracking_assertions(dataset): - assert dataset.metadata.provider == Provider.TRACAB - assert dataset.dataset_type == DatasetType.TRACKING - assert len(dataset.records) == 7 - assert len(dataset.metadata.periods) == 2 - assert dataset.metadata.periods[0].id == 1 - assert dataset.metadata.periods[0].start_timestamp == timedelta( - seconds=73940.32 - ) - assert dataset.metadata.periods[0].end_timestamp == timedelta( - seconds=76656.32 - ) - assert dataset.metadata.periods[1].id == 2 - assert dataset.metadata.periods[1].start_timestamp == timedelta( - seconds=77684.56 - ) - assert dataset.metadata.periods[1].end_timestamp == timedelta( - seconds=80717.32 - ) - assert dataset.metadata.orientation == Orientation.AWAY_HOME - - player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number(1) - assert dataset.records[0].players_data[player_home_1].coordinates == Point( - x=5270.0, y=27.0 - ) - - player_away_12 = dataset.metadata.teams[1].get_player_by_jersey_number(12) - assert dataset.records[0].players_data[ - player_away_12 - ].coordinates == Point(x=-4722.0, y=28.0) - assert dataset.records[0].ball_state == BallState.DEAD - assert dataset.records[1].ball_state == BallState.ALIVE - # Shouldn't this be closer to (0,0,0)? - assert dataset.records[1].ball_coordinates == Point3D( - x=2710.0, y=3722.0, z=11.0 - ) - - # make sure player data is only in the frame when the player is at the pitch - assert "12170" in [ - player.player_id for player in dataset.records[0].players_data.keys() - ] - assert "12170" not in [ - player.player_id for player in dataset.records[6].players_data.keys() - ] - - -class TestTracabJSONTracking: - def test_correct_deserialization( - self, json_meta_data: Path, json_raw_data: Path - ): - dataset = tracab.load( - meta_data=json_meta_data, - raw_data=json_raw_data, - coordinates="tracab", - only_alive=False, - file_format="json", - ) - meta_tracking_assertions(dataset) - - def test_correct_normalized_deserialization( - self, json_meta_data: Path, json_raw_data: Path - ): - dataset = tracab.load( - meta_data=json_meta_data, raw_data=json_raw_data, only_alive=False - ) - player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number( - 1 - ) - assert dataset.records[0].players_data[ - player_home_1 - ].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583) - - -class TestTracabDATTracking: - def test_correct_deserialization( - self, xml_meta_data: Path, dat_raw_data: Path - ): - dataset = tracab.load( - meta_data=xml_meta_data, - raw_data=dat_raw_data, - coordinates="tracab", - only_alive=False, - ) - - meta_tracking_assertions(dataset) - - def test_correct_normalized_deserialization( - self, xml_meta_data: Path, dat_raw_data: Path - ): - dataset = tracab.load( - meta_data=xml_meta_data, raw_data=dat_raw_data, only_alive=False - ) - - player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number( - 1 - ) - - assert dataset.records[0].players_data[ - player_home_1 - ].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583) - - date = dataset.metadata.date - if date: - assert isinstance(date, datetime) - assert date == datetime( - 2023, 12, 15, 20, 32, 20, tzinfo=timezone.utc - ) - - game_week = dataset.metadata.game_week - if game_week: - assert isinstance(game_week, str) - - game_id = dataset.metadata.game_id - if game_id: - assert isinstance(game_id, str) - assert game_id == "1" - - -class TestTracabMeta2: - def test_correct_deserialization( - self, xml_meta2_data: Path, dat_raw_data: Path - ): - dataset = tracab.load( - meta_data=xml_meta2_data, - raw_data=dat_raw_data, - coordinates="tracab", - only_alive=False, - ) - - # Check metadata - assert dataset.metadata.provider == Provider.TRACAB - assert dataset.dataset_type == DatasetType.TRACKING - assert len(dataset.records) == 7 - assert len(dataset.metadata.periods) == 2 - assert dataset.metadata.orientation == Orientation.AWAY_HOME - assert dataset.metadata.periods[0].id == 1 - assert dataset.metadata.periods[0].start_timestamp == timedelta( - seconds=73940, microseconds=320000 - ) - assert dataset.metadata.periods[0].end_timestamp == timedelta( - seconds=76656, microseconds=320000 - ) - assert dataset.metadata.periods[1].id == 2 - assert dataset.metadata.periods[1].start_timestamp == timedelta( - seconds=77684, microseconds=560000 - ) - assert dataset.metadata.periods[1].end_timestamp == timedelta( - seconds=80717, microseconds=320000 - ) - - # No need to check frames, since we do that in TestTracabDATTracking - # The only difference in this test is the meta data file structure - - # make sure player data is only in the frame when the player is at the pitch - assert "home_20" in [ - player.player_id - for player in dataset.records[0].players_data.keys() - ] - assert "home_20" not in [ - player.player_id - for player in dataset.records[6].players_data.keys() - ] - - def test_correct_normalized_deserialization( - self, xml_meta2_data: Path, dat_raw_data: Path - ): - dataset = tracab.load( - meta_data=xml_meta2_data, raw_data=dat_raw_data, only_alive=False - ) - - player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number( - 1 - ) - - assert dataset.records[0].players_data[ - player_home_1 - ].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583) - - -class TestTracabMeta3: - def test_correct_deserialization( - self, xml_meta3_data: Path, dat_raw_data: Path - ): - dataset = tracab.load( - meta_data=xml_meta3_data, - raw_data=dat_raw_data, - coordinates="tracab", - only_alive=False, - ) - - # Check metadata - assert dataset.metadata.provider == Provider.TRACAB - assert dataset.dataset_type == DatasetType.TRACKING - assert len(dataset.records) == 7 - assert len(dataset.metadata.periods) == 2 - assert dataset.metadata.orientation == Orientation.AWAY_HOME - assert dataset.metadata.periods[0].id == 1 - assert dataset.metadata.periods[0].start_timestamp == timedelta( - seconds=73940, microseconds=320000 - ) - assert dataset.metadata.periods[0].end_timestamp == timedelta( - seconds=76656, microseconds=320000 - ) - assert dataset.metadata.periods[1].id == 2 - assert dataset.metadata.periods[1].start_timestamp == timedelta( - seconds=77684, microseconds=560000 - ) - assert dataset.metadata.periods[1].end_timestamp == timedelta( - seconds=80717, microseconds=320000 - ) - - # No need to check frames, since we do that in TestTracabDATTracking - # The only difference in this test is the meta data file structure - - # make sure player data is only in the frame when the player is at the pitch - assert "home_20" in [ - player.player_id - for player in dataset.records[0].players_data.keys() - ] - assert "home_20" not in [ - player.player_id - for player in dataset.records[6].players_data.keys() - ] - - def test_correct_normalized_deserialization( - self, xml_meta3_data: Path, dat_raw_data: Path - ): - dataset = tracab.load( - meta_data=xml_meta3_data, raw_data=dat_raw_data, only_alive=False - ) - - player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number( - 1 - ) - - assert dataset.records[0].players_data[ - player_home_1 - ].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583) - - -class TestTracabMeta4: - def test_correct_deserialization( - self, xml_meta4_data: Path, dat_raw_data: Path - ): - dataset = tracab.load( - meta_data=xml_meta4_data, - raw_data=dat_raw_data, - coordinates="tracab", - only_alive=False, - ) - - # Check metadata - assert dataset.metadata.provider == Provider.TRACAB - assert dataset.dataset_type == DatasetType.TRACKING - assert len(dataset.records) == 7 - assert len(dataset.metadata.periods) == 2 - assert dataset.metadata.orientation == Orientation.AWAY_HOME - assert dataset.metadata.periods[0].id == 1 - assert dataset.metadata.periods[0].start_timestamp == timedelta( - seconds=73940, microseconds=320000 - ) - assert dataset.metadata.periods[0].end_timestamp == timedelta( - seconds=76656, microseconds=320000 - ) - assert dataset.metadata.periods[1].id == 2 - assert dataset.metadata.periods[1].start_timestamp == timedelta( - seconds=77684, microseconds=560000 - ) - assert dataset.metadata.periods[1].end_timestamp == timedelta( - seconds=80717, microseconds=320000 - ) - - # No need to check frames, since we do that in TestTracabDATTracking - # The only difference in this test is the meta data file structure - - # make sure player data is only in the frame when the player is at the pitch - assert "12170" in [ - player.player_id - for player in dataset.records[0].players_data.keys() - ] - assert "12170" not in [ - player.player_id - for player in dataset.records[6].players_data.keys() - ] - - def test_correct_normalized_deserialization( - self, xml_meta4_data: Path, dat_raw_data: Path - ): - dataset = tracab.load( - meta_data=xml_meta4_data, raw_data=dat_raw_data, only_alive=False - ) - - player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number( - 1 - ) - - assert dataset.records[0].players_data[ - player_home_1 - ].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583) - - -class TestTracabDATTrackingJSONMeta: - def test_correct_deserialization( - self, json_meta_data: Path, dat_raw_data: Path - ): - dataset = tracab.load( - meta_data=json_meta_data, - raw_data=dat_raw_data, - coordinates="tracab", - only_alive=False, - ) - - meta_tracking_assertions(dataset) - - -class TestTracabJSONTrackingXMLNMeta: - def test_correct_deserialization( - self, xml_meta_data: Path, json_raw_data: Path - ): - dataset = tracab.load( - meta_data=xml_meta_data, - raw_data=json_raw_data, - coordinates="tracab", - only_alive=False, - ) - - meta_tracking_assertions(dataset) +from pathlib import Path +from datetime import datetime, timedelta, timezone + +import pytest + +from kloppy._providers.tracab import ( + identify_deserializer, + TRACABJSONDeserializer, + TRACABDatDeserializer, +) +from kloppy.domain import ( + Period, + AttackingDirection, + Orientation, + Provider, + Point, + Point3D, + BallState, + Team, + Ground, + DatasetType, +) + +from kloppy import tracab + + +@pytest.fixture(scope="session") +def json_meta_data(base_dir: Path) -> Path: + return base_dir / "files" / "tracab_meta.json" + + +@pytest.fixture(scope="session") +def json_raw_data(base_dir: Path) -> Path: + return base_dir / "files" / "tracab_raw.json" + + +@pytest.fixture(scope="session") +def xml_meta_data(base_dir: Path) -> Path: + return base_dir / "files" / "tracab_meta.xml" + + +@pytest.fixture(scope="session") +def xml_meta2_data(base_dir: Path) -> Path: + return base_dir / "files" / "tracab_meta_2.xml" + + +@pytest.fixture(scope="session") +def xml_meta3_data(base_dir: Path) -> Path: + return base_dir / "files" / "tracab_meta_3.xml" + + +@pytest.fixture(scope="session") +def xml_meta4_data(base_dir: Path) -> Path: + return base_dir / "files" / "tracab_meta_4.xml" + + +@pytest.fixture(scope="session") +def dat_raw_data(base_dir: Path) -> Path: + return base_dir / "files" / "tracab_raw.dat" + + +def test_correct_auto_recognize_deserialization( + json_raw_data: Path, + dat_raw_data: Path, +): + tracab_json_deserializer = identify_deserializer(raw_data=json_raw_data) + assert tracab_json_deserializer == TRACABJSONDeserializer + tracab_dat_deserializer = identify_deserializer(raw_data=dat_raw_data) + assert tracab_dat_deserializer == TRACABDatDeserializer + + +def meta_tracking_assertions(dataset): + assert dataset.metadata.provider == Provider.TRACAB + assert dataset.dataset_type == DatasetType.TRACKING + assert len(dataset.records) == 7 + assert len(dataset.metadata.periods) == 2 + assert dataset.metadata.periods[0].id == 1 + assert dataset.metadata.periods[0].start_timestamp == timedelta( + seconds=73940.32 + ) + assert dataset.metadata.periods[0].end_timestamp == timedelta( + seconds=76656.32 + ) + assert dataset.metadata.periods[1].id == 2 + assert dataset.metadata.periods[1].start_timestamp == timedelta( + seconds=77684.56 + ) + assert dataset.metadata.periods[1].end_timestamp == timedelta( + seconds=80717.32 + ) + assert dataset.metadata.orientation == Orientation.AWAY_HOME + + player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number(1) + assert dataset.records[0].players_data[player_home_1].coordinates == Point( + x=5270.0, y=27.0 + ) + + player_away_12 = dataset.metadata.teams[1].get_player_by_jersey_number(12) + assert dataset.records[0].players_data[ + player_away_12 + ].coordinates == Point(x=-4722.0, y=28.0) + assert dataset.records[0].ball_state == BallState.DEAD + assert dataset.records[1].ball_state == BallState.ALIVE + # Shouldn't this be closer to (0,0,0)? + assert dataset.records[1].ball_coordinates == Point3D( + x=2710.0, y=3722.0, z=11.0 + ) + + # make sure player data is only in the frame when the player is at the pitch + assert "12170" in [ + player.player_id for player in dataset.records[0].players_data.keys() + ] + assert "12170" not in [ + player.player_id for player in dataset.records[6].players_data.keys() + ] + + +class TestTracabJSONTracking: + def test_correct_deserialization( + self, json_meta_data: Path, json_raw_data: Path + ): + dataset = tracab.load( + meta_data=json_meta_data, + raw_data=json_raw_data, + coordinates="tracab", + only_alive=False, + file_format="json", + ) + meta_tracking_assertions(dataset) + + def test_correct_normalized_deserialization( + self, json_meta_data: Path, json_raw_data: Path + ): + dataset = tracab.load( + meta_data=json_meta_data, raw_data=json_raw_data, only_alive=False + ) + player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number( + 1 + ) + assert dataset.records[0].players_data[ + player_home_1 + ].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583) + + +class TestTracabDATTracking: + def test_correct_deserialization( + self, xml_meta_data: Path, dat_raw_data: Path + ): + dataset = tracab.load( + meta_data=xml_meta_data, + raw_data=dat_raw_data, + coordinates="tracab", + only_alive=False, + ) + + meta_tracking_assertions(dataset) + + def test_correct_normalized_deserialization( + self, xml_meta_data: Path, dat_raw_data: Path + ): + dataset = tracab.load( + meta_data=xml_meta_data, raw_data=dat_raw_data, only_alive=False + ) + + player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number( + 1 + ) + + assert dataset.records[0].players_data[ + player_home_1 + ].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583) + + date = dataset.metadata.date + if date: + assert isinstance(date, datetime) + assert date == datetime( + 2023, 12, 15, 20, 32, 20, tzinfo=timezone.utc + ) + + game_week = dataset.metadata.game_week + if game_week: + assert isinstance(game_week, str) + + game_id = dataset.metadata.game_id + if game_id: + assert isinstance(game_id, str) + assert game_id == "1" + + +class TestTracabMeta2: + def test_correct_deserialization( + self, xml_meta2_data: Path, dat_raw_data: Path + ): + dataset = tracab.load( + meta_data=xml_meta2_data, + raw_data=dat_raw_data, + coordinates="tracab", + only_alive=False, + ) + + # Check metadata + assert dataset.metadata.provider == Provider.TRACAB + assert dataset.dataset_type == DatasetType.TRACKING + assert len(dataset.records) == 7 + assert len(dataset.metadata.periods) == 2 + assert dataset.metadata.orientation == Orientation.AWAY_HOME + assert dataset.metadata.periods[0].id == 1 + assert dataset.metadata.periods[0].start_timestamp == timedelta( + seconds=73940, microseconds=320000 + ) + assert dataset.metadata.periods[0].end_timestamp == timedelta( + seconds=76656, microseconds=320000 + ) + assert dataset.metadata.periods[1].id == 2 + assert dataset.metadata.periods[1].start_timestamp == timedelta( + seconds=77684, microseconds=560000 + ) + assert dataset.metadata.periods[1].end_timestamp == timedelta( + seconds=80717, microseconds=320000 + ) + + # No need to check frames, since we do that in TestTracabDATTracking + # The only difference in this test is the meta data file structure + + # make sure player data is only in the frame when the player is at the pitch + assert "home_20" in [ + player.player_id + for player in dataset.records[0].players_data.keys() + ] + assert "home_20" not in [ + player.player_id + for player in dataset.records[6].players_data.keys() + ] + + def test_correct_normalized_deserialization( + self, xml_meta2_data: Path, dat_raw_data: Path + ): + dataset = tracab.load( + meta_data=xml_meta2_data, raw_data=dat_raw_data, only_alive=False + ) + + player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number( + 1 + ) + + assert dataset.records[0].players_data[ + player_home_1 + ].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583) + + +class TestTracabMeta3: + def test_correct_deserialization( + self, xml_meta3_data: Path, dat_raw_data: Path + ): + dataset = tracab.load( + meta_data=xml_meta3_data, + raw_data=dat_raw_data, + coordinates="tracab", + only_alive=False, + ) + + # Check metadata + assert dataset.metadata.provider == Provider.TRACAB + assert dataset.dataset_type == DatasetType.TRACKING + assert len(dataset.records) == 7 + assert len(dataset.metadata.periods) == 2 + assert dataset.metadata.orientation == Orientation.AWAY_HOME + assert dataset.metadata.periods[0].id == 1 + assert dataset.metadata.periods[0].start_timestamp == timedelta( + seconds=73940, microseconds=320000 + ) + assert dataset.metadata.periods[0].end_timestamp == timedelta( + seconds=76656, microseconds=320000 + ) + assert dataset.metadata.periods[1].id == 2 + assert dataset.metadata.periods[1].start_timestamp == timedelta( + seconds=77684, microseconds=560000 + ) + assert dataset.metadata.periods[1].end_timestamp == timedelta( + seconds=80717, microseconds=320000 + ) + + # No need to check frames, since we do that in TestTracabDATTracking + # The only difference in this test is the meta data file structure + + # make sure player data is only in the frame when the player is at the pitch + assert "home_20" in [ + player.player_id + for player in dataset.records[0].players_data.keys() + ] + assert "home_20" not in [ + player.player_id + for player in dataset.records[6].players_data.keys() + ] + + def test_correct_normalized_deserialization( + self, xml_meta3_data: Path, dat_raw_data: Path + ): + dataset = tracab.load( + meta_data=xml_meta3_data, raw_data=dat_raw_data, only_alive=False + ) + + player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number( + 1 + ) + + assert dataset.records[0].players_data[ + player_home_1 + ].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583) + + +class TestTracabMeta4: + def test_correct_deserialization( + self, xml_meta4_data: Path, dat_raw_data: Path + ): + dataset = tracab.load( + meta_data=xml_meta4_data, + raw_data=dat_raw_data, + coordinates="tracab", + only_alive=False, + ) + + # Check metadata + assert dataset.metadata.provider == Provider.TRACAB + assert dataset.dataset_type == DatasetType.TRACKING + assert len(dataset.records) == 7 + assert len(dataset.metadata.periods) == 2 + assert dataset.metadata.orientation == Orientation.AWAY_HOME + assert dataset.metadata.periods[0].id == 1 + assert dataset.metadata.periods[0].start_timestamp == timedelta( + seconds=73940, microseconds=320000 + ) + assert dataset.metadata.periods[0].end_timestamp == timedelta( + seconds=76656, microseconds=320000 + ) + assert dataset.metadata.periods[1].id == 2 + assert dataset.metadata.periods[1].start_timestamp == timedelta( + seconds=77684, microseconds=560000 + ) + assert dataset.metadata.periods[1].end_timestamp == timedelta( + seconds=80717, microseconds=320000 + ) + + # No need to check frames, since we do that in TestTracabDATTracking + # The only difference in this test is the meta data file structure + + # make sure player data is only in the frame when the player is at the pitch + assert "12170" in [ + player.player_id + for player in dataset.records[0].players_data.keys() + ] + assert "12170" not in [ + player.player_id + for player in dataset.records[6].players_data.keys() + ] + + def test_correct_normalized_deserialization( + self, xml_meta4_data: Path, dat_raw_data: Path + ): + dataset = tracab.load( + meta_data=xml_meta4_data, raw_data=dat_raw_data, only_alive=False + ) + + player_home_1 = dataset.metadata.teams[0].get_player_by_jersey_number( + 1 + ) + + assert dataset.records[0].players_data[ + player_home_1 + ].coordinates == Point(x=1.0019047619047619, y=0.49602941176470583) + + +class TestTracabDATTrackingJSONMeta: + def test_correct_deserialization( + self, json_meta_data: Path, dat_raw_data: Path + ): + dataset = tracab.load( + meta_data=json_meta_data, + raw_data=dat_raw_data, + coordinates="tracab", + only_alive=False, + ) + + meta_tracking_assertions(dataset) + + +class TestTracabJSONTrackingXMLNMeta: + def test_correct_deserialization( + self, xml_meta_data: Path, json_raw_data: Path + ): + dataset = tracab.load( + meta_data=xml_meta_data, + raw_data=json_raw_data, + coordinates="tracab", + only_alive=False, + ) + + meta_tracking_assertions(dataset)