diff --git a/asteroid/utils/test_utils.py b/asteroid/utils/test_utils.py new file mode 100644 index 000000000..6de8bc77d --- /dev/null +++ b/asteroid/utils/test_utils.py @@ -0,0 +1,14 @@ +import torch +from torch.utils import data + + +class DummyDataset(data.Dataset): + def __init__(self): + self.inp_dim = 10 + self.out_dim = 10 + + def __len__(self): + return 20 + + def __getitem__(self, idx): + return torch.randn(1, self.inp_dim), torch.randn(1, self.out_dim) diff --git a/tests/engine/scheduler_test.py b/tests/engine/scheduler_test.py new file mode 100644 index 000000000..88fee7814 --- /dev/null +++ b/tests/engine/scheduler_test.py @@ -0,0 +1,64 @@ +from torch import nn, optim +from torch.utils import data +from pytorch_lightning import Trainer + + +from asteroid.engine.system import System +from asteroid.utils.test_utils import DummyDataset +from asteroid.engine.schedulers import NoamScheduler, DPTNetScheduler + + +def common_setup(): + model = nn.Sequential(nn.Linear(10, 10), nn.ReLU()) + optimizer = optim.Adam(model.parameters(), lr=1e-3) + dataset = DummyDataset() + loader = data.DataLoader(dataset, batch_size=2, num_workers=4) + trainer = Trainer(max_epochs=1, fast_dev_run=True) + return model, optimizer, loader, trainer + + +def test_state_dict(): + """ Load and serialize scheduler. """ + model, optimizer, loader, trainer = common_setup() + sched = NoamScheduler(optimizer, d_model=10, warmup_steps=100) + state_dict = sched.state_dict() + sched.load_state_dict(state_dict) + state_dict_c = sched.state_dict() + assert state_dict == state_dict_c + + +def test_noam_scheduler(): + model, optimizer, loader, trainer = common_setup() + scheduler = { + "scheduler": NoamScheduler(optimizer, d_model=10, warmup_steps=100), + "interval": "batch", + } + + system = System( + model, + optimizer, + loss_func=nn.MSELoss(), + train_loader=loader, + val_loader=loader, + scheduler=scheduler, + ) + trainer.fit(system) + + +def test_dptnet_scheduler(): + model, optimizer, loader, trainer = common_setup() + + scheduler = { + "scheduler": DPTNetScheduler(optimizer, d_model=10, steps_per_epoch=6, warmup_steps=4), + "interval": "batch", + } + + system = System( + model, + optimizer, + loss_func=nn.MSELoss(), + train_loader=loader, + val_loader=loader, + scheduler=scheduler, + ) + trainer.fit(system) diff --git a/tests/engine/system_test.py b/tests/engine/system_test.py index 073377402..02768df3c 100644 --- a/tests/engine/system_test.py +++ b/tests/engine/system_test.py @@ -1,21 +1,9 @@ -import torch from torch import nn, optim from torch.utils import data from pytorch_lightning import Trainer from asteroid.engine.system import System - - -class DummyDataset(data.Dataset): - def __init__(self): - self.inp_dim = 10 - self.out_dim = 10 - - def __len__(self): - return 20 - - def __getitem__(self, idx): - return torch.randn(1, self.inp_dim), torch.randn(1, self.out_dim) +from asteroid.utils.test_utils import DummyDataset def test_system():