Skip to content

Commit

Permalink
minor internals
Browse files Browse the repository at this point in the history
  • Loading branch information
UnravelSports [JB] committed Jan 30, 2025
1 parent 7c8fc7e commit 822d1d7
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 34 deletions.
4 changes: 2 additions & 2 deletions tests/test_bigdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def node_feature_values(self):
"is_possession_team": 0.0,
"is_qb": 0.0,
"is_ball": 0.0,
"weight_normed": 0.21941714285714287,
"height_normed": 0.4722666666666665,
"weight_normed": 0.21428571428571427,
"height_normed": 0.5333333333333333,
}
return item_idx, assert_values

Expand Down
48 changes: 35 additions & 13 deletions unravel/american_football/graphs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,23 @@ def __init__(
tracking_file_path: str,
players_file_path: str,
plays_file_path: str,
sample_rate: float = None,
**kwargs,
):
super().__init__(**kwargs)
self.tracking_file_path = tracking_file_path
self.players_file_path = players_file_path
self.plays_file_path = plays_file_path
self.sample_rate = 1 if sample_rate is None else sample_rate
self.pitch_dimensions = AmericanFootballPitchDimensions()

def load(self):
pitch_length = self.pitch_dimensions.pitch_length
pitch_width = self.pitch_dimensions.pitch_width

df = pl.read_csv(
sample = 1.0 / self.sample_rate

df = pl.scan_csv(
self.tracking_file_path,
separator=",",
encoding="utf8",
Expand All @@ -71,7 +75,7 @@ def load(self):

play_direction = "left"

if "club" in df.columns:
if "club" in df.collect_schema().names():
df = df.with_columns(pl.col(Column.CLUB).alias(Column.TEAM))
df = df.drop(Column.CLUB)

Expand Down Expand Up @@ -119,7 +123,14 @@ def load(self):
.alias(Column.OBJECT_ID),
]
)
)
.with_columns(
[
pl.lit(play_direction).alias("playDirection"),
]
)
.filter((pl.col(Column.FRAME_ID) % sample) == 0)
).collect()

players = pl.read_csv(
self.players_file_path,
separator=",",
Expand Down Expand Up @@ -206,14 +217,25 @@ def _convert_weight_height_to_metric(df: pl.DataFrame):
.alias("inches"), # Extract inches and cast to float
]
)
df = df.with_columns(
[
(pl.col("feet") * 30.48 + pl.col("inches") * 2.54).alias(
Column.HEIGHT_CM
),
(pl.col("weight") * 0.453592).alias(
Column.WEIGHT_KG
), # Convert pounds to kilograms
]
).drop(["height", "feet", "inches", "weight"])
# Convert height and weight to centimeters and kilograms
# Round them to 0.1 to make sure we don't leak any player specific info
df = (
df.with_columns(
[
((pl.col("feet") * 30.48 + pl.col("inches") * 2.54) / 10)
.round(0)
.alias(Column.HEIGHT_CM),
((pl.col("weight") * 0.453592) / 10)
.round(0)
.alias(Column.WEIGHT_KG),
]
)
.with_columns(
[
(pl.col(Column.HEIGHT_CM) * 10).alias(Column.HEIGHT_CM),
(pl.col(Column.WEIGHT_KG) * 10).alias(Column.WEIGHT_KG),
]
)
.drop(["height", "feet", "inches", "weight"])
)
return df
6 changes: 6 additions & 0 deletions unravel/american_football/graphs/features/node_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def compute_node_features(
possession_team,
height,
weight,
graph_features,
settings,
):
ball_id = Constant.BALL
Expand Down Expand Up @@ -141,4 +142,9 @@ def compute_node_features(
)
)

if graph_features is not None:
eg = np.ones((X.shape[0], graph_features.shape[0])) * 0.0
eg[ball_index] = graph_features
X = np.hstack((X, eg))

return X
103 changes: 92 additions & 11 deletions unravel/american_football/graphs/graph_converter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from warnings import warn

from dataclasses import dataclass

import polars as pl
import numpy as np

from typing import List
from typing import List, Optional

from spektral.data import Graph

Expand All @@ -19,7 +21,7 @@
compute_adjacency_matrix,
)

from ...utils import DefaultGraphConverter, reshape_array, make_sparse
from ...utils import *


