Skip to content

Commit

Permalink
examples [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
UnravelSports [JB] committed Aug 22, 2024
1 parent 7858818 commit b7d7429
Show file tree
Hide file tree
Showing 6 changed files with 444 additions and 32 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ build.py
tests/files/models/my-test-gnn/*
tests/files/test.pickle.gz

examples/models/*
/models
/dev

*.ipynb
/*.ipynb

.env
2 changes: 1 addition & 1 deletion examples/0_quick_start_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"converter = GraphConverter(dataset=kloppy_dataset, labels=dummy_labels(kloppy_dataset))\n",
"\n",
"# Compute the graphs and add them to the CustomSpektralDataset\n",
"dataset = CustomSpektralDataset(graph=converter.to_spektral_graphs())"
"dataset = CustomSpektralDataset(graphs=converter.to_spektral_graphs())"
]
},
{
Expand Down
192 changes: 164 additions & 28 deletions examples/1_kloppy_gnn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -184,9 +184,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing frames: 100%|██████████| 500/500 [00:02<00:00, 244.81it/s]\n",
"Processing frames: 100%|██████████| 500/500 [00:01<00:00, 285.65it/s]\n",
"Processing frames: 100%|██████████| 500/500 [00:01<00:00, 343.58it/s] \n",
"Processing frames: 100%|██████████| 500/500 [00:01<00:00, 285.17it/s]\n"
]
}
],
"source": [
"from os.path import exists\n",
"\n",
Expand Down Expand Up @@ -266,7 +277,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -302,9 +313,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train: CustomSpektralDataset(n_graphs=791)\n",
"Test: CustomSpektralDataset(n_graphs=477)\n",
"Validation: CustomSpektralDataset(n_graphs=336)\n"
]
}
],
"source": [
"train, test, val = dataset.split_test_train_validation(\n",
" split_train=4, split_test=1, split_validation=1, by_graph_id=True, random_seed=42\n",
Expand All @@ -330,7 +351,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -358,7 +379,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -417,7 +438,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -479,7 +500,9 @@
"\n",
"1. We have a a [`DisjointLoader`](https://graphneural.network/loaders/#disjointloader) for training and validation sets.\n",
"2. Fit the model. \n",
"3. We add `EarlyStopping` and a `validation_data` dataset to monitor performance, and set `use_multiprocessing=True` to improve training speed."
"3. We add `EarlyStopping` and a `validation_data` dataset to monitor performance, and set `use_multiprocessing=True` to improve training speed.\n",
"\n",
"⚠️ When trying to fit the model _again_ make sure to reload Data Loaders in [Section 6.4](#64-create-dataloaders), because they are generators."
]
},
{
Expand All @@ -491,7 +514,7 @@
"model.fit(\n",
" loader_tr.load(),\n",
" steps_per_epoch=loader_tr.steps_per_epoch,\n",
" epochs=10,\n",
" epochs=5,\n",
" use_multiprocessing=True,\n",
" validation_data=loader_va.load(),\n",
" callbacks=[EarlyStopping(monitor=\"loss\", patience=5, restore_best_weights=True)],\n",
Expand Down Expand Up @@ -529,14 +552,24 @@
"1. Create another `DisjointLoader`, this time for the test set.\n",
"2. Evaluate model performance on the test set. This evaluation function uses the `metrics` passed to `model.compile`\n",
"\n",
"Note: Our performance is really bad because we're using random labels, very few epochs and a small dataset."
"🗒️ Our performance is really bad because we're using random labels, very few epochs and a small dataset.\n",
"\n",
"📖 For more information on evaluation in sports analytics see: [Methodology and evaluation in sports analytics: challenges, approaches, and lessons learned {J. Davis et. al. (2024)}](https://link.springer.com/article/10.1007/s10994-024-06585-0)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"15/15 [==============================] - 0s 4ms/step - loss: 0.7250 - auc: 0.5309 - binary_accuracy: 0.5241\n"
]
}
],
"source": [
"loader_te = DisjointLoader(test, epochs=1, shuffle=False, batch_size=batch_size)\n",
"results = model.evaluate(loader_te.load())"
Expand All @@ -555,7 +588,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -595,9 +628,24 @@
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Processing frames: 100%|██████████| 500/500 [00:01<00:00, 326.02it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"11/11 [==============================] - 0s 4ms/step\n"
]
}
],
"source": [
"# Compute the graphs and add them to the CustomSpektralDataset\n",
"pred_dataset = CustomSpektralDataset(graphs=preds_converter.to_spektral_graphs())\n",
Expand All @@ -612,16 +660,95 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"5. Convert Klopy dataset to a dataframe and merge back the pedictions using the frame_ids.\n",
"\n",
"Note: Not all frames have a prediction because of missing (ball) data."
"5. Convert Klopy dataset to a dataframe and merge back the pedictions using the frame_ids."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>frame_id</th>\n",
" <th>period_id</th>\n",
" <th>timestamp</th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>300</th>\n",
" <td>2166</td>\n",
" <td>1</td>\n",
" <td>0 days 00:00:33.300000</td>\n",
" <td>0.259016</td>\n",
" </tr>\n",
" <tr>\n",
" <th>301</th>\n",
" <td>2167</td>\n",
" <td>1</td>\n",
" <td>0 days 00:00:33.400000</td>\n",
" <td>0.251124</td>\n",
" </tr>\n",
" <tr>\n",
" <th>302</th>\n",
" <td>2168</td>\n",
" <td>1</td>\n",
" <td>0 days 00:00:33.500000</td>\n",
" <td>0.258305</td>\n",
" </tr>\n",
" <tr>\n",
" <th>303</th>\n",
" <td>2169</td>\n",
" <td>1</td>\n",
" <td>0 days 00:00:33.600000</td>\n",
" <td>0.256378</td>\n",
" </tr>\n",
" <tr>\n",
" <th>304</th>\n",
" <td>2170</td>\n",
" <td>1</td>\n",
" <td>0 days 00:00:33.700000</td>\n",
" <td>0.305434</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" frame_id period_id timestamp y\n",
"300 2166 1 0 days 00:00:33.300000 0.259016\n",
"301 2167 1 0 days 00:00:33.400000 0.251124\n",
"302 2168 1 0 days 00:00:33.500000 0.258305\n",
"303 2169 1 0 days 00:00:33.600000 0.256378\n",
"304 2170 1 0 days 00:00:33.700000 0.305434"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
Expand All @@ -631,7 +758,16 @@
" {\"frame_id\": [x.id for x in pred_dataset], \"y\": preds.flatten()}\n",
")\n",
"\n",
"kloppy_df = pd.merge(kloppy_df, preds_df, on=\"frame_id\", how=\"left\")"
"kloppy_df = pd.merge(kloppy_df, preds_df, on=\"frame_id\", how=\"left\")\n",
"\n",
"kloppy_df[300: 305][['frame_id', 'period_id', 'timestamp', 'y']]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"🗒️ Not all frames have a prediction because of missing (ball) data, so we look at the 300th frame."
]
}
],
Expand Down
Loading

0 comments on commit b7d7429

Please sign in to comment.