Skip to content

Commit

Permalink
Merge pull request #190 from probberechts/fix-transform-orientation
Browse files Browse the repository at this point in the history
Fix transform orientation
  • Loading branch information
koenvo authored Jan 31, 2024
2 parents bb2dbad + 2f0e5b3 commit 7b8c03e
Show file tree
Hide file tree
Showing 28 changed files with 489 additions and 354 deletions.
2 changes: 1 addition & 1 deletion examples/datasets/statsbomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main():
with performance_logging("transform", logger=logger):
# convert to TRACAB coordinates
dataset = dataset.transform(
to_orientation="FIXED_HOME_AWAY",
to_orientation="STATIC_HOME_AWAY",
to_pitch_dimensions=[(-5500, 5500), (-3300, 3300)],
)

Expand Down
260 changes: 169 additions & 91 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,86 +232,179 @@ def __repr__(self):
return self.value


class AttackingDirection(Enum):
@dataclass
class Period:
"""
AttackingDirection
Period
Attributes:
HOME_AWAY (AttackingDirection): Home team is playing from left to right
AWAY_HOME (AttackingDirection): Home team is playing from right to left
NOT_SET (AttackingDirection): not set yet
id: `1` for first half, `2` for second half, `3` for first overtime,
`4` for second overtime, and `5` for penalty shootouts
start_timestamp: timestamp given by provider (can be unix timestamp or relative)
end_timestamp: timestamp given by provider (can be unix timestamp or relative)
"""

HOME_AWAY = "home-away"
AWAY_HOME = "away-home"
NOT_SET = "not-set"
id: int
start_timestamp: float
end_timestamp: float

def __repr__(self):
return self.value
def contains(self, timestamp: float):
return self.start_timestamp <= timestamp <= self.end_timestamp

@property
def duration(self):
return self.end_timestamp - self.start_timestamp

def __eq__(self, other):
return isinstance(other, Period) and other.id == self.id


class Orientation(Enum):
"""
The attacking direction of each team in a dataset.
Attributes:
BALL_OWNING_TEAM: The team that is currently in possession of the ball
plays from left to right.
ACTION_EXECUTING_TEAM: The team that executes the action
plays from left to right. Used in event stream data only. Equivalent
to "BALL_OWNING_TEAM" for tracking data.
HOME_AWAY: The home team plays from left to right in the first period.
The away team plays from left to right in the second period.
AWAY_HOME: The away team plays from left to right in the first period.
The home team plays from left to right in the second period.
STATIC_HOME_AWAY: The home team plays from left to right in both periods.
STATIC_AWAY_HOME: The away team plays from left to right in both periods.
NOT_SET: The attacking direction is not defined.
Notes:
The attacking direction is not defined for penalty shootouts in the
`HOME_AWAY`, `AWAY_HOME`, `STATIC_HOME_AWAY`, and `STATIC_AWAY_HOME`
orientations. This period is ignored in orientation transforms
involving one of these orientations and keeps its original
attacking direction.
"""

# change when possession changes
BALL_OWNING_TEAM = "ball-owning-team"

# depends on team which executed the action
ACTION_EXECUTING_TEAM = "action-executing-team"

# changes during half-time
HOME_TEAM = "home-team"
AWAY_TEAM = "away-team"
HOME_AWAY = "home-away"
AWAY_HOME = "away-home"

# won't change during match
FIXED_HOME_AWAY = "fixed-home-away"
FIXED_AWAY_HOME = "fixed-away-home"
STATIC_HOME_AWAY = "fixed-home-away"
STATIC_AWAY_HOME = "fixed-away-home"

# Not set in dataset
NOT_SET = "not-set"

