Skip to content

Commit

Permalink
fix test (#2528)
Browse files Browse the repository at this point in the history
  • Loading branch information
hadifawaz1999 authored Feb 18, 2025
1 parent 6f75a45 commit 4707d3e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
9 changes: 6 additions & 3 deletions aeon/clustering/deep_learning/_ae_fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def _fit(self, X):
outputs=X,
batch_size=mini_batch_size,
epochs=self.n_epochs,
verbose=self.verbose,
)

try:
Expand Down Expand Up @@ -345,6 +346,7 @@ def _fit_multi_rec_model(
outputs,
batch_size,
epochs,
verbose,
):
import tensorflow as tf

Expand Down Expand Up @@ -451,9 +453,10 @@ def loss(y_true, y_pred):
epoch_loss /= num_batches
history["loss"].append(epoch_loss)

sys.stdout.write(
"Training loss at epoch %d: %.4f\n" % (epoch, float(epoch_loss))
)
if verbose:
sys.stdout.write(
"Training loss at epoch %d: %.4f\n" % (epoch, float(epoch_loss))
)

for callback in self.callbacks_:
callback.on_epoch_end(epoch, {"loss": float(epoch_loss)})
Expand Down
9 changes: 6 additions & 3 deletions aeon/clustering/deep_learning/_ae_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _fit(self, X):
outputs=X,
batch_size=mini_batch_size,
epochs=self.n_epochs,
verbose=self.verbose,
)

try:
Expand Down Expand Up @@ -359,6 +360,7 @@ def _fit_multi_rec_model(
outputs,
batch_size,
epochs,
verbose,
):
import tensorflow as tf

Expand Down Expand Up @@ -463,9 +465,10 @@ def loss(y_true, y_pred):
epoch_loss /= num_batches
history["loss"].append(epoch_loss)

sys.stdout.write(
"Training loss at epoch %d: %.4f\n" % (epoch, float(epoch_loss))
)
if verbose:
sys.stdout.write(
"Training loss at epoch %d: %.4f\n" % (epoch, float(epoch_loss))
)

for callback in self.callbacks_:
callback.on_epoch_end(epoch, {"loss": float(epoch_loss)})
Expand Down
27 changes: 21 additions & 6 deletions aeon/testing/estimator_checking/_yield_clustering_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,33 @@ def check_clustering_random_state_deep_learning(estimator, datatype):
deep_clr1 = _clone_estimator(estimator, random_state=random_state)
deep_clr1.fit(FULL_TEST_DATA_DICT[datatype]["train"][0])

layers1 = deep_clr1.training_model_.layers[1:]
encoder_layers1 = deep_clr1.training_model_.layers[1].layers[1:]
decoder_layers1 = deep_clr1.training_model_.layers[2].layers[1:]

deep_clr2 = _clone_estimator(estimator, random_state=random_state)
deep_clr2.fit(FULL_TEST_DATA_DICT[datatype]["train"][0])

layers2 = deep_clr2.training_model_.layers[1:]
encoder_layers2 = deep_clr2.training_model_.layers[1].layers[1:]
decoder_layers2 = deep_clr2.training_model_.layers[2].layers[1:]

assert len(layers1) == len(layers2)
assert len(encoder_layers1) == len(encoder_layers2)
assert len(decoder_layers1) == len(decoder_layers2)

for i in range(len(layers1)):
weights1 = layers1[i].get_weights()
weights2 = layers2[i].get_weights()
for i in range(len(encoder_layers1)):
weights1 = encoder_layers1[i].get_weights()
weights2 = encoder_layers2[i].get_weights()

assert len(weights1) == len(weights2)

for j in range(len(weights1)):
_weight1 = np.asarray(weights1[j])
_weight2 = np.asarray(weights2[j])

np.testing.assert_almost_equal(_weight1, _weight2, 4)

for i in range(len(decoder_layers1)):
weights1 = decoder_layers1[i].get_weights()
weights2 = decoder_layers2[i].get_weights()

assert len(weights1) == len(weights2)

Expand Down

0 comments on commit 4707d3e

Please sign in to comment.