Skip to content

Commit

Permalink
Provide more information to the user (#358)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Soare <[email protected]>
Co-authored-by: Remi <[email protected]>
  • Loading branch information
3 people authored Aug 23, 2024
1 parent b5ad79a commit a2592a5
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 15 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,20 @@ checkpoints
│ └── training_state.pth # optimizer/scheduler/rng state and training step
```

To resume training from a checkpoint, you can add these to the `train.py` python command:
```bash
hydra.run.dir=your/original/experiment/dir resume=true
```

It will load the pretrained model, optimizer and scheduler states for training. For more information please see our tutorial on training resumption [here](https://github.com/huggingface/lerobot/blob/main/examples/5_resume_training.md).

To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding:

```bash
wandb.enable=true
```

A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser:
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explaination of some commonly used metrics in logs.

![](media/wandb.png)

Expand Down
30 changes: 30 additions & 0 deletions examples/4_train_policy_with_script.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,36 @@ python lerobot/scripts/train.py --config-dir outputs/train/my_experiment/checkpo

Note that you may still use the regular syntax for config parameter overrides (eg: by adding `training.offline_steps=200000`).

## Typical logs and metrics

When you start the training process, you will first see your full configuration being printed in the terminal. You can check it to make sure that you config it correctly and your config is not overrided by other files. The final configuration will also be saved with the checkpoint.

After that, you will see training log like this one:

```
INFO 2024-08-14 13:35:12 ts/train.py:192 step:0 smpl:64 ep:1 epch:0.00 loss:1.112 grdn:15.387 lr:2.0e-07 updt_s:1.738 data_s:4.774
```

or evaluation log like:

```
INFO 2024-08-14 13:38:45 ts/train.py:226 step:100 smpl:6K ep:52 epch:0.25 ∑rwrd:20.693 success:0.0% eval_s:120.266
```

These logs will also be saved in wandb if `wandb.enable` is set to `true`. Here are the meaning of some abbreviations:

- `smpl`: number of samples seen during training.
- `ep`: number of episodes seen during training. An episode contains multiple samples in a complete manipulation task.
- `epch`: number of time all unique samples are seen (epoch).
- `grdn`: gradient norm.
- `∑rwrd`: compute the sum of rewards in every evaluation episode and then take an average of them.
- `success`: average success rate of eval episodes. Reward and success are usually different except for the sparsing reward setting, where reward=1 only when the task is completed successfully.
- `eval_s`: time to evaluate the policy in the environment, in second.
- `updt_s`: time to update the network parameters, in second.
- `data_s`: time to load a batch of data, in second.

Some metrics are useful for initial performance profiling. For example, if you find the current GPU utilization is low via the `nvidia-smi` command and `data_s` sometimes is too high, you may need to modify batch size or number of dataloading workers to accelerate dataloading. We also recommend [pytorch profiler](https://github.com/huggingface/lerobot?tab=readme-ov-file#improve-your-code-with-profiling) for detailed performance probing.

---

So far we've seen how to train Diffusion Policy for PushT and ACT for ALOHA. Now, what if we want to train ACT for PushT? Well, there are aspects of the ACT configuration that are specific to the ALOHA environments, and these happen to be incompatible with PushT. Therefore, trying to run the following will almost certainly raise an exception of sorts (eg: feature dimension mismatch):
Expand Down
17 changes: 16 additions & 1 deletion lerobot/common/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import os.path as osp
import random
from contextlib import contextmanager
Expand All @@ -27,6 +28,12 @@
from omegaconf import DictConfig


def inside_slurm():
"""Check whether the python process was launched through slurm"""
# TODO(rcadene): return False for interactive mode `--pty bash`
return "SLURM_JOB_ID" in os.environ


def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
"""Given a string, return a torch.device with checks on whether the device is available."""
match cfg_device:
Expand Down Expand Up @@ -158,7 +165,15 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D
version_base="1.2",
)
cfg = hydra.compose(Path(config_path).stem, overrides)

if cfg.eval.batch_size > cfg.eval.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
"This might significantly slow down evaluation. To fix this, you should update your command "
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
)
return cfg


Expand Down
2 changes: 1 addition & 1 deletion lerobot/configs/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ eval:
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
batch_size: 1
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
use_async_envs: false
use_async_envs: true

wandb:
enable: false
Expand Down
5 changes: 5 additions & 0 deletions lerobot/configs/env/aloha.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

fps: 50

eval:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
# set it to false to avoid some problems of the aloha env
use_async_envs: false

env:
name: aloha
task: AlohaInsertion-v0
Expand Down
5 changes: 5 additions & 0 deletions lerobot/configs/env/xarm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

fps: 15

eval:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
# set it to false to avoid some problems of the aloha env
use_async_envs: false

env:
name: xarm
task: XarmLift-v0
Expand Down
22 changes: 10 additions & 12 deletions lerobot/scripts/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
from lerobot.common.utils.utils import (
get_safe_torch_device,
init_hydra_config,
init_logging,
inside_slurm,
set_global_seed,
)


def rollout(
Expand All @@ -79,7 +85,6 @@ def rollout(
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
enable_progbar: bool = False,
) -> dict:
"""Run a batched policy rollout once through a batch of environments.
Expand Down Expand Up @@ -109,7 +114,6 @@ def rollout(
are returned optionally because they typically take more memory to cache. Defaults to False.
render_callback: Optional rendering callback to be used after the environments are reset, and after
every step.
enable_progbar: Enable a progress bar over rollout steps.
Returns:
The dictionary described above.
"""
Expand All @@ -136,7 +140,7 @@ def rollout(
progbar = trange(
max_steps,
desc=f"Running rollout with at most {max_steps} steps",
disable=not enable_progbar,
disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs
leave=False,
)
while not np.all(done):
Expand Down Expand Up @@ -210,8 +214,6 @@ def eval_policy(
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
enable_progbar: bool = False,
enable_inner_progbar: bool = False,
) -> dict:
"""
Args:
Expand All @@ -224,8 +226,6 @@ def eval_policy(
the "episodes" key of the returned dictionary.
start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the
seed is incremented by 1. If not provided, the environments are not manually seeded.
enable_progbar: Enable progress bar over batches.
enable_inner_progbar: Enable progress bar over steps in each batch.
Returns:
Dictionary with metrics and data regarding the rollouts.
"""
Expand Down Expand Up @@ -266,7 +266,8 @@ def render_frame(env: gym.vector.VectorEnv):
if return_episode_data:
episode_data: dict | None = None

progbar = trange(n_batches, desc="Stepping through eval batches", disable=not enable_progbar)
# we dont want progress bar when we use slurm, since it clutters the logs
progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm())
for batch_ix in progbar:
# Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout
# step.
Expand All @@ -285,7 +286,6 @@ def render_frame(env: gym.vector.VectorEnv):
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
enable_progbar=enable_inner_progbar,
)

# Figure out where in each rollout sequence the first done condition was encountered (results after
Expand Down Expand Up @@ -487,8 +487,6 @@ def main(
max_episodes_rendered=10,
videos_dir=Path(out_dir) / "videos",
start_seed=hydra_cfg.seed,
enable_progbar=True,
enable_inner_progbar=True,
)
print(info["aggregated"])

Expand Down
1 change: 1 addition & 0 deletions lerobot/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
raise NotImplementedError()

init_logging()
logging.info(pformat(OmegaConf.to_container(cfg)))

if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
Expand Down

0 comments on commit a2592a5

Please sign in to comment.