def get_orientation_factor(
self,
attacking_direction: AttackingDirection,
ball_owning_team: Team,
action_executing_team: Team,
) -> int:
if self == Orientation.FIXED_HOME_AWAY:
return -1
elif self == Orientation.FIXED_AWAY_HOME:
return 1
elif self == Orientation.HOME_TEAM:
if attacking_direction == AttackingDirection.HOME_AWAY:
return -1
elif attacking_direction == AttackingDirection.AWAY_HOME:
return 1
else:
raise OrientationError("AttackingDirection not set")
elif self == Orientation.AWAY_TEAM:
if attacking_direction == AttackingDirection.AWAY_HOME:
return -1
elif attacking_direction == AttackingDirection.HOME_AWAY:
return 1
else:
raise OrientationError("AttackingDirection not set")
elif self == Orientation.BALL_OWNING_TEAM:
if ball_owning_team.ground == Ground.HOME:
return -1
elif ball_owning_team.ground == Ground.AWAY:
return 1
else:
def __repr__(self):
return self.value


class AttackingDirection(Enum):
"""
AttackingDirection
Attributes:
LTR (AttackingDirection): Home team is playing from left to right
RTL (AttackingDirection): Home team is playing from right to left
NOT_SET (AttackingDirection): not set yet
"""

LTR = "left-to-right"
RTL = "right-to-left"
NOT_SET = "not-set"

@staticmethod
def from_orientation(
orientation: Orientation,
period: Optional[Period] = None,
ball_owning_team: Optional[Team] = None,
action_executing_team: Optional[Team] = None,
) -> "AttackingDirection":
"""Determines the attacking direction for a specific data record.
Args:
orientation: The orientation of the dataset.
period: The period of the data record.
ball_owning_team: The team that is in possession of the ball.
action_executing_team: The team that executes the action.
Raises:
OrientationError: If the attacking direction cannot be determined
from the given data.
Returns:
The attacking direction for the given data record.
"""
if orientation == Orientation.STATIC_HOME_AWAY:
return AttackingDirection.LTR
if orientation == Orientation.STATIC_AWAY_HOME:
return AttackingDirection.RTL
if orientation == Orientation.HOME_AWAY:
if period is None:
raise OrientationError(
f"Invalid ball_owning_team: {ball_owning_team}"
"You must provide a period to determine the attacking direction"
)
elif self == Orientation.ACTION_EXECUTING_TEAM:
if action_executing_team.ground == Ground.HOME:
return -1
elif action_executing_team.ground == Ground.AWAY:
return 1
else:
dirmap = {
1: AttackingDirection.LTR,
2: AttackingDirection.RTL,
3: AttackingDirection.LTR,
4: AttackingDirection.RTL,
}
if period.id in dirmap:
return dirmap[period.id]
raise OrientationError(
"This orientation is not defined for period %s" % period.id
)
if orientation == Orientation.AWAY_HOME:
if period is None:
raise OrientationError(
f"Invalid action_executing_team: {action_executing_team}"
"You must provide a period to determine the attacking direction"
)
else:
raise OrientationError(f"Unknown orientation: {self}")
dirmap = {
1: AttackingDirection.RTL,
2: AttackingDirection.LTR,
3: AttackingDirection.RTL,
4: AttackingDirection.LTR,
}
if period.id in dirmap:
return dirmap[period.id]
raise OrientationError(
"This orientation is not defined for period %s" % period.id
)
if orientation == Orientation.BALL_OWNING_TEAM:
if ball_owning_team is None:
raise OrientationError(
"You must provide the ball owning team to determine the attacking direction"
)
if ball_owning_team is not None:
if ball_owning_team.ground == Ground.HOME:
return AttackingDirection.LTR
if ball_owning_team.ground == Ground.AWAY:
return AttackingDirection.RTL
raise OrientationError(
"Invalid ball_owning_team: %s", ball_owning_team
)
return AttackingDirection.NOT_SET
if orientation == Orientation.ACTION_EXECUTING_TEAM:
if action_executing_team is None:
raise ValueError(
"You must provide the action executing team to determine the attacking direction"
)
if action_executing_team.ground == Ground.HOME:
return AttackingDirection.LTR
if action_executing_team.ground == Ground.AWAY:
return AttackingDirection.RTL
raise OrientationError(
"Invalid action_executing_team: %s", action_executing_team
)
raise OrientationError("Unknown orientation: %s", orientation)

