Skip to content

Commit 65c88dd

Browse files
authored
Merge pull request #172 from PySport/feature/add-polars
Add polars support
2 parents abd9af2 + cd3e067 commit 65c88dd

File tree

12 files changed

+466
-187
lines changed

12 files changed

+466
-187
lines changed

kloppy/config.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from contextlib import contextmanager
33
from copy import copy
4-
from typing import Any, Optional
4+
from typing import Any, Optional, Union
55

66
from kloppy.domain import EventFactory
77

@@ -27,6 +27,9 @@
2727
"event_factory": Optional[EventFactory],
2828
"adapters.http.basic_authentication": Optional[str],
2929
"adapters.s3.s3fs": Optional[Any],
30+
"dataframe.engine": Optional[
31+
Union[Literal["pandas"], Literal["polars"]]
32+
],
3033
},
3134
)
3235

@@ -37,6 +40,7 @@
3740
"event_factory",
3841
"adapters.http.basic_authentication",
3942
"adapters.s3.s3fs",
43+
"dataframe.engine",
4044
]
4145

4246

@@ -50,6 +54,7 @@ class PartialConfig(Config, total=False):
5054
"event_factory": None,
5155
"adapters.http.basic_authentication": None,
5256
"adapters.s3.s3fs": None,
57+
"dataframe.engine": "pandas",
5358
}
5459

5560
config = copy(_default_config)

kloppy/domain/models/code.py

+11-12
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from kloppy.domain.models.common import DatasetType
55

66
from .common import Dataset, DataRecord
7-
from ...utils import docstring_inherit_attributes
7+
from kloppy.utils import (
8+
docstring_inherit_attributes,
9+
deprecated,
10+
)
811

912

1013
@dataclass
@@ -45,6 +48,9 @@ class CodeDataset(Dataset[Code]):
4548
def codes(self):
4649
return self.records
4750

