Skip to content

Commit

Permalink
wyscout v3 - add substitutions to the event stream
Browse files Browse the repository at this point in the history
  • Loading branch information
DriesDeprest committed Nov 26, 2024
1 parent 80e16fb commit 11b1ed5
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 1 deletion.
75 changes: 74 additions & 1 deletion kloppy/infra/serializers/event/wyscout/deserializer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,75 @@ def _parse_period_id(raw_period: str) -> int:
return period_id


def insert_substitution_events(
deserializer, events, raw_events, players, teams, periods, transformer
):
# Step 1: Create substitution events
substitution_events = []
for team_id, periods_subs in raw_events["substitutions"].items():
for raw_period, sub_info in periods_subs.items():
for raw_seconds, players_info in sub_info.items():
subs_out = players_info["out"]
subs_in = players_info["in"]
for sub_out, sub_in in zip(subs_out, subs_in):
sub_out_player = players[team_id][str(sub_out["playerId"])]
sub_in_player = players[team_id][str(sub_in["playerId"])]

# Build the substitution event
sub_event = deserializer.event_factory.build_substitution(
event_id=f"substitution-{sub_out['playerId']}-{sub_in['playerId']}",
ball_owning_team=None,
ball_state=None,
coordinates=Point(x=0, y=0),
player=sub_out_player,
replacement_player=sub_in_player,
team=teams[team_id],
period=periods[int(raw_period[0]) - 1],
timestamp=timedelta(seconds=int(raw_seconds)),
result=None,
raw_event=None,
qualifiers=None,
)

if sub_event and deserializer.should_include_event(
sub_event
):
substitution_events.append(
transformer.transform_event(sub_event)
)

# Step 2: Sort substitution events globally by period and timestamp
substitution_events.sort(key=lambda e: (e.period.id, e.timestamp))

# Step 3: Merge events and substitutions in ascending order
merged_events = []
sub_index = 0
total_subs = len(substitution_events)

for event in events:
# Insert all substitution events that occur before or at the current event's timestamp
while sub_index < total_subs:
sub_event = substitution_events[sub_index]
if sub_event.period.id < event.period.id or (
sub_event.period.id == event.period.id
and sub_event.timestamp <= event.timestamp
):
merged_events.append(sub_event)
sub_index += 1
else:
break

# Add the current event to the merged list
merged_events.append(event)

# Step 4: Add any remaining substitution events
while sub_index < total_subs:
merged_events.append(substitution_events[sub_index])
sub_index += 1

return merged_events


class WyscoutDeserializerV3(EventDataDeserializer[WyscoutInputs]):
@property
def provider(self) -> Provider:
Expand Down Expand Up @@ -864,6 +933,10 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
if event and self.should_include_event(event):
events.append(transformer.transform_event(event))

all_events = insert_substitution_events(
self, events, raw_events, players, teams, periods, transformer
)

metadata = Metadata(
teams=[home_team, away_team],
periods=periods,
Expand All @@ -881,4 +954,4 @@ def deserialize(self, inputs: WyscoutInputs) -> EventDataset:
away_coach=away_coach,
)

return EventDataset(metadata=metadata, records=events)
return EventDataset(metadata=metadata, records=all_events)
26 changes: 26 additions & 0 deletions kloppy/tests/test_wyscout.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,29 @@ def test_carry_event(self, dataset: EventDataset):
carry_event = dataset.get_event_by_id(1927028490)
assert carry_event.event_type == EventType.CARRY
assert carry_event.end_coordinates == Point(17.0, 4.0)

def test_sub_event(self, dataset: EventDataset):
second_period = dataset.metadata.periods[1]

sub_events = [
event
for event in dataset.events
if event.event_type == EventType.SUBSTITUTION
]
assert len(sub_events) == 9

first_sub_event = sub_events[0]
assert first_sub_event.time == Time(
period=second_period, timestamp=timedelta(seconds=4)
)
assert first_sub_event.team.team_id == "3164"
assert first_sub_event.player.player_id == "415809"
assert first_sub_event.replacement_player.player_id == "703"

last_sub_event = sub_events[-1]
assert last_sub_event.time == Time(
period=second_period, timestamp=timedelta(seconds=2192)
)
assert last_sub_event.team.team_id == "3159"
assert last_sub_event.player.player_id == "20461"
assert last_sub_event.replacement_player.player_id == "345695"

0 comments on commit 11b1ed5

Please sign in to comment.