From db25a6c05639ac53698cfa24b1708a4d45d594ef Mon Sep 17 00:00:00 2001 From: Sasha Abramowitz Date: Sun, 3 Nov 2024 17:47:12 +0200 Subject: [PATCH 1/4] feat: switch sebublba to using shard_map like mava --- .../default/sebulba/default_ff_ppo.yaml | 2 +- stoix/systems/ppo/sebulba/ff_ppo.py | 225 ++++++++++-------- stoix/utils/sebulba_utils.py | 40 ++-- 3 files changed, 148 insertions(+), 119 deletions(-) diff --git a/stoix/configs/default/sebulba/default_ff_ppo.yaml b/stoix/configs/default/sebulba/default_ff_ppo.yaml index f0501dc..b5849f2 100644 --- a/stoix/configs/default/sebulba/default_ff_ppo.yaml +++ b/stoix/configs/default/sebulba/default_ff_ppo.yaml @@ -1,7 +1,7 @@ defaults: - logger: base_logger - arch: sebulba - - system: ff_ppo + - system: ppo/ff_ppo - network: mlp - env: envpool/cartpole - _self_ diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py index 0910b0e..3e5b0cb 100644 --- a/stoix/systems/ppo/sebulba/ff_ppo.py +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -8,7 +8,6 @@ from typing import Any, Callable, Dict, List, Sequence, Tuple import chex -import flax import hydra import jax import jax.numpy as jnp @@ -16,7 +15,9 @@ import optax from colorama import Fore, Style from flax.core.frozen_dict import FrozenDict -from flax.jax_utils import unreplicate +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map +from jax.sharding import Mesh, NamedSharding, PartitionSpec, Sharding from omegaconf import DictConfig, OmegaConf from rich.pretty import pprint @@ -53,9 +54,10 @@ def get_act_fn( - apply_fns: Tuple[ActorApply, CriticApply] + apply_fns: Tuple[ActorApply, CriticApply], ) -> Callable[ - [ActorCriticParams, Observation, chex.PRNGKey], Tuple[chex.Array, chex.Array, chex.Array] + [ActorCriticParams, Observation, chex.PRNGKey], + Tuple[chex.Array, chex.Array, chex.Array], ]: """Get the act function that is used by the actor threads.""" actor_apply_fn, critic_apply_fn = apply_fns @@ -88,89 +90,87 @@ def get_rollout_fn( # Unpack and set up the functions act_fn = get_act_fn(apply_fns) act_fn = jax.jit(act_fn, device=actor_device) - cpu = jax.devices("cpu")[0] - move_to_device = lambda tree: jax.tree.map(lambda x: jax.device_put(x, actor_device), tree) + move_to_device = lambda x: jax.device_put(x, device=actor_device) split_key_fn = jax.jit(jax.random.split, device=actor_device) # Build the environments envs = env_factory(config.arch.actor.num_envs_per_actor) # Create the rollout function def rollout_fn(rng_key: chex.PRNGKey) -> None: - # Ensure all computation is on the actor device - with jax.default_device(actor_device): - # Reset the environment - timestep = envs.reset(seed=seeds) - - # Loop until the thread is stopped - while not thread_lifetime.should_stop(): - # Create the list to store transitions - traj: List[PPOTransition] = [] - # Create the dictionary to store timings for metrics - actor_timings_dict: Dict[str, List[float]] = defaultdict(list) - episode_metrics: List[Dict[str, List[float]]] = [] - # Rollout the environment - with RecordTimeTo(actor_timings_dict["single_rollout_time"]): - # Loop until the rollout length is reached - for _ in range(config.system.rollout_length): - # Get the latest parameters from the source - with RecordTimeTo(actor_timings_dict["get_params_time"]): - params = params_source.get() - - # Move the environment data to the actor device - cached_obs = move_to_device(timestep.observation) - - # Run the actor and critic networks to get the action, value and log_prob - with RecordTimeTo(actor_timings_dict["compute_action_time"]): - rng_key, policy_key = split_key_fn(rng_key) - action, value, log_prob = act_fn(params, cached_obs, policy_key) - + # Reset the environment + timestep = envs.reset(seed=seeds) + + # Loop until the thread is stopped + while not thread_lifetime.should_stop(): + # Create the list to store transitions + traj: List[PPOTransition] = [] + # Create the dictionary to store timings for metrics + actor_timings_dict: Dict[str, List[float]] = defaultdict(list) + episode_metrics: List[Dict[str, List[float]]] = [] + # Rollout the environment + with RecordTimeTo(actor_timings_dict["single_rollout_time"]): + # Loop until the rollout length is reached + for _ in range(config.system.rollout_length): + # Get the latest parameters from the source + with RecordTimeTo(actor_timings_dict["get_params_time"]): + params = params_source.get() + + # Move the environment data to the actor device + cached_obs = move_to_device(timestep.observation) + + # Run the actor and critic networks to get the action, value and log_prob + with RecordTimeTo(actor_timings_dict["compute_action_time"]): + rng_key, policy_key = split_key_fn(rng_key) + action, value, log_prob = act_fn(params, cached_obs, policy_key) # Move the action to the CPU - action_cpu = np.asarray(jax.device_put(action, cpu)) + action_cpu = jax.device_get(action) - # Step the environment - with RecordTimeTo(actor_timings_dict["env_step_time"]): - timestep = envs.step(action_cpu) + # Step the environment + with RecordTimeTo(actor_timings_dict["env_step_time"]): + timestep = envs.step(action_cpu) - # Get the next dones and truncation flags - dones = np.logical_and( - np.asarray(timestep.last()), np.asarray(timestep.discount == 0.0) - ) - trunc = np.logical_and( - np.asarray(timestep.last()), np.asarray(timestep.discount == 1.0) - ) - cached_next_dones = move_to_device(dones) - cached_next_trunc = move_to_device(trunc) - - # Append PPOTransition to the trajectory list - reward = timestep.reward - metrics = timestep.extras["metrics"] - traj.append( - PPOTransition( - cached_next_dones, - cached_next_trunc, - action, - value, - reward, - log_prob, - cached_obs, - metrics, - ) - ) - episode_metrics.append(metrics) - - # Send the trajectory to the pipeline - with RecordTimeTo(actor_timings_dict["rollout_put_time"]): - try: - pipeline.put(traj, timestep, actor_timings_dict, episode_metrics) - except queue.Full: - warnings.warn( - "Waited too long to add to the rollout queue, killing the actor thread", - stacklevel=2, + # Get the next dones and truncation flags + dones = np.logical_and( + np.asarray(timestep.last()), + np.asarray(timestep.discount == 0.0), + ) + trunc = np.logical_and( + np.asarray(timestep.last()), + np.asarray(timestep.discount == 1.0), + ) + cached_next_dones = move_to_device(dones) + cached_next_trunc = move_to_device(trunc) + + # Append PPOTransition to the trajectory list + reward = timestep.reward + metrics = timestep.extras["metrics"] + traj.append( + PPOTransition( + cached_next_dones, + cached_next_trunc, + action, + value, + reward, + log_prob, + cached_obs, + metrics, ) - break + ) + episode_metrics.append(metrics) + + # Send the trajectory to the pipeline + with RecordTimeTo(actor_timings_dict["rollout_put_time"]): + try: + pipeline.put(traj, timestep, actor_timings_dict, episode_metrics) + except queue.Full: + warnings.warn( + "Waited too long to add to the rollout queue, killing the actor thread", + stacklevel=2, + ) + break - # Close the environments - envs.close() + # Close the environments + envs.close() return rollout_fn @@ -226,7 +226,6 @@ def get_learner_step_fn( def _update_step( learner_state: CoreLearnerState, traj_batch: PPOTransition ) -> Tuple[CoreLearnerState, Dict[str, chex.Array]]: - # CALCULATE ADVANTAGE params, opt_states, key, last_timestep = learner_state last_val = critic_apply_fn(params.critic_params, last_timestep.observation) @@ -400,6 +399,9 @@ def learner_step_fn( - timesteps (TimeStep): The initial timestep in the initial trajectory. """ + # This function is shard mapped on the batch axis, but `_update_step` needs + # the first axis to be time + traj_batch = jax.tree_util.tree_map(lambda x: x.swapaxes(0, 1), traj_batch) learner_state, loss_info = _update_step(learner_state, traj_batch) return SebulbaExperimentOutput( @@ -458,11 +460,11 @@ def learner_rollout(learner_state: CoreLearnerState) -> None: q_sizes.append(pipeline.qsize()) # After the update we need to update the params sources with the new params - unreplicated_params = unreplicate(learner_state.params) + params = jax.block_until_ready(learner_state.params) # We loop over all params sources and update them with the new params # This is so that all the actors can get the latest params for source in params_sources: - source.update(unreplicated_params) + source.update(params) # We then pass all the environment metrics, training metrics, current learner state # and timings to the evaluation queue. This is so the evaluator correctly evaluates @@ -477,7 +479,8 @@ def learner_rollout(learner_state: CoreLearnerState) -> None: # If the queue is full for more than 60 seconds we kill the learner thread # This should never happen eval_queue.put( - (episode_metrics, train_metrics, learner_state, timing_dict), timeout=60 + (episode_metrics, train_metrics, learner_state, timing_dict), + timeout=60, ) except queue.Full: warnings.warn( @@ -520,6 +523,7 @@ def learner_setup( SebulbaLearnerFn[CoreLearnerState, PPOTransition], Tuple[ActorApply, CriticApply], CoreLearnerState, + Sharding, ]: """Setup for the learner state and networks.""" @@ -546,10 +550,16 @@ def learner_setup( critic_network = Critic(torso=critic_torso, critic_head=critic_head) actor_lr = make_learning_rate( - config.system.actor_lr, config, config.system.epochs, config.system.num_minibatches + config.system.actor_lr, + config, + config.system.epochs, + config.system.num_minibatches, ) critic_lr = make_learning_rate( - config.system.critic_lr, config, config.system.epochs, config.system.num_minibatches + config.system.critic_lr, + config, + config.system.epochs, + config.system.num_minibatches, ) actor_optim = optax.chain( @@ -584,9 +594,24 @@ def learner_setup( apply_fns = (actor_network_apply_fn, critic_network_apply_fn) update_fns = (actor_optim.update, critic_optim.update) - # Get batched iterated update and replicate it to pmap it over cores. + # Define how data is distributed with `shard_map` + devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) + mesh = Mesh(devices, axis_names=("device",)) + model_spec = PartitionSpec() # replicate the model + data_spec = PartitionSpec("device") # shard the data + learner_sharding = NamedSharding(mesh, model_spec) # used in the pipeline + # Defines how the learner state is sharded: params, opt and key = replicated, timestep = sharded + learn_state_spec = CoreLearnerState(model_spec, model_spec, model_spec, data_spec) + learn_step = get_learner_step_fn(apply_fns, update_fns, config) - learn_step = jax.pmap(learn_step, axis_name="device") + learn_step = jax.jit( + shard_map( + learn_step, + mesh=mesh, + in_specs=(learn_state_spec, data_spec), + out_specs=SebulbaExperimentOutput(learn_state_spec, data_spec), + ) + ) # Load model from checkpoint if specified. if config.logger.checkpointing.load_model: @@ -601,18 +626,15 @@ def learner_setup( # Define params to be replicated across learner devices. opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state) - replicate_learner = (params, opt_states) - # Duplicate across learner devices. - replicate_learner = flax.jax_utils.replicate(replicate_learner, devices=learner_devices) + # Shard across learner devices. + key, step_key = jax.random.split(key) + params, opt_states, step_keys = jax.device_put((params, opt_states, step_key), learner_sharding) # Initialise learner state. - params, opt_states = replicate_learner - key, step_key = jax.random.split(key) - step_keys = jax.random.split(step_key, len(learner_devices)) init_learner_state = CoreLearnerState(params, opt_states, step_keys, None) - return learn_step, apply_fns, init_learner_state + return learn_step, apply_fns, init_learner_state, learner_sharding def run_experiment(_config: DictConfig) -> float: @@ -658,7 +680,7 @@ def run_experiment(_config: DictConfig) -> float: np_rng = np.random.default_rng(config.arch.seed) # Setup learner. - learn_step, apply_fns, learner_state = learner_setup( + learn_step, apply_fns, learner_state, learner_sharding = learner_setup( env_factory, (key, actor_net_key, critic_net_key), local_learner_devices, config ) actor_apply_fn, _ = apply_fns @@ -684,7 +706,7 @@ def run_experiment(_config: DictConfig) -> float: ) # Get initial parameters - initial_params = unreplicate(learner_state.params) + initial_params = jax.device_put(learner_state.params, actor_devices[0]) # Get the number of steps consumed by the learner per learner step steps_per_learner_step = config.system.rollout_length * config.arch.actor.num_envs_per_actor @@ -696,7 +718,7 @@ def run_experiment(_config: DictConfig) -> float: pipeline_lifetime = ThreadLifetime() # Now we create the pipeline pipeline = OnPolicyPipeline( - config.arch.pipeline_queue_size, local_learner_devices, pipeline_lifetime + config.arch.pipeline_queue_size, learner_sharding, pipeline_lifetime ) # Start the pipeline pipeline.start() @@ -776,9 +798,9 @@ def run_experiment(_config: DictConfig) -> float: logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) # Evaluate the current model and log the metrics - unreplicated_actor_params = unreplicate(learner_state.params.actor_params) + learner_state_cpu = jax.device_get(learner_state) key, eval_key = jax.random.split(key, 2) - eval_metrics = evaluator(unreplicated_actor_params, eval_key) + eval_metrics = evaluator(learner_state_cpu.params.actor_params, eval_key) logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) episode_return = jnp.mean(eval_metrics["episode_return"]) @@ -787,12 +809,12 @@ def run_experiment(_config: DictConfig) -> float: # Save checkpoint of learner state checkpointer.save( timestep=steps_consumed_per_eval * (eval_step + 1), - unreplicated_learner_state=unreplicate(learner_state), + unreplicated_learner_state=learner_state_cpu, episode_return=episode_return, ) if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(unreplicated_actor_params) + best_params = copy.deepcopy(learner_state_cpu.params.actor_params) max_episode_return = episode_return evaluator_envs.close() @@ -829,7 +851,12 @@ def run_experiment(_config: DictConfig) -> float: if config.arch.absolute_metric: print(f"{Fore.MAGENTA}{Style.BRIGHT}Measuring absolute metric...{Style.RESET_ALL}") abs_metric_evaluator, abs_metric_evaluator_envs = get_sebulba_eval_fn( - env_factory, eval_act_fn, config, np_rng, evaluator_device, eval_multiplier=10 + env_factory, + eval_act_fn, + config, + np_rng, + evaluator_device, + eval_multiplier=10, ) key, eval_key = jax.random.split(key, 2) eval_metrics = abs_metric_evaluator(best_params, eval_key) diff --git a/stoix/utils/sebulba_utils.py b/stoix/utils/sebulba_utils.py index 6d35ab8..e97a93e 100644 --- a/stoix/utils/sebulba_utils.py +++ b/stoix/utils/sebulba_utils.py @@ -7,10 +7,13 @@ import jax import jax.numpy as jnp from colorama import Fore, Style +from jax.sharding import Sharding from jumanji.types import TimeStep from stoix.base_types import Parameters, StoixTransition +QUEUE_PUT_TIMEOUT = 100 + # Copied from https://github.com/instadeepai/sebulba/blob/main/sebulba/core.py class ThreadLifetime: @@ -33,16 +36,18 @@ class OnPolicyPipeline(threading.Thread): and limit the max number of samples in device memory at one time to avoid OOM issues. """ - def __init__(self, max_size: int, learner_devices: List[jax.Device], lifetime: ThreadLifetime): + def __init__(self, max_size: int, learner_sharding: Sharding, lifetime: ThreadLifetime): """ Initializes the pipeline with a maximum size and the devices to shard trajectories across. Args: max_size: The maximum number of trajectories to keep in the pipeline. - learner_devices: The devices to shard trajectories across. + learner_sharding: The sharding used for the learner's update function. + lifetime: A `ThreadLifetime` which is used to stop this thread. """ super().__init__(name="Pipeline") - self.learner_devices = learner_devices + + self.sharding = learner_sharding self.tickets_queue: queue.Queue = queue.Queue() self._queue: queue.Queue = queue.Queue(maxsize=max_size) self.lifetime = lifetime @@ -77,12 +82,7 @@ def put( # [Transition(num_envs)] * rollout_len --> Transition[(rollout_len, num_envs,) traj = self.stack_trajectory(traj) - # Split trajectory on the num envs axis so each learner device gets a valid full rollout - sharded_traj = jax.tree.map(lambda x: self.shard_split_playload(x, axis=1), traj) - - # Timestep[(num_envs, ...), ...] --> - # [(num_envs / num_learner_devices, ...)] * num_learner_devices - sharded_timestep = jax.tree.map(self.shard_split_playload, timestep) + traj, timestep = jax.device_put((traj, timestep), device=self.sharding) # Concatenate metrics - List[Dict[str, List[float]]] --> Dict[str, List[float]] actor_episode_metrics = self.concatenate_metrics(actor_episode_metrics) @@ -99,9 +99,9 @@ def put( # is raised. try: self._queue.put( - (sharded_traj, sharded_timestep, actor_timings_dict, actor_episode_metrics), + (traj, timestep, actor_timings_dict, actor_episode_metrics), block=True, - timeout=180, + timeout=QUEUE_PUT_TIMEOUT, ) except queue.Full: print( @@ -110,7 +110,7 @@ def put( ) finally: with end_condition: - end_condition.notify() # tell we have finish + end_condition.notify() # notify that we have finished def qsize(self) -> int: """Returns the number of trajectories in the pipeline.""" @@ -126,7 +126,10 @@ def get( def stack_trajectory(self, trajectory: List[StoixTransition]) -> StoixTransition: """Stack a list of parallel_env transitions into a single transition of shape [rollout_len, num_envs, ...].""" - return jax.tree_map(lambda *x: jnp.stack(x, axis=0), *trajectory) # type: ignore + return jax.tree_map( # type: ignore + lambda *x: jnp.stack(x, axis=0).swapaxes(0, 1), + *trajectory, + ) @partial(jax.jit, static_argnums=(0,)) def concatenate_metrics( @@ -135,14 +138,13 @@ def concatenate_metrics( """Concatenate a list of actor metrics into a single dictionary.""" return jax.tree_map(lambda *x: jnp.concatenate(x, axis=0), *actor_metrics) # type: ignore - def shard_split_playload(self, payload: Any, axis: int = 0) -> Any: - split_payload = jnp.split(payload, len(self.learner_devices), axis=axis) - return jax.device_put_sharded(split_payload, devices=self.learner_devices) - def clear(self) -> None: """Clear the pipeline.""" while not self._queue.empty(): - self._queue.get() + try: + self._queue.get(block=False) + except queue.Empty: + break class ParamsSource(threading.Thread): @@ -164,7 +166,7 @@ def run(self) -> None: while not self.lifetime.should_stop(): try: waiting = self.new_value.get(block=True, timeout=1) - self.value = jax.device_put(jax.block_until_ready(waiting), self.device) + self.value = jax.device_put(waiting, self.device) except queue.Empty: continue From 1dd52672d0b14d4ce103d28708ae651ad951e3c8 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Mon, 4 Nov 2024 11:10:20 +0000 Subject: [PATCH 2/4] chore: add comments --- stoix/systems/ppo/sebulba/ff_ppo.py | 31 +++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py index 3e5b0cb..77e689e 100644 --- a/stoix/systems/ppo/sebulba/ff_ppo.py +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -54,10 +54,9 @@ def get_act_fn( - apply_fns: Tuple[ActorApply, CriticApply], + apply_fns: Tuple[ActorApply, CriticApply] ) -> Callable[ - [ActorCriticParams, Observation, chex.PRNGKey], - Tuple[chex.Array, chex.Array, chex.Array], + [ActorCriticParams, Observation, chex.PRNGKey], Tuple[chex.Array, chex.Array, chex.Array] ]: """Get the act function that is used by the actor threads.""" actor_apply_fn, critic_apply_fn = apply_fns @@ -595,15 +594,27 @@ def learner_setup( update_fns = (actor_optim.update, critic_optim.update) # Define how data is distributed with `shard_map` - devices = mesh_utils.create_device_mesh((len(learner_devices),), devices=learner_devices) + # First create the device mesh + num_learner_devices = len(learner_devices) + devices = mesh_utils.create_device_mesh((num_learner_devices,), devices=learner_devices) mesh = Mesh(devices, axis_names=("device",)) - model_spec = PartitionSpec() # replicate the model - data_spec = PartitionSpec("device") # shard the data - learner_sharding = NamedSharding(mesh, model_spec) # used in the pipeline - # Defines how the learner state is sharded: params, opt and key = replicated, timestep = sharded + # Then create partition specs for a) the model items such as neural network parameters, + # optimizer state, and the random keys and b) the actual training data. + # For a) we replicate the model over the devices + model_spec = PartitionSpec() + # For b) we shard the data over the devices (each device will have a slice of training data) + data_spec = PartitionSpec("device") + # Using these specs, we create the learner state spec for the respective items. + # The learner state is sharded with: params, opt and key = replicated, timestep = sharded learn_state_spec = CoreLearnerState(model_spec, model_spec, model_spec, data_spec) + # Lastly, we create the learner sharding which is used in the pipeline. + # This uses the model spec so it simply replicates the data over the learner devices. + learner_sharding = NamedSharding(mesh, model_spec) + # We now construct the learner step learn_step = get_learner_step_fn(apply_fns, update_fns, config) + # and compile it giving it the input specs and output specs of how it is + # sharded over the devices learn_step = jax.jit( shard_map( learn_step, @@ -627,7 +638,7 @@ def learner_setup( # Define params to be replicated across learner devices. opt_states = ActorCriticOptStates(actor_opt_state, critic_opt_state) - # Shard across learner devices. + # Shard across learner devices i.e. replicate. key, step_key = jax.random.split(key) params, opt_states, step_keys = jax.device_put((params, opt_states, step_key), learner_sharding) @@ -705,7 +716,7 @@ def run_experiment(_config: DictConfig) -> float: **config.logger.checkpointing.save_args, # Checkpoint args ) - # Get initial parameters + # Get initial parameters and put them on the first actor device. initial_params = jax.device_put(learner_state.params, actor_devices[0]) # Get the number of steps consumed by the learner per learner step From 0e049878e39192faa4fee302fffb5afd4624c3cb Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Mon, 4 Nov 2024 12:11:01 +0000 Subject: [PATCH 3/4] chore: change stack code --- stoix/utils/sebulba_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/stoix/utils/sebulba_utils.py b/stoix/utils/sebulba_utils.py index e97a93e..bafb974 100644 --- a/stoix/utils/sebulba_utils.py +++ b/stoix/utils/sebulba_utils.py @@ -80,7 +80,7 @@ def put( self.tickets_queue.put((start_condition, end_condition)) start_condition.wait() # wait to be allowed to start - # [Transition(num_envs)] * rollout_len --> Transition[(rollout_len, num_envs,) + # [Transition(num_envs)] * rollout_len --> Transition[(num_envs, rollout_len,) traj = self.stack_trajectory(traj) traj, timestep = jax.device_put((traj, timestep), device=self.sharding) @@ -125,9 +125,10 @@ def get( @partial(jax.jit, static_argnums=(0,)) def stack_trajectory(self, trajectory: List[StoixTransition]) -> StoixTransition: """Stack a list of parallel_env transitions into a single - transition of shape [rollout_len, num_envs, ...].""" + transition of shape [num_envs, rollout_len, ...] i.e. + stack to create the time axis.""" return jax.tree_map( # type: ignore - lambda *x: jnp.stack(x, axis=0).swapaxes(0, 1), + lambda *x: jnp.stack(x, axis=1), *trajectory, ) From ef08c9a6b0e8bd822aab202e243376e7062ffdc8 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Mon, 4 Nov 2024 13:30:37 +0000 Subject: [PATCH 4/4] chore: slight chhanges to eval learner state device placement --- stoix/configs/arch/sebulba.yaml | 10 +++++----- stoix/systems/ppo/sebulba/ff_ppo.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/stoix/configs/arch/sebulba.yaml b/stoix/configs/arch/sebulba.yaml index 78f55a1..a7faa17 100644 --- a/stoix/configs/arch/sebulba.yaml +++ b/stoix/configs/arch/sebulba.yaml @@ -2,19 +2,19 @@ architecture_name : sebulba # --- Training --- seed: 42 # RNG seed. -total_num_envs: 1024 # Total Number of vectorised environments across all actors. Needs to be divisible by the number of actor devices and actors per device. -total_timesteps: 1e7 # Set the total environment steps. +total_num_envs: 128 # Total Number of vectorised environments across all actors. Needs to be divisible by the number of actor devices and actors per device. +total_timesteps: 1e5 # Set the total environment steps. # If unspecified, it's derived from num_updates; otherwise, num_updates adjusts based on this value. num_updates: ~ # Number of updates # Define the number of actors per device and which devices to use. actor: - device_ids: [0,1] # Define which devices to use for the actors. - actor_per_device: 2 # number of different threads per actor device. + device_ids: [0] # Define which devices to use for the actors. + actor_per_device: 1 # number of different threads per actor device. # Define which devices to use for the learner. learner: - device_ids: [2,3] # Define which devices to use for the learner. + device_ids: [1] # Define which devices to use for the learner. # Size of the queue for the pipeline where actors push data and the learner pulls data. pipeline_queue_size: 10 diff --git a/stoix/systems/ppo/sebulba/ff_ppo.py b/stoix/systems/ppo/sebulba/ff_ppo.py index 77e689e..868b8d0 100644 --- a/stoix/systems/ppo/sebulba/ff_ppo.py +++ b/stoix/systems/ppo/sebulba/ff_ppo.py @@ -809,9 +809,9 @@ def run_experiment(_config: DictConfig) -> float: logger.log(train_metrics, t, eval_step, LogEvent.TRAIN) # Evaluate the current model and log the metrics - learner_state_cpu = jax.device_get(learner_state) + eval_learner_state = jax.device_put(learner_state, evaluator_device) key, eval_key = jax.random.split(key, 2) - eval_metrics = evaluator(learner_state_cpu.params.actor_params, eval_key) + eval_metrics = evaluator(eval_learner_state.params.actor_params, eval_key) logger.log(eval_metrics, t, eval_step, LogEvent.EVAL) episode_return = jnp.mean(eval_metrics["episode_return"]) @@ -820,12 +820,12 @@ def run_experiment(_config: DictConfig) -> float: # Save checkpoint of learner state checkpointer.save( timestep=steps_consumed_per_eval * (eval_step + 1), - unreplicated_learner_state=learner_state_cpu, + unreplicated_learner_state=jax.device_get(learner_state), episode_return=episode_return, ) if config.arch.absolute_metric and max_episode_return <= episode_return: - best_params = copy.deepcopy(learner_state_cpu.params.actor_params) + best_params = copy.deepcopy(eval_learner_state.params.actor_params) max_episode_return = episode_return evaluator_envs.close()