Skip to content

Commit

Permalink
Merge pull request #55 from PySport/add-state
Browse files Browse the repository at this point in the history
Add StateBuilder
  • Loading branch information
koenvo authored Sep 8, 2020
2 parents 19419ec + a4a40c5 commit b702a01
Show file tree
Hide file tree
Showing 30 changed files with 171,880 additions and 53 deletions.
2 changes: 1 addition & 1 deletion examples/datasets/statsbomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys

from kloppy import datasets, transform, to_pandas
from kloppy.infra.utils import performance_logging
from kloppy.utils import performance_logging


def main():
Expand Down
2 changes: 1 addition & 1 deletion examples/pattern_matching/ball_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import Counter

from kloppy import datasets, event_pattern_matching as pm
from kloppy.infra.utils import performance_logging
from kloppy.utils import performance_logging


def main():
Expand Down
1 change: 1 addition & 0 deletions kloppy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .helpers import *
from .infra import datasets
from .domain.services.matchers.pattern import event as event_pattern_matching
from .domain.services.state_builder import add_state
2 changes: 1 addition & 1 deletion kloppy/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
load_opta_event_data,
event_pattern_matching as pm,
)
from kloppy.infra.utils import performance_logging
from kloppy.utils import performance_logging

sys.path.append(".")

Expand Down
4 changes: 4 additions & 0 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ class Player:
name: str = None
first_name: str = None
last_name: str = None

# match specific
starting: bool = None
position: Position = None

attributes: Optional[Dict] = field(default_factory=dict, compare=False)

@property
Expand Down
50 changes: 49 additions & 1 deletion kloppy/domain/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,23 @@ def is_success(self):
return self == self.COMPLETE


class CardType(Enum):
FIRST_YELLOW = "FIRST_YELLOW"
SECOND_YELLOW = "SECOND_YELLOW"
RED = "RED"


class EventType(Enum):
GENERIC = "generic"

PASS = "PASS"
SHOT = "SHOT"
TAKE_ON = "TAKE_ON"
CARRY = "CARRY"
SUBSTITUTION = "SUBSTITUTION"
CARD = "CARD"
PLAYER_ON = "PLAYER_ON"
PLAYER_OFF = "PLAYER_OFF"


@dataclass
Expand All @@ -75,9 +85,10 @@ class Event(DataRecord, ABC):
player: Player
coordinates: Point

result: ResultType
result: Union[ResultType, None]

raw_event: Dict
state: Dict[str, any]

@property
@abstractmethod
Expand All @@ -89,6 +100,10 @@ def event_type(self) -> EventType:
def event_name(self) -> str:
raise NotImplementedError

@classmethod
def create(cls, **kwargs):
return cls(**kwargs, state={})


@dataclass
class GenericEvent(Event):
Expand Down Expand Up @@ -135,6 +150,34 @@ class CarryEvent(Event):
event_name: str = "carry"


@dataclass
class SubstitutionEvent(Event):
replacement_player: Player

event_type: EventType = EventType.SUBSTITUTION
event_name: str = "substitution"


@dataclass
class PlayerOffEvent(Event):
event_type: EventType = EventType.PLAYER_OFF
event_name: str = "player_off"


@dataclass
class PlayerOnEvent(Event):
event_type: EventType = EventType.PLAYER_ON
event_name: str = "player_on"


@dataclass
class CardEvent(Event):
card_type: CardType

event_type: EventType = EventType.CARD
event_name: str = "card"


@dataclass
class EventDataset(Dataset):
records: List[
Expand All @@ -161,5 +204,10 @@ def events(self):
"PassEvent",
"TakeOnEvent",
"CarryEvent",
"SubstitutionEvent",
"PlayerOnEvent",
"PlayerOffEvent",
"CardEvent",
"CardType",
"EventDataset",
]
30 changes: 30 additions & 0 deletions kloppy/domain/services/state_builder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from dataclasses import replace

from kloppy.domain import List, EventDataset

# register all of them
from . import builders

from .registered import create_state_builder


def add_state(dataset: EventDataset, builder_keys: List[str]) -> EventDataset:
builders = {
builder_key: create_state_builder(builder_key)
for builder_key in builder_keys
}

state = {
builder_key: builder.initial_state(dataset)
for builder_key, builder in builders.items()
}

events = []
for event in dataset.events:
events.append(replace(event, state=state))
state = {
builder_key: builder.reduce(state[builder_key], event)
for builder_key, builder in builders.items()
}

return replace(dataset, records=events)
17 changes: 17 additions & 0 deletions kloppy/domain/services/state_builder/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from abc import abstractmethod, ABC
from typing import TypeVar

from kloppy.domain import EventDataset, Event
from .registered import RegisteredStateBuilder

T = TypeVar("T")


class StateBuilder(metaclass=RegisteredStateBuilder):
@abstractmethod
def initial_state(self, dataset: EventDataset) -> T:
pass

@abstractmethod
def reduce(self, state: T, event: Event) -> T:
pass
3 changes: 3 additions & 0 deletions kloppy/domain/services/state_builder/builders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .lineup import LineupStateBuilder
from .score import ScoreStateBuilder
from .sequence import SequenceStateBuilder
58 changes: 58 additions & 0 deletions kloppy/domain/services/state_builder/builders/lineup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from dataclasses import dataclass
from typing import Set

