Skip to content

Commit

Permalink
Training simple test
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed Jul 5, 2024
1 parent 120f239 commit 848737e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 8 deletions.
27 changes: 26 additions & 1 deletion supirfactor_dynamical/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
dynamical_model_training as model_training,
pretrain_and_tune_dynamic_model,
process_results_to_dataframes,
process_combined_results
process_combined_results,
train_simple_model
)

from supirfactor_dynamical.models import _CLASS_DICT
Expand Down Expand Up @@ -64,6 +65,30 @@ def setUp(self) -> None:
)


class TestSimpleTraining(_SetupMixin, unittest.TestCase):

def test_training(self):

model = get_model('static')(
self.prior
)
pre_weights = torch.clone(model.encoder[0].weight.detach())

train_simple_model(
model,
self.static_dataloader,
10
)

post_weights = torch.clone(model.encoder[0].weight.detach())

with self.assertRaises(AssertionError):
torch.testing.assert_close(
pre_weights,
post_weights
)


class TestCoupledTraining(_SetupMixin, unittest.TestCase):

def test_training(self):
Expand Down
20 changes: 13 additions & 7 deletions supirfactor_dynamical/training/train_simple_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,20 +108,26 @@ def train_simple_model(
_validation_loss.append(loss.item())
_validation_n.append(_nobs(val_x))

_validation_loss = np.average(
np.array(_validation_loss),
axis=0,
weights=np.array(_validation_n)
)
_validation_n = np.sum(_validation_n)

else:
_validation_loss = None
_validation_n = None

model_ref.append_loss(
training_loss=np.average(
np.array(_batch_losses),
axis=0,
weights=np.array(_batch_n)
),
training_n=np.sum(_batch_n),
validation_loss=np.average(
np.array(_validation_loss),
axis=0,
weights=np.array(_validation_n)
),
validation_n=np.sum(_validation_n),

validation_loss=_validation_loss,
validation_n=_validation_n
)

model_ref.current_epoch = epoch_num
Expand Down

0 comments on commit 848737e

Please sign in to comment.