-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cfcb52a
commit a4a40c5
Showing
7 changed files
with
205 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from abc import abstractmethod, ABC | ||
from typing import TypeVar | ||
|
||
from kloppy.domain import EventDataset, Event | ||
from .registered import RegisteredStateBuilder | ||
|
||
T = TypeVar("T") | ||
|
||
|
||
class StateBuilder(metaclass=RegisteredStateBuilder): | ||
@abstractmethod | ||
def initial_state(self, dataset: EventDataset) -> T: | ||
pass | ||
|
||
@abstractmethod | ||
def reduce(self, state: T, event: Event) -> T: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .lineup import LineupStateBuilder | ||
from .score import ScoreStateBuilder | ||
from .sequence import SequenceStateBuilder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from dataclasses import dataclass | ||
from typing import Set | ||
|
||
from kloppy.domain import ( | ||
Event, | ||
EventDataset, | ||
Player, | ||
SubstitutionEvent, | ||
PlayerOffEvent, | ||
PlayerOnEvent, | ||
CardEvent, | ||
CardType, | ||
Provider, | ||
) | ||
from ..builder import StateBuilder | ||
|
||
|
||
@dataclass | ||
class Lineup: | ||
players: Set[Player] | ||
|
||
|
||
class LineupStateBuilder(StateBuilder): | ||
def initial_state(self, dataset: EventDataset) -> Lineup: | ||
if dataset.metadata.provider != Provider.STATSBOMB: | ||
raise Exception( | ||
"Lineup state can only be applied to statsbomb data" | ||
) | ||
|
||
return Lineup( | ||
players=( | ||
set( | ||
player | ||
for player in dataset.metadata.teams[0].players | ||
if player.starting | ||
) | ||
| set( | ||
player | ||
for player in dataset.metadata.teams[1].players | ||
if player.starting | ||
) | ||
) | ||
) | ||
|
||
def reduce(self, state: Lineup, event: Event) -> Lineup: | ||
if isinstance(event, SubstitutionEvent): | ||
state = Lineup( | ||
players=state.players - {event.player} | ||
| {event.replacement_player} | ||
) | ||
elif isinstance(event, PlayerOffEvent): | ||
state = Lineup(players=state.players - {event.player}) | ||
elif isinstance(event, PlayerOnEvent): | ||
state = Lineup(players=state.players | {event.player}) | ||
elif isinstance(event, CardEvent): | ||
if event.card_type in (CardType.SECOND_YELLOW, CardType.RED): | ||
state = Lineup(players=state.players - {event.player}) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from dataclasses import replace, dataclass | ||
|
||
from kloppy.domain import ShotEvent, Event, Ground, ShotResult, EventDataset | ||
from ..builder import StateBuilder | ||
|
||
|
||
@dataclass | ||
class Score: | ||
home: int | ||
away: int | ||
|
||
def __str__(self): | ||
return f"{self.home}-{self.away}" | ||
|
||
|
||
class ScoreStateBuilder(StateBuilder): | ||
def initial_state(self, dataset: EventDataset) -> Score: | ||
return Score(home=0, away=0) | ||
|
||
def reduce(self, state: Score, event: Event) -> Score: | ||
if isinstance(event, ShotEvent): | ||
if event.result == ShotResult.GOAL: | ||
if event.team.ground == Ground.HOME: | ||
state = replace(state, home=state.home + 1) | ||
else: | ||
state = replace(state, away=state.away + 1) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from dataclasses import replace, dataclass | ||
|
||
from kloppy.domain import Event, Team, EventDataset, PassEvent | ||
from ..builder import StateBuilder | ||
|
||
|
||
@dataclass | ||
class Sequence: | ||
sequence_id: int | ||
team: Team | ||
|
||
|
||
class SequenceStateBuilder(StateBuilder): | ||
def initial_state(self, dataset: EventDataset) -> Sequence: | ||
for event in dataset.events: | ||
if isinstance(event, PassEvent): | ||
return Sequence(sequence_id=0, team=event.team) | ||
return Sequence(sequence_id=0, team=None) | ||
|
||
def reduce(self, state: Sequence, event: Event) -> Sequence: | ||
if state.team != event.team: | ||
state = replace( | ||
state, sequence_id=state.sequence_id + 1, team=event.team | ||
) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import abc | ||
import inspect | ||
from typing import Dict, Type | ||
|
||
from kloppy.utils import camelcase_to_snakecase | ||
|
||
_STATE_BUILDER_REGISTRY: Dict[str, Type["StateBuilder"]] = {} | ||
|
||
|
||
class RegisteredStateBuilder(abc.ABCMeta): | ||
def __new__(mcs, cls_name, bases, class_dict): | ||
name = camelcase_to_snakecase(cls_name) | ||
class_dict["name"] = name | ||
builder_cls = super(RegisteredStateBuilder, mcs).__new__( | ||
mcs, cls_name, bases, class_dict | ||
) | ||
if not inspect.isabstract(builder_cls): | ||
_STATE_BUILDER_REGISTRY[ | ||
name.replace("_state_builder", "") | ||
] = builder_cls | ||
return builder_cls | ||
|
||
|
||
def create_state_builder(builder_key: str): | ||
if builder_key not in _STATE_BUILDER_REGISTRY: | ||
raise ValueError( | ||
f"StateBuilder {builder_key} not found. Known builders: {', '.join(_STATE_BUILDER_REGISTRY.keys())}" | ||
) | ||
return _STATE_BUILDER_REGISTRY[builder_key]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import re | ||
import time | ||
from contextlib import contextmanager | ||
from io import BytesIO | ||
from typing import BinaryIO, Union | ||
|
||
Readable = Union[bytes, BinaryIO] | ||
|
||
|
||
def to_file_object(s: Readable) -> BinaryIO: | ||
if isinstance(s, bytes): | ||
return BytesIO(s) | ||
return s | ||
|
||
|
||
@contextmanager | ||
def performance_logging(description: str, counter: int = None, logger=None): | ||
start = time.time() | ||
try: | ||
yield | ||
finally: | ||
took = (time.time() - start) * 1000 | ||
extra = "" | ||
if counter is not None: | ||
extra = f" ({int(counter / took * 1000)}items/sec)" | ||
|
||
unit = "ms" | ||
if took < 0.1: | ||
took *= 1000 | ||
unit = "us" | ||
|
||
msg = f"{description} took: {took:.2f}{unit} {extra}" | ||
if logger: | ||
logger.info(msg) | ||
else: | ||
print(msg) | ||
|
||
|
||
_first_cap_re = re.compile("(.)([A-Z][a-z0-9]+)") | ||
_all_cap_re = re.compile("([a-z0-9])([A-Z])") | ||
|
||
|
||
def camelcase_to_snakecase(name): | ||
"""Convert camel-case string to snake-case.""" | ||
s1 = _first_cap_re.sub(r"\1_\2", name) | ||
return _all_cap_re.sub(r"\1_\2", s1).lower() |