@dataclass(repr=True)
Expand All @@ -28,19 +30,20 @@ class AmericanFootballGraphConverter(DefaultGraphConverter):
Converts our dataset TrackingDataset into an internal structure
Attributes:
dataset (TrackingDataset): Kloppy TrackingDataset.
label_col (str): Column name that contains labels in the dataset.data Polars dataframe
graph_id_col (str): Column name that contains graph ids in the dataset.data Polars dataframe
dataset (BigDataBowlDataset): BigDataBowlDataset.
chunk_size (int): Used to batch convert Polars into Graphs
attacking_non_qb_node_value (float): Value between 0 and 1 to assign any attacking team player who is not the QB
graph_features_as_node_features_columns (list):
List of columns in the dataset that are Graph level features (e.g. team strength rating, win probabilities etc)
we want to add to our model. They will be recorded as Node Features on the "football" node.
"""

def __init__(
self,
dataset: BigDataBowlDataset,
chunk_size: int = 2_000,
attacking_non_qb_node_value: float = 0.1,
graph_feature_cols: Optional[List[str]] = None,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -63,12 +66,51 @@ def __init__(
)
self.chunk_size = chunk_size
self.attacking_non_qb_node_value = attacking_non_qb_node_value

self._sport_specific_checks()
self.graph_feature_cols = graph_feature_cols

self.settings = self._apply_settings()

self._sport_specific_checks()

def _sport_specific_checks(self):
def __remove_with_missing_values(min_object_count: int = 10):
cs = (
self.dataset.group_by(Group.BY_FRAME)
.agg(pl.len().alias("size"))
.filter(
pl.col("size") < min_object_count
) # Step 2: Keep groups with size < 10
)

self.dataset = self.dataset.join(cs, on=Group.BY_FRAME, how="anti")
if len(cs) > 0:
warn(
f"Removed {len(cs)} frames with less than {min_object_count} objects...",
UserWarning,
)

def __remove_with_missing_football():
cs = (
self.dataset.group_by(Group.BY_FRAME)
.agg(
[
pl.len().alias("size"), # Count total rows in each group
pl.col(Column.TEAM)
.filter(pl.col(Column.TEAM) == Constant.BALL)
.count()
.alias("football_count"), # Count rows where team == 'football'
]
)
.filter(
(pl.col("football_count") == 0)
) # Step 2: Keep groups with size < 10 and no "football"
)
self.dataset = self.dataset.join(cs, on=Group.BY_FRAME, how="anti")
if len(cs) > 0:
warn(
f"Removed {len(cs)} frames with a missing '{Constant.BALL}' object...",
UserWarning,
)

if not isinstance(self.label_column, str):
raise Exception("'label_col' should be of type string (str)")
Expand All @@ -95,6 +137,9 @@ def _sport_specific_checks(self):
"'attacking_non_qb_node_value' should be of type float or integer (int)"
)

__remove_with_missing_values(min_object_count=10)
__remove_with_missing_football()

def _apply_settings(self):
return AmericanFootballGraphSettings(
pitch_dimensions=self.pitch_dimensions,
Expand All @@ -115,7 +160,7 @@ def _apply_settings(self):

@property
def __exprs_variables(self):
return [
exprs_variables = [
Column.X,
Column.Y,
Column.SPEED,
Expand All @@ -130,10 +175,22 @@ def __exprs_variables(self):
self.graph_id_column,
self.label_column,
]
exprs = (
exprs_variables
if self.graph_feature_cols is None
else exprs_variables + self.graph_feature_cols
)
return exprs

def __compute(self, args: List[pl.Series]) -> dict:
d = {col: args[i].to_numpy() for i, col in enumerate(self.__exprs_variables)}

graph_features = (
np.asarray([d[col] for col in self.graph_feature_cols]).T[0]
if self.graph_feature_cols
else None
)

if not np.all(d[self.graph_id_column] == d[self.graph_id_column][0]):
raise Exception(
"GraphId selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..."
Expand Down Expand Up @@ -174,6 +231,7 @@ def __compute(self, args: List[pl.Series]) -> dict:
possession_team=d[Column.POSSESSION_TEAM],
height=d[Column.HEIGHT_CM],
weight=d[Column.WEIGHT_KG],
graph_features=graph_features,
settings=self.settings,
)
return {
Expand All @@ -196,13 +254,33 @@ def __compute(self, args: List[pl.Series]) -> dict:
self.label_column: d[self.label_column][0],
}

@property
def return_dtypes(self):
return pl.Struct(
{
"e": pl.List(pl.List(pl.Float64)),
"x": pl.List(pl.List(pl.Float64)),
"a": pl.List(pl.List(pl.Float64)),
"e_shape_0": pl.Int64,
"e_shape_1": pl.Int64,
"x_shape_0": pl.Int64,
"x_shape_1": pl.Int64,
"a_shape_0": pl.Int64,
"a_shape_1": pl.Int64,
self.graph_id_column: pl.String,
self.label_column: pl.Int64,
}
)

def _convert(self):
# Group and aggregate in one step
return (
self.dataset.group_by(Group.BY_FRAME, maintain_order=True)
.agg(
pl.map_groups(
exprs=self.__exprs_variables, function=self.__compute
exprs=self.__exprs_variables,
function=self.__compute,
return_dtype=self.return_dtypes,
).alias("result_dict")
)
.with_columns(
Expand Down Expand Up @@ -273,7 +351,7 @@ def to_spektral_graphs(self) -> List[Graph]:
for d in self.graph_frames
]

def to_pickle(self, file_path: str) -> None:
def to_pickle(self, file_path: str, verbose: bool = False) -> None:
"""
We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y
To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data)
Expand All @@ -286,6 +364,9 @@ def to_pickle(self, file_path: str) -> None:
if not self.graph_frames:
self.to_graph_frames()

