diff --git a/tests/test_bigdb.py b/tests/test_bigdb.py index eb2caca..98de174 100644 --- a/tests/test_bigdb.py +++ b/tests/test_bigdb.py @@ -130,8 +130,8 @@ def node_feature_values(self): "is_possession_team": 0.0, "is_qb": 0.0, "is_ball": 0.0, - "weight_normed": 0.21941714285714287, - "height_normed": 0.4722666666666665, + "weight_normed": 0.21428571428571427, + "height_normed": 0.5333333333333333, } return item_idx, assert_values diff --git a/unravel/american_football/graphs/dataset.py b/unravel/american_football/graphs/dataset.py index 5273b4a..29c4fd6 100644 --- a/unravel/american_football/graphs/dataset.py +++ b/unravel/american_football/graphs/dataset.py @@ -49,19 +49,23 @@ def __init__( tracking_file_path: str, players_file_path: str, plays_file_path: str, + sample_rate: float = None, **kwargs, ): super().__init__(**kwargs) self.tracking_file_path = tracking_file_path self.players_file_path = players_file_path self.plays_file_path = plays_file_path + self.sample_rate = 1 if sample_rate is None else sample_rate self.pitch_dimensions = AmericanFootballPitchDimensions() def load(self): pitch_length = self.pitch_dimensions.pitch_length pitch_width = self.pitch_dimensions.pitch_width - df = pl.read_csv( + sample = 1.0 / self.sample_rate + + df = pl.scan_csv( self.tracking_file_path, separator=",", encoding="utf8", @@ -71,7 +75,7 @@ def load(self): play_direction = "left" - if "club" in df.columns: + if "club" in df.collect_schema().names(): df = df.with_columns(pl.col(Column.CLUB).alias(Column.TEAM)) df = df.drop(Column.CLUB) @@ -119,7 +123,14 @@ def load(self): .alias(Column.OBJECT_ID), ] ) - ) + .with_columns( + [ + pl.lit(play_direction).alias("playDirection"), + ] + ) + .filter((pl.col(Column.FRAME_ID) % sample) == 0) + ).collect() + players = pl.read_csv( self.players_file_path, separator=",", @@ -206,14 +217,25 @@ def _convert_weight_height_to_metric(df: pl.DataFrame): .alias("inches"), # Extract inches and cast to float ] ) - df = df.with_columns( - [ - (pl.col("feet") * 30.48 + pl.col("inches") * 2.54).alias( - Column.HEIGHT_CM - ), - (pl.col("weight") * 0.453592).alias( - Column.WEIGHT_KG - ), # Convert pounds to kilograms - ] - ).drop(["height", "feet", "inches", "weight"]) + # Convert height and weight to centimeters and kilograms + # Round them to 0.1 to make sure we don't leak any player specific info + df = ( + df.with_columns( + [ + ((pl.col("feet") * 30.48 + pl.col("inches") * 2.54) / 10) + .round(0) + .alias(Column.HEIGHT_CM), + ((pl.col("weight") * 0.453592) / 10) + .round(0) + .alias(Column.WEIGHT_KG), + ] + ) + .with_columns( + [ + (pl.col(Column.HEIGHT_CM) * 10).alias(Column.HEIGHT_CM), + (pl.col(Column.WEIGHT_KG) * 10).alias(Column.WEIGHT_KG), + ] + ) + .drop(["height", "feet", "inches", "weight"]) + ) return df diff --git a/unravel/american_football/graphs/features/node_features.py b/unravel/american_football/graphs/features/node_features.py index dbf74f2..59737a3 100644 --- a/unravel/american_football/graphs/features/node_features.py +++ b/unravel/american_football/graphs/features/node_features.py @@ -27,6 +27,7 @@ def compute_node_features( possession_team, height, weight, + graph_features, settings, ): ball_id = Constant.BALL @@ -141,4 +142,9 @@ def compute_node_features( ) ) + if graph_features is not None: + eg = np.ones((X.shape[0], graph_features.shape[0])) * 0.0 + eg[ball_index] = graph_features + X = np.hstack((X, eg)) + return X diff --git a/unravel/american_football/graphs/graph_converter.py b/unravel/american_football/graphs/graph_converter.py index 5f899b7..d6df259 100644 --- a/unravel/american_football/graphs/graph_converter.py +++ b/unravel/american_football/graphs/graph_converter.py @@ -1,9 +1,11 @@ +from warnings import warn + from dataclasses import dataclass import polars as pl import numpy as np -from typing import List +from typing import List, Optional from spektral.data import Graph @@ -19,7 +21,7 @@ compute_adjacency_matrix, ) -from ...utils import DefaultGraphConverter, reshape_array, make_sparse +from ...utils import * @dataclass(repr=True) @@ -28,12 +30,12 @@ class AmericanFootballGraphConverter(DefaultGraphConverter): Converts our dataset TrackingDataset into an internal structure Attributes: - dataset (TrackingDataset): Kloppy TrackingDataset. - label_col (str): Column name that contains labels in the dataset.data Polars dataframe - graph_id_col (str): Column name that contains graph ids in the dataset.data Polars dataframe - + dataset (BigDataBowlDataset): BigDataBowlDataset. chunk_size (int): Used to batch convert Polars into Graphs attacking_non_qb_node_value (float): Value between 0 and 1 to assign any attacking team player who is not the QB + graph_features_as_node_features_columns (list): + List of columns in the dataset that are Graph level features (e.g. team strength rating, win probabilities etc) + we want to add to our model. They will be recorded as Node Features on the "football" node. """ def __init__( @@ -41,6 +43,7 @@ def __init__( dataset: BigDataBowlDataset, chunk_size: int = 2_000, attacking_non_qb_node_value: float = 0.1, + graph_feature_cols: Optional[List[str]] = None, **kwargs, ): super().__init__(**kwargs) @@ -63,12 +66,51 @@ def __init__( ) self.chunk_size = chunk_size self.attacking_non_qb_node_value = attacking_non_qb_node_value - - self._sport_specific_checks() + self.graph_feature_cols = graph_feature_cols self.settings = self._apply_settings() + self._sport_specific_checks() + def _sport_specific_checks(self): + def __remove_with_missing_values(min_object_count: int = 10): + cs = ( + self.dataset.group_by(Group.BY_FRAME) + .agg(pl.len().alias("size")) + .filter( + pl.col("size") < min_object_count + ) # Step 2: Keep groups with size < 10 + ) + + self.dataset = self.dataset.join(cs, on=Group.BY_FRAME, how="anti") + if len(cs) > 0: + warn( + f"Removed {len(cs)} frames with less than {min_object_count} objects...", + UserWarning, + ) + + def __remove_with_missing_football(): + cs = ( + self.dataset.group_by(Group.BY_FRAME) + .agg( + [ + pl.len().alias("size"), # Count total rows in each group + pl.col(Column.TEAM) + .filter(pl.col(Column.TEAM) == Constant.BALL) + .count() + .alias("football_count"), # Count rows where team == 'football' + ] + ) + .filter( + (pl.col("football_count") == 0) + ) # Step 2: Keep groups with size < 10 and no "football" + ) + self.dataset = self.dataset.join(cs, on=Group.BY_FRAME, how="anti") + if len(cs) > 0: + warn( + f"Removed {len(cs)} frames with a missing '{Constant.BALL}' object...", + UserWarning, + ) if not isinstance(self.label_column, str): raise Exception("'label_col' should be of type string (str)") @@ -95,6 +137,9 @@ def _sport_specific_checks(self): "'attacking_non_qb_node_value' should be of type float or integer (int)" ) + __remove_with_missing_values(min_object_count=10) + __remove_with_missing_football() + def _apply_settings(self): return AmericanFootballGraphSettings( pitch_dimensions=self.pitch_dimensions, @@ -115,7 +160,7 @@ def _apply_settings(self): @property def __exprs_variables(self): - return [ + exprs_variables = [ Column.X, Column.Y, Column.SPEED, @@ -130,10 +175,22 @@ def __exprs_variables(self): self.graph_id_column, self.label_column, ] + exprs = ( + exprs_variables + if self.graph_feature_cols is None + else exprs_variables + self.graph_feature_cols + ) + return exprs def __compute(self, args: List[pl.Series]) -> dict: d = {col: args[i].to_numpy() for i, col in enumerate(self.__exprs_variables)} + graph_features = ( + np.asarray([d[col] for col in self.graph_feature_cols]).T[0] + if self.graph_feature_cols + else None + ) + if not np.all(d[self.graph_id_column] == d[self.graph_id_column][0]): raise Exception( "GraphId selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..." @@ -174,6 +231,7 @@ def __compute(self, args: List[pl.Series]) -> dict: possession_team=d[Column.POSSESSION_TEAM], height=d[Column.HEIGHT_CM], weight=d[Column.WEIGHT_KG], + graph_features=graph_features, settings=self.settings, ) return { @@ -196,13 +254,33 @@ def __compute(self, args: List[pl.Series]) -> dict: self.label_column: d[self.label_column][0], } + @property + def return_dtypes(self): + return pl.Struct( + { + "e": pl.List(pl.List(pl.Float64)), + "x": pl.List(pl.List(pl.Float64)), + "a": pl.List(pl.List(pl.Float64)), + "e_shape_0": pl.Int64, + "e_shape_1": pl.Int64, + "x_shape_0": pl.Int64, + "x_shape_1": pl.Int64, + "a_shape_0": pl.Int64, + "a_shape_1": pl.Int64, + self.graph_id_column: pl.String, + self.label_column: pl.Int64, + } + ) + def _convert(self): # Group and aggregate in one step return ( self.dataset.group_by(Group.BY_FRAME, maintain_order=True) .agg( pl.map_groups( - exprs=self.__exprs_variables, function=self.__compute + exprs=self.__exprs_variables, + function=self.__compute, + return_dtype=self.return_dtypes, ).alias("result_dict") ) .with_columns( @@ -273,7 +351,7 @@ def to_spektral_graphs(self) -> List[Graph]: for d in self.graph_frames ] - def to_pickle(self, file_path: str) -> None: + def to_pickle(self, file_path: str, verbose: bool = False) -> None: """ We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data) @@ -286,6 +364,9 @@ def to_pickle(self, file_path: str) -> None: if not self.graph_frames: self.to_graph_frames() + if verbose: + print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...") + import pickle import gzip from pathlib import Path diff --git a/unravel/soccer/graphs/graph_converter_pl.py b/unravel/soccer/graphs/graph_converter_pl.py index 703c57d..b3d21f5 100644 --- a/unravel/soccer/graphs/graph_converter_pl.py +++ b/unravel/soccer/graphs/graph_converter_pl.py @@ -369,13 +369,33 @@ def __compute(self, args: List[pl.Series]) -> dict: self.label_column: d[self.label_column][0], } + @property + def return_dtypes(self): + return pl.Struct( + { + "e": pl.List(pl.List(pl.Float64)), + "x": pl.List(pl.List(pl.Float64)), + "a": pl.List(pl.List(pl.Float64)), + "e_shape_0": pl.Int64, + "e_shape_1": pl.Int64, + "x_shape_0": pl.Int64, + "x_shape_1": pl.Int64, + "a_shape_0": pl.Int64, + "a_shape_1": pl.Int64, + self.graph_id_column: pl.String, + self.label_column: pl.Int64, + } + ) + def _convert(self): # Group and aggregate in one step return ( self.dataset.group_by(Group.BY_FRAME, maintain_order=True) .agg( pl.map_groups( - exprs=self.__exprs_variables, function=self.__compute + exprs=self.__exprs_variables, + function=self.__compute, + return_dtype=self.return_dtypes, ).alias("result_dict") ) .with_columns( @@ -446,7 +466,7 @@ def to_spektral_graphs(self) -> List[Graph]: for d in self.graph_frames ] - def to_pickle(self, file_path: str) -> None: + def to_pickle(self, file_path: str, verbose: bool = False) -> None: """ We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data) @@ -459,6 +479,9 @@ def to_pickle(self, file_path: str) -> None: if not self.graph_frames: self.to_graph_frames() + if verbose: + print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...") + import pickle import gzip from pathlib import Path diff --git a/unravel/utils/objects/custom_disjoint_loader.py b/unravel/utils/objects/custom_disjoint_loader.py new file mode 100644 index 0000000..b3c7648 --- /dev/null +++ b/unravel/utils/objects/custom_disjoint_loader.py @@ -0,0 +1,25 @@ +from spektral.data import DisjointLoader +import tensorflow as tf + +version = tf.__version__.split(".") +major, minor = int(version[0]), int(version[1]) +tf_loader_available = major >= 2 and minor >= 4 + + +class CustomDisjointLoader(DisjointLoader): + def __init__( + self, dataset, node_level=False, batch_size=1, epochs=None, shuffle=True + ): + self.node_level = node_level + super().__init__(dataset, batch_size=batch_size, epochs=epochs, shuffle=shuffle) + + def load(self): + if not tf_loader_available: + raise RuntimeError( + "Calling DisjointLoader.load() requires TensorFlow 2.4 or greater." + ) + dataset = tf.data.Dataset.from_generator( + lambda: self, output_signature=self.tf_signature() + ) + dataset = dataset.shuffle(buffer_size=1000) + return dataset.repeat() diff --git a/unravel/utils/objects/custom_spektral_dataset.py b/unravel/utils/objects/custom_spektral_dataset.py index dcc21a5..f7524da 100644 --- a/unravel/utils/objects/custom_spektral_dataset.py +++ b/unravel/utils/objects/custom_spektral_dataset.py @@ -17,6 +17,7 @@ from collections.abc import Sequence from spektral.data import Dataset, Graph +from spektral.data.utils import get_spec from .default_graph_frame import DefaultGraphFrame @@ -39,8 +40,10 @@ def __init__(self, **kwargs): """ Constructor to load parameters. """ - # super().__init__(**kwargs) - self._kwargs = kwargs # Store kwargs for serialization + self._kwargs = kwargs + + sample_rate = kwargs.get("sample_rate", 1.0) + self.sample = 1.0 / sample_rate if kwargs.get("pickle_folder", None): pickle_folder = Path(kwargs["pickle_folder"]) @@ -52,6 +55,17 @@ def __init__(self, **kwargs): self.graphs = self.__convert(data) else: self.add(data) + + elif kwargs.get("pickle_file", None): + pickle_file = Path(kwargs["pickle_file"]) + self.graphs = None + data = load_pickle_gz(pickle_file) + + if not self.graphs: + self.graphs = self.__convert(data) + else: + self.add(data) + elif kwargs.get("graphs", None): if not isinstance(kwargs["graphs"], list): raise NotImplementedError("""data should be of type list""") @@ -59,7 +73,7 @@ def __init__(self, **kwargs): self.graphs = kwargs["graphs"] else: raise NotImplementedError( - "Please provide either 'pickle_folder' or 'graphs' as parameter to CustomSpektralDataset" + "Please provide either 'pickle_folder', 'pickle_file' or 'graphs' as parameter to CustomSpektralDataset" ) super().__init__(**kwargs) @@ -69,12 +83,18 @@ def __convert(self, data) -> List[Graph]: Convert incoming data to correct List[Graph] format """ if isinstance(data[0], Graph): - return data + return [g for i, g in enumerate(data) if i % self.sample == 0] elif isinstance(data[0], DefaultGraphFrame): - return [g.to_spektral_graph() for g in self.data] + return [ + g.to_spektral_graph() + for i, g in enumerate(self.data) + if i % self.sample == 0 + ] elif isinstance(data[0], dict): return [ - Graph(x=g["x"], a=g["a"], e=g["e"], y=g["y"], id=g["id"]) for g in data + Graph(x=g["x"], a=g["a"], e=g["e"], y=g["y"], id=g["id"]) + for i, g in enumerate(data) + if i % self.sample == 0 ] else: raise NotImplementedError() @@ -242,3 +262,50 @@ def __handle_graph_id(i): return self[train_idxs], self[test_idxs], self[validation_idxs] else: return self[train_idxs], self[test_idxs] + + @property + def signature(self): + """ + This property computes the signature of the dataset, which can be + passed to `spektral.data.utils.to_tf_signature(signature)` to compute + the TensorFlow signature. + + The signature includes TensorFlow TypeSpec, shape, and dtype for all + characteristic matrices of the graphs in the Dataset. + """ + if len(self.graphs) == 0: + return None + signature = {} + graph = self.graphs[0] # This is always non-empty + + if graph.x is not None: + signature["x"] = dict() + signature["x"]["spec"] = get_spec(graph.x) + signature["x"]["shape"] = (None, self.n_node_features) + signature["x"]["dtype"] = tf.as_dtype(graph.x.dtype) + + if graph.a is not None: + signature["a"] = dict() + signature["a"]["spec"] = get_spec(graph.a) + signature["a"]["shape"] = (None, None) + signature["a"]["dtype"] = tf.as_dtype(graph.a.dtype) + + if graph.e is not None: + signature["e"] = dict() + signature["e"]["spec"] = get_spec(graph.e) + signature["e"]["shape"] = (None, self.n_edge_features) + signature["e"]["dtype"] = tf.as_dtype(graph.e.dtype) + + if graph.y is not None: + signature["y"] = dict() + signature["y"]["spec"] = get_spec(graph.y) + signature["y"]["shape"] = (self.n_labels,) + signature["y"]["dtype"] = tf.as_dtype(np.array(graph.y).dtype) + + if hasattr(graph, "g") and graph.g is not None: + signature["g"] = dict() + signature["g"]["spec"] = get_spec(graph.g) + signature["g"]["shape"] = graph.g.shape + signature["g"]["dtype"] = tf.as_dtype(np.array(graph.g).dtype) + + return signature