-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #55 from PySport/add-state
Add StateBuilder
- Loading branch information
Showing
30 changed files
with
171,880 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from dataclasses import replace | ||
|
||
from kloppy.domain import List, EventDataset | ||
|
||
# register all of them | ||
from . import builders | ||
|
||
from .registered import create_state_builder | ||
|
||
|
||
def add_state(dataset: EventDataset, builder_keys: List[str]) -> EventDataset: | ||
builders = { | ||
builder_key: create_state_builder(builder_key) | ||
for builder_key in builder_keys | ||
} | ||
|
||
state = { | ||
builder_key: builder.initial_state(dataset) | ||
for builder_key, builder in builders.items() | ||
} | ||
|
||
events = [] | ||
for event in dataset.events: | ||
events.append(replace(event, state=state)) | ||
state = { | ||
builder_key: builder.reduce(state[builder_key], event) | ||
for builder_key, builder in builders.items() | ||
} | ||
|
||
return replace(dataset, records=events) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from abc import abstractmethod, ABC | ||
from typing import TypeVar | ||
|
||
from kloppy.domain import EventDataset, Event | ||
from .registered import RegisteredStateBuilder | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class StateBuilder(metaclass=RegisteredStateBuilder): | ||
@abstractmethod | ||
def initial_state(self, dataset: EventDataset) -> T: | ||
pass | ||
|
||
@abstractmethod | ||
def reduce(self, state: T, event: Event) -> T: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .lineup import LineupStateBuilder | ||
from .score import ScoreStateBuilder | ||
from .sequence import SequenceStateBuilder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from dataclasses import dataclass | ||
from typing import Set | ||
|
||
from kloppy.domain import ( | ||
Event, | ||
EventDataset, | ||
Player, | ||
SubstitutionEvent, | ||
PlayerOffEvent, | ||
PlayerOnEvent, | ||
CardEvent, | ||
CardType, | ||
Provider, | ||
) | ||
from ..builder import StateBuilder | ||
|
||
|
||
@dataclass | ||
class Lineup: | ||
players: Set[Player] | ||
|
||
|
||
class LineupStateBuilder(StateBuilder): | ||
def initial_state(self, dataset: EventDataset) -> Lineup: | ||
if dataset.metadata.provider != Provider.STATSBOMB: | ||
raise Exception( | ||
"Lineup state can only be applied to statsbomb data" | ||
) | ||
|
||
return Lineup( | ||
players=( | ||
set( | ||
player | ||
for player in dataset.metadata.teams[0].players | ||
if player.starting | ||
) | ||
| set( | ||
player | ||
for player in dataset.metadata.teams[1].players | ||
if player.starting | ||
) | ||
) | ||
) | ||
|
||
def reduce(self, state: Lineup, event: Event) -> Lineup: | ||
if isinstance(event, SubstitutionEvent): | ||
state = Lineup( | ||
players=state.players - {event.player} | ||
| {event.replacement_player} | ||
) | ||
elif isinstance(event, PlayerOffEvent): | ||
state = Lineup(players=state.players - {event.player}) | ||
elif isinstance(event, PlayerOnEvent): | ||
state = Lineup(players=state.players | {event.player}) | ||
elif isinstance(event, CardEvent): | ||
if event.card_type in (CardType.SECOND_YELLOW, CardType.RED): | ||
state = Lineup(players=state.players - {event.player}) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from dataclasses import replace, dataclass | ||
|
||
from kloppy.domain import ShotEvent, Event, Ground, ShotResult, EventDataset | ||
from ..builder import StateBuilder | ||
|
||
|
||
@dataclass | ||
class Score: | ||
home: int | ||
away: int | ||
|
||
def __str__(self): | ||
return f"{self.home}-{self.away}" | ||
|
||
|
||
class ScoreStateBuilder(StateBuilder): | ||
def initial_state(self, dataset: EventDataset) -> Score: | ||
return Score(home=0, away=0) | ||
|
||
def reduce(self, state: Score, event: Event) -> Score: | ||
if isinstance(event, ShotEvent): | ||
if event.result == ShotResult.GOAL: | ||
if event.team.ground == Ground.HOME: | ||
state = replace(state, home=state.home + 1) | ||
else: | ||
state = replace(state, away=state.away + 1) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from dataclasses import replace, dataclass | ||
|
||
from kloppy.domain import Event, Team, EventDataset, PassEvent | ||
from ..builder import StateBuilder | ||
|
||
|
||
@dataclass | ||
class Sequence: | ||
sequence_id: int | ||
team: Team | ||
|
||
|
||
class SequenceStateBuilder(StateBuilder): | ||
def initial_state(self, dataset: EventDataset) -> Sequence: | ||
for event in dataset.events: | ||
if isinstance(event, PassEvent): | ||
return Sequence(sequence_id=0, team=event.team) | ||
return Sequence(sequence_id=0, team=None) | ||
|
||
def reduce(self, state: Sequence, event: Event) -> Sequence: | ||
if state.team != event.team: | ||
state = replace( | ||
state, sequence_id=state.sequence_id + 1, team=event.team | ||
) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import abc | ||
import inspect | ||
from typing import Dict, Type | ||
|
||
from kloppy.utils import camelcase_to_snakecase | ||
|
||
_STATE_BUILDER_REGISTRY: Dict[str, Type["StateBuilder"]] = {} | ||
|
||
|
||
class RegisteredStateBuilder(abc.ABCMeta): | ||
def __new__(mcs, cls_name, bases, class_dict): | ||
name = camelcase_to_snakecase(cls_name) | ||
class_dict["name"] = name | ||
builder_cls = super(RegisteredStateBuilder, mcs).__new__( | ||
mcs, cls_name, bases, class_dict | ||
) | ||
if not inspect.isabstract(builder_cls): | ||
_STATE_BUILDER_REGISTRY[ | ||
name.replace("_state_builder", "") | ||
] = builder_cls | ||
return builder_cls | ||
|
||
|
||
def create_state_builder(builder_key: str): | ||
if builder_key not in _STATE_BUILDER_REGISTRY: | ||
raise ValueError( | ||
f"StateBuilder {builder_key} not found. Known builders: {', '.join(_STATE_BUILDER_REGISTRY.keys())}" | ||
) | ||
return _STATE_BUILDER_REGISTRY[builder_key]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.