Skip to content
This repository has been archived by the owner on Sep 1, 2024. It is now read-only.

Commit

Permalink
Merge pull request #111 from facebookresearch/lep.double_normalization
Browse files Browse the repository at this point in the history
Add option to do normalization using double precision for 1-D models
  • Loading branch information
luisenp authored Jul 24, 2021
2 parents 3b8505a + 14f75e1 commit 2af0cf1
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 26 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Changelog

## v0.1.3
- Methods `loss`, `eval_score` and `update` of `Model` class now return a
tuple of loss/score and metadata. Currently, supports the old version as well,
but this will be deprecated in v0.2.0.
- `ModelTrainer` now accepts a callback that will be called after every batch
both during training and evaluation.
- `Normalizer` in `util.math` can now operate using double precision. Utilities
now allow specifying if replay buffer and normalizer should use double or float
via Hydra config.

## v0.1.2
- Multiple bug fixes
- Added a training browser to compare results of multiple runs
Expand Down
2 changes: 1 addition & 1 deletion mbrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
__version__ = "0.1.2"
__version__ = "0.1.3"
11 changes: 9 additions & 2 deletions mbrl/algorithms/mbpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,16 @@ def train(

# -------------- Create initial overrides. dataset --------------
dynamics_model = mbrl.util.common.create_one_dim_tr_model(cfg, obs_shape, act_shape)

use_double_dtype = cfg.algorithm.get("normalize_double_precision", False)
dtype = np.double if use_double_dtype else np.float32
replay_buffer = mbrl.util.common.create_replay_buffer(
cfg, obs_shape, act_shape, rng=rng
cfg,
obs_shape,
act_shape,
rng=rng,
obs_type=dtype,
action_type=dtype,
reward_type=dtype,
)
random_explore = cfg.algorithm.random_initial_explore
mbrl.util.common.rollout_agent_trajectories(
Expand Down
10 changes: 9 additions & 1 deletion mbrl/algorithms/pets.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,16 @@ def train(

# -------- Create and populate initial env dataset --------
dynamics_model = mbrl.util.common.create_one_dim_tr_model(cfg, obs_shape, act_shape)
use_double_dtype = cfg.algorithm.get("normalize_double_precision", False)
dtype = np.double if use_double_dtype else np.float32
replay_buffer = mbrl.util.common.create_replay_buffer(
cfg, obs_shape, act_shape, rng=rng
cfg,
obs_shape,
act_shape,
rng=rng,
obs_type=dtype,
action_type=dtype,
reward_type=dtype,
)
mbrl.util.common.rollout_agent_trajectories(
env,
Expand Down
2 changes: 1 addition & 1 deletion mbrl/diagnostics/finetune_model_with_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
self.cfg,
self.env.observation_space.shape,
self.env.action_space.shape,
None if new_model else model_dir,
load_dir=None if new_model else model_dir,
)
self.rng = np.random.default_rng(seed)

Expand Down
1 change: 1 addition & 0 deletions mbrl/examples/conf/algorithm/mbpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
name: "mbpo"

normalize: true
normalize_double_precision: true
target_is_delta: true
learned_rewards: true
freq_train_model: ${overrides.freq_train_model}
Expand Down
1 change: 1 addition & 0 deletions mbrl/examples/conf/algorithm/pets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ optimizer:
device: ${device}

normalize: true
normalize_double_precision: true
target_is_delta: true
initial_exploration_steps: ${overrides.trial_length}
freq_train_model: ${overrides.freq_train_model}
Expand Down
27 changes: 18 additions & 9 deletions mbrl/models/one_dim_tr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class OneDTransitionRewardModel(Model):
which will be used every time the model is called using the methods in this
class. To update the normalizer statistics, the user needs to call
:meth:`update_normalizer` before using the model. Defaults to ``False``.
normalize_double_precision (bool): if ``True``, the normalizer will work with
double precision.
learned_rewards (bool): if ``True``, the wrapper considers the last output of the model
to correspond to rewards predictions, and will use it to construct training
targets for the model and when returning model predictions. Defaults to ``True``.
Expand All @@ -74,6 +76,7 @@ def __init__(
model: Model,
target_is_delta: bool = True,
normalize: bool = False,
normalize_double_precision: bool = False,
learned_rewards: bool = True,
obs_process_fn: Optional[mbrl.types.ObsProcessFnType] = None,
no_delta_list: Optional[List[int]] = None,
Expand All @@ -84,7 +87,9 @@ def __init__(
self.input_normalizer: Optional[mbrl.util.math.Normalizer] = None
if normalize:
self.input_normalizer = mbrl.util.math.Normalizer(
self.model.in_size, self.model.device
self.model.in_size,
self.model.device,
dtype=torch.double if normalize_double_precision else torch.float,
)
self.device = self.model.device
self.learned_rewards = learned_rewards
Expand All @@ -109,15 +114,15 @@ def _get_model_input_from_np(
model_in_np = np.concatenate([obs, action], axis=obs.ndim - 1)
if self.input_normalizer:
# Normalizer lives on device
return self.input_normalizer.normalize(model_in_np)
return self.input_normalizer.normalize(model_in_np).float().to(device)
return torch.from_numpy(model_in_np).to(device)

def _get_model_input_from_tensors(self, obs: torch.Tensor, action: torch.Tensor):
if self.obs_process_fn:
obs = self.obs_process_fn(obs)
model_in = torch.cat([obs, action], axis=obs.ndim - 1)
if self.input_normalizer:
model_in = self.input_normalizer.normalize(model_in)
model_in = self.input_normalizer.normalize(model_in).float()
return model_in

def _get_model_input_and_target_from_batch(
Expand All @@ -133,14 +138,18 @@ def _get_model_input_and_target_from_batch(

model_in = self._get_model_input_from_np(obs, action, self.device)
if self.learned_rewards:
target = torch.from_numpy(
np.concatenate(
[target_obs, np.expand_dims(reward, axis=reward.ndim)],
axis=obs.ndim - 1,
target = (
torch.from_numpy(
np.concatenate(
[target_obs, np.expand_dims(reward, axis=reward.ndim)],
axis=obs.ndim - 1,
)
)
).to(self.device)
.float()
.to(self.device)
)
else:
target = torch.from_numpy(target_obs).to(self.device)
target = torch.from_numpy(target_obs).float().to(self.device)
return model_in, target

def forward(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, ...]:
Expand Down
14 changes: 13 additions & 1 deletion mbrl/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pathlib
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import gym.wrappers
import hydra
Expand Down Expand Up @@ -96,6 +96,9 @@ def create_one_dim_tr_model(
model,
target_is_delta=cfg.algorithm.target_is_delta,
normalize=cfg.algorithm.normalize,
normalize_double_precision=cfg.algorithm.get(
"normalize_double_precision", False
),
learned_rewards=cfg.algorithm.learned_rewards,
obs_process_fn=obs_process_fn,
no_delta_list=cfg.overrides.get("no_delta_list", None),
Expand Down Expand Up @@ -131,6 +134,9 @@ def create_replay_buffer(
cfg: omegaconf.DictConfig,
obs_shape: Sequence[int],
act_shape: Sequence[int],
obs_type: Type = np.float32,
action_type: Type = np.float32,
reward_type: Type = np.float32,
load_dir: Optional[Union[str, pathlib.Path]] = None,
collect_trajectories: bool = False,
rng: Optional[np.random.Generator] = None,
Expand All @@ -155,6 +161,9 @@ def create_replay_buffer(
cfg (omegaconf.DictConfig): the configuration to use.
obs_shape (Sequence of ints): the shape of observation arrays.
act_shape (Sequence of ints): the shape of action arrays.
obs_type (type): the data type of the observations (defaults to np.float32).
action_type (type): the data type of the actions (defaults to np.float32).
reward_type (type): the data type of the rewards (defaults to np.float32).
load_dir (optional str or pathlib.Path): if provided, the function will attempt to
populate the buffers from "load_dir/replay_buffer.npz".
collect_trajectories (bool, optional): if ``True`` sets the replay buffers to collect
Expand Down Expand Up @@ -183,6 +192,9 @@ def create_replay_buffer(
dataset_size,
obs_shape,
act_shape,
obs_type=obs_type,
action_type=action_type,
reward_type=reward_type,
rng=rng,
max_trajectory_length=maybe_max_trajectory_len,
)
Expand Down
9 changes: 6 additions & 3 deletions mbrl/util/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,15 @@ class Normalizer:
Args:
in_size (int): the size of the data that will be normalized.
device (torch.device): the device in which the data will reside.
dtype (torch.dtype): the data type to use for the normalizer.
"""

_STATS_FNAME = "env_stats.pickle"

def __init__(self, in_size: int, device: torch.device):
self.mean = torch.zeros((1, in_size), device=device)
self.std = torch.ones((1, in_size), device=device)
def __init__(self, in_size: int, device: torch.device, dtype=torch.float32):
self.mean = torch.zeros((1, in_size), device=device, dtype=dtype)
self.std = torch.ones((1, in_size), device=device, dtype=dtype)
self.eps = 1e-12 if dtype == torch.double else 1e-5
self.device = device

def update_stats(self, data: mbrl.types.TensorType):
Expand All @@ -120,6 +122,7 @@ def update_stats(self, data: mbrl.types.TensorType):
data = torch.from_numpy(data).to(self.device)
self.mean = data.mean(0, keepdim=True)
self.std = data.std(0, keepdim=True)
self.std[self.std < self.eps] = 1.0

def normalize(self, val: Union[float, mbrl.types.TensorType]) -> torch.Tensor:
"""Normalizes the value according to the stored statistics.
Expand Down
10 changes: 6 additions & 4 deletions mbrl/util/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
import pathlib
import warnings
from typing import List, Optional, Sequence, Sized, Tuple, Union
from typing import List, Optional, Sequence, Sized, Tuple, Type, Union

import numpy as np

Expand Down Expand Up @@ -303,6 +303,7 @@ class ReplayBuffer:
action_shape (Sequence of ints): the shape of the actions to store.
obs_type (type): the data type of the observations (defaults to np.float32).
action_type (type): the data type of the actions (defaults to np.float32).
reward_type (type): the data type of the rewards (defaults to np.float32).
rng (np.random.Generator, optional): a random number generator when sampling
batches. If None (default value), a new default generator will be used.
max_trajectory_length (int, optional): if given, indicates that trajectory
Expand All @@ -321,8 +322,9 @@ def __init__(
capacity: int,
obs_shape: Sequence[int],
action_shape: Sequence[int],
obs_type=np.float32,
action_type=np.float32,
obs_type: Type = np.float32,
action_type: Type = np.float32,
reward_type: Type = np.float32,
rng: Optional[np.random.Generator] = None,
max_trajectory_length: Optional[int] = None,
):
Expand All @@ -338,7 +340,7 @@ def __init__(
self.obs = np.empty((capacity, *obs_shape), dtype=obs_type)
self.next_obs = np.empty((capacity, *obs_shape), dtype=obs_type)
self.action = np.empty((capacity, *action_shape), dtype=action_type)
self.reward = np.empty(capacity, dtype=np.float32)
self.reward = np.empty(capacity, dtype=reward_type)
self.done = np.empty(capacity, dtype=bool)

if rng is None:
Expand Down
28 changes: 24 additions & 4 deletions tests/core/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def test_create_one_dim_tr_model():
assert dynamics_model.model.x == 1 and dynamics_model.model.y == 2
assert dynamics_model.num_elites is None
assert dynamics_model.no_delta_list == []
# default when no normalization type is given is float
assert dynamics_model.input_normalizer.mean.dtype == torch.float32

# Check given input/output sizes, overrides active, and no learned rewards option
cfg.dynamics_model.model.in_size = 11
Expand All @@ -77,6 +79,14 @@ def test_create_one_dim_tr_model():
assert dynamics_model.no_delta_list == [0]
assert dynamics_model.obs_process_fn == mock_obs_func

# Test normalization option
for double_norm in [True, False]:
cfg_dict["algorithm"]["normalize_double_precision"] = double_norm
cfg = omegaconf.OmegaConf.create(cfg_dict)
dynamics_model = utils.create_one_dim_tr_model(cfg, obs_shape, act_shape)
dtype = torch.double if double_norm else torch.float32
assert dynamics_model.input_normalizer.mean.dtype == dtype


def test_create_replay_buffer():
trial_length = 20
Expand Down Expand Up @@ -104,10 +114,20 @@ def _check_shapes(how_many):
_check_shapes(num_trials * trial_length)

# Now add a training bootstrap and override the dataset size
cfg_dict["algorithm"]["dataset_size"] = 1500
cfg = omegaconf.OmegaConf.create(cfg_dict)
buffer = utils.create_replay_buffer(cfg, obs_shape, act_shape)
_check_shapes(1500)
for dtype in [np.float32, np.double]:
cfg_dict["algorithm"]["dataset_size"] = 1500
cfg = omegaconf.OmegaConf.create(cfg_dict)
buffer = utils.create_replay_buffer(
cfg,
obs_shape,
act_shape,
obs_type=dtype,
action_type=dtype,
reward_type=dtype,
)
for array in [buffer.obs, buffer.action, buffer.reward]:
assert array.dtype == dtype
_check_shapes(1500)


class MockModelEnv:
Expand Down

0 comments on commit 2af0cf1

Please sign in to comment.