|
| 1 | +import sys |
1 | 2 | from abc import ABC, abstractmethod
|
2 | 3 | from dataclasses import dataclass, field, replace
|
3 | 4 | from enum import Enum, Flag
|
4 |
| -from typing import Dict, List, Optional, Callable, Union, Any, TypeVar, Generic |
| 5 | +from typing import ( |
| 6 | + Dict, |
| 7 | + List, |
| 8 | + Optional, |
| 9 | + Callable, |
| 10 | + Union, |
| 11 | + Any, |
| 12 | + TypeVar, |
| 13 | + Generic, |
| 14 | + NewType, |
| 15 | + overload, |
| 16 | + Iterable, |
| 17 | +) |
| 18 | + |
| 19 | +if sys.version_info >= (3, 8): |
| 20 | + from typing import Literal |
| 21 | +else: |
| 22 | + from typing_extensions import Literal |
| 23 | + |
5 | 24 |
|
6 | 25 | from .pitch import PitchDimensions, Point, Dimension
|
7 | 26 | from .formation import FormationType
|
8 | 27 | from ...exceptions import (
|
9 | 28 | OrientationError,
|
10 |
| - OrphanedRecordError, |
11 | 29 | InvalidFilterError,
|
| 30 | + KloppyParameterError, |
12 | 31 | )
|
13 | 32 |
|
14 | 33 |
|
@@ -777,6 +796,8 @@ class Dataset(ABC, Generic[T]):
|
777 | 796 |
|
778 | 797 | """
|
779 | 798 |
|
| 799 | + Column = NewType("Column", Union[str, Callable[[T], Any]]) |
| 800 | + |
780 | 801 | records: List[T]
|
781 | 802 | metadata: Metadata
|
782 | 803 |
|
@@ -882,3 +903,65 @@ def get_record_by_id(self, record_id: Union[int, str]) -> Optional[T]:
|
882 | 903 | for record in self.records:
|
883 | 904 | if record.record_id == record_id:
|
884 | 905 | return record
|
| 906 | + |
| 907 | + @overload |
| 908 | + def to_records( |
| 909 | + self, |
| 910 | + *columns: "Column", |
| 911 | + as_list: Literal[True] = True, |
| 912 | + **named_columns: "Column", |
| 913 | + ) -> List[Dict[str, Any]]: |
| 914 | + ... |
| 915 | + |
| 916 | + @overload |
| 917 | + def to_records( |
| 918 | + self, |
| 919 | + *columns: "Column", |
| 920 | + as_list: Literal[False] = False, |
| 921 | + **named_columns: "Column", |
| 922 | + ) -> Iterable[Dict[str, Any]]: |
| 923 | + ... |
| 924 | + |
| 925 | + def to_records( |
| 926 | + self, |
| 927 | + *columns: "Column", |
| 928 | + as_list: bool = True, |
| 929 | + **named_columns: "Column", |
| 930 | + ) -> Union[List[Dict[str, Any]], Iterable[Dict[str, Any]]]: |
| 931 | + |
| 932 | + from ..services.transformers.data_record import get_transformer_cls |
| 933 | + |
| 934 | + transformer = get_transformer_cls(self.dataset_type)( |
| 935 | + *columns, **named_columns |
| 936 | + ) |
| 937 | + iterator = map(transformer, self.records) |
| 938 | + if as_list: |
| 939 | + return list(iterator) |
| 940 | + else: |
| 941 | + return iterator |
| 942 | + |
| 943 | + def to_df( |
| 944 | + self, |
| 945 | + *columns: "Column", |
| 946 | + engine: Optional[Union[Literal["polars"], Literal["pandas"]]] = None, |
| 947 | + **named_columns: "Column", |
| 948 | + ): |
| 949 | + from kloppy.config import get_config |
| 950 | + |
| 951 | + if not engine: |
| 952 | + engine = get_config("dataframe.engine") |
| 953 | + |
| 954 | + if engine == "pandas": |
| 955 | + from pandas import DataFrame |
| 956 | + |
| 957 | + return DataFrame.from_records( |
| 958 | + self.to_records(*columns, **named_columns, as_list=False) |
| 959 | + ) |
| 960 | + elif engine == "polars": |
| 961 | + from polars import DataFrame |
| 962 | + |
| 963 | + return DataFrame( |
| 964 | + self.to_records(*columns, **named_columns, as_list=False) |
| 965 | + ) |
| 966 | + else: |
| 967 | + raise KloppyParameterError(f"Engine {engine} is not valid") |
0 commit comments