Skip to content

Commit

Permalink
fix(cgan): y_train when None
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Dec 9, 2024
1 parent 7493c3a commit 806b73b
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions neuralnetlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2644,10 +2644,16 @@ def fit(
callback.on_epoch_begin(epoch, epoch_logs)

start_time = time.time()
x_train_shuffled, y_train_shuffled = shuffle(
x_train, y_train,
random_state=random_state if random_state is not None else self.random_state
)
if y_train:
x_train_shuffled, y_train_shuffled = shuffle(
x_train, y_train,
random_state=random_state if random_state is not None else self.random_state
)
else:
x_train_shuffled = shuffle(
x_train,
random_state=random_state if random_state is not None else self.random_state
)
d_error = 0
g_error = 0

Expand All @@ -2662,7 +2668,10 @@ def fit(
for j in range(0, x_train.shape[0], batch_size):
batch_index = j // batch_size
x_batch = x_train_shuffled[j:j + batch_size]
y_batch = y_train_shuffled[j:j + batch_size] if y_train is not None else None
if y_train:
y_batch = y_train_shuffled[j:j + batch_size]
else:
y_batch = None

batch_logs = {
'batch': batch_index,
Expand Down

0 comments on commit 806b73b

Please sign in to comment.