51+
@deprecated(
52+
"to_pandas will be removed in the future. Please use to_df instead."
53+
)
4854
def to_pandas(
4955
self,
5056
record_converter: Callable[[Code], Dict] = None,
@@ -61,18 +67,11 @@ def to_pandas(
6167
)
6268

6369
if not record_converter:
70+
from ..services.transformers.attribute import (
71+
DefaultCodeTransformer,
72+
)
6473

65-
def record_converter(code: Code) -> Dict:
66-
row = dict(
67-
code_id=code.code_id,
68-
period_id=code.period.id if code.period else None,
69-
timestamp=code.timestamp,
70-
end_timestamp=code.end_timestamp,
71-
code=code.code,
72-
)
73-
row.update(code.labels)
74-
75-
return row
74+
record_converter = DefaultCodeTransformer()
7675

7776
def generic_record_converter(code: Code):
7877
row = record_converter(code)

kloppy/domain/models/common.py

+85-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,33 @@
1+
import sys
12
from abc import ABC, abstractmethod
23
from dataclasses import dataclass, field, replace
34
from enum import Enum, Flag
4-
from typing import Dict, List, Optional, Callable, Union, Any, TypeVar, Generic
5+
from typing import (
6+
Dict,
7+
List,
8+
Optional,
9+
Callable,
10+
Union,
11+
Any,
12+
TypeVar,
13+
Generic,
14+
NewType,
15+
overload,
16+
Iterable,
17+
)
18+
19+
if sys.version_info >= (3, 8):
20+
from typing import Literal
21+
else:
22+
from typing_extensions import Literal
23+
524

625
from .pitch import PitchDimensions, Point, Dimension
726
from .formation import FormationType
827
from ...exceptions import (
928
OrientationError,
10-
OrphanedRecordError,
1129
InvalidFilterError,
30+
KloppyParameterError,
1231
)
1332

1433

@@ -777,6 +796,8 @@ class Dataset(ABC, Generic[T]):
777796
778797
"""
779798

799+
Column = NewType("Column", Union[str, Callable[[T], Any]])
800+
780801
records: List[T]
781802
metadata: Metadata
782803

@@ -882,3 +903,65 @@ def get_record_by_id(self, record_id: Union[int, str]) -> Optional[T]:
882903
for record in self.records:
883904
if record.record_id == record_id:
884905
return record
906+
907+
@overload
908+
def to_records(
909+
self,
910+
*columns: "Column",
911+
as_list: Literal[True] = True,
912+
**named_columns: "Column",
913+
) -> List[Dict[str, Any]]:
914+
...
915+
916+
@overload
917+
def to_records(
918+
self,
919+
*columns: "Column",
920+
as_list: Literal[False] = False,
921+
**named_columns: "Column",
922+
) -> Iterable[Dict[str, Any]]:
923+
...
924+
925+
def to_records(
926+
self,
927+
*columns: "Column",
928+
as_list: bool = True,
929+
**named_columns: "Column",
930+
) -> Union[List[Dict[str, Any]], Iterable[Dict[str, Any]]]:
931+
932+
from ..services.transformers.data_record import get_transformer_cls
933+
934+
transformer = get_transformer_cls(self.dataset_type)(
935+
*columns, **named_columns
936+
)
937+
iterator = map(transformer, self.records)
938+
if as_list:
939+
return list(iterator)
940+
else:
941+
return iterator
942+
943+
def to_df(
944+
self,
945+
*columns: "Column",
946+
engine: Optional[Union[Literal["polars"], Literal["pandas"]]] = None,
947+
**named_columns: "Column",
948+
):
949+
from kloppy.config import get_config
950+
951+
if not engine:
952+
engine = get_config("dataframe.engine")
953+
954+
if engine == "pandas":
955+
from pandas import DataFrame
956+
957+
return DataFrame.from_records(
958+
self.to_records(*columns, **named_columns, as_list=False)
959+
)
960+
elif engine == "polars":
961+
from polars import DataFrame
962+
963+
return DataFrame(
964+
self.to_records(*columns, **named_columns, as_list=False)
965+
)
966+
else:
967+
raise KloppyParameterError(f"Engine {engine} is not valid")

kloppy/domain/models/event.py

+8-52
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import sys
2-
31
from abc import ABC, abstractmethod
42
from dataclasses import dataclass
53
from enum import Enum
@@ -11,21 +9,15 @@
119
Any,
1210
Callable,
1311
Optional,
14-
Iterable,
15-
overload,
1612
TYPE_CHECKING,
1713
)
1814

19-
if sys.version_info >= (3, 8):
20-
from typing import Literal
21-
else:
22-
from typing_extensions import Literal
23-
2415
from kloppy.domain.models.common import DatasetType
2516
from kloppy.utils import (
2617
camelcase_to_snakecase,
2718
removes_suffix,
2819
docstring_inherit_attributes,
20+
deprecated,
2921
)
3022

3123
from .common import DataRecord, Dataset, Player, Team
@@ -35,7 +27,6 @@
3527
from ...exceptions import OrphanedRecordError, InvalidFilterError
3628

3729
if TYPE_CHECKING:
38-
from ..services.transformers.event import Column
3930
from .tracking import Frame
4031

4132

@@ -819,6 +810,9 @@ def add_state(self, *builder_keys):
819810

820811
return add_state(self, *builder_keys)
821812

813+
@deprecated(
814+
"to_pandas will be removed in the future. Please use to_df instead."
815+
)
822816
def to_pandas(
823817
self,
824818
record_converter: Callable[[Event], Dict] = None,
@@ -835,9 +829,11 @@ def to_pandas(
835829
)
836830

837831
if not record_converter:
838-
from ..services.transformers.attribute import DefaultTransformer
832+
from ..services.transformers.attribute import (
833+
DefaultEventTransformer,
834+
)
839835

840-
record_converter = DefaultTransformer()
836+
record_converter = DefaultEventTransformer()
841837

842838
def generic_record_converter(event: Event):
843839
row = record_converter(event)
@@ -854,46 +850,6 @@ def generic_record_converter(event: Event):
854850
map(generic_record_converter, self.records)
855851
)
856852

857-
@overload
858-
def to_records(
859-
self,
860-
*columns: "Column",
861-
as_list: Literal[True] = True,
862-
**named_columns: "Column",
863-
) -> List[Dict[str, Any]]:
864-
...
865-
866-
@overload
867-
def to_records(
868-
self,
869-
*columns: "Column",
870-
as_list: Literal[False] = False,
871-
**named_columns: "Column",
872-
) -> Iterable[Dict[str, Any]]:
873-
...
874-
875-
def to_records(
876-
self,
877-
*columns: "Column",
878-
as_list: bool = True,
879-
**named_columns: "Column",
880-
) -> Union[List[Dict[str, Any]], Iterable[Dict[str, Any]]]:
881-
from ..services.transformers.event import EventToRecordTransformer
882-
883-
transformer = EventToRecordTransformer(*columns, **named_columns)
884-
iterator = map(transformer, self.events)
885-
if as_list:
886-
return list(iterator)
887-
else:
888-
return iterator
889-
890-
def to_df(self, *columns: "Column", **named_columns: "Column"):
891-
from pandas import DataFrame
892-
893-
return DataFrame.from_records(
894-
self.to_records(*columns, **named_columns, as_list=False)
895-
)
896-
897853

898854
__all__ = [
899855
"EnumQualifier",

kloppy/domain/models/tracking.py

+10-52
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
from .common import Dataset, DataRecord, Player
77
from .pitch import Point, Point3D
8+
from kloppy.utils import (
9+
deprecated,
10+
)
811

912

1013
@dataclass
@@ -48,6 +51,9 @@ def frames(self):
4851
def frame_rate(self):
4952
return self.metadata.frame_rate
5053

54+
@deprecated(
55+
"to_pandas will be removed in the future. Please use to_df instead."
56+
)
5157
def to_pandas(
5258
self,
5359
record_converter: Callable[[Frame], Dict] = None,
@@ -64,59 +70,11 @@ def to_pandas(
6470
)
6571

6672
if not record_converter:
73+
from ..services.transformers.attribute import (
74+
DefaultFrameTransformer,
75+
)
6776

68-
def record_converter(frame: Frame) -> Dict:
69-
row = dict(
70-
period_id=frame.period.id if frame.period else None,
71-
timestamp=frame.timestamp,
72-
ball_state=frame.ball_state.value
73-
if frame.ball_state
74-
else None,
75-
ball_owning_team_id=frame.ball_owning_team.team_id
76-
if frame.ball_owning_team
77-
else None,
78-
ball_x=frame.ball_coordinates.x
79-
if frame.ball_coordinates
80-
else None,
81-
ball_y=frame.ball_coordinates.y
82-
if frame.ball_coordinates
83-
else None,
84-
ball_z=getattr(frame.ball_coordinates, "z", None)
85-
if frame.ball_coordinates
86-
else None,
87-
)
88-
for player, player_data in frame.players_data.items():
89-
90-
row.update(
91-
{
92-
f"{player.player_id}_x": player_data.coordinates.x
93-
if player_data.coordinates
94-
else None,
95-
f"{player.player_id}_y": player_data.coordinates.y
96-
if player_data.coordinates
97-
else None,
98-
f"{player.player_id}_d": player_data.distance,
99-
f"{player.player_id}_s": player_data.speed,
100-
}
101-
)
102-
103-
if player_data.other_data:
104-
for name, value in player_data.other_data.items():
105-
row.update(
106-
{
107-
f"{player.player_id}_{name}": value,
108-
}
109-
)
110-
111-
if frame.other_data:
112-
for name, value in frame.other_data.items():
113-
row.update(
114-
{
115-
name: value,
116-
}
117-
)
118-
119-
return row
77+
record_converter = DefaultFrameTransformer()
12078

12179
def generic_record_converter(frame: Frame):
12280
row = record_converter(frame)

0 commit comments

Comments
 (0)