Skip to content

Commit

Permalink
[StatsBomb] Fix player identity inference issues in freeze frame data (
Browse files Browse the repository at this point in the history
…PySport#386)

---------

Co-authored-by: Pieter Robberechts <[email protected]>
  • Loading branch information
AndrewRook and probberechts authored Dec 26, 2024
1 parent 2131739 commit 9cee426
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 67 deletions.
42 changes: 22 additions & 20 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@
from datetime import datetime, timedelta
from enum import Enum, Flag
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
List,
NewType,
Optional,
Callable,
Union,
Any,
TypeVar,
Generic,
NewType,
Union,
overload,
Iterable,
)

from .position import PositionType

from ...utils import deprecated, snake_case
from .position import PositionType

if sys.version_info >= (3, 8):
from typing import Literal
Expand All @@ -32,23 +31,23 @@
else:
from typing_extensions import Self

from ...exceptions import (
InvalidFilterError,
KloppyParameterError,
OrientationError,
)
from .formation import FormationType
from .pitch import (
PitchDimensions,
Unit,
Dimension,
NormalizedPitchDimensions,
MetricPitchDimensions,
ImperialPitchDimensions,
MetricPitchDimensions,
NormalizedPitchDimensions,
OptaPitchDimensions,
PitchDimensions,
Unit,
WyscoutPitchDimensions,
)
from .formation import FormationType
from .time import Time, Period, TimeContainer
from ...exceptions import (
OrientationError,
InvalidFilterError,
KloppyParameterError,
)
from .time import Period, Time, TimeContainer


@dataclass
Expand Down Expand Up @@ -264,7 +263,10 @@ def get_player_by_jersey_number(self, jersey_no: int):
def get_player_by_position(self, position: PositionType, time: Time):
for player in self.players:
if player.positions.items:
player_position = player.positions.value_at(time)
try:
player_position = player.positions.value_at(time)
except KeyError: # player that is subbed in later
continue
if player_position and player_position == position:
return player

Expand Down
14 changes: 7 additions & 7 deletions kloppy/domain/models/time.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from dataclasses import dataclass, field
from datetime import timedelta, datetime
from datetime import datetime, timedelta
from typing import (
overload,
Union,
Optional,
TypeVar,
Generic,
List,
Tuple,
NamedTuple,
Literal,
NamedTuple,
Optional,
Tuple,
TypeVar,
Union,
overload,
)

from sortedcontainers import SortedDict
Expand Down
64 changes: 29 additions & 35 deletions kloppy/infra/serializers/event/statsbomb/deserializer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import NamedTuple, IO, Optional
import logging
import json
import logging
from itertools import zip_longest
from typing import IO, NamedTuple, Optional

