From 79e758eb16d3d06e481c918f7954d9c4cbf1e7a7 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Thu, 22 Jul 2021 11:29:59 -0400 Subject: [PATCH] Added an argument to the batch_callback of ModelTrainer to indicate if it's training or evaluation. Also added a unit test for the callback. --- mbrl/models/model_trainer.py | 12 ++++++--- tests/core/test_models.py | 47 +++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/mbrl/models/model_trainer.py b/mbrl/models/model_trainer.py index 93b86346..a19fb2af 100644 --- a/mbrl/models/model_trainer.py +++ b/mbrl/models/model_trainer.py @@ -110,7 +110,9 @@ def train( batch_callback (callable, optional): if provided, this function will be called for every batch with the output of ``model.update()`` (during training), and ``model.eval_score()`` (during evaluation). It will be called - with three arguments ``(epoch_index, loss/score, meta)``. + with four arguments ``(epoch_index, loss/score, meta, mode)``, where + ``mode`` is one of ``"train"`` or ``"eval"``, indicating if the callback + was called during training or evaluation. Returns: (tuple of two list(float)): the history of training losses and validation losses. @@ -141,7 +143,7 @@ def train( meta = None batch_losses.append(loss) if batch_callback_epoch: - batch_callback_epoch(loss, meta) + batch_callback_epoch(loss, meta, "train") total_avg_loss = np.mean(batch_losses).mean().item() training_losses.append(total_avg_loss) @@ -207,7 +209,9 @@ def evaluate( dataset (bool): the transition iterator to use. batch_callback (callable, optional): if provided, this function will be called for every batch with the output of ``model.eval_score()`` (the score will - be passed as a float, reduced using mean()). + be passed as a float, reduced using mean()). It will be called + with four arguments ``(epoch_index, loss/score, meta, mode)``, where + ``mode`` is the string ``"eval"``. Returns: (tensor): The average score of the model over the dataset (and for ensembles, per @@ -229,7 +233,7 @@ def evaluate( meta = None batch_scores_list.append(batch_score) if batch_callback: - batch_callback(batch_score.mean(), meta) + batch_callback(batch_score.mean(), meta, "eval") batch_scores = torch.cat(batch_scores_list, dim=batch_scores_list[0].ndim - 2) if isinstance(dataset, BootstrapIterator): diff --git a/tests/core/test_models.py b/tests/core/test_models.py index 8c4e0871..962698b8 100644 --- a/tests/core/test_models.py +++ b/tests/core/test_models.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import collections import functools import numpy as np @@ -11,7 +12,9 @@ import torch.nn as nn import mbrl.models +import mbrl.util.replay_buffer from mbrl.env.termination_fns import no_termination +from mbrl.types import TransitionBatch _DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -298,6 +301,7 @@ def test_model_env_expectation_fixed(): class DummyModel(mbrl.models.Model): def __init__(self): super().__init__() + self.param = nn.Parameter(torch.ones(1)) self.device = torch.device(_DEVICE) self.to(self.device) @@ -312,9 +316,12 @@ def sample(self, x, deterministic=False, rng=None): return (self.forward(x),) def loss(self, _input, target=None): - pass + return 0.0 * self.param, {"loss": 0} def eval_score(self, _input, target=None): + return torch.zeros_like(_input), {"score": 0} + + def set_elite(self, _indices): pass @@ -339,3 +346,41 @@ def test_model_env_evaluate_action_sequences(): num_particles=num_particles, ) assert torch.allclose(expected_returns, returns) + + +def test_model_trainer_batch_callback(): + model = DummyModel() + wrapper = mbrl.models.OneDTransitionRewardModel(model, target_is_delta=False) + trainer = mbrl.models.ModelTrainer(wrapper) + num_batches = 10 + dummy_data = torch.zeros(num_batches, 1) + mock_dataset = mbrl.util.replay_buffer.TransitionIterator( + TransitionBatch( + dummy_data, + dummy_data, + dummy_data, + dummy_data.squeeze(1), + dummy_data.squeeze(1), + ), + 1, + ) + + train_counter = collections.Counter() + val_counter = collections.Counter() + + def batch_callback(epoch, val, meta, mode): + assert mode in ["train", "eval"] + if mode == "train": + assert "loss" in meta + train_counter[epoch] += 1 + else: + assert "score" in meta + val_counter[epoch] += 1 + + num_epochs = 20 + trainer.train(mock_dataset, num_epochs=num_epochs, batch_callback=batch_callback) + + for counter in [train_counter, val_counter]: + assert set(counter.keys()) == set(range(num_epochs)) + for i in range(num_epochs): + assert counter[i] == num_batches