Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
UnravelSports [JB] committed Jul 23, 2024
1 parent 13ec2a9 commit f5de05f
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 57 deletions.
24 changes: 8 additions & 16 deletions examples/0_quick_start_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@
"source": [
"from spektral.data import DisjointLoader\n",
"\n",
"train, test, val = dataset.split_test_train_validation(split_train=4, split_test=1, split_validation=1, random_seed=42)"
"train, test, val = dataset.split_test_train_validation(\n",
" split_train=4, split_test=1, split_validation=1, random_seed=42\n",
")"
]
},
{
Expand Down Expand Up @@ -130,19 +132,12 @@
"\n",
"from tensorflow.keras.losses import BinaryCrossentropy\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.metrics import (\n",
" AUC, BinaryAccuracy\n",
")\n",
"from tensorflow.keras.metrics import AUC, BinaryAccuracy\n",
"\n",
"model = CrystalGraphClassifier()\n",
"\n",
"model.compile(\n",
" loss=BinaryCrossentropy(), \n",
" optimizer=Adam(),\n",
" metrics=[\n",
" AUC(), \n",
" BinaryAccuracy()\n",
" ]\n",
" loss=BinaryCrossentropy(), optimizer=Adam(), metrics=[AUC(), BinaryAccuracy()]\n",
")"
]
},
Expand Down Expand Up @@ -236,14 +231,12 @@
"loader_va = DisjointLoader(val, epochs=1, shuffle=False, batch_size=batch_size)\n",
"\n",
"model.fit(\n",
" loader_tr.load(), \n",
" loader_tr.load(),\n",
" epochs=epochs,\n",
" steps_per_epoch=loader_tr.steps_per_epoch, \n",
" steps_per_epoch=loader_tr.steps_per_epoch,\n",
" use_multiprocessing=True,\n",
" validation_data=loader_va.load(),\n",
" callbacks=[\n",
" EarlyStopping(monitor='loss', patience=5, restore_best_weights=True)\n",
" ]\n",
" callbacks=[EarlyStopping(monitor=\"loss\", patience=5, restore_best_weights=True)],\n",
")"
]
},
Expand Down Expand Up @@ -390,7 +383,6 @@
}
],
"source": [
"\n",
"loader_te = DisjointLoader(test, batch_size=batch_size, epochs=1, shuffle=False)\n",
"loaded_pred = model.predict(loader_te.load(), use_multiprocessing=True)\n",
"\n",
Expand Down
67 changes: 26 additions & 41 deletions examples/1_kloppy_gnn_train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,7 @@
"\n",
"for match_id in match_ids:\n",
" match_pickle_file_path = compressed_pickle_file_path.format(\n",
" pickle_folder=pickle_folder,\n",
" match_id=match_id\n",
" pickle_folder=pickle_folder, match_id=match_id\n",
" )\n",
" # if the output file already exists, skip this whole step\n",
" if not exists(match_pickle_file_path):\n",
Expand Down Expand Up @@ -551,18 +550,19 @@
"class CrystalGraphClassifier(Model):\n",
" def __init__(\n",
" self,\n",
" n_layers: int = 3, \n",
" channels: int = 128, \n",
" n_layers: int = 3,\n",
" channels: int = 128,\n",
" drop_out: float = 0.5,\n",
" n_out: int = 1, \n",
" **kwargs):\n",
" n_out: int = 1,\n",
" **kwargs\n",
" ):\n",
" super().__init__(**kwargs)\n",
" \n",
"\n",
" self.n_layers = n_layers\n",
" self.channels = channels\n",
" self.drop_out = drop_out\n",
" self.n_out = n_out\n",
" \n",
"\n",
" self.conv1 = CrystalConv()\n",
" self.convs = [CrystalConv() for _ in range(1, self.n_layers)]\n",
" self.pool = GlobalAvgPool()\n",
Expand All @@ -571,7 +571,7 @@
" self.dense2 = Dense(self.channels, activation=\"relu\")\n",
" self.dense3 = Dense(self.n_out, activation=\"sigmoid\")\n",
"\n",
" def call(self, inputs): \n",
" def call(self, inputs):\n",
" x, a, e, i = inputs\n",
" x = self.conv1([x, a, e])\n",
" for conv in self.convs:\n",
Expand Down Expand Up @@ -601,7 +601,6 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"from spektral.data import DisjointLoader\n",
"\n",
"loader_tr = DisjointLoader(train, batch_size=batch_size, epochs=epochs)\n",
Expand Down Expand Up @@ -642,23 +641,16 @@
}
],
"source": [
"from tensorflow.keras.metrics import (\n",
" AUC, BinaryAccuracy\n",
")\n",
"from tensorflow.keras.metrics import AUC, BinaryAccuracy\n",
"from tensorflow.keras.losses import BinaryCrossentropy\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.callbacks import EarlyStopping\n",
"\n",
"model = CrystalGraphClassifier()\n",
"\n",
"model.compile(\n",
" loss=BinaryCrossentropy(), \n",
" optimizer=Adam(),\n",
" metrics=[\n",
" AUC(), \n",
" BinaryAccuracy()\n",
" ]\n",
")\n"
" loss=BinaryCrossentropy(), optimizer=Adam(), metrics=[AUC(), BinaryAccuracy()]\n",
")"
]
},
{
Expand Down Expand Up @@ -750,14 +742,12 @@
],
"source": [
"model.fit(\n",
" loader_tr.load(), \n",
" steps_per_epoch=loader_tr.steps_per_epoch, \n",
" loader_tr.load(),\n",
" steps_per_epoch=loader_tr.steps_per_epoch,\n",
" epochs=10,\n",
" use_multiprocessing=True,\n",
" validation_data=loader_va.load(),\n",
" callbacks=[\n",
" EarlyStopping(monitor='loss', patience=5, restore_best_weights=True)\n",
" ]\n",
" callbacks=[EarlyStopping(monitor=\"loss\", patience=5, restore_best_weights=True)],\n",
")"
]
},
Expand Down Expand Up @@ -797,7 +787,7 @@
"source": [
"from tensorflow.keras.models import load_model\n",
"\n",
"model_path = 'models/my-first-graph-classifier'\n",
"model_path = \"models/my-first-graph-classifier\"\n",
"model.save(model_path)\n",
"loaded_model = load_model(model_path)"
]
Expand Down Expand Up @@ -852,13 +842,12 @@
"kloppy_dataset = skillcorner.load_open_data(\n",
" match_id=2068,\n",
" include_empty_frames=False,\n",
" limit=500, \n",
" limit=500,\n",
")\n",
"\n",
"preds_converter = GraphConverter(\n",
" dataset=kloppy_dataset, \n",
" dataset=kloppy_dataset,\n",
" prediction=True,\n",
" \n",
" ball_carrier_treshold=25.0,\n",
" max_player_speed=12.0,\n",
" max_ball_speed=28.0,\n",
Expand Down Expand Up @@ -908,7 +897,9 @@
"# Compute the graphs and add them to the CustomSpektralDataset\n",
"pred_dataset = CustomSpektralDataset(graphs=preds_converter.to_spektral_graphs())\n",
"\n",
"loader_pred = DisjointLoader(pred_dataset, batch_size=batch_size, epochs=1, shuffle=False)\n",
"loader_pred = DisjointLoader(\n",
" pred_dataset, batch_size=batch_size, epochs=1, shuffle=False\n",
")\n",
"preds = model.predict(loader_pred.load(), use_multiprocessing=True)"
]
},
Expand All @@ -931,17 +922,11 @@
"\n",
"kloppy_df = kloppy_dataset.to_df()\n",
"\n",
"preds_df = pd.DataFrame({\n",
" \"frame_id\": [x.id for x in pred_dataset],\n",
" \"y\": preds.flatten()\n",
"})\n",
"\n",
"kloppy_df = pd.merge(\n",
" kloppy_df,\n",
" preds_df,\n",
" on='frame_id',\n",
" how='left'\n",
")\n"
"preds_df = pd.DataFrame(\n",
" {\"frame_id\": [x.id for x in pred_dataset], \"y\": preds.flatten()}\n",
")\n",
"\n",
"kloppy_df = pd.merge(kloppy_df, preds_df, on=\"frame_id\", how=\"left\")"
]
}
],
Expand Down

0 comments on commit f5de05f

Please sign in to comment.