Skip to content

Commit

Permalink
Fix transform orientation
Browse files Browse the repository at this point in the history
Fixes #175
  • Loading branch information
probberechts committed May 15, 2023
1 parent 64d9f1e commit b50bbe9
Show file tree
Hide file tree
Showing 3 changed files with 309 additions and 94 deletions.
169 changes: 102 additions & 67 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,42 @@ def __repr__(self):
return self.value


@dataclass
class Period:
"""
Period
Attributes:
id: `1` for first half, `2` for second half, `3` for first overtime,
`4` for second overtime, and `5` for penalty shootouts
start_timestamp: timestamp given by provider (can be unix timestamp or relative)
end_timestamp: timestamp given by provider (can be unix timestamp or relative)
attacking_direction: See [`AttackingDirection`][kloppy.domain.models.common.AttackingDirection]
"""

id: int
start_timestamp: float
end_timestamp: float
attacking_direction: AttackingDirection = AttackingDirection.NOT_SET

def contains(self, timestamp: float):
return self.start_timestamp <= timestamp <= self.end_timestamp

@property
def attacking_direction_set(self):
return self.attacking_direction != AttackingDirection.NOT_SET

def set_attacking_direction(self, attacking_direction: AttackingDirection):
self.attacking_direction = attacking_direction

@property
def duration(self):
return self.end_timestamp - self.start_timestamp

def __eq__(self, other):
return isinstance(other, Period) and other.id == self.id


class Orientation(Enum):
# change when possession changes
BALL_OWNING_TEAM = "ball-owning-team"
Expand All @@ -261,50 +297,74 @@ class Orientation(Enum):
# Not set in dataset
NOT_SET = "not-set"

def get_attacking_direction(self, period: Period) -> AttackingDirection:
if self == Orientation.FIXED_HOME_AWAY:
return AttackingDirection.HOME_AWAY
if self == Orientation.FIXED_AWAY_HOME:
return AttackingDirection.AWAY_HOME
if self == Orientation.HOME_TEAM:
dirmap = {
1: AttackingDirection.HOME_AWAY,
2: AttackingDirection.AWAY_HOME,
3: AttackingDirection.HOME_AWAY,
4: AttackingDirection.AWAY_HOME,
}
return dirmap.get(period.id, period.attacking_direction)
if self == Orientation.AWAY_TEAM:
dirmap = {
1: AttackingDirection.AWAY_HOME,
2: AttackingDirection.HOME_AWAY,
3: AttackingDirection.AWAY_HOME,
4: AttackingDirection.HOME_AWAY,
}
return dirmap.get(period.id, period.attacking_direction)
return AttackingDirection.NOT_SET

def get_orientation_factor(
self,
attacking_direction: AttackingDirection,
period: Period,
ball_owning_team: Team,
action_executing_team: Team,
) -> int:
if period.id == 5:
return 1 # the orientation of penalty shootouts is not transformed
if self == Orientation.FIXED_HOME_AWAY:
return -1
elif self == Orientation.FIXED_AWAY_HOME:
return 1
elif self == Orientation.HOME_TEAM:
if attacking_direction == AttackingDirection.HOME_AWAY:
return -1
elif attacking_direction == AttackingDirection.AWAY_HOME:
if self == Orientation.FIXED_AWAY_HOME:
return -1
if self == Orientation.HOME_TEAM:
if period.id == 1 or period.id == 3:
return 1
else:
raise OrientationError("AttackingDirection not set")
elif self == Orientation.AWAY_TEAM:
if attacking_direction == AttackingDirection.AWAY_HOME:
if period.id == 2 or period.id == 4:
return -1
elif attacking_direction == AttackingDirection.HOME_AWAY:
raise OrientationError(
f"AttackingDirection not defined for period with id {period.id}"
)
if self == Orientation.AWAY_TEAM:
if period.id == 1 or period.id == 3:
return -1
if period.id == 2 or period.id == 4:
return 1
else:
raise OrientationError("AttackingDirection not set")
elif self == Orientation.BALL_OWNING_TEAM:
raise OrientationError(
f"AttackingDirection not defined for period with id {period.id}"
)
if self == Orientation.BALL_OWNING_TEAM:
if ball_owning_team.ground == Ground.HOME:
return -1
elif ball_owning_team.ground == Ground.AWAY:
return 1
else:
raise OrientationError(
f"Invalid ball_owning_team: {ball_owning_team}"
)
elif self == Orientation.ACTION_EXECUTING_TEAM:
if action_executing_team.ground == Ground.HOME:
if ball_owning_team.ground == Ground.AWAY:
return -1
elif action_executing_team.ground == Ground.AWAY:
raise OrientationError(
f"Invalid ball_owning_team: {ball_owning_team}"
)
if self == Orientation.ACTION_EXECUTING_TEAM:
if action_executing_team.ground == Ground.HOME:
return 1
else:
raise OrientationError(
f"Invalid action_executing_team: {action_executing_team}"
)
else:
raise OrientationError(f"Unknown orientation: {self}")
if action_executing_team.ground == Ground.AWAY:
return -1
raise OrientationError(
f"Invalid action_executing_team: {action_executing_team}"
)
raise OrientationError(f"Unknown orientation: {self}")

def __repr__(self):
return self.value
Expand All @@ -318,43 +378,6 @@ class VerticalOrientation(Enum):
BOTTOM_TO_TOP = "bottom-to-top"


