diff --git a/DeepCrazyhouse/src/training/train_cli.py b/DeepCrazyhouse/src/training/train_cli.py index fc1dc645..6e798a1f 100644 --- a/DeepCrazyhouse/src/training/train_cli.py +++ b/DeepCrazyhouse/src/training/train_cli.py @@ -81,7 +81,7 @@ def main(): update_train_config_via_args(args, train_config) - val_data, x_val, _ = get_validation_data(train_config) + val_data, x_val = get_validation_data(train_config) input_shape = x_val[0].shape fill_train_config(train_config, x_val) diff --git a/DeepCrazyhouse/src/training/train_cli_util.py b/DeepCrazyhouse/src/training/train_cli_util.py index f966816a..d514f8c6 100644 --- a/DeepCrazyhouse/src/training/train_cli_util.py +++ b/DeepCrazyhouse/src/training/train_cli_util.py @@ -284,7 +284,7 @@ def get_validation_data(train_config: TrainConfig): """ pgn_dataset_arrays_dict = load_pgn_dataset(dataset_type='val', part_id=0, verbose=True, normalize=train_config.normalize) val_data = get_data_loader(pgn_dataset_arrays_dict, train_config, shuffle=False) - return val_data, x_val, yp_val + return val_data, pgn_dataset_arrays_dict["x"] def print_model_summary(input_shape: tuple, model, x_val) -> None: