diff --git a/mbrl/models/basic_ensemble.py b/mbrl/models/basic_ensemble.py index e6448013..b15d979b 100644 --- a/mbrl/models/basic_ensemble.py +++ b/mbrl/models/basic_ensemble.py @@ -3,14 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import warnings -from typing import Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union, cast import hydra import omegaconf import torch import torch.nn as nn -from .model import Ensemble +from .model import _NO_META_WARNING_MSG, Ensemble class BasicEnsemble(Ensemble): @@ -193,9 +193,12 @@ def loss( # type: ignore self, model_ins: Sequence[torch.Tensor], targets: Optional[Sequence[torch.Tensor]] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Computes average loss over the losses of all members of the ensemble. + Returns a dictionary with metadata for all models, indexed as + meta["model_i"] = meta_for_model_i + Args: model_ins (sequence of tensors): one input for each model in the ensemble. targets (sequence of tensors): one target for each model in the ensemble. @@ -205,15 +208,25 @@ def loss( # type: ignore """ assert targets is not None avg_ensemble_loss: torch.Tensor = 0.0 + ensemble_meta = {} for i, model in enumerate(self.members): model.train() - loss = model.loss(model_ins[i], targets[i]) + loss_and_maybe_meta = model.loss(model_ins[i], targets[i]) + if isinstance(loss_and_maybe_meta, tuple): + loss = cast(torch.Tensor, loss_and_maybe_meta[0]) + meta = cast(Dict[str, Any], loss_and_maybe_meta[1]) + else: + # TODO remove in v0.2.0 + warnings.warn(_NO_META_WARNING_MSG) + loss = cast(torch.Tensor, loss_and_maybe_meta) + meta = None + ensemble_meta[f"model_{i}"] = meta avg_ensemble_loss += loss - return avg_ensemble_loss / len(self.members) + return avg_ensemble_loss / len(self.members), ensemble_meta def eval_score( # type: ignore self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Computes the average score over all members given input/target. The input and target tensors are replicated once for each model in the ensemble. @@ -231,14 +244,25 @@ def eval_score( # type: ignore with torch.no_grad(): scores = [] + ensemble_meta = {} for i, model in enumerate(self.members): model.eval() - score = model.eval_score(inputs[i], targets[i]) + score_and_maybe_meta = model.eval_score(inputs[i], targets[i]) + if isinstance(score_and_maybe_meta, tuple): + score = cast(torch.Tensor, score_and_maybe_meta[0]) + meta = cast(Dict[str, Any], score_and_maybe_meta[1]) + else: + # TODO remove in v0.2.0 + warnings.warn(_NO_META_WARNING_MSG) + score = cast(torch.Tensor, score_and_maybe_meta) + meta = None + ensemble_meta[f"model_{i}"] = meta + if score.ndim == 3: assert score.shape[0] == 1 score = score[0] scores.append(score) - return torch.stack(scores) + return torch.stack(scores), ensemble_meta def reset( # type: ignore self, x: torch.Tensor, rng: Optional[torch.Generator] = None diff --git a/mbrl/models/gaussian_mlp.py b/mbrl/models/gaussian_mlp.py index 29d4f69e..2eb3ccfb 100644 --- a/mbrl/models/gaussian_mlp.py +++ b/mbrl/models/gaussian_mlp.py @@ -5,7 +5,7 @@ import pathlib import pickle import warnings -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import nn as nn @@ -295,12 +295,14 @@ def loss( self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Computes Gaussian NLL loss. It also includes terms for ``max_logvar`` and ``min_logvar`` with small weights, with positive and negative signs, respectively. + This function returns no metadata, so the second output is set to an empty dict. + Args: model_in (tensor): input tensor. The shape must be ``E x B x Id``, or ``B x Id`` where ``E``, ``B`` and ``Id`` represent ensemble size, batch size, and input @@ -315,19 +317,21 @@ def loss( the average over all models. """ if self.deterministic: - return self._mse_loss(model_in, target) + return self._mse_loss(model_in, target), {} else: - return self._nll_loss(model_in, target) + return self._nll_loss(model_in, target), {} def eval_score( # type: ignore self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Computes the squared error for the model over the given input/target. When model is not an ensemble, this is equivalent to `F.mse_loss(model(model_in, target), reduction="none")`. If the model is ensemble, then return is batched over the model dimension. + This function returns no metadata, so the second output is set to an empty dict. + Args: model_in (tensor): input tensor. The shape must be ``B x Id``, where `B`` and ``Id`` batch size, and input dimension, respectively. @@ -341,7 +345,7 @@ def eval_score( # type: ignore with torch.no_grad(): pred_mean, _ = self.forward(model_in, use_propagation=False) target = target.repeat((self.num_members, 1, 1)) - return F.mse_loss(pred_mean, target, reduction="none") + return F.mse_loss(pred_mean, target, reduction="none"), {} def reset( # type: ignore self, x: torch.Tensor, rng: Optional[torch.Generator] = None diff --git a/mbrl/models/model.py b/mbrl/models/model.py index e8e93eba..6825f2a7 100644 --- a/mbrl/models/model.py +++ b/mbrl/models/model.py @@ -4,13 +4,25 @@ # LICENSE file in the root directory of this source tree. import abc import pathlib -from typing import Optional, Sequence, Tuple, Union, cast +import warnings +from typing import Any, Dict, Optional, Sequence, Tuple, Union, cast import torch from torch import nn as nn from mbrl.types import ModelInput +# TODO: these are temporary, eventually it will be tuple(tensor, dict), keeping this +# for back-compatibility with v0.1.x, and will be removed in v0.2.0 +LossOutput = Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]] +UpdateOutput = Union[float, Tuple[float, Dict[str, Any]]] + + +_NO_META_WARNING_MSG = ( + "Starting in version v0.2.0, `model.loss()`, model.update(), and model.eval_score() " + "must all return a tuple with (loss, metadata)." +) + # --------------------------------------------------------------------------- # ABSTRACT MODEL CLASS @@ -92,7 +104,7 @@ def loss( self, model_in: ModelInput, target: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> LossOutput: """Computes a loss that can be used to update the model using backpropagation. Args: @@ -101,13 +113,16 @@ def loss( cannot be computed from ``model_in``. Returns: - (tensor): a loss tensor. + (tuple of tensor and optional dict): the loss tensor and, optionally, + any additional metadata computed by the model, + as a dictionary from strings to objects with metadata computed by + the model (e.g., reconstruction, entropy) that will be used for logging. """ @abc.abstractmethod def eval_score( self, model_in: ModelInput, target: Optional[torch.Tensor] = None - ) -> torch.Tensor: + ) -> LossOutput: """Computes an evaluation score for the model over the given input/target. This method should compute a non-reduced score for the model, intended mostly for @@ -127,7 +142,9 @@ def eval_score( cannot be computed from ``model_in``. Returns: - (tensor): a non-reduced tensor score. + (tuple of tensor and optional dict): a non-reduced tensor score, and a dictionary + from strings to objects with metadata computed by the model + (e.g., reconstructions, entropy, etc.) that will be used for logging. """ def update( @@ -135,7 +152,7 @@ def update( model_in: ModelInput, optimizer: torch.optim.Optimizer, target: Optional[torch.Tensor] = None, - ) -> float: + ) -> UpdateOutput: """Updates the model using backpropagation with given input and target tensors. Provides a basic update function, following the steps below: @@ -155,15 +172,34 @@ def update( Returns: (float): the numeric value of the computed loss. - + (dict): any additional metadata dictionary computed by :meth:`loss`. """ optimizer = cast(torch.optim.Optimizer, optimizer) self.train() optimizer.zero_grad() - loss = self.loss(model_in, target) - loss.backward() - optimizer.step() - return loss.item() + loss_and_maybe_meta = self.loss(model_in, target) + if isinstance(loss_and_maybe_meta, tuple): + # TODO - v0.2.0 remove this back-compatibility logic + loss = cast(torch.Tensor, loss_and_maybe_meta[0]) + meta = cast(Dict[str, Any], loss_and_maybe_meta[1]) + loss.backward() + + if meta is not None: + with torch.no_grad(): + grad_norm = 0.0 + for p in list( + filter(lambda p: p.grad is not None, self.parameters()) + ): + grad_norm += p.grad.data.norm(2).item() ** 2 + meta["grad_norm"] = grad_norm + optimizer.step() + return loss.item(), meta + + else: + warnings.warn(_NO_META_WARNING_MSG) + loss_and_maybe_meta.backward() + optimizer.step() + return loss_and_maybe_meta.item() def __len__(self): return 1 @@ -239,6 +275,7 @@ def forward(self, x: ModelInput, **kwargs) -> Tuple[torch.Tensor, ...]: """ pass + # TODO this and eval_score are no longer necessary @abc.abstractmethod def loss( self, diff --git a/mbrl/models/model_trainer.py b/mbrl/models/model_trainer.py index 6ad117c2..a19fb2af 100644 --- a/mbrl/models/model_trainer.py +++ b/mbrl/models/model_trainer.py @@ -3,8 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import copy +import functools import itertools -from typing import Callable, Dict, List, Optional, Tuple +import warnings +from typing import Callable, Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -13,7 +15,7 @@ from mbrl.util.logger import Logger from mbrl.util.replay_buffer import BootstrapIterator, TransitionIterator -from .model import Model +from .model import _NO_META_WARNING_MSG, Model MODEL_LOG_FORMAT = [ ("train_iteration", "I", "int"), @@ -71,6 +73,7 @@ def train( patience: Optional[int] = None, improvement_threshold: float = 0.01, callback: Optional[Callable] = None, + batch_callback: Optional[Callable] = None, ) -> Tuple[List[float], List[float]]: """Trains the model for some number of epochs. @@ -104,6 +107,12 @@ def train( - validation score (for ensembles, factored per member) - best validation score so far + batch_callback (callable, optional): if provided, this function will be called + for every batch with the output of ``model.update()`` (during training), + and ``model.eval_score()`` (during evaluation). It will be called + with four arguments ``(epoch_index, loss/score, meta, mode)``, where + ``mode`` is one of ``"train"`` or ``"eval"``, indicating if the callback + was called during training or evaluation. Returns: (tuple of two list(float)): the history of training losses and validation losses. @@ -117,14 +126,30 @@ def train( epochs_since_update = 0 best_val_score = self.evaluate(eval_dataset) for epoch in epoch_iter: + if batch_callback: + batch_callback_epoch = functools.partial(batch_callback, epoch) + else: + batch_callback_epoch = None batch_losses: List[float] = [] for batch in dataset_train: - loss = self.model.update(batch, self.optimizer) + loss_and_maybe_meta = self.model.update(batch, self.optimizer) + if isinstance(loss_and_maybe_meta, tuple): + loss = cast(float, loss_and_maybe_meta[0]) + meta = cast(Dict, loss_and_maybe_meta[1]) + else: + # TODO remove this if in v0.2.0 + warnings.warn(_NO_META_WARNING_MSG) + loss = cast(float, loss_and_maybe_meta) + meta = None batch_losses.append(loss) + if batch_callback_epoch: + batch_callback_epoch(loss, meta, "train") total_avg_loss = np.mean(batch_losses).mean().item() training_losses.append(total_avg_loss) - eval_score = self.evaluate(eval_dataset) + eval_score = self.evaluate( + eval_dataset, batch_callback=batch_callback_epoch + ) val_scores.append(eval_score.mean().item()) maybe_best_weights = self.maybe_get_best_weights( @@ -171,7 +196,9 @@ def train( self._train_iteration += 1 return training_losses, val_scores - def evaluate(self, dataset: TransitionIterator) -> torch.Tensor: + def evaluate( + self, dataset: TransitionIterator, batch_callback: Optional[Callable] = None + ) -> torch.Tensor: """Evaluates the model on the validation dataset. Iterates over the dataset, one batch at a time, and calls @@ -180,6 +207,11 @@ def evaluate(self, dataset: TransitionIterator) -> torch.Tensor: Args: dataset (bool): the transition iterator to use. + batch_callback (callable, optional): if provided, this function will be called + for every batch with the output of ``model.eval_score()`` (the score will + be passed as a float, reduced using mean()). It will be called + with four arguments ``(epoch_index, loss/score, meta, mode)``, where + ``mode`` is the string ``"eval"``. Returns: (tensor): The average score of the model over the dataset (and for ensembles, per @@ -190,15 +222,25 @@ def evaluate(self, dataset: TransitionIterator) -> torch.Tensor: batch_scores_list = [] for batch in dataset: - avg_batch_score = self.model.eval_score(batch) - batch_scores_list.append(avg_batch_score) - batch_scores = torch.cat(batch_scores_list, axis=batch_scores_list[0].ndim - 2) + batch_score_and_maybe_meta = self.model.eval_score(batch) + if isinstance(batch_score_and_maybe_meta, tuple): + batch_score = cast(torch.Tensor, batch_score_and_maybe_meta[0]) + meta = cast(Dict, batch_score_and_maybe_meta[1]) + else: + # TODO remove this "else" in v0.2.0 + warnings.warn(_NO_META_WARNING_MSG) + batch_score = cast(torch.Tensor, batch_score_and_maybe_meta) + meta = None + batch_scores_list.append(batch_score) + if batch_callback: + batch_callback(batch_score.mean(), meta, "eval") + batch_scores = torch.cat(batch_scores_list, dim=batch_scores_list[0].ndim - 2) if isinstance(dataset, BootstrapIterator): dataset.toggle_bootstrap() mean_axis = 1 if batch_scores.ndim == 2 else (1, 2) - batch_scores = batch_scores.mean(axis=mean_axis) + batch_scores = batch_scores.mean(dim=mean_axis) return batch_scores diff --git a/mbrl/models/one_dim_tr_model.py b/mbrl/models/one_dim_tr_model.py index f3e4995e..55938a38 100644 --- a/mbrl/models/one_dim_tr_model.py +++ b/mbrl/models/one_dim_tr_model.py @@ -14,7 +14,7 @@ import mbrl.types import mbrl.util.math -from .model import Ensemble, Model +from .model import Ensemble, LossOutput, Model, UpdateOutput MODEL_LOG_FORMAT = [ ("train_iteration", "I", "int"), @@ -173,17 +173,18 @@ def loss( self, batch: mbrl.types.TransitionBatch, target: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Evaluates the model score over a batch of transitions. + ) -> LossOutput: + """Computes the model loss over a batch of transitions. This method constructs input and targets from the information in the batch, - then calls `self.model.eval_score()` on them and returns the value. + then calls `self.model.loss()` on them and returns the value and the metadata + as returned by the model. Args: batch (transition batch): a batch of transition to train the model. Returns: - (tensor): as returned by `model.eval_score().` + (tensor and optional dict): as returned by `model.loss().` """ assert target is None model_in, target = self._get_model_input_and_target_from_batch(batch) @@ -194,12 +195,15 @@ def update( batch: mbrl.types.TransitionBatch, optimizer: torch.optim.Optimizer, target: Optional[torch.Tensor] = None, - ) -> float: + ) -> UpdateOutput: """Updates the model given a batch of transitions and an optimizer. Args: batch (transition batch): a batch of transition to train the model. optimizer (torch optimizer): the optimizer to use to update the model. + + Returns: + (tensor and optional dict): as returned by `model.loss().` """ assert target is None model_in, target = self._get_model_input_and_target_from_batch(batch) @@ -209,7 +213,7 @@ def eval_score( self, batch: mbrl.types.TransitionBatch, target: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> LossOutput: """Evaluates the model score over a batch of transitions. This method constructs input and targets from the information in the batch, diff --git a/notebooks/fit_gaussian_mlp_ensemble_1d.ipynb b/notebooks/fit_gaussian_mlp_ensemble_1d.ipynb index ff156b65..00c8b0d8 100644 --- a/notebooks/fit_gaussian_mlp_ensemble_1d.ipynb +++ b/notebooks/fit_gaussian_mlp_ensemble_1d.ipynb @@ -215,4 +215,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/tests/core/test_models.py b/tests/core/test_models.py index 8c4e0871..962698b8 100644 --- a/tests/core/test_models.py +++ b/tests/core/test_models.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import collections import functools import numpy as np @@ -11,7 +12,9 @@ import torch.nn as nn import mbrl.models +import mbrl.util.replay_buffer from mbrl.env.termination_fns import no_termination +from mbrl.types import TransitionBatch _DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -298,6 +301,7 @@ def test_model_env_expectation_fixed(): class DummyModel(mbrl.models.Model): def __init__(self): super().__init__() + self.param = nn.Parameter(torch.ones(1)) self.device = torch.device(_DEVICE) self.to(self.device) @@ -312,9 +316,12 @@ def sample(self, x, deterministic=False, rng=None): return (self.forward(x),) def loss(self, _input, target=None): - pass + return 0.0 * self.param, {"loss": 0} def eval_score(self, _input, target=None): + return torch.zeros_like(_input), {"score": 0} + + def set_elite(self, _indices): pass @@ -339,3 +346,41 @@ def test_model_env_evaluate_action_sequences(): num_particles=num_particles, ) assert torch.allclose(expected_returns, returns) + + +def test_model_trainer_batch_callback(): + model = DummyModel() + wrapper = mbrl.models.OneDTransitionRewardModel(model, target_is_delta=False) + trainer = mbrl.models.ModelTrainer(wrapper) + num_batches = 10 + dummy_data = torch.zeros(num_batches, 1) + mock_dataset = mbrl.util.replay_buffer.TransitionIterator( + TransitionBatch( + dummy_data, + dummy_data, + dummy_data, + dummy_data.squeeze(1), + dummy_data.squeeze(1), + ), + 1, + ) + + train_counter = collections.Counter() + val_counter = collections.Counter() + + def batch_callback(epoch, val, meta, mode): + assert mode in ["train", "eval"] + if mode == "train": + assert "loss" in meta + train_counter[epoch] += 1 + else: + assert "score" in meta + val_counter[epoch] += 1 + + num_epochs = 20 + trainer.train(mock_dataset, num_epochs=num_epochs, batch_callback=batch_callback) + + for counter in [train_counter, val_counter]: + assert set(counter.keys()) == set(range(num_epochs)) + for i in range(num_epochs): + assert counter[i] == num_batches