if verbose:
print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...")

import pickle
import gzip
from pathlib import Path
Expand Down
27 changes: 25 additions & 2 deletions unravel/soccer/graphs/graph_converter_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,33 @@ def __compute(self, args: List[pl.Series]) -> dict:
self.label_column: d[self.label_column][0],
}

@property
def return_dtypes(self):
return pl.Struct(
{
"e": pl.List(pl.List(pl.Float64)),
"x": pl.List(pl.List(pl.Float64)),
"a": pl.List(pl.List(pl.Float64)),
"e_shape_0": pl.Int64,
"e_shape_1": pl.Int64,
"x_shape_0": pl.Int64,
"x_shape_1": pl.Int64,
"a_shape_0": pl.Int64,
"a_shape_1": pl.Int64,
self.graph_id_column: pl.String,
self.label_column: pl.Int64,
}
)

def _convert(self):
# Group and aggregate in one step
return (
self.dataset.group_by(Group.BY_FRAME, maintain_order=True)
.agg(
pl.map_groups(
exprs=self.__exprs_variables, function=self.__compute
exprs=self.__exprs_variables,
function=self.__compute,
return_dtype=self.return_dtypes,
).alias("result_dict")
)
.with_columns(
Expand Down Expand Up @@ -446,7 +466,7 @@ def to_spektral_graphs(self) -> List[Graph]:
for d in self.graph_frames
]

def to_pickle(self, file_path: str) -> None:
def to_pickle(self, file_path: str, verbose: bool = False) -> None:
"""
We store the 'dict' version of the Graphs to pickle each graph is now a dict with keys x, a, e, and y
To use for training with Spektral feed the loaded pickle data to CustomDataset(data=pickled_data)
Expand All @@ -459,6 +479,9 @@ def to_pickle(self, file_path: str) -> None:
if not self.graph_frames:
self.to_graph_frames()

if verbose:
print(f"Storing {len(self.graph_frames)} Graphs in {file_path}...")

import pickle
import gzip
from pathlib import Path
Expand Down
25 changes: 25 additions & 0 deletions unravel/utils/objects/custom_disjoint_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from spektral.data import DisjointLoader
import tensorflow as tf

version = tf.__version__.split(".")
major, minor = int(version[0]), int(version[1])
tf_loader_available = major >= 2 and minor >= 4


class CustomDisjointLoader(DisjointLoader):
def __init__(
self, dataset, node_level=False, batch_size=1, epochs=None, shuffle=True
):
self.node_level = node_level
super().__init__(dataset, batch_size=batch_size, epochs=epochs, shuffle=shuffle)

def load(self):
if not tf_loader_available:
raise RuntimeError(
"Calling DisjointLoader.load() requires TensorFlow 2.4 or greater."
)
dataset = tf.data.Dataset.from_generator(
lambda: self, output_signature=self.tf_signature()
)
dataset = dataset.shuffle(buffer_size=1000)
return dataset.repeat()
Loading

0 comments on commit 822d1d7

Please sign in to comment.