Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deduce synthetic carry events for event providers which don't annotate them #396

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions kloppy/domain/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,17 @@ def aggregate(self, type_: str, **aggregator_kwargs) -> List[Any]:

return aggregator.aggregate(self)

def add_deduced_event(self, event_type_: EventType):
lodevt marked this conversation as resolved.
Show resolved Hide resolved
lodevt marked this conversation as resolved.
Show resolved Hide resolved
if event_type_ == EventType.CARRY:
from kloppy.domain.services.event_deducers.carry import (
CarryDeducer,
)

deducer = CarryDeducer()
probberechts marked this conversation as resolved.
Show resolved Hide resolved
else:
raise KloppyError(f"Not possible to deduce {event_type_}")
deducer.deduce(self)


__all__ = [
"EnumQualifier",
Expand Down
Empty file.
142 changes: 142 additions & 0 deletions kloppy/domain/services/event_deducers/carry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import bisect
import uuid
from datetime import timedelta

from kloppy.domain import (
EventDataset,
EventType,
BodyPart,
CarryResult,
GenericEvent,
EventFactory,
Unit,
)
from kloppy.domain.services.event_deducers.event_deducer import (
EventDatasetDeduducer,
)


class CarryDeducer(EventDatasetDeduducer):
min_carry_length_meters = 3
max_carry_length_meters = 60
max_carry_duration = timedelta(seconds=10)
lodevt marked this conversation as resolved.
Show resolved Hide resolved
event_factory = EventFactory()
probberechts marked this conversation as resolved.
Show resolved Hide resolved

def deduce(self, dataset: EventDataset):
lodevt marked this conversation as resolved.
Show resolved Hide resolved
pitch = dataset.metadata.pitch_dimensions

new_carries = []

valid_event_types = [
EventType.PASS,
EventType.SHOT,
EventType.TAKE_ON,
EventType.CLEARANCE,
EventType.INTERCEPTION,
EventType.DUEL,
EventType.RECOVERY,
EventType.MISCONTROL,
EventType.GOALKEEPER,
]

for idx, event in enumerate(dataset.events):
if event.event_type not in valid_event_types:
continue
idx_plus = 1
generic_next_event = True
while idx + idx_plus < len(dataset.events) and generic_next_event:
next_event = dataset.events[idx + idx_plus]

if next_event.event_type in [
EventType.GENERIC,
EventType.PRESSURE,
]:
idx_plus += 1
continue
else:
generic_next_event = False
if not event.team.team_id == next_event.team.team_id:
continue

if next_event.event_type not in valid_event_types:
continue
# not headed shot
if (
(hasattr(next_event, "body_part"))
lodevt marked this conversation as resolved.
Show resolved Hide resolved
and (next_event.event_type == EventType.SHOT)
and (
next_event.body_part.type.isin(
[BodyPart.HEAD, BodyPart.HEAD_OTHER]
)
)
):
continue

if hasattr(event, "end_coordinates"):
last_coord = event.end_coordinates
elif hasattr(event, "receiver_coordinates"):
last_coord = event.receiver_coordinates
else:
last_coord = event.coordinates

new_coord = next_event.coordinates

distance_meters = pitch.distance_between(
new_coord, last_coord, Unit.METERS
)
# Not far enough
if distance_meters < self.min_carry_length_meters:
continue
# Too far
if distance_meters > self.max_carry_length_meters:
continue

dt = next_event.timestamp - event.timestamp
# not same phase
if dt > self.max_carry_duration:
continue
# not same period
if not event.period.id == next_event.period.id:
continue

if hasattr(event, "end_timestamp"):
last_timestamp = event.end_timestamp + timedelta(
seconds=0.1
)
elif hasattr(event, "receive_timestamp"):
last_timestamp = event.receive_timestamp + timedelta(
seconds=0.1
)
else:
last_timestamp = (
event.timestamp
+ (next_event.timestamp - event.timestamp) / 10
)

generic_event_args = {
"event_id": f"{str(uuid.uuid4())}",
lodevt marked this conversation as resolved.
Show resolved Hide resolved
"coordinates": last_coord,
"team": next_event.team,
"player": next_event.player,
"ball_owning_team": next_event.ball_owning_team,
"ball_state": event.ball_state,
"period": next_event.period,
"timestamp": last_timestamp,
"raw_event": {},
lodevt marked this conversation as resolved.
Show resolved Hide resolved
}
carry_event_args = {
"result": CarryResult.COMPLETE,
"qualifiers": None,
"end_coordinates": new_coord,
"end_timestamp": next_event.timestamp,
}
new_carry = self.event_factory.build_carry(
**carry_event_args, **generic_event_args
)
new_carries.append(new_carry)

for new_carry in new_carries:
pos = bisect.bisect_left(
[e.time for e in dataset.events], new_carry.time
)
dataset.records.insert(pos, new_carry)
lodevt marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 8 additions & 0 deletions kloppy/domain/services/event_deducers/event_deducer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from abc import ABC, abstractmethod
from kloppy.domain import EventDataset


class EventDatasetDeduducer(ABC):
@abstractmethod
def deduce(self, dataset: EventDataset) -> EventDataset:
raise NotImplementedError
125 changes: 125 additions & 0 deletions kloppy/tests/test_event_deducer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from datetime import timedelta
from itertools import groupby

from kloppy.domain import (
EventType,
Event,
EventDataset,
FormationType,
CarryEvent,
Unit,
)
from kloppy.domain.services.state_builder.builder import StateBuilder
from kloppy.utils import performance_logging
from kloppy import statsbomb, statsperform


class TestEventDeducer:
""""""

def _load_dataset_statsperform(
self, base_dir, base_filename="statsperform"
):
return statsperform.load_event(
ma1_data=base_dir / f"files/{base_filename}_event_ma1.json",
ma3_data=base_dir / f"files/{base_filename}_event_ma3.json",
)

def _load_dataset_statsbomb(
self, base_dir, base_filename="statsbomb", event_types=None
):
return statsbomb.load(
event_data=base_dir / f"files/{base_filename}_event.json",
lineup_data=base_dir / f"files/{base_filename}_lineup.json",
event_types=event_types,
)

def calculate_carry_accuracy(
self, real_carries, deduced_carries, real_carries_with_min_length
):
def is_match(real_carry, deduced_carry):
return (
real_carry.player
and deduced_carry.player
and real_carry.player.player_id
== deduced_carry.player.player_id
and real_carry.period == deduced_carry.period
and abs(real_carry.timestamp - deduced_carry.timestamp)
< timedelta(seconds=5)
)

true_positives = 0
matched_real_carries = set()
for deduced_carry in deduced_carries:
for idx, real_carry in enumerate(real_carries):
if idx in matched_real_carries:
continue
if is_match(real_carry, deduced_carry):
true_positives += 1
matched_real_carries.add(idx)
break

false_negatives = 0
matched_deduced_carries = set()
for real_carry in real_carries_with_min_length:
found_match = False
for idx, deduced_carry in enumerate(deduced_carries):
if idx in matched_deduced_carries:
continue
if is_match(real_carry, deduced_carry):
found_match = True
matched_deduced_carries.add(idx)
break
if not found_match:
false_negatives += 1

false_positives = len(deduced_carries) - true_positives

accuracy = true_positives / (
true_positives + false_positives + false_negatives
)

print("TP:", true_positives)
print("FP:", false_positives)
print("FN:", false_negatives)
print("accuracy:", accuracy)

return accuracy

def test_carry_deducer(self, base_dir):
dataset_with_carries = self._load_dataset_statsbomb(base_dir)
pitch = dataset_with_carries.metadata.pitch_dimensions

all_statsbomb_caries = dataset_with_carries.find_all("carry")
all_statsbomb_caries_with_min_length = [
carry
for carry in all_statsbomb_caries
if pitch.distance_between(
carry.coordinates, carry.end_coordinates, Unit.METERS
)
>= 3
]

dataset = self._load_dataset_statsbomb(
base_dir,
event_types=[
event.value for event in EventType if event.value != "CARRY"
],
)

with performance_logging("deduce_events"):
dataset.add_deduced_event(EventType.CARRY)
carry = dataset.find("carry")
index = dataset.events.index(carry)
# Assert end location is equal to start location of next action
assert carry.end_coordinates == dataset.events[index + 1].coordinates
assert carry.player == dataset.events[index + 1].player
all_carries = dataset.find_all("carry")
assert (
self.calculate_carry_accuracy(
all_statsbomb_caries,
all_carries,
all_statsbomb_caries_with_min_length,
)
> 0.80
)