From 524e04848ac848e1372835213c45fcad404d0ef3 Mon Sep 17 00:00:00 2001 From: HelpstoneX Date: Mon, 23 Sep 2024 17:07:13 +0200 Subject: [PATCH] Fix additional dataloader creation (#220) --- DeepCrazyhouse/src/training/train_cnn.ipynb | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/DeepCrazyhouse/src/training/train_cnn.ipynb b/DeepCrazyhouse/src/training/train_cnn.ipynb index 46952807..1f82e346 100644 --- a/DeepCrazyhouse/src/training/train_cnn.ipynb +++ b/DeepCrazyhouse/src/training/train_cnn.ipynb @@ -399,18 +399,7 @@ " for phase in [str(phase) for phase in to.phase_weights.keys()] + [\"None\"]:\n", " pgn_dataset_arrays_dict = load_pgn_dataset(dataset_type='test', part_id=0,\n", " verbose=True, normalize=tc.normalize, phase=phase)\n", - " s_idcs_val_tmp = pgn_dataset_arrays_dict[\"start_indices\"]\n", - " x_val_tmp = pgn_dataset_arrays_dict[\"x\"]\n", - " yv_val_tmp = pgn_dataset_arrays_dict[\"y_value\"]\n", - " yp_val_tmp = pgn_dataset_arrays_dict[\"y_policy\"]\n", - " plys_to_end_tmp = pgn_dataset_arrays_dict[\"plys_to_end\"]\n", - " pgn_datasets_val_tmp = pgn_dataset_arrays_dict[\"pgn_dataset\"]\n", - " phase_vector_tmp = pgn_dataset_arrays_dict[\"phase_vector\"]\n", - "\n", - " if tc.discount != 1:\n", - " yv_val_tmp *= tc.discount**plys_to_end_tmp\n", - "\n", - " data_loader = get_data_loader(x_val_tmp, yv_val_tmp, yp_val_tmp, plys_to_end_tmp, phase_vector_tmp, tc, shuffle=False)\n", + " data_loader = get_data_loader(pgn_dataset_arrays_dict, tc, shuffle=False)\n", " additional_data_loaders[f\"Phase{phase}Test\"] = data_loader" ] }, @@ -1726,4 +1715,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file