Skip to content

Commit

Permalink
Fix: Sebulba PPO Metrics (#108)
Browse files Browse the repository at this point in the history
fix: sebulba ppo metrics issues
  • Loading branch information
EdanToledo authored Aug 27, 2024
1 parent b3b7c2b commit 5284124
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 104 deletions.
2 changes: 1 addition & 1 deletion stoix/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def get_sebulba_eval_fn(
# we will run all episodes in parallel.
# Otherwise we will run `num_envs` parallel envs and loop enough times
# so that we do at least `eval_episodes` number of episodes.
n_parallel_envs = int(min(eval_episodes, config.arch.num_envs))
n_parallel_envs = int(min(eval_episodes, config.arch.total_num_envs))
episode_loops = math.ceil(eval_episodes / n_parallel_envs)
envs = env_factory(n_parallel_envs)
cpu = jax.devices("cpu")[0]
Expand Down
145 changes: 66 additions & 79 deletions stoix/systems/ppo/sebulba/ff_ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import queue
import threading
import time
import warnings
from collections import defaultdict
from queue import Queue
Expand Down Expand Up @@ -91,7 +92,7 @@ def get_rollout_fn(
move_to_device = lambda tree: jax.tree.map(lambda x: jax.device_put(x, actor_device), tree)
split_key_fn = jax.jit(jax.random.split, device=actor_device)
# Build the environments
envs = env_factory(config.arch.actor.envs_per_actor)
envs = env_factory(config.arch.actor.num_envs_per_actor)

# Create the rollout function
def rollout_fn(rng_key: chex.PRNGKey) -> None:
Expand Down Expand Up @@ -348,7 +349,7 @@ def _critic_loss_fn(

# SHUFFLE MINIBATCHES
# Since we shard the envs per actor across the devices
envs_per_batch = config.arch.actor.envs_per_actor // len(config.arch.learner.device_ids)
envs_per_batch = config.arch.actor.num_envs_per_actor // config.num_learner_devices
batch_size = config.system.rollout_length * envs_per_batch
permutation = jax.random.permutation(shuffle_key, batch_size)
batch = (traj_batch, advantages, targets)
Expand Down Expand Up @@ -429,38 +430,39 @@ def learner_rollout(learner_state: CoreLearnerState) -> None:
actor_timings: List[Dict] = []
learner_timings: Dict[str, List[float]] = defaultdict(list)
q_sizes: List[int] = []
# Loop for the number of updates per evaluation
for _ in range(config.arch.num_updates_per_eval):
# Get the trajectory batch from the pipeline
# This is blocking so it will wait until the pipeline has data.
with RecordTimeTo(learner_timings["rollout_get_time"]):
(
traj_batch,
timestep,
actor_times,
episode_metrics,
) = pipeline.get( # type: ignore
block=True
)
# We then replace the timestep in the learner state with the latest timestep
# This means the learner has access to the entire trajectory as well as
# an additional timestep which it can use to bootstrap.
learner_state = learner_state._replace(timestep=timestep)
# We then call the update function to update the networks
with RecordTimeTo(learner_timings["learning_time"]):
learner_state, train_metrics = learn_step(learner_state, traj_batch)

# We store the metrics and timings for this update
metrics.append((episode_metrics, train_metrics))
actor_timings.append(actor_times)
q_sizes.append(pipeline.qsize())

# After the update we need to update the params sources with the new params
unreplicated_params = unreplicate(learner_state.params)
# We loop over all params sources and update them with the new params
# This is so that all the actors can get the latest params
for source in params_sources:
source.update(unreplicated_params)
with RecordTimeTo(learner_timings["learner_time_per_eval"]):
# Loop for the number of updates per evaluation
for _ in range(config.arch.num_updates_per_eval):
# Get the trajectory batch from the pipeline
# This is blocking so it will wait until the pipeline has data.
with RecordTimeTo(learner_timings["rollout_get_time"]):
(
traj_batch,
timestep,
actor_times,
episode_metrics,
) = pipeline.get( # type: ignore
block=True
)
# We then replace the timestep in the learner state with the latest timestep
# This means the learner has access to the entire trajectory as well as
# an additional timestep which it can use to bootstrap.
learner_state = learner_state._replace(timestep=timestep)
# We then call the update function to update the networks
with RecordTimeTo(learner_timings["learner_step_time"]):
learner_state, train_metrics = learn_step(learner_state, traj_batch)

# We store the metrics and timings for this update
metrics.append((episode_metrics, train_metrics))
actor_timings.append(actor_times)
q_sizes.append(pipeline.qsize())

# After the update we need to update the params sources with the new params
unreplicated_params = unreplicate(learner_state.params)
# We loop over all params sources and update them with the new params
# This is so that all the actors can get the latest params
for source in params_sources:
source.update(unreplicated_params)

# We then pass all the environment metrics, training metrics, current learner state
# and timings to the evaluation queue. This is so the evaluator correctly evaluates
Expand Down Expand Up @@ -617,18 +619,6 @@ def run_experiment(_config: DictConfig) -> float:
"""Runs experiment."""
config = copy.deepcopy(_config)

# Perform some checks on the config
# This additionally calculates certains
# values based on the config
config = check_total_timesteps(config)

assert (
config.arch.num_updates > config.arch.num_evaluation
), "Number of updates per evaluation must be less than total number of updates."

# Calculate the number of updates per evaluation
config.arch.num_updates_per_eval = int(config.arch.num_updates // config.arch.num_evaluation)

# Get the learner and actor devices
local_devices = jax.local_devices()
global_devices = jax.devices()
Expand All @@ -647,28 +637,13 @@ def run_experiment(_config: DictConfig) -> float:
print(f"{Fore.MAGENTA}{Style.BRIGHT}Global devices: {global_devices}{Style.RESET_ALL}")
# Set the number of learning and acting devices in the config
# useful for keeping track of experimental setup
config.num_learning_devices = len(local_learner_devices)
config.num_actor_actor_devices = len(actor_devices)

# Calculate the number of envs per actor
assert (
config.arch.num_envs == config.arch.total_num_envs
), "arch.num_envs must equal arch.total_num_envs for Sebulba architectures"
# We first simply take the total number of envs and divide by the number of actor devices
# to get the number of envs per actor device
num_envs_per_actor_device = config.arch.total_num_envs // len(actor_devices)
# We then divide this by the number of actors per device to get the number of envs per actor
num_envs_per_actor = int(num_envs_per_actor_device // config.arch.actor.actor_per_device)
config.arch.actor.envs_per_actor = num_envs_per_actor

# We then perform a simple check to ensure that the number of envs per actor is
# divisible by the number of learner devices. This is because we shard the envs
# per actor across the learner devices This check is mainly relevant for on-policy
# algorithms
assert num_envs_per_actor % len(local_learner_devices) == 0, (
f"The number of envs per actor must be divisible by the number of learner devices. "
f"Got {num_envs_per_actor} envs per actor and {len(local_learner_devices)} learner devices"
)
config.num_learner_devices = len(local_learner_devices)
config.num_actor_devices = len(actor_devices)

# Perform some checks on the config
# This additionally calculates certains
# values based on the config
config = check_total_timesteps(config)

# Create the environment factory.
env_factory = environments.make_factory(config)
Expand Down Expand Up @@ -711,10 +686,10 @@ def run_experiment(_config: DictConfig) -> float:
# Get initial parameters
initial_params = unreplicate(learner_state.params)

# Get the number of steps per rollout
steps_per_rollout = (
config.system.rollout_length * config.arch.total_num_envs * config.arch.num_updates_per_eval
)
# Get the number of steps consumed by the learner per learner step
steps_per_learner_step = config.system.rollout_length * config.arch.actor.num_envs_per_actor
# Get the number of steps consumed by the learner per evaluation
steps_consumed_per_eval = steps_per_learner_step * config.arch.num_updates_per_eval

# Creating the pipeline
# First we create the lifetime so we can stop the pipeline when we want
Expand Down Expand Up @@ -743,7 +718,7 @@ def run_experiment(_config: DictConfig) -> float:
for i in range(config.arch.actor.actor_per_device):
key, actors_key = jax.random.split(key)
seeds = np_rng.integers(
np.iinfo(np.int32).max, size=config.arch.actor.envs_per_actor
np.iinfo(np.int32).max, size=config.arch.actor.num_envs_per_actor
).tolist()
actor_thread = get_actor_thread(
env_factory,
Expand Down Expand Up @@ -779,17 +754,25 @@ def run_experiment(_config: DictConfig) -> float:
episode_metrics, train_metrics, learner_state, timings_dict = eval_queue.get(block=True)

# Log the metrics and timings
t = int(steps_per_rollout * (eval_step + 1))
t = int(steps_consumed_per_eval * (eval_step + 1))
timings_dict["timestep"] = t
logger.log(timings_dict, t, eval_step, LogEvent.MISC)

episode_metrics, ep_completed = get_final_step_metrics(episode_metrics)
# Calculate steps per second for actor
# Here we use the number of steps pushed to the pipeline each time
# and the average time it takes to do a single rollout across
# all the updates per evaluation
episode_metrics["steps_per_second"] = (
steps_per_rollout / timings_dict["single_rollout_time"]
steps_per_learner_step / timings_dict["single_rollout_time"]
)
if ep_completed:
logger.log(episode_metrics, t, eval_step, LogEvent.ACT)

train_metrics["learner_step"] = (eval_step + 1) * config.arch.num_updates_per_eval
train_metrics["sgd_steps_per_second"] = (config.arch.num_updates_per_eval) / timings_dict[
"learner_time_per_eval"
]
logger.log(train_metrics, t, eval_step, LogEvent.TRAIN)

# Evaluate the current model and log the metrics
Expand All @@ -803,7 +786,7 @@ def run_experiment(_config: DictConfig) -> float:
if save_checkpoint:
# Save checkpoint of learner state
checkpointer.save(
timestep=steps_per_rollout * (eval_step + 1),
timestep=steps_consumed_per_eval * (eval_step + 1),
unreplicated_learner_state=unreplicate(learner_state),
episode_return=episode_return,
)
Expand All @@ -824,6 +807,7 @@ def run_experiment(_config: DictConfig) -> float:

# Now we stop the actors and params sources
print(f"{Fore.MAGENTA}{Style.BRIGHT}Closing actors...{Style.RESET_ALL}")
pipeline.clear()
for actor in actor_threads:
# We clear the pipeline before stopping each actor thread
# since actors can be blocked on the pipeline
Expand All @@ -850,7 +834,7 @@ def run_experiment(_config: DictConfig) -> float:
key, eval_key = jax.random.split(key, 2)
eval_metrics = abs_metric_evaluator(best_params, eval_key)

t = int(steps_per_rollout * (eval_step + 1))
t = int(steps_consumed_per_eval * (eval_step + 1))
logger.log(eval_metrics, t, eval_step, LogEvent.ABSOLUTE)
abs_metric_evaluator_envs.close()

Expand All @@ -871,9 +855,12 @@ def hydra_entry_point(cfg: DictConfig) -> float:
OmegaConf.set_struct(cfg, False)

# Run experiment.
start = time.monotonic()
eval_performance = run_experiment(cfg)

print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed{Style.RESET_ALL}")
end = time.monotonic()
print(
f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed in {end - start:.2f}s.{Style.RESET_ALL}"
)
return eval_performance


Expand Down
Loading

0 comments on commit 5284124

Please sign in to comment.