diff --git a/tf_shell_ml/model_base.py b/tf_shell_ml/model_base.py index 5622044..f6ce9f6 100644 --- a/tf_shell_ml/model_base.py +++ b/tf_shell_ml/model_base.py @@ -19,6 +19,7 @@ import tf_shell import tf_shell_ml import time +import gc class SequentialBase(keras.Sequential): @@ -213,6 +214,8 @@ def fit( features_dataset, labels_dataset = self.prep_dataset_for_model( features_dataset, labels_dataset ) + tf.keras.backend.clear_session() + gc.collect() # Calculate samples if possible. if steps_per_epoch is None: @@ -251,6 +254,7 @@ def fit( callback_list.on_train_batch_begin(step, logs) logs, num_slots = self.train_step_tf_func(batch_x, batch_y) callback_list.on_train_batch_end(step, logs) + gc.collect() if steps_per_epoch is not None and step + 1 >= steps_per_epoch: break