Skip to content

Commit

Permalink
Add loss index
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Jul 8, 2024
1 parent 848737e commit b32e938
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
10 changes: 9 additions & 1 deletion supirfactor_dynamical/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
X,
A,
T,
XV_tensor
XV_tensor,
X_tensor
)


Expand Down Expand Up @@ -80,6 +81,8 @@ def test_training(self):
10
)

model(X_tensor)

post_weights = torch.clone(model.encoder[0].weight.detach())

with self.assertRaises(AssertionError):
Expand All @@ -88,6 +91,11 @@ def test_training(self):
post_weights
)

torch.testing.assert_close(
pre_weights != 0,
post_weights != 0
)


class TestCoupledTraining(_SetupMixin, unittest.TestCase):

Expand Down
7 changes: 5 additions & 2 deletions supirfactor_dynamical/training/train_simple_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def train_simple_multidecoder(
loss_function=torch.nn.MSELoss(),
optimizer=None,
post_epoch_hook=None,
training_loss_weights=None
training_loss_weights=None,
loss_index=None
):
"""
Train this model with multiple decoders
Expand Down Expand Up @@ -170,7 +171,9 @@ def train_simple_multidecoder(
axis=0,
weights=np.array(_validation_n)
),
validation_n=np.sum(_validation_n)
validation_n=np.sum(_validation_n),
training_loss_idx=loss_index,
validation_loss_idx=loss_index
)

model_ref.current_epoch = epoch_num
Expand Down
7 changes: 5 additions & 2 deletions supirfactor_dynamical/training/train_simple_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def train_simple_model(
validation_dataloader=None,
loss_function=torch.nn.MSELoss(),
optimizer=None,
post_epoch_hook=None
post_epoch_hook=None,
loss_index=None
):
"""
Train this model
Expand Down Expand Up @@ -127,7 +128,9 @@ def train_simple_model(
),
training_n=np.sum(_batch_n),
validation_loss=_validation_loss,
validation_n=_validation_n
validation_n=_validation_n,
training_loss_idx=loss_index,
validation_loss_idx=loss_index
)

model_ref.current_epoch = epoch_num
Expand Down

0 comments on commit b32e938

Please sign in to comment.