@dataclass
class Period:
"""
Period
Attributes:
id: `1` for first half, `2` for second half
start_timestamp: timestamp given by provider (can be unix timestamp or relative)
end_timestamp: timestamp given by provider (can be unix timestamp or relative)
attacking_direction: See [`AttackingDirection`][kloppy.domain.models.common.AttackingDirection]
"""

id: int
start_timestamp: float
end_timestamp: float
attacking_direction: Optional[
AttackingDirection
] = AttackingDirection.NOT_SET

def contains(self, timestamp: float):
return self.start_timestamp <= timestamp <= self.end_timestamp

@property
def attacking_direction_set(self):
return self.attacking_direction != AttackingDirection.NOT_SET

def set_attacking_direction(self, attacking_direction: AttackingDirection):
self.attacking_direction = attacking_direction

@property
def duration(self):
return self.end_timestamp - self.start_timestamp

def __eq__(self, other):
return isinstance(other, Period) and other.id == self.id


class Origin(Enum):
"""
Attributes:
Expand Down Expand Up @@ -764,6 +787,18 @@ class Metadata:
coordinate_system: CoordinateSystem
attributes: Optional[Dict] = field(default_factory=dict, compare=False)

def __post_init__(self):
if self.coordinate_system is not None:
# set the pitch dimensions from the coordinate system
self.pitch_dimensions = self.coordinate_system.pitch_dimensions

if self.orientation is not None:
# set the attacking directions from the orientation
for period in self.periods:
period.attacking_direction = (
self.orientation.get_attacking_direction(period)
)


class DatasetType(Enum):
"""
Expand Down
59 changes: 40 additions & 19 deletions kloppy/domain/services/transformers/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
AttackingDirection,
Dataset,
DatasetFlag,
DataRecord,
EventDataset,
Frame,
Metadata,
Orientation,
PitchDimensions,
Period,
Point,
Point3D,
Team,
Expand Down Expand Up @@ -73,6 +75,10 @@ def _needs_pitch_dimensions_change(self):
if self._from_pitch_dimensions and self._to_pitch_dimensions:
return self._from_pitch_dimensions != self._to_pitch_dimensions

@property
def _needs_orientation_change(self):
return self._from_orientation != self._to_orientation

def change_point_dimensions(self, point: Union[Point, Point3D]) -> Point:

if point is None:
Expand Down Expand Up @@ -111,7 +117,7 @@ def flip_point(self, point: Union[Point, Point3D]):
def __needs_flip(
self,
ball_owning_team: Team,
attacking_direction: AttackingDirection,
period: Period,
action_executing_team: Team = None,
) -> bool:
if self._from_orientation == self._to_orientation:
Expand All @@ -123,35 +129,38 @@ def __needs_flip(
orientation_factor_from = (
self._from_orientation.get_orientation_factor(
ball_owning_team=ball_owning_team,
attacking_direction=attacking_direction,
period=period,
action_executing_team=action_executing_team,
)
)
orientation_factor_to = (
self._to_orientation.get_orientation_factor(
ball_owning_team=ball_owning_team,
attacking_direction=attacking_direction,
period=period,
action_executing_team=action_executing_team,
)
)
flip = orientation_factor_from != orientation_factor_to
return flip

def transform_frame(self, frame: Frame) -> Frame:

# Change coordinate system
if self._needs_coordinate_system_change:
frame = self.__change_frame_coordinate_system(frame)

# Change dimensions
elif self._needs_pitch_dimensions_change:
frame = self.__change_frame_dimensions(frame)

# Flip frame based on orientation
if self.__needs_flip(
ball_owning_team=frame.ball_owning_team,
attacking_direction=frame.period.attacking_direction,
):
frame = self.__flip_frame(frame)
if self._needs_orientation_change:
if self.__needs_flip(
ball_owning_team=frame.ball_owning_team,
period=frame.period,
):
frame = self.__flip_frame(frame)

frame = self.__change_attacking_direction(frame)

return frame

Expand Down Expand Up @@ -236,6 +245,15 @@ def __change_point_coordinate_system(self, point: Union[Point, Point3D]):
else:
return Point(x=x, y=y)

def __change_attacking_direction(self, record: DataRecord):
new_attacking_direction = self._to_orientation.get_attacking_direction(
record.period
)
period = replace(
record.period, attacking_direction=new_attacking_direction
)
return replace(record, period=period)

def __flip_frame(self, frame: Frame):

players_data = {}
Expand All @@ -261,24 +279,27 @@ def __flip_frame(self, frame: Frame):
)

def transform_event(self, event: Event) -> Event:

# Change coordinate system
if self._needs_coordinate_system_change:
event = self.__change_event_coordinate_system(event)

# Change dimensions
elif self._needs_pitch_dimensions_change:
event = self.__change_event_dimensions(event)

# Flip event based on orientation
if self.__needs_flip(
ball_owning_team=event.ball_owning_team,
attacking_direction=event.period.attacking_direction,
action_executing_team=event.team,
):
event = self.__flip_event(event)

if event.freeze_frame:
event.freeze_frame = self.transform_frame(event.freeze_frame)
if self._needs_orientation_change:
if self.__needs_flip(
ball_owning_team=event.ball_owning_team,
period=event.period,
action_executing_team=event.team,
):
event = self.__flip_event(event)

if event.freeze_frame:
event.freeze_frame = self.transform_frame(event.freeze_frame)

event = self.__change_attacking_direction(event)

return event

Expand Down
Loading

0 comments on commit b50bbe9

Please sign in to comment.