Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Commit

Permalink
Merge pull request #44 from facebookresearch/test_util
Browse files Browse the repository at this point in the history
Added tests for common utilities
  • Loading branch information
luisenp authored Mar 9, 2021
2 parents e01261e + 3a30b1e commit c7a7425
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 32 deletions.
10 changes: 10 additions & 0 deletions docs/models.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
Models module
=============
This module provides implementations of common model architectures used in model-based RL,
including probabilistic and deterministic ensembles. All models in the library derive from
class :class:`mbrl.models.Model`. We provide a generic ensemble implementation,
:class:`mbrl.models.BasicEnsemble`, that can be used to produce epistemic uncertainty estimates
for any subclass of `Model`. For efficiency considerations, some specific model implementations
also provide their own ensemble implementations, without having to rely on BasicEnsemble.
One such model is :class:`mbrl.models.GaussianMLP`, which can be used as a single model or as
an ensemble. Additionally, it can be used as a deterministic model
trained with MSE loss, or a parameterized Gaussian with mean and log variance outputs, trained
with negative log-likelihood.

.. automodule:: mbrl.models
:members:
Expand Down
30 changes: 21 additions & 9 deletions mbrl/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def __init__(
self.in_size = in_size
self.out_size = out_size
self.device = torch.device(device)
self.is_ensemble = False
self.to(device)

def forward(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -141,7 +140,6 @@ def loss(
Returns:
(tensor): a loss tensor.
"""
pass

@abc.abstractmethod
def eval_score(self, model_in: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Expand All @@ -166,23 +164,35 @@ def eval_score(self, model_in: torch.Tensor, target: torch.Tensor) -> torch.Tens
Returns:
(tensor): a non-reduced tensor score.
"""
pass

@abc.abstractmethod
def save(self, path: str):
"""Saves the model to the given path. """
pass

@abc.abstractmethod
def load(self, path: str):
"""Loads the model from the given path."""

@abc.abstractmethod
def _is_deterministic_impl(self):
# Subclasses must specify if model is _deterministic or not
pass

@abc.abstractmethod
def is_deterministic(self):
"""Whether the model produces logvar predictions or not."""
def _is_ensemble_impl(self):
# Subclasses must specify if they are ensembles or not
pass

@property
def is_deterministic(self):
"""Whether the model is deterministic or not."""
return self._is_deterministic_impl()

@property
def is_ensemble(self):
"""Whether the model is an ensemble or not."""
return self._is_ensemble_impl()

def update(
self,
model_in: torch.Tensor,
Expand Down Expand Up @@ -277,7 +287,6 @@ def __init__(
):
super().__init__(in_size, out_size, device)
self.members = []
self.is_ensemble = True
for i in range(ensemble_size):
model = hydra.utils.instantiate(member_cfg)
self.members.append(model)
Expand Down Expand Up @@ -483,8 +492,11 @@ def load(self, path: str):
state_dict = torch.load(path)
self.load_state_dict(state_dict)

def is_deterministic(self):
return self.members[0].is_deterministic()
def _is_ensemble_impl(self):
return True

def _is_deterministic_impl(self):
return self.members[0].is_deterministic

def sample_propagation_indices(
self, batch_size: int, rng: torch.Generator
Expand Down
30 changes: 17 additions & 13 deletions mbrl/models/gaussian_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def __init__(
activation_cls = nn.SiLU if use_silu else nn.ReLU

self.num_members = None
self._is_ensemble = False
if ensemble_size > 1:
self.is_ensemble = True
self._is_ensemble = True
self.num_members = ensemble_size

def create_linear_layer(l_in, l_out):
Expand All @@ -77,13 +78,13 @@ def create_linear_layer(l_in, l_out):
)
self.hidden_layers = nn.Sequential(*hidden_layers)

self.deterministic = deterministic
self._deterministic = deterministic
if deterministic:
self.mean_and_logvar = create_linear_layer(hid_size, out_size)
else:
self.mean_and_logvar = create_linear_layer(hid_size, 2 * out_size)
logvar_shape = (
(self.num_members, 1, out_size) if self.is_ensemble else (1, out_size)
(self.num_members, 1, out_size) if self._is_ensemble else (1, out_size)
)
self.min_logvar = nn.Parameter(
-10 * torch.ones(logvar_shape, requires_grad=True)
Expand Down Expand Up @@ -116,12 +117,12 @@ def _default_forward(
x = self.hidden_layers(x)
mean_and_logvar = self.mean_and_logvar(x)
self._maybe_toggle_layers_use_only_elite(only_elite)
if self.deterministic:
if self._deterministic:
return mean_and_logvar, None
else:
mean = mean_and_logvar[..., : self.out_size]
logvar = mean_and_logvar[..., self.out_size :]
if self.is_ensemble and self.elite_models is not None:
if self._is_ensemble and self.elite_models is not None:
model_idx = self.elite_models if only_elite else range(self.num_members)
assert not only_elite or (len(model_idx) != self.num_members), (
"If elite size == self.num_members, it's better "
Expand Down Expand Up @@ -202,7 +203,7 @@ def forward( # type: ignore
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes mean and logvar predictions for the given input.
When ``self.is_ensemble = True``, the model supports uncertainty propagation options
When ``self._is_ensemble = True``, the model supports uncertainty propagation options
that can be used to aggregate the outputs of the different models in the ensemble.
Valid propagation options are:
Expand Down Expand Up @@ -250,7 +251,7 @@ def forward( # type: ignore
the output to :func:`mbrl.math.propagate`.
"""
if self.is_ensemble:
if self._is_ensemble:
return self._forward_ensemble(
x,
propagation=propagation,
Expand All @@ -261,7 +262,7 @@ def forward( # type: ignore

def _mse_loss(self, model_in: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
pred_mean, _ = self.forward(model_in)
if self.is_ensemble:
if self._is_ensemble:
assert model_in.ndim == 3 and target.ndim == 3
total_loss: torch.Tensor = 0.0
for i in range(self.num_members):
Expand All @@ -274,7 +275,7 @@ def _mse_loss(self, model_in: torch.Tensor, target: torch.Tensor) -> torch.Tenso

def _nll_loss(self, model_in: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
pred_mean, pred_logvar = self.forward(model_in)
if self.is_ensemble:
if self._is_ensemble:
assert model_in.ndim == 3 and target.ndim == 3
nll: torch.Tensor = 0.0
for i in range(self.num_members):
Expand Down Expand Up @@ -310,7 +311,7 @@ def loss(self, model_in: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
the model over the given input/target. If the model is an ensemble, returns
the average over all models.
"""
if self.deterministic:
if self._deterministic:
return self._mse_loss(model_in, target)
else:
return self._nll_loss(model_in, target)
Expand All @@ -334,7 +335,7 @@ def eval_score(self, model_in: torch.Tensor, target: torch.Tensor) -> torch.Tens
assert model_in.ndim == 2 and target.ndim == 2
with torch.no_grad():
pred_mean, _ = self.forward(model_in)
if self.is_ensemble:
if self._is_ensemble:
target = target.repeat((self.num_members, 1, 1))
return F.mse_loss(pred_mean, target, reduction="none")

Expand All @@ -344,8 +345,11 @@ def save(self, path: str):
def load(self, path: str):
self.load_state_dict(torch.load(path))

def is_deterministic(self):
return self.deterministic
def _is_deterministic_impl(self):
return self._deterministic

def _is_ensemble_impl(self):
return self._is_ensemble

def __len__(self):
return self.num_members
Expand Down
2 changes: 1 addition & 1 deletion mbrl/models/model_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def evaluate_action_sequences(
actions_for_step, num_particles, dim=0
)
_, rewards, _, _ = self.step(
action_batch, sample=not self.dynamics_model.model.is_deterministic()
action_batch, sample=not self.dynamics_model.model.is_deterministic
)
total_rewards += rewards

Expand Down
5 changes: 3 additions & 2 deletions mbrl/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

def create_dynamics_model(
cfg: Union[omegaconf.ListConfig, omegaconf.DictConfig],
obs_shape: Tuple[int],
act_shape: Tuple[int],
obs_shape: Tuple[int, ...],
act_shape: Tuple[int, ...],
model_dir: Optional[Union[str, pathlib.Path]] = None,
):
"""Creates a dynamics model from a given configuration.
Expand All @@ -41,6 +41,7 @@ def create_dynamics_model(
-overrides
-no_delta_list (list[int], optional): to be passed to the dynamics model wrapper
-obs_process_fn (str, optional): a Python function to pre-process observations
-num_elites (int, optional): number of elite members for ensembles
If ``cfg.dynamics_model.model.in_size`` is not provided, it will be automatically set to
`obs_shape[0] + act_shape[0]`. If ``cfg.dynamics_model.model.out_size`` is not provided,
Expand Down
34 changes: 28 additions & 6 deletions notebooks/pets_example.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit c7a7425

Please sign in to comment.