from kloppy.domain import (
Event,
EventDataset,
Player,
SubstitutionEvent,
PlayerOffEvent,
PlayerOnEvent,
CardEvent,
CardType,
Provider,
)
from ..builder import StateBuilder


@dataclass
class Lineup:
players: Set[Player]


class LineupStateBuilder(StateBuilder):
def initial_state(self, dataset: EventDataset) -> Lineup:
if dataset.metadata.provider != Provider.STATSBOMB:
raise Exception(
"Lineup state can only be applied to statsbomb data"
)

return Lineup(
players=(
set(
player
for player in dataset.metadata.teams[0].players
if player.starting
)
| set(
player
for player in dataset.metadata.teams[1].players
if player.starting
)
)
)

def reduce(self, state: Lineup, event: Event) -> Lineup:
if isinstance(event, SubstitutionEvent):
state = Lineup(
players=state.players - {event.player}
| {event.replacement_player}
)
elif isinstance(event, PlayerOffEvent):
state = Lineup(players=state.players - {event.player})
elif isinstance(event, PlayerOnEvent):
state = Lineup(players=state.players | {event.player})
elif isinstance(event, CardEvent):
if event.card_type in (CardType.SECOND_YELLOW, CardType.RED):
state = Lineup(players=state.players - {event.player})
return state
27 changes: 27 additions & 0 deletions kloppy/domain/services/state_builder/builders/score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dataclasses import replace, dataclass

from kloppy.domain import ShotEvent, Event, Ground, ShotResult, EventDataset
from ..builder import StateBuilder


@dataclass
class Score:
home: int
away: int

def __str__(self):
return f"{self.home}-{self.away}"


class ScoreStateBuilder(StateBuilder):
def initial_state(self, dataset: EventDataset) -> Score:
return Score(home=0, away=0)

def reduce(self, state: Score, event: Event) -> Score:
if isinstance(event, ShotEvent):
if event.result == ShotResult.GOAL:
if event.team.ground == Ground.HOME:
state = replace(state, home=state.home + 1)
else:
state = replace(state, away=state.away + 1)
return state
25 changes: 25 additions & 0 deletions kloppy/domain/services/state_builder/builders/sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from dataclasses import replace, dataclass

from kloppy.domain import Event, Team, EventDataset, PassEvent
from ..builder import StateBuilder


@dataclass
class Sequence:
sequence_id: int
team: Team


class SequenceStateBuilder(StateBuilder):
def initial_state(self, dataset: EventDataset) -> Sequence:
for event in dataset.events:
if isinstance(event, PassEvent):
return Sequence(sequence_id=0, team=event.team)
return Sequence(sequence_id=0, team=None)

def reduce(self, state: Sequence, event: Event) -> Sequence:
if state.team != event.team:
state = replace(
state, sequence_id=state.sequence_id + 1, team=event.team
)
return state
29 changes: 29 additions & 0 deletions kloppy/domain/services/state_builder/registered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import abc
import inspect
from typing import Dict, Type

from kloppy.utils import camelcase_to_snakecase

_STATE_BUILDER_REGISTRY: Dict[str, Type["StateBuilder"]] = {}


class RegisteredStateBuilder(abc.ABCMeta):
def __new__(mcs, cls_name, bases, class_dict):
name = camelcase_to_snakecase(cls_name)
class_dict["name"] = name
builder_cls = super(RegisteredStateBuilder, mcs).__new__(
mcs, cls_name, bases, class_dict
)
if not inspect.isabstract(builder_cls):
_STATE_BUILDER_REGISTRY[
name.replace("_state_builder", "")
] = builder_cls
return builder_cls


def create_state_builder(builder_key: str):
if builder_key not in _STATE_BUILDER_REGISTRY:
raise ValueError(
f"StateBuilder {builder_key} not found. Known builders: {', '.join(_STATE_BUILDER_REGISTRY.keys())}"
)
return _STATE_BUILDER_REGISTRY[builder_key]()
12 changes: 2 additions & 10 deletions kloppy/infra/datasets/core/registered.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
import inspect
import re
import abc
from typing import Type, Dict


_first_cap_re = re.compile("(.)([A-Z][a-z0-9]+)")
_all_cap_re = re.compile("([a-z0-9])([A-Z])")

# from .builder import DatasetBuilder
_DATASET_REGISTRY: Dict[str, Type["DatasetBuilder"]] = {}
from kloppy.utils import camelcase_to_snakecase


def camelcase_to_snakecase(name):
"""Convert camel-case string to snake-case."""
s1 = _first_cap_re.sub(r"\1_\2", name)
return _all_cap_re.sub(r"\1_\2", s1).lower()
_DATASET_REGISTRY: Dict[str, Type["DatasetBuilder"]] = {}


class RegisteredDataset(abc.ABCMeta):
Expand Down
2 changes: 1 addition & 1 deletion kloppy/infra/serializers/event/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Tuple, Dict

from kloppy.infra.utils import Readable
from kloppy.utils import Readable
from kloppy.domain import EventDataset


Expand Down
Loading

0 comments on commit b702a01

Please sign in to comment.