From b32e938792fb25d9f468ceb9605c8d9465df0efd Mon Sep 17 00:00:00 2001 From: asistradition Date: Mon, 8 Jul 2024 11:10:16 -0400 Subject: [PATCH] Add loss index --- supirfactor_dynamical/tests/test_train.py | 10 +++++++++- .../training/train_simple_decoders.py | 7 +++++-- supirfactor_dynamical/training/train_simple_models.py | 7 +++++-- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/supirfactor_dynamical/tests/test_train.py b/supirfactor_dynamical/tests/test_train.py index f8ab939..34364b6 100644 --- a/supirfactor_dynamical/tests/test_train.py +++ b/supirfactor_dynamical/tests/test_train.py @@ -22,7 +22,8 @@ X, A, T, - XV_tensor + XV_tensor, + X_tensor ) @@ -80,6 +81,8 @@ def test_training(self): 10 ) + model(X_tensor) + post_weights = torch.clone(model.encoder[0].weight.detach()) with self.assertRaises(AssertionError): @@ -88,6 +91,11 @@ def test_training(self): post_weights ) + torch.testing.assert_close( + pre_weights != 0, + post_weights != 0 + ) + class TestCoupledTraining(_SetupMixin, unittest.TestCase): diff --git a/supirfactor_dynamical/training/train_simple_decoders.py b/supirfactor_dynamical/training/train_simple_decoders.py index 18222a9..19880d6 100644 --- a/supirfactor_dynamical/training/train_simple_decoders.py +++ b/supirfactor_dynamical/training/train_simple_decoders.py @@ -20,7 +20,8 @@ def train_simple_multidecoder( loss_function=torch.nn.MSELoss(), optimizer=None, post_epoch_hook=None, - training_loss_weights=None + training_loss_weights=None, + loss_index=None ): """ Train this model with multiple decoders @@ -170,7 +171,9 @@ def train_simple_multidecoder( axis=0, weights=np.array(_validation_n) ), - validation_n=np.sum(_validation_n) + validation_n=np.sum(_validation_n), + training_loss_idx=loss_index, + validation_loss_idx=loss_index ) model_ref.current_epoch = epoch_num diff --git a/supirfactor_dynamical/training/train_simple_models.py b/supirfactor_dynamical/training/train_simple_models.py index 735843f..013c7c0 100644 --- a/supirfactor_dynamical/training/train_simple_models.py +++ b/supirfactor_dynamical/training/train_simple_models.py @@ -16,7 +16,8 @@ def train_simple_model( validation_dataloader=None, loss_function=torch.nn.MSELoss(), optimizer=None, - post_epoch_hook=None + post_epoch_hook=None, + loss_index=None ): """ Train this model @@ -127,7 +128,9 @@ def train_simple_model( ), training_n=np.sum(_batch_n), validation_loss=_validation_loss, - validation_n=_validation_n + validation_n=_validation_n, + training_loss_idx=loss_index, + validation_loss_idx=loss_index ) model_ref.current_epoch = epoch_num