Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Commit

Permalink
Merge pull request #109 from facebookresearch/lep.update_returns_meta…
Browse files Browse the repository at this point in the history
…data

Model functions now include metadata
  • Loading branch information
luisenp authored Jul 24, 2021
2 parents ffc4e7e + 79e758e commit 3b8505a
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 43 deletions.
40 changes: 32 additions & 8 deletions mbrl/models/basic_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import Optional, Sequence, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Union, cast

import hydra
import omegaconf
import torch
import torch.nn as nn

from .model import Ensemble
from .model import _NO_META_WARNING_MSG, Ensemble


class BasicEnsemble(Ensemble):
Expand Down Expand Up @@ -193,9 +193,12 @@ def loss( # type: ignore
self,
model_ins: Sequence[torch.Tensor],
targets: Optional[Sequence[torch.Tensor]] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Computes average loss over the losses of all members of the ensemble.
Returns a dictionary with metadata for all models, indexed as
meta["model_i"] = meta_for_model_i
Args:
model_ins (sequence of tensors): one input for each model in the ensemble.
targets (sequence of tensors): one target for each model in the ensemble.
Expand All @@ -205,15 +208,25 @@ def loss( # type: ignore
"""
assert targets is not None
avg_ensemble_loss: torch.Tensor = 0.0
ensemble_meta = {}
for i, model in enumerate(self.members):
model.train()
loss = model.loss(model_ins[i], targets[i])
loss_and_maybe_meta = model.loss(model_ins[i], targets[i])
if isinstance(loss_and_maybe_meta, tuple):
loss = cast(torch.Tensor, loss_and_maybe_meta[0])
meta = cast(Dict[str, Any], loss_and_maybe_meta[1])
else:
# TODO remove in v0.2.0
warnings.warn(_NO_META_WARNING_MSG)
loss = cast(torch.Tensor, loss_and_maybe_meta)
meta = None
ensemble_meta[f"model_{i}"] = meta
avg_ensemble_loss += loss
return avg_ensemble_loss / len(self.members)
return avg_ensemble_loss / len(self.members), ensemble_meta

def eval_score( # type: ignore
self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Computes the average score over all members given input/target.
The input and target tensors are replicated once for each model in the ensemble.
Expand All @@ -231,14 +244,25 @@ def eval_score( # type: ignore

with torch.no_grad():
scores = []
ensemble_meta = {}
for i, model in enumerate(self.members):
model.eval()
score = model.eval_score(inputs[i], targets[i])
score_and_maybe_meta = model.eval_score(inputs[i], targets[i])
if isinstance(score_and_maybe_meta, tuple):
score = cast(torch.Tensor, score_and_maybe_meta[0])
meta = cast(Dict[str, Any], score_and_maybe_meta[1])
else:
# TODO remove in v0.2.0
warnings.warn(_NO_META_WARNING_MSG)
score = cast(torch.Tensor, score_and_maybe_meta)
meta = None
ensemble_meta[f"model_{i}"] = meta

if score.ndim == 3:
assert score.shape[0] == 1
score = score[0]
scores.append(score)
return torch.stack(scores)
return torch.stack(scores), ensemble_meta

def reset( # type: ignore
self, x: torch.Tensor, rng: Optional[torch.Generator] = None
Expand Down
16 changes: 10 additions & 6 deletions mbrl/models/gaussian_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathlib
import pickle
import warnings
from typing import List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch import nn as nn
Expand Down Expand Up @@ -295,12 +295,14 @@ def loss(
self,
model_in: torch.Tensor,
target: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Computes Gaussian NLL loss.
It also includes terms for ``max_logvar`` and ``min_logvar`` with small weights,
with positive and negative signs, respectively.
This function returns no metadata, so the second output is set to an empty dict.
Args:
model_in (tensor): input tensor. The shape must be ``E x B x Id``, or ``B x Id``
where ``E``, ``B`` and ``Id`` represent ensemble size, batch size, and input
Expand All @@ -315,19 +317,21 @@ def loss(
the average over all models.
"""
if self.deterministic:
return self._mse_loss(model_in, target)
return self._mse_loss(model_in, target), {}
else:
return self._nll_loss(model_in, target)
return self._nll_loss(model_in, target), {}

def eval_score( # type: ignore
self, model_in: torch.Tensor, target: Optional[torch.Tensor] = None
) -> torch.Tensor:
) -> Tuple[torch.Tensor, Dict[str, Any]]:
"""Computes the squared error for the model over the given input/target.
When model is not an ensemble, this is equivalent to
`F.mse_loss(model(model_in, target), reduction="none")`. If the model is ensemble,
then return is batched over the model dimension.
This function returns no metadata, so the second output is set to an empty dict.
Args:
model_in (tensor): input tensor. The shape must be ``B x Id``, where `B`` and ``Id``
batch size, and input dimension, respectively.
Expand All @@ -341,7 +345,7 @@ def eval_score( # type: ignore
with torch.no_grad():
pred_mean, _ = self.forward(model_in, use_propagation=False)
target = target.repeat((self.num_members, 1, 1))
return F.mse_loss(pred_mean, target, reduction="none")
return F.mse_loss(pred_mean, target, reduction="none"), {}

def reset( # type: ignore
self, x: torch.Tensor, rng: Optional[torch.Generator] = None
Expand Down
59 changes: 48 additions & 11 deletions mbrl/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,25 @@
# LICENSE file in the root directory of this source tree.
import abc
import pathlib
from typing import Optional, Sequence, Tuple, Union, cast
import warnings
from typing import Any, Dict, Optional, Sequence, Tuple, Union, cast

import torch
from torch import nn as nn

from mbrl.types import ModelInput

# TODO: these are temporary, eventually it will be tuple(tensor, dict), keeping this
# for back-compatibility with v0.1.x, and will be removed in v0.2.0
LossOutput = Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, Any]]]
UpdateOutput = Union[float, Tuple[float, Dict[str, Any]]]


_NO_META_WARNING_MSG = (
"Starting in version v0.2.0, `model.loss()`, model.update(), and model.eval_score() "
"must all return a tuple with (loss, metadata)."
)


# ---------------------------------------------------------------------------
# ABSTRACT MODEL CLASS
Expand Down Expand Up @@ -92,7 +104,7 @@ def loss(
self,
model_in: ModelInput,
target: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> LossOutput:
"""Computes a loss that can be used to update the model using backpropagation.
Args:
Expand All @@ -101,13 +113,16 @@ def loss(
cannot be computed from ``model_in``.
Returns:
(tensor): a loss tensor.
(tuple of tensor and optional dict): the loss tensor and, optionally,
any additional metadata computed by the model,
as a dictionary from strings to objects with metadata computed by
the model (e.g., reconstruction, entropy) that will be used for logging.
"""

@abc.abstractmethod
def eval_score(
self, model_in: ModelInput, target: Optional[torch.Tensor] = None
) -> torch.Tensor:
) -> LossOutput:
"""Computes an evaluation score for the model over the given input/target.
This method should compute a non-reduced score for the model, intended mostly for
Expand All @@ -127,15 +142,17 @@ def eval_score(
cannot be computed from ``model_in``.
Returns:
(tensor): a non-reduced tensor score.
(tuple of tensor and optional dict): a non-reduced tensor score, and a dictionary
from strings to objects with metadata computed by the model
(e.g., reconstructions, entropy, etc.) that will be used for logging.
"""

def update(
self,
model_in: ModelInput,
optimizer: torch.optim.Optimizer,
target: Optional[torch.Tensor] = None,
) -> float:
) -> UpdateOutput:
"""Updates the model using backpropagation with given input and target tensors.
Provides a basic update function, following the steps below:
Expand All @@ -155,15 +172,34 @@ def update(
Returns:
(float): the numeric value of the computed loss.
(dict): any additional metadata dictionary computed by :meth:`loss`.
"""
optimizer = cast(torch.optim.Optimizer, optimizer)
self.train()
optimizer.zero_grad()
loss = self.loss(model_in, target)
loss.backward()
optimizer.step()
return loss.item()
loss_and_maybe_meta = self.loss(model_in, target)
if isinstance(loss_and_maybe_meta, tuple):
# TODO - v0.2.0 remove this back-compatibility logic
loss = cast(torch.Tensor, loss_and_maybe_meta[0])
meta = cast(Dict[str, Any], loss_and_maybe_meta[1])
loss.backward()

if meta is not None:
with torch.no_grad():
grad_norm = 0.0
for p in list(
filter(lambda p: p.grad is not None, self.parameters())
):
grad_norm += p.grad.data.norm(2).item() ** 2
meta["grad_norm"] = grad_norm
optimizer.step()
return loss.item(), meta

else:
warnings.warn(_NO_META_WARNING_MSG)
loss_and_maybe_meta.backward()
optimizer.step()
return loss_and_maybe_meta.item()

def __len__(self):
return 1
Expand Down Expand Up @@ -239,6 +275,7 @@ def forward(self, x: ModelInput, **kwargs) -> Tuple[torch.Tensor, ...]:
"""
pass

# TODO this and eval_score are no longer necessary
@abc.abstractmethod
def loss(
self,
Expand Down
60 changes: 51 additions & 9 deletions mbrl/models/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import functools
import itertools
from typing import Callable, Dict, List, Optional, Tuple
import warnings
from typing import Callable, Dict, List, Optional, Tuple, cast

import numpy as np
import torch
Expand All @@ -13,7 +15,7 @@
from mbrl.util.logger import Logger
from mbrl.util.replay_buffer import BootstrapIterator, TransitionIterator

from .model import Model
from .model import _NO_META_WARNING_MSG, Model

MODEL_LOG_FORMAT = [
("train_iteration", "I", "int"),
Expand Down Expand Up @@ -71,6 +73,7 @@ def train(
patience: Optional[int] = None,
improvement_threshold: float = 0.01,
callback: Optional[Callable] = None,
batch_callback: Optional[Callable] = None,
) -> Tuple[List[float], List[float]]:
"""Trains the model for some number of epochs.
Expand Down Expand Up @@ -104,6 +107,12 @@ def train(
- validation score (for ensembles, factored per member)
- best validation score so far
batch_callback (callable, optional): if provided, this function will be called
for every batch with the output of ``model.update()`` (during training),
and ``model.eval_score()`` (during evaluation). It will be called
with four arguments ``(epoch_index, loss/score, meta, mode)``, where
``mode`` is one of ``"train"`` or ``"eval"``, indicating if the callback
was called during training or evaluation.
Returns:
(tuple of two list(float)): the history of training losses and validation losses.
Expand All @@ -117,14 +126,30 @@ def train(
epochs_since_update = 0
best_val_score = self.evaluate(eval_dataset)
for epoch in epoch_iter:
if batch_callback:
batch_callback_epoch = functools.partial(batch_callback, epoch)
else:
batch_callback_epoch = None
batch_losses: List[float] = []
for batch in dataset_train:
loss = self.model.update(batch, self.optimizer)
loss_and_maybe_meta = self.model.update(batch, self.optimizer)
if isinstance(loss_and_maybe_meta, tuple):
loss = cast(float, loss_and_maybe_meta[0])
meta = cast(Dict, loss_and_maybe_meta[1])
else:
# TODO remove this if in v0.2.0
warnings.warn(_NO_META_WARNING_MSG)
loss = cast(float, loss_and_maybe_meta)
meta = None
batch_losses.append(loss)
if batch_callback_epoch:
batch_callback_epoch(loss, meta, "train")
total_avg_loss = np.mean(batch_losses).mean().item()
training_losses.append(total_avg_loss)

eval_score = self.evaluate(eval_dataset)
eval_score = self.evaluate(
eval_dataset, batch_callback=batch_callback_epoch
)
val_scores.append(eval_score.mean().item())

maybe_best_weights = self.maybe_get_best_weights(
Expand Down Expand Up @@ -171,7 +196,9 @@ def train(
self._train_iteration += 1
return training_losses, val_scores

def evaluate(self, dataset: TransitionIterator) -> torch.Tensor:
def evaluate(
self, dataset: TransitionIterator, batch_callback: Optional[Callable] = None
) -> torch.Tensor:
"""Evaluates the model on the validation dataset.
Iterates over the dataset, one batch at a time, and calls
Expand All @@ -180,6 +207,11 @@ def evaluate(self, dataset: TransitionIterator) -> torch.Tensor:
Args:
dataset (bool): the transition iterator to use.
batch_callback (callable, optional): if provided, this function will be called
for every batch with the output of ``model.eval_score()`` (the score will
be passed as a float, reduced using mean()). It will be called
with four arguments ``(epoch_index, loss/score, meta, mode)``, where
``mode`` is the string ``"eval"``.
Returns:
(tensor): The average score of the model over the dataset (and for ensembles, per
Expand All @@ -190,15 +222,25 @@ def evaluate(self, dataset: TransitionIterator) -> torch.Tensor:

batch_scores_list = []
for batch in dataset:
avg_batch_score = self.model.eval_score(batch)
batch_scores_list.append(avg_batch_score)
batch_scores = torch.cat(batch_scores_list, axis=batch_scores_list[0].ndim - 2)
batch_score_and_maybe_meta = self.model.eval_score(batch)
if isinstance(batch_score_and_maybe_meta, tuple):
batch_score = cast(torch.Tensor, batch_score_and_maybe_meta[0])
meta = cast(Dict, batch_score_and_maybe_meta[1])
else:
# TODO remove this "else" in v0.2.0
warnings.warn(_NO_META_WARNING_MSG)
batch_score = cast(torch.Tensor, batch_score_and_maybe_meta)
meta = None
batch_scores_list.append(batch_score)
if batch_callback:
batch_callback(batch_score.mean(), meta, "eval")
batch_scores = torch.cat(batch_scores_list, dim=batch_scores_list[0].ndim - 2)

if isinstance(dataset, BootstrapIterator):
dataset.toggle_bootstrap()

mean_axis = 1 if batch_scores.ndim == 2 else (1, 2)
batch_scores = batch_scores.mean(axis=mean_axis)
batch_scores = batch_scores.mean(dim=mean_axis)

return batch_scores

Expand Down
Loading

0 comments on commit 3b8505a

Please sign in to comment.