From 676014ebfcd6831e001dc1e7f6c848bb8a54f6b6 Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Sun, 26 Jan 2025 17:26:23 +0100 Subject: [PATCH] minor --- unravel/__init__.py | 2 +- unravel/american_football/graphs/dataset.py | 2 +- .../graphs/graph_converter.py | 42 +++++++++++------ unravel/soccer/graphs/graph_converter.py | 1 - unravel/soccer/graphs/graph_converter_pl.py | 46 ++++++++++++------- .../utils/objects/default_graph_converter.py | 3 ++ 6 files changed, 61 insertions(+), 35 deletions(-) diff --git a/unravel/__init__.py b/unravel/__init__.py index b0cda09..b235f04 100644 --- a/unravel/__init__.py +++ b/unravel/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.2.0" +__version__ = "0.3.0" from .soccer import * from .american_football import * diff --git a/unravel/american_football/graphs/dataset.py b/unravel/american_football/graphs/dataset.py index 4b8ccff..5273b4a 100644 --- a/unravel/american_football/graphs/dataset.py +++ b/unravel/american_football/graphs/dataset.py @@ -125,7 +125,7 @@ def load(self): separator=",", encoding="utf8", null_values=["NA", "NULL", ""], - dtypes={"birthDate": pl.Date}, + schema_overrides={"birthDate": pl.Date}, ignore_errors=True, ) if "position" in players.columns: diff --git a/unravel/american_football/graphs/graph_converter.py b/unravel/american_football/graphs/graph_converter.py index 07b01e1..5f899b7 100644 --- a/unravel/american_football/graphs/graph_converter.py +++ b/unravel/american_football/graphs/graph_converter.py @@ -48,8 +48,14 @@ def __init__( if not isinstance(dataset, BigDataBowlDataset): raise Exception("'dataset' should be an instance of BigDataBowlDataset") - self.label_col = dataset._label_column - self.graph_id_col = dataset._graph_id_column + self.label_column: str = ( + self.label_col if self.label_col is not None else dataset._label_column + ) + self.graph_id_column: str = ( + self.graph_id_col + if self.graph_id_col is not None + else dataset._graph_id_column + ) self.dataset: pl.DataFrame = dataset.data self.pitch_dimensions: AmericanFootballPitchDimensions = ( @@ -64,21 +70,21 @@ def __init__( def _sport_specific_checks(self): - if not isinstance(self.label_col, str): + if not isinstance(self.label_column, str): raise Exception("'label_col' should be of type string (str)") - if not isinstance(self.graph_id_col, str): + if not isinstance(self.graph_id_column, str): raise Exception("'graph_id_col' should be of type string (str)") if not isinstance(self.chunk_size, int): raise Exception("chunk_size should be of type integer (int)") - if not self.label_col in self.dataset.columns and not self.prediction: + if not self.label_column in self.dataset.columns and not self.prediction: raise Exception( "Please specify a 'label_col' and add that column to your 'dataset' or set 'prediction=True' if you want to use the converted dataset to make predictions on." ) - if not self.graph_id_col in self.dataset.columns: + if not self.graph_id_column in self.dataset.columns: raise Exception( "Please specify a 'graph_id_col' and add that column to your 'dataset' ..." ) @@ -121,20 +127,20 @@ def __exprs_variables(self): Column.POSSESSION_TEAM, Column.HEIGHT_CM, Column.WEIGHT_KG, - self.graph_id_col, - self.label_col, + self.graph_id_column, + self.label_column, ] def __compute(self, args: List[pl.Series]) -> dict: d = {col: args[i].to_numpy() for i, col in enumerate(self.__exprs_variables)} - if not np.all(d[self.graph_id_col] == d[self.graph_id_col][0]): + if not np.all(d[self.graph_id_column] == d[self.graph_id_column][0]): raise Exception( "GraphId selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..." ) if not self.prediction and not np.all( - d[self.label_col] == d[self.label_col][0] + d[self.label_column] == d[self.label_column][0] ): raise Exception( """Label selection contains multiple different values for a single selection (group by) of game_id and frame_id, @@ -186,8 +192,8 @@ def __compute(self, args: List[pl.Series]) -> dict: "x_shape_1": node_features.shape[1], "a_shape_0": adjacency_matrix.shape[0], "a_shape_1": adjacency_matrix.shape[1], - self.graph_id_col: d[self.graph_id_col][0], - self.label_col: d[self.label_col][0], + self.graph_id_column: d[self.graph_id_column][0], + self.label_column: d[self.label_column][0], } def _convert(self): @@ -203,7 +209,13 @@ def _convert(self): [ *[ pl.col("result_dict").struct.field(f).alias(f) - for f in ["a", "e", "x", self.graph_id_col, self.label_col] + for f in [ + "a", + "e", + "x", + self.graph_id_column, + self.label_column, + ] ], *[ pl.col("result_dict") @@ -232,8 +244,8 @@ def process_chunk(chunk: pl.DataFrame) -> List[dict]: "e": reshape_array( chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i] ), - "y": np.asarray([chunk[self.label_col][i]]), - "id": chunk[self.graph_id_col][i], + "y": np.asarray([chunk[self.label_column][i]]), + "id": chunk[self.graph_id_column][i], } for i in range(len(chunk)) ] diff --git a/unravel/soccer/graphs/graph_converter.py b/unravel/soccer/graphs/graph_converter.py index d57cb5c..1262598 100644 --- a/unravel/soccer/graphs/graph_converter.py +++ b/unravel/soccer/graphs/graph_converter.py @@ -71,7 +71,6 @@ class SoccerGraphConverter(DefaultGraphConverter): dataset: TrackingDataset = None labels: dict = None - labels: dict = None graph_id: Union[str, int, dict] = None graph_ids: dict = None diff --git a/unravel/soccer/graphs/graph_converter_pl.py b/unravel/soccer/graphs/graph_converter_pl.py index 50d07b6..5beeb82 100644 --- a/unravel/soccer/graphs/graph_converter_pl.py +++ b/unravel/soccer/graphs/graph_converter_pl.py @@ -45,8 +45,14 @@ class SoccerGraphConverterPolars(DefaultGraphConverter): def __post_init__(self): self.pitch_dimensions: MetricPitchDimensions = self.dataset.pitch_dimensions - self.label_col = self.dataset._label_column - self.graph_id_col = self.dataset._graph_id_column + self.label_column: str = ( + self.label_col if self.label_col is not None else self.dataset._label_column + ) + self.graph_id_column: str = ( + self.graph_id_col + if self.graph_id_col is not None + else self.dataset._graph_id_column + ) self.dataset = self.dataset.data @@ -76,8 +82,8 @@ def _apply_padding(self) -> pl.DataFrame: Column.TIMESTAMP, Column.BALL_STATE, Column.POSITION_NAME, - self.label_col, - self.graph_id_col, + self.label_column, + self.graph_id_column, ] empty_columns = [ Column.OBJECT_ID, @@ -240,21 +246,21 @@ def _apply_settings(self): ) def _sport_specific_checks(self): - if not isinstance(self.label_col, str): + if not isinstance(self.label_column, str): raise Exception("'label_col' should be of type string (str)") - if not isinstance(self.graph_id_col, str): + if not isinstance(self.graph_id_column, str): raise Exception("'graph_id_col' should be of type string (str)") if not isinstance(self.chunk_size, int): raise Exception("chunk_size should be of type integer (int)") - if not self.label_col in self.dataset.columns and not self.prediction: + if not self.label_column in self.dataset.columns and not self.prediction: raise Exception( "Please specify a 'label_col' and add that column to your 'dataset' or set 'prediction=True' if you want to use the converted dataset to make predictions on." ) - if not self.graph_id_col in self.dataset.columns: + if not self.graph_id_column in self.dataset.columns: raise Exception( "Please specify a 'graph_id_col' and add that column to your 'dataset' ..." ) @@ -284,20 +290,20 @@ def __exprs_variables(self): Column.POSITION_NAME, Column.BALL_OWNING_TEAM_ID, Column.IS_BALL_CARRIER, - self.graph_id_col, - self.label_col, + self.graph_id_column, + self.label_column, ] def __compute(self, args: List[pl.Series]) -> dict: d = {col: args[i].to_numpy() for i, col in enumerate(self.__exprs_variables)} - if not np.all(d[self.graph_id_col] == d[self.graph_id_col][0]): + if not np.all(d[self.graph_id_column] == d[self.graph_id_column][0]): raise Exception( "GraphId selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..." ) if not self.prediction and not np.all( - d[self.label_col] == d[self.label_col][0] + d[self.label_column] == d[self.label_column][0] ): raise Exception( """Label selection contains multiple different values for a single selection (group by) of game_id and frame_id, @@ -356,8 +362,8 @@ def __compute(self, args: List[pl.Series]) -> dict: "x_shape_1": node_features.shape[1], "a_shape_0": adjacency_matrix.shape[0], "a_shape_1": adjacency_matrix.shape[1], - self.graph_id_col: d[self.graph_id_col][0], - self.label_col: d[self.label_col][0], + self.graph_id_column: d[self.graph_id_column][0], + self.label_column: d[self.label_column][0], } def _convert(self): @@ -373,7 +379,13 @@ def _convert(self): [ *[ pl.col("result_dict").struct.field(f).alias(f) - for f in ["a", "e", "x", self.graph_id_col, self.label_col] + for f in [ + "a", + "e", + "x", + self.graph_id_column, + self.label_column, + ] ], *[ pl.col("result_dict") @@ -402,8 +414,8 @@ def process_chunk(chunk: pl.DataFrame) -> List[dict]: "e": reshape_array( chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i] ), - "y": np.asarray([chunk[self.label_col][i]]), - "id": chunk[self.graph_id_col][i], + "y": np.asarray([chunk[self.label_column][i]]), + "id": chunk[self.graph_id_column][i], } for i in range(len(chunk)) ] diff --git a/unravel/utils/objects/default_graph_converter.py b/unravel/utils/objects/default_graph_converter.py index 79bc16e..dfc9133 100644 --- a/unravel/utils/objects/default_graph_converter.py +++ b/unravel/utils/objects/default_graph_converter.py @@ -87,6 +87,9 @@ class DefaultGraphConverter: pad: bool = False verbose: bool = False + label_col: str = None + graph_id_col: str = None + graph_frames: dict = field(init=False, repr=False, default=None) settings: DefaultGraphSettings = field( init=False, repr=False, default_factory=DefaultGraphSettings