From 806b73bccb9ce97a917a5b43fb3a531187813eca Mon Sep 17 00:00:00 2001 From: GitHub Action <52708150+marcpinet@users.noreply.github.com> Date: Mon, 9 Dec 2024 19:21:52 +0100 Subject: [PATCH] fix(cgan): y_train when None --- neuralnetlib/models.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/neuralnetlib/models.py b/neuralnetlib/models.py index 5114dae..0e9fa9b 100644 --- a/neuralnetlib/models.py +++ b/neuralnetlib/models.py @@ -2644,10 +2644,16 @@ def fit( callback.on_epoch_begin(epoch, epoch_logs) start_time = time.time() - x_train_shuffled, y_train_shuffled = shuffle( - x_train, y_train, - random_state=random_state if random_state is not None else self.random_state - ) + if y_train: + x_train_shuffled, y_train_shuffled = shuffle( + x_train, y_train, + random_state=random_state if random_state is not None else self.random_state + ) + else: + x_train_shuffled = shuffle( + x_train, + random_state=random_state if random_state is not None else self.random_state + ) d_error = 0 g_error = 0 @@ -2662,7 +2668,10 @@ def fit( for j in range(0, x_train.shape[0], batch_size): batch_index = j // batch_size x_batch = x_train_shuffled[j:j + batch_size] - y_batch = y_train_shuffled[j:j + batch_size] if y_train is not None else None + if y_train: + y_batch = y_train_shuffled[j:j + batch_size] + else: + y_batch = None batch_logs = { 'batch': batch_index,