from kloppy.domain import (
DatasetFlag,
Expand All @@ -18,6 +18,7 @@
from kloppy.exceptions import DeserializationError
from kloppy.infra.serializers.event.deserializer import EventDataDeserializer
from kloppy.utils import performance_logging

from . import specification as SB
from .helpers import parse_freeze_frame, parse_str_ts
from .specification import position_types_mapping
Expand Down Expand Up @@ -78,36 +79,6 @@ def deserialize(
if self.should_include_event(event):
# Transform event to the coordinate system
event = self.transformer.transform_event(event)

# Add freeze_frame information
if "freeze_frame" in event.raw_event.get("shot", {}):
event.freeze_frame = self.transformer.transform_frame(
parse_freeze_frame(
freeze_frame=event.raw_event["shot"][
"freeze_frame"
],
home_team=teams[0],
away_team=teams[1],
event=event,
fidelity_version=data_version.shot_fidelity_version,
)
)

if (
not event.freeze_frame
and event.event_id in three_sixty_data
):
freeze_frame = three_sixty_data[event.event_id]
event.freeze_frame = self.transformer.transform_frame(
parse_freeze_frame(
freeze_frame=freeze_frame["freeze_frame"],
home_team=teams[0],
away_team=teams[1],
event=event,
fidelity_version=data_version.xy_fidelity_version,
visible_area=freeze_frame["visible_area"],
)
)
events.append(event)

metadata = Metadata(
Expand All @@ -120,9 +91,33 @@ def deserialize(
score=None,
provider=Provider.STATSBOMB,
coordinate_system=self.transformer.get_to_coordinate_system(),
**additional_metadata
**additional_metadata,
)
return EventDataset(metadata=metadata, records=events)
dataset = EventDataset(metadata=metadata, records=events)
for event in dataset:
if "freeze_frame" in event.raw_event.get("shot", {}):
event.freeze_frame = self.transformer.transform_frame(
parse_freeze_frame(
freeze_frame=event.raw_event["shot"]["freeze_frame"],
home_team=teams[0],
away_team=teams[1],
event=event,
fidelity_version=data_version.shot_fidelity_version,
)
)
if not event.freeze_frame and event.event_id in three_sixty_data:
freeze_frame = three_sixty_data[event.event_id]
event.freeze_frame = self.transformer.transform_frame(
parse_freeze_frame(
freeze_frame=freeze_frame["freeze_frame"],
home_team=teams[0],
away_team=teams[1],
event=event,
fidelity_version=data_version.xy_fidelity_version,
visible_area=freeze_frame["visible_area"],
)
)
return dataset

def load_data(self, inputs: StatsBombInputs):
raw_events = {}
Expand Down Expand Up @@ -185,7 +180,6 @@ def create_teams_and_players(self, raw_events, lineups):
for raw_event in starting_xi_events
for player in raw_event["tactics"]["lineup"]
}

starting_formations = {
raw_event["team"]["id"]: FormationType(
"-".join(list(str(raw_event["tactics"]["formation"])))
Expand Down
10 changes: 5 additions & 5 deletions kloppy/infra/serializers/event/statsbomb/helpers.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from datetime import timedelta
from typing import List, Dict, Optional
from typing import Dict, List, Optional

from kloppy.domain import (
Point,
Point3D,
Team,
ActionValue,
Event,
Frame,
Period,
Player,
PlayerData,
Point,
Point3D,
PositionType,
ActionValue,
Provider,
Team,
)
from kloppy.domain.services.frame_factory import create_frame
from kloppy.exceptions import DeserializationError
Expand Down
25 changes: 25 additions & 0 deletions kloppy/tests/test_statsbomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ def test_player_position(self, dataset):
)
assert home_ending_lam.player_id == "5633" # Yannick Ferreira Carrasco

away_starting_gk = dataset.metadata.teams[1].get_player_by_position(
PositionType.Goalkeeper,
time=Time(period=period_1, timestamp=timedelta(seconds=92)),
)
assert away_starting_gk.player_id == "5205" # Rui Patricio

def test_periods(self, dataset):
"""It should create the periods"""
assert len(dataset.metadata.periods) == 2
Expand Down Expand Up @@ -542,6 +548,25 @@ def get_color(player):
base_dir / "outputs" / "test_statsbomb_freeze_frame_360.png"
)

def test_freeze_frame_player_identities(self, dataset: EventDataset):
"""It should set the identities of the player that executed the event and the goalkeepers."""
event = dataset.get_event_by_id("0f525aa9-70f4-4f85-8a8d-6103722aee50")
home_team, away_team = dataset.metadata.teams
# The goalkeeper should be identified
keeper = next(p for p in away_team.players if p.player_id == "5205")
assert keeper in event.freeze_frame.players_coordinates
# The player that executed the event should be identified
player = next(p for p in away_team.players if p.player_id == "5209")
assert player in event.freeze_frame.players_coordinates
# All other players should be anonymous
for player in event.freeze_frame.players_coordinates.keys():
if player not in [keeper, player]:
assert player.id.startswith(
"T780-E0f525aa9-70f4-4f85-8a8d-6103722aee50-"
)
assert player.team in [home_team, away_team]
assert player.name is None

def test_correct_normalized_deserialization(self):
"""Test if the normalized deserialization is correct"""
dataset = statsbomb.load(
Expand Down

0 comments on commit 9cee426

Please sign in to comment.