Skip to content

Commit

Permalink
Merge pull request #16 from bdagnino/fix-epts-loading
Browse files Browse the repository at this point in the history
Fix EPTS serialization issues with period setting and players not pre…
  • Loading branch information
koenvo authored Jun 13, 2020
2 parents 1fca251 + 36c4214 commit 3d249f9
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 18 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.idea/
__pycache__/
.vscode
2 changes: 1 addition & 1 deletion kloppy/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def transform(dataset: DatasetType, to_orientation=None, to_pitch_dimensions=Non

def _frame_to_pandas_row_converter(frame: Frame) -> Dict:
row = dict(
period_id=frame.period.id,
period_id=frame.period.id if frame.period else None,
timestamp=frame.timestamp,
ball_state=frame.ball_state.value if frame.ball_state else None,
ball_owning_team=frame.ball_owning_team.value if frame.ball_owning_team else None,
Expand Down
10 changes: 5 additions & 5 deletions kloppy/infra/serializers/tracking/epts/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def _set_current_data_spec(idx):
_set_current_data_spec(0)

periods = meta_data.periods
period_idx = 0
n = 0
sample = 1. / sample_rate

Expand All @@ -79,10 +78,11 @@ def _set_current_data_spec(idx):
row['frame_id'] = frame_id
row['timestamp'] = timestamp

if period_idx > len(periods):
if timestamp > periods[period_idx].end_timestamp:
period_idx += 1
row['period_id'] = periods[period_idx].id
row['period_id'] = None
for period in periods:
if period.start_timestamp <= timestamp <= period.end_timestamp:
row['period_id'] = period.id
break

yield row

Expand Down
22 changes: 12 additions & 10 deletions kloppy/infra/serializers/tracking/epts/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __validate_inputs(inputs: Dict[str, Readable]):
@staticmethod
def _frame_from_row(row: dict, meta_data: EPTSMetaData) -> Frame:
timestamp = row['timestamp']
if meta_data.periods:
if meta_data.periods and row['period_id']:
# might want to search for it instead
period = meta_data.periods[row['period_id'] - 1]
else:
Expand All @@ -41,15 +41,17 @@ def _frame_from_row(row: dict, meta_data: EPTSMetaData) -> Frame:
away_team_player_positions = {}
for player in meta_data.players:
if player.team == Team.HOME:
home_team_player_positions[player.jersey_no] = Point(
x=row[f'player_home_{player.jersey_no}_x'],
y=row[f'player_home_{player.jersey_no}_y']
)
if f'player_home_{player.jersey_no}_x' in row:
home_team_player_positions[player.jersey_no] = Point(
x=row[f'player_home_{player.jersey_no}_x'],
y=row[f'player_home_{player.jersey_no}_y']
)
elif player.team == Team.AWAY:
home_team_player_positions[player.jersey_no] = Point(
x=row[f'player_away_{player.jersey_no}_x'],
y=row[f'player_away_{player.jersey_no}_y']
)
if f'player_away_{player.jersey_no}_x' in row:
away_team_player_positions[player.jersey_no] = Point(
x=row[f'player_away_{player.jersey_no}_x'],
y=row[f'player_away_{player.jersey_no}_y']
)

return Frame(
frame_id=row['frame_id'],
Expand Down Expand Up @@ -140,7 +142,7 @@ def deserialize(self, inputs: Dict[str, Readable], options: Dict = None) -> Trac
Orientation.FIXED_HOME_AWAY
if start_attacking_direction == AttackingDirection.HOME_AWAY else
Orientation.FIXED_AWAY_HOME
) if start_attacking_direction else None
) if start_attacking_direction != AttackingDirection.NOT_SET else None

return TrackingDataset(
flags=~(DatasetFlag.BALL_STATE | DatasetFlag.BALL_OWNING_TEAM),
Expand Down
30 changes: 30 additions & 0 deletions kloppy/tests/files/epts_meta.xml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,16 @@
<TrackingType>GPS</TrackingType>
<ProviderName>Example</ProviderName>
<FrameRate>25</FrameRate>
<ProviderGlobalParameters>
<ProviderParameter>
<Name>first_half_start</Name>
<Value>100</Value>
</ProviderParameter>
<ProviderParameter>
<Name>first_half_end</Name>
<Value>14900</Value>
</ProviderParameter>
</ProviderGlobalParameters>
</GlobalConfig>

<Sessions>
Expand Down Expand Up @@ -69,6 +79,16 @@
</ProviderParameter>
</ProviderPlayerParameters>
</Player>
<Player id="4" teamId="1">
<Name>Juan Perez</Name>
<ShirtNumber>14</ShirtNumber>
<ProviderPlayerParameters>
<ProviderParameter>
<Name>position</Name>
<Value>Goalkeeper</Value>
</ProviderParameter>
</ProviderPlayerParameters>
</Player>
</Players>

<Devices>
Expand Down Expand Up @@ -175,6 +195,16 @@
<PlayerChannel id="player3_max_acceleration" playerId="3" channelId="max_acceleration"></PlayerChannel>
<PlayerChannel id="player3_heartbeat" playerId="3" channelId="heartbeat"></PlayerChannel>
<PlayerChannel id="player3_max_heartbeat" playerId="3" channelId="max_heartbeat"></PlayerChannel>
<PlayerChannel id="player4_x" playerId="4" channelId="x"></PlayerChannel>
<PlayerChannel id="player4_y" playerId="4" channelId="y"></PlayerChannel>
<PlayerChannel id="player4_z" playerId="4" channelId="z"></PlayerChannel>
<PlayerChannel id="player4_distance" playerId="4" channelId="distance"></PlayerChannel>
<PlayerChannel id="player4_avg_speed" playerId="4" channelId="avg_speed"></PlayerChannel>
<PlayerChannel id="player4_max_speed" playerId="4" channelId="max_speed"></PlayerChannel>
<PlayerChannel id="player4_acceleration" playerId="4" channelId="acceleration"></PlayerChannel>
<PlayerChannel id="player4_max_acceleration" playerId="4" channelId="max_acceleration"></PlayerChannel>
<PlayerChannel id="player4_heartbeat" playerId="4" channelId="heartbeat"></PlayerChannel>
<PlayerChannel id="player4_max_heartbeat" playerId="4" channelId="max_heartbeat"></PlayerChannel>
</PlayerChannels>
</Metadata>
<DataFormatSpecifications>
Expand Down
4 changes: 2 additions & 2 deletions kloppy/tests/test_epts.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def test_correct_deserialization(self):
)

assert len(dataset.records) == 2
assert len(dataset.periods) == 0
assert dataset.orientation == Orientation.FIXED_HOME_AWAY
assert len(dataset.periods) == 1
assert dataset.orientation is None

assert dataset.records[0].home_team_player_positions['22'] == Point(x=-769, y=-2013)
assert dataset.records[0].away_team_player_positions == {}
Expand Down

0 comments on commit 3d249f9

Please sign in to comment.