diff --git a/stoix/configs/default_ff_dpo_continuous.yaml b/stoix/configs/default_ff_dpo_continuous.yaml new file mode 100644 index 00000000..bfe4e609 --- /dev/null +++ b/stoix/configs/default_ff_dpo_continuous.yaml @@ -0,0 +1,7 @@ +defaults: + - logger: ff_ppo + - arch: anakin + - system: ff_dpo + - network: mlp_continuous + - env: brax/ant + - _self_ diff --git a/stoix/configs/system/ff_dpo.yaml b/stoix/configs/system/ff_dpo.yaml new file mode 100644 index 00000000..81685a91 --- /dev/null +++ b/stoix/configs/system/ff_dpo.yaml @@ -0,0 +1,23 @@ +# --- Defaults FF-PPO --- + +total_timesteps: 1e8 # Set the total environment steps. +# If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. +num_updates: ~ # Number of updates +seed: 42 + +# --- RL hyperparameters --- +actor_lr: 3e-4 # Learning rate for actor network +critic_lr: 3e-4 # Learning rate for critic network +update_batch_size: 1 # Number of vectorised gradient updates per device. +rollout_length: 16 # Number of environment steps per vectorised environment. +ppo_epochs: 4 # Number of ppo epochs per training data batch. +num_minibatches: 16 # Number of minibatches per ppo epoch. +gamma: 0.99 # Discounting factor. +gae_lambda: 0.95 # Lambda value for GAE computation. +clip_eps: 0.2 # Clipping value for PPO updates and value function. +ent_coef: 0.001 # Entropy regularisation term for loss function. +vf_coef: 1.0 # Critic weight in +max_grad_norm: 0.5 # Maximum norm of the gradients for a weight update. +decay_learning_rates: False # Whether learning rates should be linearly decayed during training. +alpha : 2.0 +beta : 0.6 diff --git a/stoix/systems/ppo/ff_dpo_continuous.py b/stoix/systems/ppo/ff_dpo_continuous.py new file mode 100644 index 00000000..787e660e --- /dev/null +++ b/stoix/systems/ppo/ff_dpo_continuous.py @@ -0,0 +1,556 @@ +import copy +import time +from typing import Any, Dict, Tuple + +import chex +import flax +import hydra +import jax +import jax.numpy as jnp +import optax +from colorama import Fore, Style +from flax.core.frozen_dict import FrozenDict +from jumanji.env import Environment +from omegaconf import DictConfig, OmegaConf +from optax._src.base import OptState +from rich.pretty import pprint + +from stoix.evaluator import evaluator_setup +from stoix.networks.base import FeedForwardActor as Actor +from stoix.networks.base import FeedForwardCritic as Critic +from stoix.systems.ppo.types import ( + ActorCriticOptStates, + ActorCriticParams, + LearnerState, + PPOTransition, +) +from stoix.types import ActorApply, CriticApply, ExperimentOutput, LearnerFn +from stoix.utils import make_env as environments +from stoix.utils.checkpointing import Checkpointer +from stoix.utils.jax import ( + merge_leading_dims, + unreplicate_batch_dim, + unreplicate_n_dims, +) +from stoix.utils.logger import LogEvent, StoixLogger +from stoix.utils.loss import dpo_loss +from stoix.utils.multistep import calculate_gae +from stoix.utils.total_timestep_checker import check_total_timesteps +from stoix.utils.training import make_learning_rate + + +def get_learner_fn( + env: Environment, + apply_fns: Tuple[ActorApply, CriticApply], + update_fns: Tuple[optax.TransformUpdateFn, optax.TransformUpdateFn], + config: DictConfig, +) -> LearnerFn[LearnerState]: + """Get the learner function.""" + + # Get apply and update functions for actor and critic networks. + actor_apply_fn, critic_apply_fn = apply_fns + actor_update_fn, critic_update_fn = update_fns + + def _update_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, Tuple]: + """A single update of the network. + + This function steps the environment and records the trajectory batch for + training. It then calculates advantages and targets based on the recorded + trajectory and updates the actor and critic networks based on the calculated + losses. + + Args: + learner_state (NamedTuple): + - params (ActorCriticParams): The current model parameters. + - opt_states (OptStates): The current optimizer states. + - key (PRNGKey): The random number generator state. + - env_state (State): The environment state. + - last_timestep (TimeStep): The last timestep in the current trajectory. + _ (Any): The current metrics info. + """ + + def _env_step(learner_state: LearnerState, _: Any) -> Tuple[LearnerState, PPOTransition]: + """Step the environment.""" + params, opt_states, key, env_state, last_timestep = learner_state + + # SELECT ACTION + key, policy_key = jax.random.split(key) + actor_policy = actor_apply_fn(params.actor_params, last_timestep.observation) + value = critic_apply_fn(params.critic_params, last_timestep.observation) + action = actor_policy.sample(seed=policy_key) + log_prob = actor_policy.log_prob(action) + + # STEP ENVIRONMENT + env_state, timestep = jax.vmap(env.step, in_axes=(0, 0))(env_state, action) + + # LOG EPISODE METRICS + done = timestep.last().reshape(-1) + info = timestep.extras["episode_metrics"] + + transition = PPOTransition( + done, action, value, timestep.reward, log_prob, last_timestep.observation, info + ) + learner_state = LearnerState(params, opt_states, key, env_state, timestep) + return learner_state, transition + + # STEP ENVIRONMENT FOR ROLLOUT LENGTH + learner_state, traj_batch = jax.lax.scan( + _env_step, learner_state, None, config.system.rollout_length + ) + + # CALCULATE ADVANTAGE + params, opt_states, key, env_state, last_timestep = learner_state + last_val = critic_apply_fn(params.critic_params, last_timestep.observation) + + r_t = traj_batch.reward + v_t = traj_batch.value + d_t = 1.0 - traj_batch.done.astype(jnp.float32) + d_t = (d_t * config.system.gamma).astype(jnp.float32) + advantages, targets = calculate_gae(v_t, r_t, d_t, last_val, config.system.gae_lambda) + + def _update_epoch(update_state: Tuple, _: Any) -> Tuple: + """Update the network for a single epoch.""" + + def _update_minibatch(train_state: Tuple, batch_info: Tuple) -> Tuple: + """Update the network for a single minibatch.""" + + # UNPACK TRAIN STATE AND BATCH INFO + params, opt_states, key = train_state + traj_batch, advantages, targets = batch_info + + def _actor_loss_fn( + actor_params: FrozenDict, + actor_opt_state: OptState, + traj_batch: PPOTransition, + gae: chex.Array, + rng_key: chex.PRNGKey, + ) -> Tuple: + """Calculate the actor loss.""" + # RERUN NETWORK + actor_policy = actor_apply_fn(actor_params, traj_batch.obs) + log_prob = actor_policy.log_prob(traj_batch.action) + + # CALCULATE ACTOR LOSS + loss_actor = dpo_loss( + log_prob, traj_batch.log_prob, gae, config.system.alpha, config.system.beta + ) + entropy = actor_policy.entropy(seed=rng_key).mean() + + total_loss_actor = loss_actor - config.system.ent_coef * entropy + return total_loss_actor, (loss_actor, entropy) + + def _critic_loss_fn( + critic_params: FrozenDict, + critic_opt_state: OptState, + traj_batch: PPOTransition, + targets: chex.Array, + ) -> Tuple: + """Calculate the critic loss.""" + # RERUN NETWORK + value = critic_apply_fn(critic_params, traj_batch.obs) + + # CALCULATE VALUE LOSS + value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( + -config.system.clip_eps, config.system.clip_eps + ) + value_losses = jnp.square(value - targets) + value_losses_clipped = jnp.square(value_pred_clipped - targets) + value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() + + critic_total_loss = config.system.vf_coef * value_loss + return critic_total_loss, (value_loss) + + # CALCULATE ACTOR LOSS + key, actor_loss_key = jax.random.split(key) + actor_grad_fn = jax.value_and_grad(_actor_loss_fn, has_aux=True) + actor_loss_info, actor_grads = actor_grad_fn( + params.actor_params, + opt_states.actor_opt_state, + traj_batch, + advantages, + actor_loss_key, + ) + + # CALCULATE CRITIC LOSS + critic_grad_fn = jax.value_and_grad(_critic_loss_fn, has_aux=True) + critic_loss_info, critic_grads = critic_grad_fn( + params.critic_params, opt_states.critic_opt_state, traj_batch, targets + ) + + # Compute the parallel mean (pmean) over the batch. + # This calculation is inspired by the Anakin architecture demo notebook. + # available at https://tinyurl.com/26tdzs5x + # This pmean could be a regular mean as the batch axis is on the same device. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="batch" + ) + # pmean over devices. + actor_grads, actor_loss_info = jax.lax.pmean( + (actor_grads, actor_loss_info), axis_name="device" + ) + + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="batch" + ) + # pmean over devices. + critic_grads, critic_loss_info = jax.lax.pmean( + (critic_grads, critic_loss_info), axis_name="device" + ) + + # UPDATE ACTOR PARAMS AND OPTIMISER STATE + actor_updates, actor_new_opt_state = actor_update_fn( + actor_grads, opt_states.actor_opt_state + ) + actor_new_params = optax.apply_updates(params.actor_params, actor_updates) + + # UPDATE CRITIC PARAMS AND OPTIMISER STATE + critic_updates, critic_new_opt_state = critic_update_fn( + critic_grads, opt_states.critic_opt_state + ) + critic_new_params = optax.apply_updates(params.critic_params, critic_updates) + + # PACK NEW PARAMS AND OPTIMISER STATE + new_params = ActorCriticParams(actor_new_params, critic_new_params) + new_opt_state = ActorCriticOptStates(actor_new_opt_state, critic_new_opt_state) + + # PACK LOSS INFO + total_loss = actor_loss_info[0] + critic_loss_info[0] + value_loss = critic_loss_info[1] + actor_loss = actor_loss_info[1][0] + entropy = actor_loss_info[1][1] + loss_info = { + "total_loss": total_loss, + "value_loss": value_loss, + "actor_loss": actor_loss, + "entropy": entropy, + } + return (new_params, new_opt_state, key), loss_info + + params, opt_states, traj_batch, advantages, targets, key = update_state + key, shuffle_key = jax.random.split(key) + + # SHUFFLE MINIBATCHES + batch_size = config.system.rollout_length * config.arch.num_envs + permutation = jax.random.permutation(shuffle_key, batch_size) + batch = (traj_batch, advantages, targets) + batch = jax.tree_util.tree_map(lambda x: merge_leading_dims(x, 2), batch) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, permutation, axis=0), batch + ) + minibatches = jax.tree_util.tree_map( + lambda x: jnp.reshape(x, [config.system.num_minibatches, -1] + list(x.shape[1:])), + shuffled_batch, + ) + + # UPDATE MINIBATCHES + (params, opt_states, key), loss_info = jax.lax.scan( + _update_minibatch, (params, opt_states, key), minibatches + ) + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + return update_state, loss_info + + update_state = (params, opt_states, traj_batch, advantages, targets, key) + + # UPDATE EPOCHS + update_state, loss_info = jax.lax.scan( + _update_epoch, update_state, None, config.system.ppo_epochs + ) + + params, opt_states, traj_batch, advantages, targets, key = update_state + learner_state = LearnerState(params, opt_states, key, env_state, last_timestep) + metric = traj_batch.info + return learner_state, (metric, loss_info) + + def learner_fn(learner_state: LearnerState) -> ExperimentOutput[LearnerState]: + """Learner function. + + This function represents the learner, it updates the network parameters + by iteratively applying the `_update_step` function for a fixed number of + updates. The `_update_step` function is vectorized over a batch of inputs. + + Args: + learner_state (NamedTuple): + - params (ActorCriticParams): The initial model parameters. + - opt_states (OptStates): The initial optimizer state. + - key (chex.PRNGKey): The random number generator state. + - env_state (LogEnvState): The environment state. + - timesteps (TimeStep): The initial timestep in the initial trajectory. + """ + + batched_update_step = jax.vmap(_update_step, in_axes=(0, None), axis_name="batch") + + learner_state, (episode_info, loss_info) = jax.lax.scan( + batched_update_step, learner_state, None, config.system.num_updates_per_eval + ) + return ExperimentOutput( + learner_state=learner_state, + episode_metrics=episode_info, + train_metrics=loss_info, + ) + + return learner_fn + + +def learner_setup( + env: Environment, keys: chex.Array, config: DictConfig +) -> Tuple[LearnerFn[LearnerState], Actor, LearnerState]: + """Initialise learner_fn, network, optimiser, environment and states.""" + # Get available TPU cores. + n_devices = len(jax.devices()) + + # Get number of actions. + num_actions = int(env.action_spec().shape[-1]) + config.system.action_dim = num_actions + + # PRNG keys. + key, actor_net_key, critic_net_key = keys + + # Define network and optimiser. + actor_torso = hydra.utils.instantiate(config.network.actor_network.pre_torso) + actor_action_head = hydra.utils.instantiate( + config.network.actor_network.action_head, action_dim=num_actions + ) + critic_torso = hydra.utils.instantiate(config.network.critic_network.pre_torso) + critic_head = hydra.utils.instantiate(config.network.critic_network.critic_head) + + actor_network = Actor(torso=actor_torso, action_head=actor_action_head) + critic_network = Critic(torso=critic_torso, critic_head=critic_head) + + actor_lr = make_learning_rate( + config.system.actor_lr, config, config.system.ppo_epochs, config.system.num_minibatches + ) + critic_lr = make_learning_rate( + config.system.critic_lr, config, config.system.ppo_epochs, config.system.num_minibatches + ) + + actor_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(actor_lr, eps=1e-5), + ) + critic_optim = optax.chain( + optax.clip_by_global_norm(config.system.max_grad_norm), + optax.adam(critic_lr, eps=1e-5), + ) + + # Initialise observation: Select only obs for a single agent. + init_x = env.observation_spec().generate_value() + init_x = jax.tree_util.tree_map(lambda x: x[None, ...], init_x) + + # Initialise actor params and optimiser state. + actor_params = actor_network.init(actor_net_key, init_x) + actor_opt_state = actor_optim.init(actor_params) + + # Initialise critic params and optimiser state. + critic_params = critic_network.init(critic_net_key, init_x) + critic_opt_state = critic_optim.init(critic_params) + + # Pack params. + params = ActorCriticParams(actor_params, critic_params) + + # Vmap network apply function over number of agents. + vmapped_actor_network_apply_fn = actor_network.apply + + vmapped_critic_network_apply_fn = critic_network.apply + + # Pack apply and update functions. + apply_fns = (vmapped_actor_network_apply_fn, vmapped_critic_network_apply_fn) + update_fns = (actor_optim.update, critic_optim.update) + + # Get batched iterated update and replicate it to pmap it over cores. + learn = get_learner_fn(env, apply_fns, update_fns, config) + learn = jax.pmap(learn, axis_name="device") + + # Initialise environment states and timesteps: across devices and batches. + key, *env_keys = jax.random.split( + key, n_devices * config.system.update_batch_size * config.arch.num_envs + 1 + ) + env_states, timesteps = jax.vmap(env.reset, in_axes=(0))( + jnp.stack(env_keys), + ) + reshape_states = lambda x: x.reshape( + (n_devices, config.system.update_batch_size, config.arch.num_envs) + x.shape[1:] + ) + # (devices, update batch size, num_envs, ...) + env_states = jax.tree_map(reshape_states, env_states) + timesteps = jax.tree_map(reshape_states, timesteps) + + # Load model from checkpoint if specified. + if config.logger.checkpointing.load_model: + loaded_checkpoint = Checkpointer( + model_name=config.logger.system_name, + **config.logger.checkpointing.load_args, # Other checkpoint args + ) + # Restore the learner state from the checkpoint + restored_params, _ = loaded_checkpoint.restore_params() + # Update the params + params = restored_params + + # Define params to be replicated across devices and batches. + key, step_keys = jax.random.split(key) + opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state) + replicate_learner = (params, opt_states, step_keys) + + # Duplicate learner for update_batch_size. + broadcast = lambda x: jnp.broadcast_to(x, (config.system.update_batch_size,) + x.shape) + replicate_learner = jax.tree_map(broadcast, replicate_learner) + + # Duplicate learner across devices. + replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=jax.devices()) + + # Initialise learner state. + params, opt_states, step_keys = replicate_learner + init_learner_state = LearnerState(params, opt_states, step_keys, env_states, timesteps) + + return learn, actor_network, init_learner_state + + +def run_experiment(_config: DictConfig) -> None: + """Runs experiment.""" + config = copy.deepcopy(_config) + + # Calculate total timesteps. + n_devices = len(jax.devices()) + config = check_total_timesteps(config) + assert ( + config.system.num_updates > config.arch.num_evaluation + ), "Number of updates per evaluation must be less than total number of updates." + + # Create the enviroments for train and eval. + env, eval_env = environments.make(config=config) + + # PRNG keys. + key, key_e, actor_net_key, critic_net_key = jax.random.split( + jax.random.PRNGKey(config["system"]["seed"]), num=4 + ) + + # Setup learner. + learn, actor_network, learner_state = learner_setup( + env, (key, actor_net_key, critic_net_key), config + ) + + # Setup evaluator. + evaluator, absolute_metric_evaluator, (trained_params, eval_keys) = evaluator_setup( + eval_env=eval_env, + key_e=key_e, + network=actor_network, + params=learner_state.params.actor_params, + config=config, + ) + + # Calculate number of updates per evaluation. + config.system.num_updates_per_eval = config.system.num_updates // config.arch.num_evaluation + steps_per_rollout = ( + n_devices + * config.system.num_updates_per_eval + * config.system.rollout_length + * config.system.update_batch_size + * config.arch.num_envs + ) + + # Logger setup + logger = StoixLogger(config) + cfg: Dict = OmegaConf.to_container(config, resolve=True) + cfg["arch"]["devices"] = jax.devices() + pprint(cfg) + + # Set up checkpointer + save_checkpoint = config.logger.checkpointing.save_model + if save_checkpoint: + checkpointer = Checkpointer( + metadata=config, # Save all config as metadata in the checkpoint + model_name=config.logger.system_name, + **config.logger.checkpointing.save_args, # Checkpoint args + ) + + # Run experiment for a total number of evaluations. + max_episode_return = jnp.float32(0.0) + best_params = unreplicate_batch_dim(learner_state.params.actor_params) + for eval_step in range(config.arch.num_evaluation): + # Train. + start_time = time.time() + + learner_output = learn(learner_state) + jax.block_until_ready(learner_output) + + # Log the results of the training. + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + learner_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + + # Separately log timesteps, actoring metrics and training metrics. + logger.log({"timestep": t}, t, eval_step, LogEvent.MISC) + logger.log(learner_output.episode_metrics, t, eval_step, LogEvent.ACT) + logger.log(learner_output.train_metrics, t, eval_step, LogEvent.TRAIN) + + # Prepare for evaluation. + start_time = time.time() + trained_params = unreplicate_batch_dim( + learner_output.learner_state.params.actor_params + ) # Select only actor params + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + # Evaluate. + evaluator_output = evaluator(trained_params, eval_keys) + jax.block_until_ready(evaluator_output) + + # Log the results of the evaluation. + elapsed_time = time.time() - start_time + episode_return = jnp.mean(evaluator_output.episode_metrics["episode_return"]) + + evaluator_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.EVAL) + + if save_checkpoint: + # Save checkpoint of learner state + checkpointer.save( + timestep=int(steps_per_rollout * (eval_step + 1)), + unreplicated_learner_state=unreplicate_n_dims(learner_output.learner_state), + episode_return=episode_return, + ) + + if config.arch.absolute_metric and max_episode_return <= episode_return: + best_params = copy.deepcopy(trained_params) + max_episode_return = episode_return + + # Update runner state to continue training. + learner_state = learner_output.learner_state + + # Measure absolute metric. + if config.arch.absolute_metric: + start_time = time.time() + + key_e, *eval_keys = jax.random.split(key_e, n_devices + 1) + eval_keys = jnp.stack(eval_keys) + eval_keys = eval_keys.reshape(n_devices, -1) + + evaluator_output = absolute_metric_evaluator(best_params, eval_keys) + jax.block_until_ready(evaluator_output) + + elapsed_time = time.time() - start_time + t = int(steps_per_rollout * (eval_step + 1)) + evaluator_output.episode_metrics["steps_per_second"] = steps_per_rollout / elapsed_time + logger.log(evaluator_output.episode_metrics, t, eval_step, LogEvent.ABSOLUTE) + + # Stop the logger. + logger.stop() + + +@hydra.main( + config_path="../../configs", config_name="default_ff_dpo_continuous.yaml", version_base="1.2" +) +def hydra_entry_point(cfg: DictConfig) -> None: + """Experiment entry point.""" + # Allow dynamic attributes. + OmegaConf.set_struct(cfg, False) + + # Run experiment. + run_experiment(cfg) + + print(f"{Fore.CYAN}{Style.BRIGHT}DPO experiment completed{Style.RESET_ALL}") + + +if __name__ == "__main__": + hydra_entry_point() diff --git a/stoix/systems/ppo/ff_ppo_continuous.py b/stoix/systems/ppo/ff_ppo_continuous.py index f3156d80..1c2b99b7 100644 --- a/stoix/systems/ppo/ff_ppo_continuous.py +++ b/stoix/systems/ppo/ff_ppo_continuous.py @@ -33,6 +33,7 @@ unreplicate_n_dims, ) from stoix.utils.logger import LogEvent, StoixLogger +from stoix.utils.loss import ppo_loss from stoix.utils.multistep import calculate_gae from stoix.utils.total_timestep_checker import check_total_timesteps from stoix.utils.training import make_learning_rate @@ -130,19 +131,9 @@ def _actor_loss_fn( log_prob = actor_policy.log_prob(traj_batch.action) # CALCULATE ACTOR LOSS - ratio = jnp.exp(log_prob - traj_batch.log_prob) - gae = (gae - gae.mean()) / (gae.std() + 1e-8) - loss_actor1 = ratio * gae - loss_actor2 = ( - jnp.clip( - ratio, - 1.0 - config.system.clip_eps, - 1.0 + config.system.clip_eps, - ) - * gae + loss_actor = ppo_loss( + log_prob, traj_batch.log_prob, gae, config.system.clip_eps ) - loss_actor = -jnp.minimum(loss_actor1, loss_actor2) - loss_actor = loss_actor.mean() entropy = actor_policy.entropy(seed=rng_key).mean() total_loss_actor = loss_actor - config.system.ent_coef * entropy diff --git a/stoix/utils/loss.py b/stoix/utils/loss.py index 5b723d54..51032bff 100644 --- a/stoix/utils/loss.py +++ b/stoix/utils/loss.py @@ -30,6 +30,25 @@ def ppo_loss( return loss_actor +def dpo_loss( + pi_log_prob_t: chex.Array, + b_pi_log_prob_t: chex.Array, + gae_t: chex.Array, + alpha: float, + beta: float, +) -> chex.Array: + log_diff = pi_log_prob_t - b_pi_log_prob_t + gae_t = (gae_t - gae_t.mean()) / (gae_t.std() + 1e-8) + ratio = jnp.exp(log_diff) + is_pos = (gae_t >= 0.0).astype(jnp.float32) + r1 = ratio - 1.0 + drift1 = jax.nn.relu(r1 * gae_t - alpha * jax.nn.tanh(r1 * gae_t / alpha)) + drift2 = jax.nn.relu(log_diff * gae_t - beta * jax.nn.tanh(log_diff * gae_t / beta)) + drift = drift1 * is_pos + drift2 * (1 - is_pos) + loss_actor = -(ratio * gae_t - drift).mean() + return loss_actor + + def clipped_value_loss( pred_value_t: chex.Array, behavior_value_t: chex.Array, targets_t: chex.Array, epsilon: float ) -> chex.Array: