Skip to content

Commit

Permalink
[RLlib; Offline RL] Fix small memory leak in `OfflineSingleAgentEnvRu…
Browse files Browse the repository at this point in the history
…nner`. (#48309)
  • Loading branch information
simonsays1980 authored Dec 2, 2024
1 parent f89aaf9 commit 0e9f3d1
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 72 deletions.
11 changes: 11 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2929,6 +2929,17 @@ py_test(
# subdirectory: offline_rl/
# ....................................

# Does run into scheduling problems in CI tests. Works on local
# and GCP cloud.
# py_test(
# name = "examples/offline_rl/cartpole_recording",
# main = "examples/offline_rl/cartpole_recording.py",
# tags = ["team:rllib", "examples", "exclusive"],
# size = "large",
# srcs = ["examples/offline_rl/cartpole_recording.py"],
# args = ["--enable-new-api-stack", "--as-test", "--framework=torch", "--num-cpus=12"],
# )

py_test(
name = "examples/offline_rl/train_w_bc_finetune_w_ppo",
main = "examples/offline_rl/train_w_bc_finetune_w_ppo.py",
Expand Down
13 changes: 13 additions & 0 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ def __init__(self, algo_class: Optional[type] = None):
self.output_compress_columns = [Columns.OBS, Columns.NEXT_OBS]
self.output_max_file_size = 64 * 1024 * 1024
self.output_max_rows_per_file = None
self.output_write_remaining_data = False
self.output_write_method = "write_parquet"
self.output_write_method_kwargs = {}
self.output_filesystem = None
Expand Down Expand Up @@ -2579,6 +2580,7 @@ def offline_data(
output_compress_columns: Optional[List[str]] = NotProvided,
output_max_file_size: Optional[float] = NotProvided,
output_max_rows_per_file: Optional[int] = NotProvided,
output_write_remaining_data: Optional[bool] = NotProvided,
output_write_method: Optional[str] = NotProvided,
output_write_method_kwargs: Optional[Dict] = NotProvided,
output_filesystem: Optional[str] = NotProvided,
Expand Down Expand Up @@ -2748,6 +2750,15 @@ def offline_data(
to a new file.
output_max_rows_per_file: Max output row numbers before rolling over to a
new file.
output_write_remaining_data: Determines whether any remaining data in the
recording buffers should be stored to disk. It is only applicable if
`output_max_rows_per_file` is defined. When sampling data, it is
buffered until the threshold specified by `output_max_rows_per_file`
is reached. Only complete multiples of `output_max_rows_per_file` are
written to disk, while any leftover data remains in the buffers. If a
recording session is stopped, residual data may still reside in these
buffers. Setting `output_write_remaining_data` to `True` ensures this
data is flushed to disk. By default, this attribute is set to `False`.
output_write_method: Write method for the `ray.data.Dataset` to write the
offline data to `output`. The default is `read_parquet` for Parquet
files. See https://docs.ray.io/en/latest/data/api/input_output.html for
Expand Down Expand Up @@ -2855,6 +2866,8 @@ def offline_data(
self.output_max_file_size = output_max_file_size
if output_max_rows_per_file is not NotProvided:
self.output_max_rows_per_file = output_max_rows_per_file
if output_write_remaining_data is not NotProvided:
self.output_write_remaining_data = output_write_remaining_data
if output_write_method is not NotProvided:
self.output_write_method = output_write_method
if output_write_method_kwargs is not NotProvided:
Expand Down
163 changes: 163 additions & 0 deletions rllib/examples/offline_rl/cartpole_recording.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Example showing how to record expert data from a trained policy.
This example:
- demonstrates how you can train a single-agent expert PPO Policy (RLModule)
and checkpoint it.
- shows how you can then record expert data from the trained PPO Policy to
disk during evaluation.
How to run this script
----------------------
`python [script file name].py --checkpoint-at-end`
For debugging, use the following additional command line options
`--no-tune --num-env-runners=0`
which should allow you to set breakpoints anywhere in the RLlib code and
have the execution stop there for inspection and debugging.
For logging to your WandB account, use:
`--wandb-key=[your WandB API key] --wandb-project=[some project name]
--wandb-run-name=[optional: WandB run name (within the defined project)]`
Results to expect
-----------------
In the console output you can see that the episode return of 350.0 is reached
before the timestep stop criteria is touched. Afterwards evaluation starts and
runs 10 iterations while recording the data. The number of recorded experiences
might differ from evaluation run to evaluation run because evaluation
`EnvRunner`s sample episodes while recording timesteps and episodes contain
usually different numbers of timesteps. Note, this is different when recording
episodes - in this case each row is one episode.
+-----------------------------+------------+----------------------+
| Trial name | status | loc |
| | | |
|-----------------------------+------------+----------------------+
| PPO_CartPole-v1_df83f_00000 | TERMINATED | 192.168.0.119:233661 |
+-----------------------------+------------+----------------------+
+--------+------------------+------------------------+------------------------+
| iter | total time (s) | num_training_step_ca | num_env_steps_sample |
| | | lls_per_iteration | d_lifetime |
+--------+------------------+------------------------+------------------------|
| 21 | 25.9162 | 1 | 84000 |
+--------+------------------+------------------------+------------------------+
...
Number of experiences recorded: 26644
"""

import ray

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core import COMPONENT_RL_MODULE
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
EVALUATION_RESULTS,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
)
from ray.rllib.utils.test_utils import add_rllib_example_script_args

parser = add_rllib_example_script_args(
default_timesteps=200000,
default_reward=350.0,
)
parser.set_defaults(checkpoint_at_end=True, max_concurrent_trials=1)
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values to set up `config` below.
args = parser.parse_args()

config = (
PPOConfig()
.env_runners(
num_env_runners=5,
)
.environment("CartPole-v1")
.rl_module(
model_config=DefaultModelConfig(
fcnet_hiddens=[32],
fcnet_activation="linear",
vf_share_layers=True,
),
)
.training(
lr=0.0003,
num_epochs=6,
vf_loss_coeff=0.01,
)
.evaluation(
evaluation_num_env_runners=1,
evaluation_interval=1,
evaluation_parallel_to_training=True,
evaluation_config=PPOConfig.overrides(explore=False),
)
)

stop = {
f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": args.stop_timesteps,
f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": (
args.stop_reward
),
}


if __name__ == "__main__":
from ray.rllib.utils.test_utils import run_rllib_example_script_experiment

results = run_rllib_example_script_experiment(config, args, stop=stop)

# Store the best checkpoint for recording.
best_checkpoint = results.get_best_result(
metric=f"{EVALUATION_RESULTS}/{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}",
mode="max",
).checkpoint.path

# Configure the algorithm for offline recording.
config.offline_data(
output="local:///tmp/cartpole/",
# Store columnar (tabular) data.
output_write_episodes=False,
# Each file should hold 1,000 rows.
output_max_rows_per_file=1000,
output_write_remaining_data=True,
# LZ4-compress columns 'obs', 'new_obs', and 'actions' to
# save disk space and increase performance. Note, this means
# that you have to use `input_compress_columns` in the same
# way when using the data for training in `RLlib`.
output_compress_columns=[Columns.OBS, Columns.ACTIONS],
)
# Change the evaluation settings to sample exactly 50 episodes
# per evaluation iteration and increase the number of evaluation
# env-runners to 5.
config.evaluation(
evaluation_num_env_runners=5,
evaluation_duration=50,
evaluation_duration_unit="episodes",
evaluation_interval=1,
evaluation_parallel_to_training=False,
evaluation_config=PPOConfig.overrides(explore=False),
)

# Build the algorithm for evaluation.
algo = config.build()
# Load the checkpoint stored above.
algo.restore_from_path(
best_checkpoint,
component=COMPONENT_RL_MODULE,
)

# Evaluate over 10 iterations and record the data.
for i in range(10):
print(f"Iteration: {i + 1}:\n")
res = algo.evaluate()
print(res)

# Stop the algorithm.
algo.stop()

# Check the number of rows in the dataset.
ds = ray.data.read_parquet("local:///tmp/cartpole")
print(f"Number of experiences recorded: {ds.count()}")
41 changes: 30 additions & 11 deletions rllib/offline/offline_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
# Initialize the parent.
super().__init__(config, **kwargs)

# Get the data context for this `EnvRunner`.
data_context = ray.data.DataContext.get_current()
# Limit the resources for Ray Data to the CPUs given to this `EnvRunner`.
data_context.execution_options.resource_limits.cpu = (
config.num_cpus_per_env_runner
)

# Set the output write method.
self.output_write_method = self.config.output_write_method
self.output_write_method_kwargs = self.config.output_write_method_kwargs
Expand Down Expand Up @@ -92,6 +99,10 @@ def __init__(self, config: AlgorithmConfig, **kwargs):
else:
self.write_data_this_iter = True

# If the remaining data should be stored. Note, this is only
# relevant in case `output_max_rows_per_file` is defined.
self.write_remaining_data = self.config.output_write_remaining_data

# Counts how often `sample` is called to define the output path for
# each file.
self._sample_counter = 0
Expand Down Expand Up @@ -155,15 +166,18 @@ def sample(
if self.output_max_rows_per_file:
# Reset the event.
self.write_data_this_iter = False

# Extract the number of samples to be written to disk this iteration.
samples_to_write = self._samples[: self.output_max_rows_per_file]
# Reset the buffer to the remaining data. This only makes sense, if
# `rollout_fragment_length` is smaller `output_max_rows_per_file` or
# a 2 x `output_max_rows_per_file`.
# TODO (simon): Find a better way to write these data.
self._samples = self._samples[self.output_max_rows_per_file :]
samples_ds = ray.data.from_items(samples_to_write)
# Ensure that all data ready to be written is released from
# the buffer. Note, this is important in case we have many
# episodes sampled and a relatively small `output_max_rows_per_file`.
while len(self._samples) >= self.output_max_rows_per_file:
# Extract the number of samples to be written to disk this
# iteration.
samples_to_write = self._samples[: self.output_max_rows_per_file]
# Reset the buffer to the remaining data. This only makes sense, if
# `rollout_fragment_length` is smaller `output_max_rows_per_file` or
# a 2 x `output_max_rows_per_file`.
self._samples = self._samples[self.output_max_rows_per_file :]
samples_ds = ray.data.from_items(samples_to_write)
# Otherwise, write the complete data.
else:
samples_ds = ray.data.from_items(self._samples)
Expand All @@ -183,6 +197,11 @@ def sample(
except Exception as e:
logger.error(e)

self.metrics.log_value(
key="recording_buffer_size",
value=len(self._samples),
)

# Finally return the samples as usual.
return samples

Expand All @@ -196,11 +215,11 @@ def stop(self) -> None:
"""
# If there are samples left over we have to write htem to disk. them
# to a dataset.
if self._samples:
if self._samples and self.write_remaining_data:
# Convert them to a `ray.data.Dataset`.
samples_ds = ray.data.from_items(self._samples)
# Increase the sample counter for the folder/file name.
self._sample_counter += 1.0
self._sample_counter += 1
# Try to write the dataset to disk/cloud storage.
try:
# Setup the path for writing data. Each run will be written to
Expand Down
61 changes: 0 additions & 61 deletions rllib/tuned_examples/bc/cartpole_recording.py

This file was deleted.

0 comments on commit 0e9f3d1

Please sign in to comment.