Skip to content

Commit

Permalink
refactor insertion of subs
Browse files Browse the repository at this point in the history
  • Loading branch information
DriesDeprest committed Dec 19, 2024
1 parent 1a9787c commit 268d08f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 40 deletions.
51 changes: 15 additions & 36 deletions kloppy/infra/serializers/event/wyscout/deserializer_v3.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import bisect
import json
import logging
from dataclasses import replace
Expand Down Expand Up @@ -38,6 +39,7 @@
ShotResult,
TakeOnResult,
Team,
BallState,
)
from kloppy.exceptions import DeserializationError
from kloppy.utils import performance_logging
Expand Down Expand Up @@ -672,10 +674,9 @@ def _parse_period_id(raw_period: str) -> int:
return period_id


def insert_substitution_events(
deserializer, events, raw_events, players, teams, periods, transformer
def _parse_substitutions(
deserializer, raw_events, players, teams, periods, transformer
):
# Step 1: Create substitution events
substitution_events = []
for team_id, periods_subs in raw_events["substitutions"].items():
for raw_period, sub_info in periods_subs.items():
Expand All @@ -690,8 +691,8 @@ def insert_substitution_events(
sub_event = deserializer.event_factory.build_substitution(
event_id=f"substitution-{sub_out['playerId']}-{sub_in['playerId']}",
ball_owning_team=None,
ball_state=None,
coordinates=Point(x=0, y=0),
ball_state=BallState.DEAD,
coordinates=None,
player=sub_out_player,
replacement_player=sub_in_player,
team=teams[team_id],
Expand All @@ -709,36 +710,12 @@ def insert_substitution_events(
transformer.transform_event(sub_event)
)

# Step 2: Sort substitution events globally by period and timestamp
substitution_events.sort(key=lambda e: (e.period.id, e.timestamp))

# Step 3: Merge events and substitutions in ascending order
merged_events = []
sub_index = 0
total_subs = len(substitution_events)

for event in events:
# Insert all substitution events that occur before or at the current event's timestamp
while sub_index < total_subs:
sub_event = substitution_events[sub_index]
if sub_event.period.id < event.period.id or (
sub_event.period.id == event.period.id
and sub_event.timestamp <= event.timestamp
):
merged_events.append(sub_event)
sub_index += 1
else:
break

# Add the current event to the merged list
merged_events.append(event)
return substitution_events

# Step 4: Add any remaining substitution events
while sub_index < total_subs:
merged_events.append(substitution_events[sub_index])
sub_index += 1

return merged_events
def insert(event, sorted_events):
pos = bisect.bisect_left([e.time for e in sorted_events], event.time)
sorted_events.insert(pos, event)


class WyscoutDeserializerV3(EventDataDeserializer[WyscoutInputs]):
Expand Down Expand Up @@ -1019,9 +996,11 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
if event and self.should_include_event(event):
events.append(transformer.transform_event(event))

all_events = insert_substitution_events(
self, events, raw_events, players, teams, periods, transformer
substitution_events = _parse_substitutions(
self, raw_events, players, teams, periods, transformer
)
for sub_event in substitution_events:
insert(sub_event, events)

metadata = Metadata(
teams=[home_team, away_team],
Expand All @@ -1040,4 +1019,4 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
away_coach=away_coach,
)

return EventDataset(metadata=metadata, records=all_events)
return EventDataset(metadata=metadata, records=events)
14 changes: 10 additions & 4 deletions kloppy/tests/test_wyscout.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,18 +344,24 @@ def test_sub_event(self, dataset: EventDataset):
]
assert len(sub_events) == 9

assert all(
sub_events[i].time < sub_events[i + 1].time
or sub_events[i].time == sub_events[i + 1].time
for i in range(len(sub_events) - 1)
), "Substitution events are not in ascending order by time"

first_sub_event = sub_events[0]
assert first_sub_event.time == Time(
period=second_period, timestamp=timedelta(seconds=4)
)
assert first_sub_event.team.team_id == "3164"
assert first_sub_event.player.player_id == "415809"
assert first_sub_event.replacement_player.player_id == "703"
assert first_sub_event.player.player_id == "21006"
assert first_sub_event.replacement_player.player_id == "20689"

last_sub_event = sub_events[-1]
assert last_sub_event.time == Time(
period=second_period, timestamp=timedelta(seconds=2192)
)
assert last_sub_event.team.team_id == "3159"
assert last_sub_event.player.player_id == "20461"
assert last_sub_event.replacement_player.player_id == "345695"
assert last_sub_event.player.player_id == "472363"
assert last_sub_event.replacement_player.player_id == "105334"

0 comments on commit 268d08f

Please sign in to comment.