From f5de05f476a9a4a1132395bf704bff9f10646270 Mon Sep 17 00:00:00 2001 From: "UnravelSports [JB]" Date: Tue, 23 Jul 2024 14:37:53 +0200 Subject: [PATCH] formatting --- examples/0_quick_start_guide.ipynb | 24 ++++------- examples/1_kloppy_gnn_train.ipynb | 67 ++++++++++++------------------ 2 files changed, 34 insertions(+), 57 deletions(-) diff --git a/examples/0_quick_start_guide.ipynb b/examples/0_quick_start_guide.ipynb index 871b0f6..0f8daa0 100644 --- a/examples/0_quick_start_guide.ipynb +++ b/examples/0_quick_start_guide.ipynb @@ -99,7 +99,9 @@ "source": [ "from spektral.data import DisjointLoader\n", "\n", - "train, test, val = dataset.split_test_train_validation(split_train=4, split_test=1, split_validation=1, random_seed=42)" + "train, test, val = dataset.split_test_train_validation(\n", + " split_train=4, split_test=1, split_validation=1, random_seed=42\n", + ")" ] }, { @@ -130,19 +132,12 @@ "\n", "from tensorflow.keras.losses import BinaryCrossentropy\n", "from tensorflow.keras.optimizers import Adam\n", - "from tensorflow.keras.metrics import (\n", - " AUC, BinaryAccuracy\n", - ")\n", + "from tensorflow.keras.metrics import AUC, BinaryAccuracy\n", "\n", "model = CrystalGraphClassifier()\n", "\n", "model.compile(\n", - " loss=BinaryCrossentropy(), \n", - " optimizer=Adam(),\n", - " metrics=[\n", - " AUC(), \n", - " BinaryAccuracy()\n", - " ]\n", + " loss=BinaryCrossentropy(), optimizer=Adam(), metrics=[AUC(), BinaryAccuracy()]\n", ")" ] }, @@ -236,14 +231,12 @@ "loader_va = DisjointLoader(val, epochs=1, shuffle=False, batch_size=batch_size)\n", "\n", "model.fit(\n", - " loader_tr.load(), \n", + " loader_tr.load(),\n", " epochs=epochs,\n", - " steps_per_epoch=loader_tr.steps_per_epoch, \n", + " steps_per_epoch=loader_tr.steps_per_epoch,\n", " use_multiprocessing=True,\n", " validation_data=loader_va.load(),\n", - " callbacks=[\n", - " EarlyStopping(monitor='loss', patience=5, restore_best_weights=True)\n", - " ]\n", + " callbacks=[EarlyStopping(monitor=\"loss\", patience=5, restore_best_weights=True)],\n", ")" ] }, @@ -390,7 +383,6 @@ } ], "source": [ - "\n", "loader_te = DisjointLoader(test, batch_size=batch_size, epochs=1, shuffle=False)\n", "loaded_pred = model.predict(loader_te.load(), use_multiprocessing=True)\n", "\n", diff --git a/examples/1_kloppy_gnn_train.ipynb b/examples/1_kloppy_gnn_train.ipynb index 96a69ad..9576db0 100644 --- a/examples/1_kloppy_gnn_train.ipynb +++ b/examples/1_kloppy_gnn_train.ipynb @@ -251,8 +251,7 @@ "\n", "for match_id in match_ids:\n", " match_pickle_file_path = compressed_pickle_file_path.format(\n", - " pickle_folder=pickle_folder,\n", - " match_id=match_id\n", + " pickle_folder=pickle_folder, match_id=match_id\n", " )\n", " # if the output file already exists, skip this whole step\n", " if not exists(match_pickle_file_path):\n", @@ -551,18 +550,19 @@ "class CrystalGraphClassifier(Model):\n", " def __init__(\n", " self,\n", - " n_layers: int = 3, \n", - " channels: int = 128, \n", + " n_layers: int = 3,\n", + " channels: int = 128,\n", " drop_out: float = 0.5,\n", - " n_out: int = 1, \n", - " **kwargs):\n", + " n_out: int = 1,\n", + " **kwargs\n", + " ):\n", " super().__init__(**kwargs)\n", - " \n", + "\n", " self.n_layers = n_layers\n", " self.channels = channels\n", " self.drop_out = drop_out\n", " self.n_out = n_out\n", - " \n", + "\n", " self.conv1 = CrystalConv()\n", " self.convs = [CrystalConv() for _ in range(1, self.n_layers)]\n", " self.pool = GlobalAvgPool()\n", @@ -571,7 +571,7 @@ " self.dense2 = Dense(self.channels, activation=\"relu\")\n", " self.dense3 = Dense(self.n_out, activation=\"sigmoid\")\n", "\n", - " def call(self, inputs): \n", + " def call(self, inputs):\n", " x, a, e, i = inputs\n", " x = self.conv1([x, a, e])\n", " for conv in self.convs:\n", @@ -601,7 +601,6 @@ "metadata": {}, "outputs": [], "source": [ - "\n", "from spektral.data import DisjointLoader\n", "\n", "loader_tr = DisjointLoader(train, batch_size=batch_size, epochs=epochs)\n", @@ -642,9 +641,7 @@ } ], "source": [ - "from tensorflow.keras.metrics import (\n", - " AUC, BinaryAccuracy\n", - ")\n", + "from tensorflow.keras.metrics import AUC, BinaryAccuracy\n", "from tensorflow.keras.losses import BinaryCrossentropy\n", "from tensorflow.keras.optimizers import Adam\n", "from tensorflow.keras.callbacks import EarlyStopping\n", @@ -652,13 +649,8 @@ "model = CrystalGraphClassifier()\n", "\n", "model.compile(\n", - " loss=BinaryCrossentropy(), \n", - " optimizer=Adam(),\n", - " metrics=[\n", - " AUC(), \n", - " BinaryAccuracy()\n", - " ]\n", - ")\n" + " loss=BinaryCrossentropy(), optimizer=Adam(), metrics=[AUC(), BinaryAccuracy()]\n", + ")" ] }, { @@ -750,14 +742,12 @@ ], "source": [ "model.fit(\n", - " loader_tr.load(), \n", - " steps_per_epoch=loader_tr.steps_per_epoch, \n", + " loader_tr.load(),\n", + " steps_per_epoch=loader_tr.steps_per_epoch,\n", " epochs=10,\n", " use_multiprocessing=True,\n", " validation_data=loader_va.load(),\n", - " callbacks=[\n", - " EarlyStopping(monitor='loss', patience=5, restore_best_weights=True)\n", - " ]\n", + " callbacks=[EarlyStopping(monitor=\"loss\", patience=5, restore_best_weights=True)],\n", ")" ] }, @@ -797,7 +787,7 @@ "source": [ "from tensorflow.keras.models import load_model\n", "\n", - "model_path = 'models/my-first-graph-classifier'\n", + "model_path = \"models/my-first-graph-classifier\"\n", "model.save(model_path)\n", "loaded_model = load_model(model_path)" ] @@ -852,13 +842,12 @@ "kloppy_dataset = skillcorner.load_open_data(\n", " match_id=2068,\n", " include_empty_frames=False,\n", - " limit=500, \n", + " limit=500,\n", ")\n", "\n", "preds_converter = GraphConverter(\n", - " dataset=kloppy_dataset, \n", + " dataset=kloppy_dataset,\n", " prediction=True,\n", - " \n", " ball_carrier_treshold=25.0,\n", " max_player_speed=12.0,\n", " max_ball_speed=28.0,\n", @@ -908,7 +897,9 @@ "# Compute the graphs and add them to the CustomSpektralDataset\n", "pred_dataset = CustomSpektralDataset(graphs=preds_converter.to_spektral_graphs())\n", "\n", - "loader_pred = DisjointLoader(pred_dataset, batch_size=batch_size, epochs=1, shuffle=False)\n", + "loader_pred = DisjointLoader(\n", + " pred_dataset, batch_size=batch_size, epochs=1, shuffle=False\n", + ")\n", "preds = model.predict(loader_pred.load(), use_multiprocessing=True)" ] }, @@ -931,17 +922,11 @@ "\n", "kloppy_df = kloppy_dataset.to_df()\n", "\n", - "preds_df = pd.DataFrame({\n", - " \"frame_id\": [x.id for x in pred_dataset],\n", - " \"y\": preds.flatten()\n", - "})\n", - "\n", - "kloppy_df = pd.merge(\n", - " kloppy_df,\n", - " preds_df,\n", - " on='frame_id',\n", - " how='left'\n", - ")\n" + "preds_df = pd.DataFrame(\n", + " {\"frame_id\": [x.id for x in pred_dataset], \"y\": preds.flatten()}\n", + ")\n", + "\n", + "kloppy_df = pd.merge(kloppy_df, preds_df, on=\"frame_id\", how=\"left\")" ] } ],