Skip to content

Commit

Permalink
Updated basic APG algorithm (#476)
Browse files Browse the repository at this point in the history
Updates APG to learn useful policies, see: google-deepmind/mujoco#1601
  • Loading branch information
Andrew-Luo1 authored Apr 18, 2024
1 parent 2329ae7 commit b45760c
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 46 deletions.
10 changes: 7 additions & 3 deletions brax/training/agents/apg/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from brax.training.types import PRNGKey
import flax
from flax import linen
from flax.linen.initializers import orthogonal


@flax.struct.dataclass
Expand Down Expand Up @@ -55,15 +56,18 @@ def make_apg_networks(
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (32,) * 4,
activation: networks.ActivationFn = linen.swish) -> APGNetworks:
activation: networks.ActivationFn = linen.elu,
layer_norm: bool = True) -> APGNetworks:
"""Make APG networks."""
parametric_action_distribution = distribution.NormalTanhDistribution(
event_size=action_size)
event_size=action_size, var_scale=0.1)
policy_network = networks.make_policy_network(
parametric_action_distribution.param_size,
observation_size,
preprocess_observations_fn=preprocess_observations_fn,
hidden_layer_sizes=hidden_layer_sizes, activation=activation)
hidden_layer_sizes=hidden_layer_sizes, activation=activation,
kernel_init = orthogonal(0.01),
layer_norm=layer_norm)
return APGNetworks(
policy_network=policy_network,
parametric_action_distribution=parametric_action_distribution)
118 changes: 83 additions & 35 deletions brax/training/agents/apg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from brax import base
from brax import envs
from brax.training import acting
from brax.training import gradients
from brax.training import pmap
from brax.training import types
from brax.training.acme import running_statistics
Expand Down Expand Up @@ -56,15 +57,20 @@ def _unpmap(v):
def train(
environment: Union[envs_v1.Env, envs.Env],
episode_length: int,
action_repeat: int = 1,
policy_updates: int,
horizon_length: int = 32,
num_envs: int = 1,
num_evals: int = 1,
action_repeat: int = 1,
max_devices_per_host: Optional[int] = None,
num_eval_envs: int = 128,
learning_rate: float = 1e-4,
adam_b: list = [0.7, 0.95],
use_schedule: bool = True,
use_float64: bool = False,
schedule_decay: float = 0.997,
seed: int = 0,
truncation_length: Optional[int] = None,
max_gradient_norm: float = 1e9,
num_evals: int = 1,
normalize_observations: bool = False,
deterministic_eval: bool = False,
network_factory: types.NetworkFactory[
Expand All @@ -91,10 +97,9 @@ def train(
process_id, local_device_count, local_devices_to_use)
device_count = local_devices_to_use * process_count

if truncation_length is not None:
assert truncation_length > 0

num_updates = policy_updates
num_evals_after_init = max(num_evals - 1, 1)
updates_per_epoch = jnp.round(num_updates / (num_evals_after_init))

assert num_envs % device_count == 0
env = environment
Expand All @@ -120,6 +125,9 @@ def train(
action_repeat=action_repeat,
randomization_fn=v_randomiation_fn,
)

reset_fn = jax.jit(jax.vmap(env.reset))
step_fn = jax.jit(jax.vmap(env.step))

normalize = lambda x, y: x
if normalize_observations:
Expand All @@ -129,8 +137,24 @@ def train(
env.action_size,
preprocess_observations_fn=normalize)
make_policy = apg_networks.make_inference_fn(apg_network)

if use_schedule:
learning_rate = optax.exponential_decay(
init_value=learning_rate,
transition_steps=1,
decay_rate=schedule_decay
)

optimizer = optax.chain(
optax.clip(1.0),
optax.adam(learning_rate=learning_rate, b1=adam_b[0], b2=adam_b[1])
)

optimizer = optax.adam(learning_rate=learning_rate)
def scramble_times(state, key):
state.info['steps'] = jnp.round(
jax.random.uniform(key, (local_devices_to_use, num_envs,),
maxval=episode_length))
return state

def env_step(
carry: Tuple[Union[envs.State, envs_v1.State], PRNGKey],
Expand All @@ -141,23 +165,17 @@ def env_step(
key, key_sample = jax.random.split(key)
actions = policy(env_state.obs, key_sample)[0]
nstate = env.step(env_state, actions)
if truncation_length is not None:
nstate = jax.lax.cond(
jnp.mod(step_index + 1, truncation_length) == 0.,
jax.lax.stop_gradient, lambda x: x, nstate)

return (nstate, key), (nstate.reward, env_state.obs)

def loss(policy_params, normalizer_params, key):
key_reset, key_scan = jax.random.split(key)
env_state = env.reset(
jax.random.split(key_reset, num_envs // process_count))
def loss(policy_params, normalizer_params, env_state, key):
f = functools.partial(
env_step, policy=make_policy((normalizer_params, policy_params)))
(rewards,
obs) = jax.lax.scan(f, (env_state, key_scan),
(jnp.array(range(episode_length // action_repeat))))[1]
return -jnp.mean(rewards), obs
(state_h, _), (rewards,
obs) = jax.lax.scan(f, (env_state, key),
(jnp.arange(horizon_length // action_repeat)))

return -jnp.mean(rewards), (obs, state_h)

loss_grad = jax.grad(loss, has_aux=True)

Expand All @@ -168,62 +186,83 @@ def clip_by_global_norm(updates):
lambda t: jnp.where(trigger, t, (t / g_norm) * max_gradient_norm),
updates)

def training_epoch(training_state: TrainingState, key: PRNGKey):
def minibatch_step(
carry, epoch_step_index: int):
(optimizer_state, normalizer_params,
policy_params, key, state) = carry

key, key_grad = jax.random.split(key)
grad, obs = loss_grad(training_state.policy_params,
training_state.normalizer_params, key_grad)
grad, (obs, state_h) = loss_grad(policy_params,
normalizer_params,
state,
key_grad)

grad = clip_by_global_norm(grad)
grad = jax.lax.pmean(grad, axis_name='i')
params_update, optimizer_state = optimizer.update(
grad, training_state.optimizer_state)
policy_params = optax.apply_updates(training_state.policy_params,
grad, optimizer_state)
policy_params = optax.apply_updates(policy_params,
params_update)

normalizer_params = running_statistics.update(
training_state.normalizer_params, obs, pmap_axis_name=_PMAP_AXIS_NAME)
normalizer_params, obs, pmap_axis_name=_PMAP_AXIS_NAME)

metrics = {
'grad_norm': optax.global_norm(grad),
'params_norm': optax.global_norm(policy_params)
}

return (optimizer_state, normalizer_params, policy_params, key, state_h), metrics

def training_epoch(training_state: TrainingState, env_state: Union[envs.State, envs_v1.State], key: PRNGKey):

(optimizer_state, normalizer_params,
policy_params, key, state_h), metrics = jax.lax.scan(
minibatch_step,
(training_state.optimizer_state, training_state.normalizer_params,
training_state.policy_params, key, env_state),
jnp.arange(updates_per_epoch))

return TrainingState(
optimizer_state=optimizer_state,
normalizer_params=normalizer_params,
policy_params=policy_params), metrics
policy_params=policy_params), state_h, metrics, key

training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)

training_walltime = 0

# Note that this is NOT a pure jittable method.
def training_epoch_with_timing(training_state: TrainingState,
env_state: Union[envs.State, envs_v1.State],
key: PRNGKey) -> Tuple[TrainingState, Metrics]:
nonlocal training_walltime
t = time.time()
(training_state, metrics) = training_epoch(training_state, key)
(training_state, env_state, metrics, key) = training_epoch(training_state, env_state, key)
metrics = jax.tree_util.tree_map(jnp.mean, metrics)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics)

epoch_training_time = time.time() - t
training_walltime += epoch_training_time
sps = (episode_length * num_envs) / epoch_training_time
sps = (updates_per_epoch * num_envs * horizon_length) / epoch_training_time
metrics = {
'training/sps': sps,
'training/walltime': training_walltime,
**{f'training/{name}': value for name, value in metrics.items()}
}
return training_state, metrics # pytype: disable=bad-return-type # py311-upgrade
return training_state, env_state, metrics, key # pytype: disable=bad-return-type # py311-upgrade

# The network key should be global, so that networks are initialized the same
# way for different processes.
policy_params = apg_network.policy_network.init(global_key)
del global_key

dtype = 'float64' if use_float64 else 'float32'
training_state = TrainingState(
optimizer_state=optimizer.init(policy_params),
policy_params=policy_params,
normalizer_params=running_statistics.init_state(
specs.Array((env.observation_size,), jnp.dtype('float32'))))
specs.Array((env.observation_size,), jnp.dtype(dtype))))
training_state = jax.device_put_replicated(
training_state,
jax.local_devices()[:local_devices_to_use])
Expand Down Expand Up @@ -251,6 +290,7 @@ def training_epoch_with_timing(training_state: TrainingState,

# Run initial eval
metrics = {}

if process_id == 0 and num_evals > 1:
metrics = evaluator.run_evaluation(
_unpmap(
Expand All @@ -259,14 +299,21 @@ def training_epoch_with_timing(training_state: TrainingState,
logging.info(metrics)
progress_fn(0, metrics)

init_key, scramble_key, local_key = jax.random.split(local_key, 3)
init_key = jax.random.split(init_key, (local_devices_to_use, num_envs // process_count))
env_state = reset_fn(init_key)
env_state = scramble_times(env_state, scramble_key)
env_state = step_fn(env_state, jnp.zeros((local_devices_to_use, num_envs // process_count,
env.action_size))) # Prevent recompilation on the second epoch

epoch_key, local_key = jax.random.split(local_key)
epoch_key = jax.random.split(epoch_key, local_devices_to_use)

for it in range(num_evals_after_init):
logging.info('starting iteration %s %s', it, time.time() - xt)

# optimization
epoch_key, local_key = jax.random.split(local_key)
epoch_keys = jax.random.split(epoch_key, local_devices_to_use)
(training_state,
training_metrics) = training_epoch_with_timing(training_state, epoch_keys)
(training_state, env_state,
training_metrics, epoch_key) = training_epoch_with_timing(training_state, env_state, epoch_key)

if process_id == 0:
# Run evals.
Expand All @@ -284,3 +331,4 @@ def training_epoch_with_timing(training_state: TrainingState,
(training_state.normalizer_params, training_state.policy_params))
pmap.synchronize_hosts()
return (make_policy, params, metrics)

8 changes: 5 additions & 3 deletions brax/training/agents/apg/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
from brax.training.agents.apg import train as apg
import jax


class APGTest(parameterized.TestCase):
"""Tests for APG module."""

def testTrain(self):
"""Test APG with a simple env."""
_, _, metrics = apg.train(
envs.get_environment('fast'),
policy_updates=200,
episode_length=128,
num_envs=64,
num_evals=200,
Expand All @@ -45,13 +45,14 @@ def testNetworkEncoding(self, normalize_observations):
env = envs.get_environment('fast')
original_inference, params, _ = apg.train(
envs.get_environment('fast'),
policy_updates=200,
episode_length=100,
action_repeat=4,
num_envs=16,
learning_rate=3e-3,
normalize_observations=normalize_observations,
num_evals=200,
truncation_length=10)
num_evals=200
)
normalize_fn = lambda x, y: x
if normalize_observations:
normalize_fn = running_statistics.normalize
Expand Down Expand Up @@ -86,6 +87,7 @@ def get_offset(rng):

_, _, _ = apg.train(
envs.get_environment('inverted_pendulum', backend='spring'),
policy_updates=200,
episode_length=100,
num_envs=8,
num_evals=10,
Expand Down
6 changes: 4 additions & 2 deletions brax/training/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,13 @@ def forward_log_det_jacobian(self, x):
class NormalTanhDistribution(ParametricDistribution):
"""Normal distribution followed by tanh."""

def __init__(self, event_size, min_std=0.001):
def __init__(self, event_size, min_std=0.001, var_scale=1):
"""Initialize the distribution.
Args:
event_size: the size of events (i.e. actions).
min_std: minimum std for the gaussian.
var_scale: adjust the gaussian's scale parameter.
"""
# We apply tanh to gaussian actions to bound them.
# Normally we would use TransformedDistribution to automatically
Expand All @@ -151,8 +152,9 @@ def __init__(self, event_size, min_std=0.001):
event_ndims=1,
reparametrizable=True)
self._min_std = min_std
self._var_scale = var_scale

def create_dist(self, parameters):
loc, scale = jnp.split(parameters, 2, axis=-1)
scale = jax.nn.softplus(scale) + self._min_std
scale = (jax.nn.softplus(scale) + self._min_std) * self._var_scale
return NormalDistribution(loc=loc, scale=scale)
12 changes: 9 additions & 3 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ class MLP(linen.Module):
kernel_init: Initializer = jax.nn.initializers.lecun_uniform()
activate_final: bool = False
bias: bool = True

layer_norm: bool = False

@linen.compact
def __call__(self, data: jnp.ndarray):
hidden = data
Expand All @@ -54,6 +55,8 @@ def __call__(self, data: jnp.ndarray):
hidden)
if i != len(self.layer_sizes) - 1 or self.activate_final:
hidden = self.activation(hidden)
if self.layer_norm:
hidden = linen.LayerNorm()(hidden)
return hidden


Expand Down Expand Up @@ -86,12 +89,15 @@ def make_policy_network(
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
hidden_layer_sizes: Sequence[int] = (256, 256),
activation: ActivationFn = linen.relu) -> FeedForwardNetwork:
activation: ActivationFn = linen.relu,
kernel_init: Initializer = jax.nn.initializers.lecun_uniform(),
layer_norm: bool = False) -> FeedForwardNetwork:
"""Creates a policy network."""
policy_module = MLP(
layer_sizes=list(hidden_layer_sizes) + [param_size],
activation=activation,
kernel_init=jax.nn.initializers.lecun_uniform())
kernel_init=kernel_init,
layer_norm=layer_norm)

def apply(processor_params, policy_params, obs):
obs = preprocess_observations_fn(obs, processor_params)
Expand Down

0 comments on commit b45760c

Please sign in to comment.