def __repr__(self):
return self.value
Expand All @@ -325,43 +418,6 @@ class VerticalOrientation(Enum):
BOTTOM_TO_TOP = "bottom-to-top"


@dataclass
class Period:
"""
Period
Attributes:
id: `1` for first half, `2` for second half
start_timestamp: timestamp given by provider (can be unix timestamp or relative)
end_timestamp: timestamp given by provider (can be unix timestamp or relative)
attacking_direction: See [`AttackingDirection`][kloppy.domain.models.common.AttackingDirection]
"""

id: int
start_timestamp: float
end_timestamp: float
attacking_direction: Optional[
AttackingDirection
] = AttackingDirection.NOT_SET

def contains(self, timestamp: float):
return self.start_timestamp <= timestamp <= self.end_timestamp

@property
def attacking_direction_set(self):
return self.attacking_direction != AttackingDirection.NOT_SET

def set_attacking_direction(self, attacking_direction: AttackingDirection):
self.attacking_direction = attacking_direction

@property
def duration(self):
return self.end_timestamp - self.start_timestamp

def __eq__(self, other):
return isinstance(other, Period) and other.id == self.id


class Origin(Enum):
"""
Attributes:
Expand Down Expand Up @@ -788,6 +844,23 @@ def set_refs(
self.prev_record = prev
self.next_record = next_

@property
def attacking_direction(self):
if (
self.dataset
and self.dataset.metadata
and self.dataset.metadata.orientation is not None
):
try:
return AttackingDirection.from_orientation(
self.dataset.metadata.orientation,
period=self.period,
ball_owning_team=self.ball_owning_team,
)
except OrientationError:
return AttackingDirection.NOT_SET
return AttackingDirection.NOT_SET

def matches(self, filter_) -> bool:
if filter_ is None:
return True
Expand Down Expand Up @@ -848,6 +921,11 @@ class Metadata:
frame_rate: Optional[float] = None
attributes: Optional[Dict] = field(default_factory=dict, compare=False)

def __post_init__(self):
if self.coordinate_system is not None:
# set the pitch dimensions from the coordinate system
self.pitch_dimensions = self.coordinate_system.pitch_dimensions


T = TypeVar("T", bound="DataRecord")

Expand Down
24 changes: 23 additions & 1 deletion kloppy/domain/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
TYPE_CHECKING,
)

from kloppy.domain.models.common import DatasetType
from kloppy.domain.models.common import (
DatasetType,
AttackingDirection,
OrientationError,
)
from kloppy.utils import (
camelcase_to_snakecase,
removes_suffix,
Expand Down Expand Up @@ -530,6 +534,24 @@ def event_type(self) -> EventType:
def event_name(self) -> str:
raise NotImplementedError

@property
def attacking_direction(self):
if (
self.dataset
and self.dataset.metadata
and self.dataset.metadata.orientation is not None
):
try:
return AttackingDirection.from_orientation(
self.dataset.metadata.orientation,
period=self.period,
ball_owning_team=self.ball_owning_team,
action_executing_team=self.team,
)
except OrientationError:
return AttackingDirection.NOT_SET
return AttackingDirection.NOT_SET

def get_qualifier_value(self, qualifier_type: Type[Qualifier]):
"""
Returns the Qualifier of a certain type, or None if qualifier is not present.
Expand Down
6 changes: 3 additions & 3 deletions kloppy/domain/services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def avg(items: List[float]) -> float:


def attacking_direction_from_frame(frame: Frame) -> AttackingDirection:
"""This method should only be called for the first frame of a"""
"""This method should only be called for the first frame of a period."""
avg_x_home = avg(
[
player_data.coordinates.x
Expand All @@ -32,6 +32,6 @@ def attacking_direction_from_frame(frame: Frame) -> AttackingDirection:
)

if avg_x_home < avg_x_away:
return AttackingDirection.HOME_AWAY
return AttackingDirection.LTR
else:
return AttackingDirection.AWAY_HOME
return AttackingDirection.RTL
Loading

0 comments on commit 7b8c03e

Please sign in to comment.