Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
UnravelSports [JB] committed Jan 26, 2025
1 parent ce14e38 commit 676014e
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 35 deletions.
2 changes: 1 addition & 1 deletion unravel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.0"
__version__ = "0.3.0"

from .soccer import *
from .american_football import *
Expand Down
2 changes: 1 addition & 1 deletion unravel/american_football/graphs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def load(self):
separator=",",
encoding="utf8",
null_values=["NA", "NULL", ""],
dtypes={"birthDate": pl.Date},
schema_overrides={"birthDate": pl.Date},
ignore_errors=True,
)
if "position" in players.columns:
Expand Down
42 changes: 27 additions & 15 deletions unravel/american_football/graphs/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,14 @@ def __init__(
if not isinstance(dataset, BigDataBowlDataset):
raise Exception("'dataset' should be an instance of BigDataBowlDataset")

self.label_col = dataset._label_column
self.graph_id_col = dataset._graph_id_column
self.label_column: str = (
self.label_col if self.label_col is not None else dataset._label_column
)
self.graph_id_column: str = (
self.graph_id_col
if self.graph_id_col is not None
else dataset._graph_id_column
)

self.dataset: pl.DataFrame = dataset.data
self.pitch_dimensions: AmericanFootballPitchDimensions = (
Expand All @@ -64,21 +70,21 @@ def __init__(

def _sport_specific_checks(self):

if not isinstance(self.label_col, str):
if not isinstance(self.label_column, str):
raise Exception("'label_col' should be of type string (str)")

if not isinstance(self.graph_id_col, str):
if not isinstance(self.graph_id_column, str):
raise Exception("'graph_id_col' should be of type string (str)")

if not isinstance(self.chunk_size, int):
raise Exception("chunk_size should be of type integer (int)")

if not self.label_col in self.dataset.columns and not self.prediction:
if not self.label_column in self.dataset.columns and not self.prediction:
raise Exception(
"Please specify a 'label_col' and add that column to your 'dataset' or set 'prediction=True' if you want to use the converted dataset to make predictions on."
)

if not self.graph_id_col in self.dataset.columns:
if not self.graph_id_column in self.dataset.columns:
raise Exception(
"Please specify a 'graph_id_col' and add that column to your 'dataset' ..."
)
Expand Down Expand Up @@ -121,20 +127,20 @@ def __exprs_variables(self):
Column.POSSESSION_TEAM,
Column.HEIGHT_CM,
Column.WEIGHT_KG,
self.graph_id_col,
self.label_col,
self.graph_id_column,
self.label_column,
]

def __compute(self, args: List[pl.Series]) -> dict:
d = {col: args[i].to_numpy() for i, col in enumerate(self.__exprs_variables)}

if not np.all(d[self.graph_id_col] == d[self.graph_id_col][0]):
if not np.all(d[self.graph_id_column] == d[self.graph_id_column][0]):
raise Exception(
"GraphId selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..."
)

if not self.prediction and not np.all(
d[self.label_col] == d[self.label_col][0]
d[self.label_column] == d[self.label_column][0]
):
raise Exception(
"""Label selection contains multiple different values for a single selection (group by) of game_id and frame_id,
Expand Down Expand Up @@ -186,8 +192,8 @@ def __compute(self, args: List[pl.Series]) -> dict:
"x_shape_1": node_features.shape[1],
"a_shape_0": adjacency_matrix.shape[0],
"a_shape_1": adjacency_matrix.shape[1],
self.graph_id_col: d[self.graph_id_col][0],
self.label_col: d[self.label_col][0],
self.graph_id_column: d[self.graph_id_column][0],
self.label_column: d[self.label_column][0],
}

def _convert(self):
Expand All @@ -203,7 +209,13 @@ def _convert(self):
[
*[
pl.col("result_dict").struct.field(f).alias(f)
for f in ["a", "e", "x", self.graph_id_col, self.label_col]
for f in [
"a",
"e",
"x",
self.graph_id_column,
self.label_column,
]
],
*[
pl.col("result_dict")
Expand Down Expand Up @@ -232,8 +244,8 @@ def process_chunk(chunk: pl.DataFrame) -> List[dict]:
"e": reshape_array(
chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i]
),
"y": np.asarray([chunk[self.label_col][i]]),
"id": chunk[self.graph_id_col][i],
"y": np.asarray([chunk[self.label_column][i]]),
"id": chunk[self.graph_id_column][i],
}
for i in range(len(chunk))
]
Expand Down
1 change: 0 additions & 1 deletion unravel/soccer/graphs/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ class SoccerGraphConverter(DefaultGraphConverter):
dataset: TrackingDataset = None
labels: dict = None

labels: dict = None
graph_id: Union[str, int, dict] = None
graph_ids: dict = None

Expand Down
46 changes: 29 additions & 17 deletions unravel/soccer/graphs/graph_converter_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,14 @@ class SoccerGraphConverterPolars(DefaultGraphConverter):

def __post_init__(self):
self.pitch_dimensions: MetricPitchDimensions = self.dataset.pitch_dimensions
self.label_col = self.dataset._label_column
self.graph_id_col = self.dataset._graph_id_column
self.label_column: str = (
self.label_col if self.label_col is not None else self.dataset._label_column
)
self.graph_id_column: str = (
self.graph_id_col
if self.graph_id_col is not None
else self.dataset._graph_id_column
)

self.dataset = self.dataset.data

Expand Down Expand Up @@ -76,8 +82,8 @@ def _apply_padding(self) -> pl.DataFrame:
Column.TIMESTAMP,
Column.BALL_STATE,
Column.POSITION_NAME,
self.label_col,
self.graph_id_col,
self.label_column,
self.graph_id_column,
]
empty_columns = [
Column.OBJECT_ID,
Expand Down Expand Up @@ -240,21 +246,21 @@ def _apply_settings(self):
)

def _sport_specific_checks(self):
if not isinstance(self.label_col, str):
if not isinstance(self.label_column, str):
raise Exception("'label_col' should be of type string (str)")

if not isinstance(self.graph_id_col, str):
if not isinstance(self.graph_id_column, str):
raise Exception("'graph_id_col' should be of type string (str)")

if not isinstance(self.chunk_size, int):
raise Exception("chunk_size should be of type integer (int)")

if not self.label_col in self.dataset.columns and not self.prediction:
if not self.label_column in self.dataset.columns and not self.prediction:
raise Exception(
"Please specify a 'label_col' and add that column to your 'dataset' or set 'prediction=True' if you want to use the converted dataset to make predictions on."
)

if not self.graph_id_col in self.dataset.columns:
if not self.graph_id_column in self.dataset.columns:
raise Exception(
"Please specify a 'graph_id_col' and add that column to your 'dataset' ..."
)
Expand Down Expand Up @@ -284,20 +290,20 @@ def __exprs_variables(self):
Column.POSITION_NAME,
Column.BALL_OWNING_TEAM_ID,
Column.IS_BALL_CARRIER,
self.graph_id_col,
self.label_col,
self.graph_id_column,
self.label_column,
]

def __compute(self, args: List[pl.Series]) -> dict:
d = {col: args[i].to_numpy() for i, col in enumerate(self.__exprs_variables)}

if not np.all(d[self.graph_id_col] == d[self.graph_id_col][0]):
if not np.all(d[self.graph_id_column] == d[self.graph_id_column][0]):
raise Exception(
"GraphId selection contains multiple different values. Make sure each graph_id is unique by at least game_id and frame_id..."
)

if not self.prediction and not np.all(
d[self.label_col] == d[self.label_col][0]
d[self.label_column] == d[self.label_column][0]
):
raise Exception(
"""Label selection contains multiple different values for a single selection (group by) of game_id and frame_id,
Expand Down Expand Up @@ -356,8 +362,8 @@ def __compute(self, args: List[pl.Series]) -> dict:
"x_shape_1": node_features.shape[1],
"a_shape_0": adjacency_matrix.shape[0],
"a_shape_1": adjacency_matrix.shape[1],
self.graph_id_col: d[self.graph_id_col][0],
self.label_col: d[self.label_col][0],
self.graph_id_column: d[self.graph_id_column][0],
self.label_column: d[self.label_column][0],
}

def _convert(self):
Expand All @@ -373,7 +379,13 @@ def _convert(self):
[
*[
pl.col("result_dict").struct.field(f).alias(f)
for f in ["a", "e", "x", self.graph_id_col, self.label_col]
for f in [
"a",
"e",
"x",
self.graph_id_column,
self.label_column,
]
],
*[
pl.col("result_dict")
Expand Down Expand Up @@ -402,8 +414,8 @@ def process_chunk(chunk: pl.DataFrame) -> List[dict]:
"e": reshape_array(
chunk["e"][i], chunk["e_shape_0"][i], chunk["e_shape_1"][i]
),
"y": np.asarray([chunk[self.label_col][i]]),
"id": chunk[self.graph_id_col][i],
"y": np.asarray([chunk[self.label_column][i]]),
"id": chunk[self.graph_id_column][i],
}
for i in range(len(chunk))
]
Expand Down
3 changes: 3 additions & 0 deletions unravel/utils/objects/default_graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class DefaultGraphConverter:
pad: bool = False
verbose: bool = False

label_col: str = None
graph_id_col: str = None

graph_frames: dict = field(init=False, repr=False, default=None)
settings: DefaultGraphSettings = field(
init=False, repr=False, default_factory=DefaultGraphSettings
Expand Down

0 comments on commit 676014e

Please sign in to comment.