Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Commit

Permalink
Added an argument to the batch_callback of ModelTrainer to indicate i…
Browse files Browse the repository at this point in the history
…f it's training or evaluation. Also added a unit test for the callback.
  • Loading branch information
luisenp committed Jul 22, 2021
1 parent c666174 commit 79e758e
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 5 deletions.
12 changes: 8 additions & 4 deletions mbrl/models/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def train(
batch_callback (callable, optional): if provided, this function will be called
for every batch with the output of ``model.update()`` (during training),
and ``model.eval_score()`` (during evaluation). It will be called
with three arguments ``(epoch_index, loss/score, meta)``.
with four arguments ``(epoch_index, loss/score, meta, mode)``, where
``mode`` is one of ``"train"`` or ``"eval"``, indicating if the callback
was called during training or evaluation.
Returns:
(tuple of two list(float)): the history of training losses and validation losses.
Expand Down Expand Up @@ -141,7 +143,7 @@ def train(
meta = None
batch_losses.append(loss)
if batch_callback_epoch:
batch_callback_epoch(loss, meta)
batch_callback_epoch(loss, meta, "train")
total_avg_loss = np.mean(batch_losses).mean().item()
training_losses.append(total_avg_loss)

Expand Down Expand Up @@ -207,7 +209,9 @@ def evaluate(
dataset (bool): the transition iterator to use.
batch_callback (callable, optional): if provided, this function will be called
for every batch with the output of ``model.eval_score()`` (the score will
be passed as a float, reduced using mean()).
be passed as a float, reduced using mean()). It will be called
with four arguments ``(epoch_index, loss/score, meta, mode)``, where
``mode`` is the string ``"eval"``.
Returns:
(tensor): The average score of the model over the dataset (and for ensembles, per
Expand All @@ -229,7 +233,7 @@ def evaluate(
meta = None
batch_scores_list.append(batch_score)
if batch_callback:
batch_callback(batch_score.mean(), meta)
batch_callback(batch_score.mean(), meta, "eval")
batch_scores = torch.cat(batch_scores_list, dim=batch_scores_list[0].ndim - 2)

if isinstance(dataset, BootstrapIterator):
Expand Down
47 changes: 46 additions & 1 deletion tests/core/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import collections
import functools

import numpy as np
Expand All @@ -11,7 +12,9 @@
import torch.nn as nn

import mbrl.models
import mbrl.util.replay_buffer
from mbrl.env.termination_fns import no_termination
from mbrl.types import TransitionBatch

_DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -298,6 +301,7 @@ def test_model_env_expectation_fixed():
class DummyModel(mbrl.models.Model):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.ones(1))
self.device = torch.device(_DEVICE)
self.to(self.device)

Expand All @@ -312,9 +316,12 @@ def sample(self, x, deterministic=False, rng=None):
return (self.forward(x),)

def loss(self, _input, target=None):
pass
return 0.0 * self.param, {"loss": 0}

def eval_score(self, _input, target=None):
return torch.zeros_like(_input), {"score": 0}

def set_elite(self, _indices):
pass


Expand All @@ -339,3 +346,41 @@ def test_model_env_evaluate_action_sequences():
num_particles=num_particles,
)
assert torch.allclose(expected_returns, returns)


def test_model_trainer_batch_callback():
model = DummyModel()
wrapper = mbrl.models.OneDTransitionRewardModel(model, target_is_delta=False)
trainer = mbrl.models.ModelTrainer(wrapper)
num_batches = 10
dummy_data = torch.zeros(num_batches, 1)
mock_dataset = mbrl.util.replay_buffer.TransitionIterator(
TransitionBatch(
dummy_data,
dummy_data,
dummy_data,
dummy_data.squeeze(1),
dummy_data.squeeze(1),
),
1,
)

train_counter = collections.Counter()
val_counter = collections.Counter()

def batch_callback(epoch, val, meta, mode):
assert mode in ["train", "eval"]
if mode == "train":
assert "loss" in meta
train_counter[epoch] += 1
else:
assert "score" in meta
val_counter[epoch] += 1

num_epochs = 20
trainer.train(mock_dataset, num_epochs=num_epochs, batch_callback=batch_callback)

for counter in [train_counter, val_counter]:
assert set(counter.keys()) == set(range(num_epochs))
for i in range(num_epochs):
assert counter[i] == num_batches

0 comments on commit 79e758e

Please sign in to comment.