Skip to content

Commit

Permalink
Merge pull request #50 from eujern/add-dataset-type
Browse files Browse the repository at this point in the history
Add DatasetType.
  • Loading branch information
koenvo authored Aug 7, 2020
2 parents a64c55a + 39cfa57 commit cb954ea
Show file tree
Hide file tree
Showing 10 changed files with 40 additions and 10 deletions.
12 changes: 11 additions & 1 deletion kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum, Flag
from typing import Optional, List, Dict
Expand Down Expand Up @@ -224,7 +224,17 @@ class Metadata:
provider: Provider


class DatasetType(Enum):
TRACKING = "TRACKING"
EVENT = "EVENT"


@dataclass
class Dataset(ABC):
records: List[DataRecord]
metadata: Metadata

@property
@abstractmethod
def dataset_type(self) -> DatasetType:
raise NotImplementedError
8 changes: 6 additions & 2 deletions kloppy/domain/models/event.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Metrica Documentation https://github.com/metrica-sports/sample-data/blob/master/documentation/events-definitions.pdf
from abc import ABC, abstractmethod, abstractproperty, ABCMeta
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import List, Union, Dict

from .pitch import Point
from kloppy.domain.models.common import DatasetType

from .common import DataRecord, Dataset, Team, Player
from .pitch import Point


class ResultType(Enum):
Expand Down Expand Up @@ -139,6 +141,8 @@ class EventDataset(Dataset):
Union[GenericEvent, ShotEvent, PassEvent, TakeOnEvent, CarryEvent]
]

dataset_type: DatasetType = DatasetType.EVENT

@property
def events(self):
return self.records
Expand Down
6 changes: 5 additions & 1 deletion kloppy/domain/models/tracking.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from dataclasses import dataclass
from typing import List, Dict

from .common import Dataset, DataRecord, Ground, Player
from kloppy.domain.models.common import DatasetType

from .common import Dataset, DataRecord, Player
from .pitch import Point


Expand All @@ -16,6 +18,8 @@ class Frame(DataRecord):
class TrackingDataset(Dataset):
records: List[Frame]

dataset_type: DatasetType = DatasetType.TRACKING

@property
def frames(self):
return self.records
Expand Down
6 changes: 3 additions & 3 deletions kloppy/domain/services/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ def transform_event(self, event: EventType) -> EventType:

return replace(event, **position_changes)

DatasetType = TypeVar("DatasetType")
DatasetT = TypeVar("DatasetT")

@classmethod
def transform_dataset(
cls,
dataset: DatasetType,
dataset: DatasetT,
to_pitch_dimensions: PitchDimensions = None,
to_orientation: Orientation = None,
) -> DatasetType:
) -> DatasetT:
if not to_pitch_dimensions and not to_orientation:
return dataset
elif not to_orientation:
Expand Down
6 changes: 3 additions & 3 deletions kloppy/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,12 @@ def load_metrica_json_event_data(
)


DatasetType = TypeVar("DatasetType")
DatasetT = TypeVar("DatasetT")


def transform(
dataset: DatasetType, to_orientation=None, to_pitch_dimensions=None
) -> DatasetType:
dataset: DatasetT, to_orientation=None, to_pitch_dimensions=None
) -> DatasetT:
if to_orientation and isinstance(to_orientation, str):
to_orientation = Orientation[to_orientation]
if to_pitch_dimensions and (
Expand Down
3 changes: 3 additions & 0 deletions kloppy/tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Ground,
Player,
)
from kloppy.domain.models.common import DatasetType


class TestHelpers:
Expand All @@ -40,6 +41,7 @@ def test_load_metrica_tracking_data(self):
assert len(dataset.records) == 6
assert len(dataset.metadata.periods) == 2
assert dataset.metadata.provider == Provider.METRICA
assert dataset.dataset_type == DatasetType.TRACKING

def test_load_tracab_tracking_data(self):
base_dir = os.path.dirname(__file__)
Expand All @@ -50,6 +52,7 @@ def test_load_tracab_tracking_data(self):
assert len(dataset.records) == 5 # only alive=True
assert len(dataset.metadata.periods) == 2
assert dataset.metadata.provider == Provider.TRACAB
assert dataset.dataset_type == DatasetType.TRACKING

def _get_tracking_dataset(self):
home_team = Team(team_id="home", name="home", ground=Ground.HOME)
Expand Down
3 changes: 3 additions & 0 deletions kloppy/tests/test_metrica.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Orientation,
Point,
)
from kloppy.domain.models.common import DatasetType


class TestMetricaTracking:
Expand All @@ -28,6 +29,7 @@ def test_correct_deserialization(self):
}
)
assert dataset.metadata.provider == Provider.METRICA
assert dataset.dataset_type == DatasetType.TRACKING
assert len(dataset.records) == 6
assert len(dataset.metadata.periods) == 2
assert dataset.metadata.orientation == Orientation.FIXED_HOME_AWAY
Expand Down Expand Up @@ -88,6 +90,7 @@ def test_correct_deserialization(self):
)

assert dataset.metadata.provider == Provider.METRICA
assert dataset.dataset_type == DatasetType.EVENT
assert len(dataset.events) == 3620
assert len(dataset.metadata.periods) == 2
assert dataset.metadata.orientation is None
Expand Down
2 changes: 2 additions & 0 deletions kloppy/tests/test_opta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Position,
Ground,
)
from kloppy.domain.models.common import DatasetType


class TestOpta:
Expand All @@ -26,6 +27,7 @@ def test_correct_deserialization(self):
inputs={"f24_data": f24_data, "f7_data": f7_data}
)
assert dataset.metadata.provider == Provider.OPTA
assert dataset.dataset_type == DatasetType.EVENT
assert len(dataset.events) == 17
assert len(dataset.metadata.periods) == 2
assert dataset.events[10].ball_owning_team == dataset.metadata.teams[1]
Expand Down
2 changes: 2 additions & 0 deletions kloppy/tests/test_statsbomb.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Position,
Provider,
)
from kloppy.domain.models.common import DatasetType


class TestStatsbomb:
Expand All @@ -31,6 +32,7 @@ def test_correct_deserialization(self):
)

assert dataset.metadata.provider == Provider.STATSBOMB
assert dataset.dataset_type == DatasetType.EVENT
assert len(dataset.events) == 4002
assert len(dataset.metadata.periods) == 2
assert (
Expand Down
2 changes: 2 additions & 0 deletions kloppy/tests/test_tracab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Team,
Ground,
)
from kloppy.domain.models.common import DatasetType


class TestTracabTracking:
Expand All @@ -29,6 +30,7 @@ def test_correct_deserialization(self):
)

assert dataset.metadata.provider == Provider.TRACAB
assert dataset.dataset_type == DatasetType.TRACKING
assert len(dataset.records) == 6
assert len(dataset.metadata.periods) == 2
assert dataset.metadata.orientation == Orientation.FIXED_HOME_AWAY
Expand Down

0 comments on commit cb954ea

Please sign in to comment.