From c49066f7ce902798cdda5f6631223f1eb1e3ba74 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Tue, 20 Jul 2021 15:19:29 -0400 Subject: [PATCH 01/10] Changed signature of Model.loss(), Model.eval_score(), Model.update() to include also return a dict with metadata. --- mbrl/models/model.py | 49 ++++++++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/mbrl/models/model.py b/mbrl/models/model.py index e8e93eba..8d375632 100644 --- a/mbrl/models/model.py +++ b/mbrl/models/model.py @@ -4,13 +4,25 @@ # LICENSE file in the root directory of this source tree. import abc import pathlib -from typing import Optional, Sequence, Tuple, Union, cast +import warnings +from typing import Any, Dict, Optional, Sequence, Tuple, Union, cast import torch from torch import nn as nn from mbrl.types import ModelInput +# TODO: these are temporary, eventually it will be tuple(tensor, dict), keeping this +# for back-compatibility with v0.1.x, and will be removed in v0.2.0 +LossOutput = Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]] +UpdateOutput = Union[float, Tuple[float, Dict[str, Any]]] + + +_NO_META_WARNING_MSG = ( + "Starting in version v0.2.0, `model.loss()`, model.update(), and model.eval_score() " + "must all return a tuple with (loss, metadata)." +) + # --------------------------------------------------------------------------- # ABSTRACT MODEL CLASS @@ -92,7 +104,7 @@ def loss( self, model_in: ModelInput, target: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> LossOutput: """Computes a loss that can be used to update the model using backpropagation. Args: @@ -101,13 +113,16 @@ def loss( cannot be computed from ``model_in``. Returns: - (tensor): a loss tensor. + (tuple of tensor and optional dict): the loss tensor and, optionally, + any additional metadata computed by the model, + as a dictionary from strings to objects with metadata computed by + the model (e.g., reconstruction, entropy) that will be used for logging. """ @abc.abstractmethod def eval_score( self, model_in: ModelInput, target: Optional[torch.Tensor] = None - ) -> torch.Tensor: + ) -> LossOutput: """Computes an evaluation score for the model over the given input/target. This method should compute a non-reduced score for the model, intended mostly for @@ -127,7 +142,9 @@ def eval_score( cannot be computed from ``model_in``. Returns: - (tensor): a non-reduced tensor score. + (tuple of tensor and optional dict): a non-reduced tensor score, and a dictionary + from strings to objects with metadata computed by the model + (e.g., reconstructions, entropy, etc.) that will be used for logging. """ def update( @@ -135,7 +152,7 @@ def update( model_in: ModelInput, optimizer: torch.optim.Optimizer, target: Optional[torch.Tensor] = None, - ) -> float: + ) -> UpdateOutput: """Updates the model using backpropagation with given input and target tensors. Provides a basic update function, following the steps below: @@ -155,15 +172,25 @@ def update( Returns: (float): the numeric value of the computed loss. - + (dict): any additional metadata dictionary computed by :meth:`loss`. """ optimizer = cast(torch.optim.Optimizer, optimizer) self.train() optimizer.zero_grad() - loss = self.loss(model_in, target) - loss.backward() - optimizer.step() - return loss.item() + loss_and_maybe_meta = self.loss(model_in, target) + if isinstance(loss_and_maybe_meta, tuple): + # TODO - v0.2.0 remove this back-compatibility logic + loss = cast(torch.Tensor, loss_and_maybe_meta[0]) + meta = cast(Optional[Dict[str, Any]], loss_and_maybe_meta[1]) + loss.backward() + optimizer.step() + return loss.item(), meta + + else: + warnings.warn(_NO_META_WARNING_MSG) + loss_and_maybe_meta.backward() + optimizer.step() + return loss_and_maybe_meta.item() def __len__(self): return 1 From 5be4b3850a259885749d19f3554e359f9cbfc4c0 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Tue, 20 Jul 2021 15:19:59 -0400 Subject: [PATCH 02/10] Updated GaussianMLP to match new Model API. --- mbrl/models/gaussian_mlp.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/mbrl/models/gaussian_mlp.py b/mbrl/models/gaussian_mlp.py index 29d4f69e..94406d5d 100644 --- a/mbrl/models/gaussian_mlp.py +++ b/mbrl/models/gaussian_mlp.py @@ -5,7 +5,7 @@ import pathlib import pickle import warnings -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from torch import nn as nn @@ -295,12 +295,14 @@ def loss( self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: """Computes Gaussian NLL loss. It also includes terms for ``max_logvar`` and ``min_logvar`` with small weights, with positive and negative signs, respectively. + This function returns no metadata, so the second output is set to an empty dict. + Args: model_in (tensor): input tensor. The shape must be ``E x B x Id``, or ``B x Id`` where ``E``, ``B`` and ``Id`` represent ensemble size, batch size, and input @@ -315,19 +317,21 @@ def loss( the average over all models. """ if self.deterministic: - return self._mse_loss(model_in, target) + return self._mse_loss(model_in, target), {} else: - return self._nll_loss(model_in, target) + return self._nll_loss(model_in, target), {} def eval_score( # type: ignore self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: """Computes the squared error for the model over the given input/target. When model is not an ensemble, this is equivalent to `F.mse_loss(model(model_in, target), reduction="none")`. If the model is ensemble, then return is batched over the model dimension. + This function returns no metadata, so the second output is set to an empty dict. + Args: model_in (tensor): input tensor. The shape must be ``B x Id``, where `B`` and ``Id`` batch size, and input dimension, respectively. @@ -341,7 +345,7 @@ def eval_score( # type: ignore with torch.no_grad(): pred_mean, _ = self.forward(model_in, use_propagation=False) target = target.repeat((self.num_members, 1, 1)) - return F.mse_loss(pred_mean, target, reduction="none") + return F.mse_loss(pred_mean, target, reduction="none"), {} def reset( # type: ignore self, x: torch.Tensor, rng: Optional[torch.Generator] = None From 61f1a7e972dafce675798c66848b8dc8ea31fe66 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Tue, 20 Jul 2021 15:20:22 -0400 Subject: [PATCH 03/10] Updated BasicEnsemble to match new Model API. --- mbrl/models/basic_ensemble.py | 36 +++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/mbrl/models/basic_ensemble.py b/mbrl/models/basic_ensemble.py index e6448013..f1d4c42c 100644 --- a/mbrl/models/basic_ensemble.py +++ b/mbrl/models/basic_ensemble.py @@ -3,14 +3,14 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import warnings -from typing import Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union, cast import hydra import omegaconf import torch import torch.nn as nn -from .model import Ensemble +from .model import _NO_META_WARNING_MSG, Ensemble class BasicEnsemble(Ensemble): @@ -193,9 +193,12 @@ def loss( # type: ignore self, model_ins: Sequence[torch.Tensor], targets: Optional[Sequence[torch.Tensor]] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: """Computes average loss over the losses of all members of the ensemble. + Returns a dictionary with metadata for all models, indexed as + meta["model_i"] = meta_for_model_i + Args: model_ins (sequence of tensors): one input for each model in the ensemble. targets (sequence of tensors): one target for each model in the ensemble. @@ -205,11 +208,21 @@ def loss( # type: ignore """ assert targets is not None avg_ensemble_loss: torch.Tensor = 0.0 + ensemble_meta = {} for i, model in enumerate(self.members): model.train() - loss = model.loss(model_ins[i], targets[i]) + loss_and_maybe_meta = model.loss(model_ins[i], targets[i]) + if isinstance(loss_and_maybe_meta, tuple): + loss = cast(torch.Tensor, loss_and_maybe_meta[0]) + meta = cast(Dict[str, Any], loss_and_maybe_meta[1]) + else: + # TODO remove in v0.2.0 + warnings.warn(_NO_META_WARNING_MSG) + loss = cast(torch.Tensor, loss_and_maybe_meta) + meta = None + ensemble_meta[f"model_{i}"] = meta avg_ensemble_loss += loss - return avg_ensemble_loss / len(self.members) + return avg_ensemble_loss / len(self.members), ensemble_meta def eval_score( # type: ignore self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None @@ -231,9 +244,20 @@ def eval_score( # type: ignore with torch.no_grad(): scores = [] + ensemble_meta = {} for i, model in enumerate(self.members): model.eval() - score = model.eval_score(inputs[i], targets[i]) + score_and_maybe_meta = model.eval_score(inputs[i], targets[i]) + if isinstance(score_and_maybe_meta, tuple): + score = cast(torch.Tensor, score_and_maybe_meta[0]) + meta = cast(Dict[str, Any], score_and_maybe_meta[1]) + else: + # TODO remove in v0.2.0 + warnings.warn(_NO_META_WARNING_MSG) + score = cast(torch.Tensor, score_and_maybe_meta) + meta = None + ensemble_meta[f"model_{i}"] = meta + if score.ndim == 3: assert score.shape[0] == 1 score = score[0] From 954861b80adcb75f15330e9238b1f549d5876293 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Tue, 20 Jul 2021 15:20:55 -0400 Subject: [PATCH 04/10] Updated OneDTransitionRewardModel to match new Model API. --- mbrl/models/one_dim_tr_model.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/mbrl/models/one_dim_tr_model.py b/mbrl/models/one_dim_tr_model.py index f3e4995e..55938a38 100644 --- a/mbrl/models/one_dim_tr_model.py +++ b/mbrl/models/one_dim_tr_model.py @@ -14,7 +14,7 @@ import mbrl.types import mbrl.util.math -from .model import Ensemble, Model +from .model import Ensemble, LossOutput, Model, UpdateOutput MODEL_LOG_FORMAT = [ ("train_iteration", "I", "int"), @@ -173,17 +173,18 @@ def loss( self, batch: mbrl.types.TransitionBatch, target: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Evaluates the model score over a batch of transitions. + ) -> LossOutput: + """Computes the model loss over a batch of transitions. This method constructs input and targets from the information in the batch, - then calls `self.model.eval_score()` on them and returns the value. + then calls `self.model.loss()` on them and returns the value and the metadata + as returned by the model. Args: batch (transition batch): a batch of transition to train the model. Returns: - (tensor): as returned by `model.eval_score().` + (tensor and optional dict): as returned by `model.loss().` """ assert target is None model_in, target = self._get_model_input_and_target_from_batch(batch) @@ -194,12 +195,15 @@ def update( batch: mbrl.types.TransitionBatch, optimizer: torch.optim.Optimizer, target: Optional[torch.Tensor] = None, - ) -> float: + ) -> UpdateOutput: """Updates the model given a batch of transitions and an optimizer. Args: batch (transition batch): a batch of transition to train the model. optimizer (torch optimizer): the optimizer to use to update the model. + + Returns: + (tensor and optional dict): as returned by `model.loss().` """ assert target is None model_in, target = self._get_model_input_and_target_from_batch(batch) @@ -209,7 +213,7 @@ def eval_score( self, batch: mbrl.types.TransitionBatch, target: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> LossOutput: """Evaluates the model score over a batch of transitions. This method constructs input and targets from the information in the batch, From 802627021fb336fe2c03ea98f1a4fae0395fe09b Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Tue, 20 Jul 2021 15:28:44 -0400 Subject: [PATCH 05/10] Updated ModelTrainer to use metadata returned by models and pass it to an optional batch-level callback. --- mbrl/models/model_trainer.py | 43 ++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/mbrl/models/model_trainer.py b/mbrl/models/model_trainer.py index 6ad117c2..a490ab8c 100644 --- a/mbrl/models/model_trainer.py +++ b/mbrl/models/model_trainer.py @@ -4,7 +4,8 @@ # LICENSE file in the root directory of this source tree. import copy import itertools -from typing import Callable, Dict, List, Optional, Tuple +import warnings +from typing import Callable, Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -13,7 +14,7 @@ from mbrl.util.logger import Logger from mbrl.util.replay_buffer import BootstrapIterator, TransitionIterator -from .model import Model +from .model import _NO_META_WARNING_MSG, Model MODEL_LOG_FORMAT = [ ("train_iteration", "I", "int"), @@ -71,6 +72,7 @@ def train( patience: Optional[int] = None, improvement_threshold: float = 0.01, callback: Optional[Callable] = None, + batch_callback: Optional[Callable] = None, ) -> Tuple[List[float], List[float]]: """Trains the model for some number of epochs. @@ -104,6 +106,8 @@ def train( - validation score (for ensembles, factored per member) - best validation score so far + batch_callback (callable, optional): if provided, this function will be called + for every batch with the output of ``model.update()``. Returns: (tuple of two list(float)): the history of training losses and validation losses. @@ -119,12 +123,20 @@ def train( for epoch in epoch_iter: batch_losses: List[float] = [] for batch in dataset_train: - loss = self.model.update(batch, self.optimizer) + loss_and_maybe_meta = self.model.update(batch, self.optimizer) + if isinstance(loss_and_maybe_meta, tuple): + loss = cast(float, loss_and_maybe_meta[0]) + else: + # TODO remove this if in v0.2.0 + warnings.warn(_NO_META_WARNING_MSG) + loss = cast(float, loss_and_maybe_meta) batch_losses.append(loss) + if batch_callback: + batch_callback(loss_and_maybe_meta) total_avg_loss = np.mean(batch_losses).mean().item() training_losses.append(total_avg_loss) - eval_score = self.evaluate(eval_dataset) + eval_score = self.evaluate(eval_dataset, batch_callback=batch_callback) val_scores.append(eval_score.mean().item()) maybe_best_weights = self.maybe_get_best_weights( @@ -171,7 +183,9 @@ def train( self._train_iteration += 1 return training_losses, val_scores - def evaluate(self, dataset: TransitionIterator) -> torch.Tensor: + def evaluate( + self, dataset: TransitionIterator, batch_callback: Optional[Callable] = None + ) -> torch.Tensor: """Evaluates the model on the validation dataset. Iterates over the dataset, one batch at a time, and calls @@ -180,6 +194,9 @@ def evaluate(self, dataset: TransitionIterator) -> torch.Tensor: Args: dataset (bool): the transition iterator to use. + batch_callback (callable, optional): if provided, this function will be called + for every batch with the output of ``model.eval_score()`` (the score will + be passed as a float, reduced using mean()). Returns: (tensor): The average score of the model over the dataset (and for ensembles, per @@ -190,15 +207,23 @@ def evaluate(self, dataset: TransitionIterator) -> torch.Tensor: batch_scores_list = [] for batch in dataset: - avg_batch_score = self.model.eval_score(batch) - batch_scores_list.append(avg_batch_score) - batch_scores = torch.cat(batch_scores_list, axis=batch_scores_list[0].ndim - 2) + batch_score_and_maybe_meta = self.model.eval_score(batch) + if isinstance(batch_score_and_maybe_meta, tuple): + batch_score = cast(torch.Tensor, batch_score_and_maybe_meta[0]) + else: + # TODO remove this if in v0.2.0 + warnings.warn(_NO_META_WARNING_MSG) + batch_score = cast(torch.Tensor, batch_score_and_maybe_meta) + batch_scores_list.append(batch_score) + if batch_callback: + batch_callback() + batch_scores = torch.cat(batch_scores_list, dim=batch_scores_list[0].ndim - 2) if isinstance(dataset, BootstrapIterator): dataset.toggle_bootstrap() mean_axis = 1 if batch_scores.ndim == 2 else (1, 2) - batch_scores = batch_scores.mean(axis=mean_axis) + batch_scores = batch_scores.mean(dim=mean_axis) return batch_scores From 941b654028f34e39e8e6a46842d7b4d81585fad5 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Tue, 20 Jul 2021 16:22:56 -0400 Subject: [PATCH 06/10] Added an option to return the norm of the gradients as metadata when doing Model.update(). --- mbrl/models/model.py | 12 ++++++++++++ mbrl/models/model_trainer.py | 9 ++++++++- mbrl/models/one_dim_tr_model.py | 11 ++++++++++- notebooks/fit_gaussian_mlp_ensemble_1d.ipynb | 2 +- 4 files changed, 31 insertions(+), 3 deletions(-) diff --git a/mbrl/models/model.py b/mbrl/models/model.py index 8d375632..476c5a6a 100644 --- a/mbrl/models/model.py +++ b/mbrl/models/model.py @@ -152,6 +152,7 @@ def update( model_in: ModelInput, optimizer: torch.optim.Optimizer, target: Optional[torch.Tensor] = None, + meta_includes_grad_norm: bool = False, ) -> UpdateOutput: """Updates the model using backpropagation with given input and target tensors. @@ -169,6 +170,9 @@ def update( optimizer (torch.optimizer): the optimizer to use for the model. target (tensor or sequence of tensors): the expected output for the given inputs, if it cannot be computed from ``model_in``. + meta_includes_grad_norm (bool): if ``True`` and metadata returned by ``model.loss()`` + is not ``None``, the norm of the gradients (summed over all parameters) + will be included in the metadata dictionary. Returns: (float): the numeric value of the computed loss. @@ -183,6 +187,14 @@ def update( loss = cast(torch.Tensor, loss_and_maybe_meta[0]) meta = cast(Optional[Dict[str, Any]], loss_and_maybe_meta[1]) loss.backward() + if meta_includes_grad_norm: + with torch.no_grad(): + grad_norm: torch.Tensor = 0 # type: ignore + for p in list( + filter(lambda p: p.grad is not None, self.parameters()) + ): + grad_norm += p.grad.data.norm(2).item() ** 2 + meta["grad_norm"] = grad_norm optimizer.step() return loss.item(), meta diff --git a/mbrl/models/model_trainer.py b/mbrl/models/model_trainer.py index a490ab8c..4699e71c 100644 --- a/mbrl/models/model_trainer.py +++ b/mbrl/models/model_trainer.py @@ -72,6 +72,7 @@ def train( patience: Optional[int] = None, improvement_threshold: float = 0.01, callback: Optional[Callable] = None, + meta_includes_grad_norm: bool = False, batch_callback: Optional[Callable] = None, ) -> Tuple[List[float], List[float]]: """Trains the model for some number of epochs. @@ -106,6 +107,8 @@ def train( - validation score (for ensembles, factored per member) - best validation score so far + meta_includes_grad_norm (bool): passed as keyword arg to ``model.update()``, + which indicates if ``batch_callback`` will also receive gradient norms. batch_callback (callable, optional): if provided, this function will be called for every batch with the output of ``model.update()``. @@ -123,7 +126,11 @@ def train( for epoch in epoch_iter: batch_losses: List[float] = [] for batch in dataset_train: - loss_and_maybe_meta = self.model.update(batch, self.optimizer) + loss_and_maybe_meta = self.model.update( + batch, + self.optimizer, + meta_includes_grad_norm=meta_includes_grad_norm, + ) if isinstance(loss_and_maybe_meta, tuple): loss = cast(float, loss_and_maybe_meta[0]) else: diff --git a/mbrl/models/one_dim_tr_model.py b/mbrl/models/one_dim_tr_model.py index 55938a38..edc8199b 100644 --- a/mbrl/models/one_dim_tr_model.py +++ b/mbrl/models/one_dim_tr_model.py @@ -195,19 +195,28 @@ def update( batch: mbrl.types.TransitionBatch, optimizer: torch.optim.Optimizer, target: Optional[torch.Tensor] = None, + meta_includes_grad_norm: bool = False, ) -> UpdateOutput: """Updates the model given a batch of transitions and an optimizer. Args: batch (transition batch): a batch of transition to train the model. optimizer (torch optimizer): the optimizer to use to update the model. + meta_includes_grad_norm (bool): if ``True`` and metadata returned by ``model.loss()`` + is not ``None``, the norm of the gradients (summed over all parameters) + will be included in the metadata dictionary. Returns: (tensor and optional dict): as returned by `model.loss().` """ assert target is None model_in, target = self._get_model_input_and_target_from_batch(batch) - return self.model.update(model_in, optimizer, target=target) + return self.model.update( + model_in, + optimizer, + target=target, + meta_includes_grad_norm=meta_includes_grad_norm, + ) def eval_score( self, diff --git a/notebooks/fit_gaussian_mlp_ensemble_1d.ipynb b/notebooks/fit_gaussian_mlp_ensemble_1d.ipynb index ff156b65..00c8b0d8 100644 --- a/notebooks/fit_gaussian_mlp_ensemble_1d.ipynb +++ b/notebooks/fit_gaussian_mlp_ensemble_1d.ipynb @@ -215,4 +215,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file From f1e0a425ac636040261671fe94487a789ae619d5 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Wed, 21 Jul 2021 13:31:52 -0400 Subject: [PATCH 07/10] Fixed bug in the way ModelTrainer was passing args to batch_callback, and also added information about epoch_index and whether it's train/eval._ --- mbrl/models/model_trainer.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/mbrl/models/model_trainer.py b/mbrl/models/model_trainer.py index 4699e71c..5a2068a6 100644 --- a/mbrl/models/model_trainer.py +++ b/mbrl/models/model_trainer.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import copy +import functools import itertools import warnings from typing import Callable, Dict, List, Optional, Tuple, cast @@ -110,7 +111,9 @@ def train( meta_includes_grad_norm (bool): passed as keyword arg to ``model.update()``, which indicates if ``batch_callback`` will also receive gradient norms. batch_callback (callable, optional): if provided, this function will be called - for every batch with the output of ``model.update()``. + for every batch with the output of ``model.update()`` (during training), + and ``model.eval_score()`` (during evaluation). It will be called + with three arguments ``(epoch_index, loss/score, meta)``. Returns: (tuple of two list(float)): the history of training losses and validation losses. @@ -124,6 +127,10 @@ def train( epochs_since_update = 0 best_val_score = self.evaluate(eval_dataset) for epoch in epoch_iter: + if batch_callback: + batch_callback_epoch = functools.partial(batch_callback, epoch) + else: + batch_callback_epoch = None batch_losses: List[float] = [] for batch in dataset_train: loss_and_maybe_meta = self.model.update( @@ -133,17 +140,21 @@ def train( ) if isinstance(loss_and_maybe_meta, tuple): loss = cast(float, loss_and_maybe_meta[0]) + meta = cast(Dict, loss_and_maybe_meta[1]) else: # TODO remove this if in v0.2.0 warnings.warn(_NO_META_WARNING_MSG) loss = cast(float, loss_and_maybe_meta) + meta = None batch_losses.append(loss) - if batch_callback: - batch_callback(loss_and_maybe_meta) + if batch_callback_epoch: + batch_callback_epoch(loss, meta) total_avg_loss = np.mean(batch_losses).mean().item() training_losses.append(total_avg_loss) - eval_score = self.evaluate(eval_dataset, batch_callback=batch_callback) + eval_score = self.evaluate( + eval_dataset, batch_callback=batch_callback_epoch + ) val_scores.append(eval_score.mean().item()) maybe_best_weights = self.maybe_get_best_weights( @@ -217,13 +228,15 @@ def evaluate( batch_score_and_maybe_meta = self.model.eval_score(batch) if isinstance(batch_score_and_maybe_meta, tuple): batch_score = cast(torch.Tensor, batch_score_and_maybe_meta[0]) + meta = cast(Dict, batch_score_and_maybe_meta[1]) else: - # TODO remove this if in v0.2.0 + # TODO remove this "else" in v0.2.0 warnings.warn(_NO_META_WARNING_MSG) batch_score = cast(torch.Tensor, batch_score_and_maybe_meta) + meta = None batch_scores_list.append(batch_score) if batch_callback: - batch_callback() + batch_callback(batch_score.mean(), meta) batch_scores = torch.cat(batch_scores_list, dim=batch_scores_list[0].ndim - 2) if isinstance(dataset, BootstrapIterator): From 26e74cb9dfa388d40a362ab55c0b25e00ada0317 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Wed, 21 Jul 2021 14:00:05 -0400 Subject: [PATCH 08/10] Removed meta_include_grad_norm argument from Model.update() and ModelTrainer.train() --- mbrl/models/model.py | 7 ++----- mbrl/models/model_trainer.py | 9 +-------- mbrl/models/one_dim_tr_model.py | 11 +---------- 3 files changed, 4 insertions(+), 23 deletions(-) diff --git a/mbrl/models/model.py b/mbrl/models/model.py index 476c5a6a..8dbf1993 100644 --- a/mbrl/models/model.py +++ b/mbrl/models/model.py @@ -152,7 +152,6 @@ def update( model_in: ModelInput, optimizer: torch.optim.Optimizer, target: Optional[torch.Tensor] = None, - meta_includes_grad_norm: bool = False, ) -> UpdateOutput: """Updates the model using backpropagation with given input and target tensors. @@ -170,9 +169,6 @@ def update( optimizer (torch.optimizer): the optimizer to use for the model. target (tensor or sequence of tensors): the expected output for the given inputs, if it cannot be computed from ``model_in``. - meta_includes_grad_norm (bool): if ``True`` and metadata returned by ``model.loss()`` - is not ``None``, the norm of the gradients (summed over all parameters) - will be included in the metadata dictionary. Returns: (float): the numeric value of the computed loss. @@ -187,7 +183,8 @@ def update( loss = cast(torch.Tensor, loss_and_maybe_meta[0]) meta = cast(Optional[Dict[str, Any]], loss_and_maybe_meta[1]) loss.backward() - if meta_includes_grad_norm: + + if meta is not None: with torch.no_grad(): grad_norm: torch.Tensor = 0 # type: ignore for p in list( diff --git a/mbrl/models/model_trainer.py b/mbrl/models/model_trainer.py index 5a2068a6..93b86346 100644 --- a/mbrl/models/model_trainer.py +++ b/mbrl/models/model_trainer.py @@ -73,7 +73,6 @@ def train( patience: Optional[int] = None, improvement_threshold: float = 0.01, callback: Optional[Callable] = None, - meta_includes_grad_norm: bool = False, batch_callback: Optional[Callable] = None, ) -> Tuple[List[float], List[float]]: """Trains the model for some number of epochs. @@ -108,8 +107,6 @@ def train( - validation score (for ensembles, factored per member) - best validation score so far - meta_includes_grad_norm (bool): passed as keyword arg to ``model.update()``, - which indicates if ``batch_callback`` will also receive gradient norms. batch_callback (callable, optional): if provided, this function will be called for every batch with the output of ``model.update()`` (during training), and ``model.eval_score()`` (during evaluation). It will be called @@ -133,11 +130,7 @@ def train( batch_callback_epoch = None batch_losses: List[float] = [] for batch in dataset_train: - loss_and_maybe_meta = self.model.update( - batch, - self.optimizer, - meta_includes_grad_norm=meta_includes_grad_norm, - ) + loss_and_maybe_meta = self.model.update(batch, self.optimizer) if isinstance(loss_and_maybe_meta, tuple): loss = cast(float, loss_and_maybe_meta[0]) meta = cast(Dict, loss_and_maybe_meta[1]) diff --git a/mbrl/models/one_dim_tr_model.py b/mbrl/models/one_dim_tr_model.py index edc8199b..55938a38 100644 --- a/mbrl/models/one_dim_tr_model.py +++ b/mbrl/models/one_dim_tr_model.py @@ -195,28 +195,19 @@ def update( batch: mbrl.types.TransitionBatch, optimizer: torch.optim.Optimizer, target: Optional[torch.Tensor] = None, - meta_includes_grad_norm: bool = False, ) -> UpdateOutput: """Updates the model given a batch of transitions and an optimizer. Args: batch (transition batch): a batch of transition to train the model. optimizer (torch optimizer): the optimizer to use to update the model. - meta_includes_grad_norm (bool): if ``True`` and metadata returned by ``model.loss()`` - is not ``None``, the norm of the gradients (summed over all parameters) - will be included in the metadata dictionary. Returns: (tensor and optional dict): as returned by `model.loss().` """ assert target is None model_in, target = self._get_model_input_and_target_from_batch(batch) - return self.model.update( - model_in, - optimizer, - target=target, - meta_includes_grad_norm=meta_includes_grad_norm, - ) + return self.model.update(model_in, optimizer, target=target) def eval_score( self, From c666174f70c13404797672c0bc49165fcb94f009 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Thu, 22 Jul 2021 10:12:28 -0400 Subject: [PATCH 09/10] Added missing meta to BasicEnsemble.eval_score() and fixed some type errors. --- mbrl/models/basic_ensemble.py | 6 +++--- mbrl/models/gaussian_mlp.py | 4 ++-- mbrl/models/model.py | 5 +++-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/mbrl/models/basic_ensemble.py b/mbrl/models/basic_ensemble.py index f1d4c42c..b15d979b 100644 --- a/mbrl/models/basic_ensemble.py +++ b/mbrl/models/basic_ensemble.py @@ -193,7 +193,7 @@ def loss( # type: ignore self, model_ins: Sequence[torch.Tensor], targets: Optional[Sequence[torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Computes average loss over the losses of all members of the ensemble. Returns a dictionary with metadata for all models, indexed as @@ -226,7 +226,7 @@ def loss( # type: ignore def eval_score( # type: ignore self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Computes the average score over all members given input/target. The input and target tensors are replicated once for each model in the ensemble. @@ -262,7 +262,7 @@ def eval_score( # type: ignore assert score.shape[0] == 1 score = score[0] scores.append(score) - return torch.stack(scores) + return torch.stack(scores), ensemble_meta def reset( # type: ignore self, x: torch.Tensor, rng: Optional[torch.Generator] = None diff --git a/mbrl/models/gaussian_mlp.py b/mbrl/models/gaussian_mlp.py index 94406d5d..2eb3ccfb 100644 --- a/mbrl/models/gaussian_mlp.py +++ b/mbrl/models/gaussian_mlp.py @@ -295,7 +295,7 @@ def loss( self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Computes Gaussian NLL loss. It also includes terms for ``max_logvar`` and ``min_logvar`` with small weights, @@ -323,7 +323,7 @@ def loss( def eval_score( # type: ignore self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None - ) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]: + ) -> Tuple[torch.Tensor, Dict[str, Any]]: """Computes the squared error for the model over the given input/target. When model is not an ensemble, this is equivalent to diff --git a/mbrl/models/model.py b/mbrl/models/model.py index 8dbf1993..6825f2a7 100644 --- a/mbrl/models/model.py +++ b/mbrl/models/model.py @@ -181,12 +181,12 @@ def update( if isinstance(loss_and_maybe_meta, tuple): # TODO - v0.2.0 remove this back-compatibility logic loss = cast(torch.Tensor, loss_and_maybe_meta[0]) - meta = cast(Optional[Dict[str, Any]], loss_and_maybe_meta[1]) + meta = cast(Dict[str, Any], loss_and_maybe_meta[1]) loss.backward() if meta is not None: with torch.no_grad(): - grad_norm: torch.Tensor = 0 # type: ignore + grad_norm = 0.0 for p in list( filter(lambda p: p.grad is not None, self.parameters()) ): @@ -275,6 +275,7 @@ def forward(self, x: ModelInput, **kwargs) -> Tuple[torch.Tensor, ...]: """ pass + # TODO this and eval_score are no longer necessary @abc.abstractmethod def loss( self, From 79e758eb16d3d06e481c918f7954d9c4cbf1e7a7 Mon Sep 17 00:00:00 2001 From: Luis Pineda Date: Thu, 22 Jul 2021 11:29:59 -0400 Subject: [PATCH 10/10] Added an argument to the batch_callback of ModelTrainer to indicate if it's training or evaluation. Also added a unit test for the callback. --- mbrl/models/model_trainer.py | 12 ++++++--- tests/core/test_models.py | 47 +++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 5 deletions(-) diff --git a/mbrl/models/model_trainer.py b/mbrl/models/model_trainer.py index 93b86346..a19fb2af 100644 --- a/mbrl/models/model_trainer.py +++ b/mbrl/models/model_trainer.py @@ -110,7 +110,9 @@ def train( batch_callback (callable, optional): if provided, this function will be called for every batch with the output of ``model.update()`` (during training), and ``model.eval_score()`` (during evaluation). It will be called - with three arguments ``(epoch_index, loss/score, meta)``. + with four arguments ``(epoch_index, loss/score, meta, mode)``, where + ``mode`` is one of ``"train"`` or ``"eval"``, indicating if the callback + was called during training or evaluation. Returns: (tuple of two list(float)): the history of training losses and validation losses. @@ -141,7 +143,7 @@ def train( meta = None batch_losses.append(loss) if batch_callback_epoch: - batch_callback_epoch(loss, meta) + batch_callback_epoch(loss, meta, "train") total_avg_loss = np.mean(batch_losses).mean().item() training_losses.append(total_avg_loss) @@ -207,7 +209,9 @@ def evaluate( dataset (bool): the transition iterator to use. batch_callback (callable, optional): if provided, this function will be called for every batch with the output of ``model.eval_score()`` (the score will - be passed as a float, reduced using mean()). + be passed as a float, reduced using mean()). It will be called + with four arguments ``(epoch_index, loss/score, meta, mode)``, where + ``mode`` is the string ``"eval"``. Returns: (tensor): The average score of the model over the dataset (and for ensembles, per @@ -229,7 +233,7 @@ def evaluate( meta = None batch_scores_list.append(batch_score) if batch_callback: - batch_callback(batch_score.mean(), meta) + batch_callback(batch_score.mean(), meta, "eval") batch_scores = torch.cat(batch_scores_list, dim=batch_scores_list[0].ndim - 2) if isinstance(dataset, BootstrapIterator): diff --git a/tests/core/test_models.py b/tests/core/test_models.py index 8c4e0871..962698b8 100644 --- a/tests/core/test_models.py +++ b/tests/core/test_models.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import collections import functools import numpy as np @@ -11,7 +12,9 @@ import torch.nn as nn import mbrl.models +import mbrl.util.replay_buffer from mbrl.env.termination_fns import no_termination +from mbrl.types import TransitionBatch _DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -298,6 +301,7 @@ def test_model_env_expectation_fixed(): class DummyModel(mbrl.models.Model): def __init__(self): super().__init__() + self.param = nn.Parameter(torch.ones(1)) self.device = torch.device(_DEVICE) self.to(self.device) @@ -312,9 +316,12 @@ def sample(self, x, deterministic=False, rng=None): return (self.forward(x),) def loss(self, _input, target=None): - pass + return 0.0 * self.param, {"loss": 0} def eval_score(self, _input, target=None): + return torch.zeros_like(_input), {"score": 0} + + def set_elite(self, _indices): pass @@ -339,3 +346,41 @@ def test_model_env_evaluate_action_sequences(): num_particles=num_particles, ) assert torch.allclose(expected_returns, returns) + + +def test_model_trainer_batch_callback(): + model = DummyModel() + wrapper = mbrl.models.OneDTransitionRewardModel(model, target_is_delta=False) + trainer = mbrl.models.ModelTrainer(wrapper) + num_batches = 10 + dummy_data = torch.zeros(num_batches, 1) + mock_dataset = mbrl.util.replay_buffer.TransitionIterator( + TransitionBatch( + dummy_data, + dummy_data, + dummy_data, + dummy_data.squeeze(1), + dummy_data.squeeze(1), + ), + 1, + ) + + train_counter = collections.Counter() + val_counter = collections.Counter() + + def batch_callback(epoch, val, meta, mode): + assert mode in ["train", "eval"] + if mode == "train": + assert "loss" in meta + train_counter[epoch] += 1 + else: + assert "score" in meta + val_counter[epoch] += 1 + + num_epochs = 20 + trainer.train(mock_dataset, num_epochs=num_epochs, batch_callback=batch_callback) + + for counter in [train_counter, val_counter]: + assert set(counter.keys()) == set(range(num_epochs)) + for i in range(num_epochs): + assert counter[i] == num_batches