Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scaling control actions for BRAX environments #473

Merged
merged 7 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions brax/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from brax.spring import pipeline as s_pipeline
from flax import struct
import jax
import jax.numpy as jnp
import mujoco
from mujoco import mjx
import numpy as np
Expand Down Expand Up @@ -128,6 +129,22 @@ def f(state, _):
)

return jax.lax.scan(f, pipeline_state, (), self._n_frames)[0]

def scale_and_clip_actions(self, action: jax.Array) -> jax.Array:
nic-barbara marked this conversation as resolved.
Show resolved Hide resolved
"""
Scale an input action from `[-1, 1]` up/down to the control limits
of each actuator in an the model.

We assume the action is in `[-1, 1]` and apply a linear transform
to scale the control to `[a, b]` with `u = (u + 1)(b-a)/2 + a`
"""
action_min = self.sys.actuator.ctrl_range[:, 0]
action_max = self.sys.actuator.ctrl_range[:, 1]

def rescale(x):
return (x + 1) * (action_max - action_min) / 2 + action_min

return jnp.clip(rescale(action), a_min=action_max, a_max=action_max)
nic-barbara marked this conversation as resolved.
Show resolved Hide resolved

@property
def dt(self) -> jax.Array:
Expand Down
1 change: 1 addition & 0 deletions brax/envs/humanoid.py
btaba marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def reset(self, rng: jax.Array) -> State:
def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
action = self.scale_and_clip_actions(action)
pipeline_state = self.pipeline_step(pipeline_state0, action)

com_before, *_ = self._com(pipeline_state0)
Expand Down
1 change: 1 addition & 0 deletions brax/envs/humanoidstandup.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def reset(self, rng: jax.Array) -> State:

def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
action = self.scale_and_clip_actions(action)
pipeline_state = self.pipeline_step(state.pipeline_state, action)

pos_after = pipeline_state.x.pos[0, 2] # z coordinate of torso
Expand Down
4 changes: 2 additions & 2 deletions brax/envs/inverted_double_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ class InvertedDoublePendulum(PipelineEnv):

The agent take a 1-element vector for actions.

The action space is a continuous `(action)` in `[-3, 3]`, where `action`
The action space is a continuous `(action)` in `[-1, 1]`, where `action`
represents the numerical force applied to the cart (with magnitude
representing the amount of force and sign representing the direction)

| Num | Action | Control Min | Control Max | Name (in
corresponding config) | Joint | Unit |
|-----|---------------------------|-------------|-------------|--------------------------------|-------|-----------|
| 0 | Force applied on the cart | -3 | 3 | slider
| 0 | Force applied on the cart | -1 | 1 | slider
| slide | Force (N) |

### Observation Space
Expand Down
4 changes: 4 additions & 0 deletions brax/envs/inverted_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class InvertedPendulum(PipelineEnv):
continuous `(action)` in `[-3, 3]`, where `action` represents the numerical
force applied to the cart (with magnitude representing the amount of force and
sign representing the direction)

Actions are assumed to be within `[-1, 1]` and are (linearly) scaled
nic-barbara marked this conversation as resolved.
Show resolved Hide resolved
to `[-3, 3]` within the environment's `step()` call.

| Num | Action | Control Min | Control Max | Name (in
corresponding config) | Joint | Unit |
Expand Down Expand Up @@ -129,6 +132,7 @@ def reset(self, rng: jax.Array) -> State:

def step(self, state: State, action: jax.Array) -> State:
"""Run one timestep of the environment's dynamics."""
action = self.scale_and_clip_actions(action)
pipeline_state = self.pipeline_step(state.pipeline_state, action)
obs = self._get_obs(pipeline_state)
reward = 1.0
Expand Down
10 changes: 7 additions & 3 deletions brax/envs/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ class Pusher(PipelineEnv):
### Action Space

The action space is a `Box(-2, 2, (7,), float32)`. An action `(a, b)`
represents the torques applied at the hinge joints.
represents the torques applied at the hinge joints.
btaba marked this conversation as resolved.
Show resolved Hide resolved

Actions are assumed to be within `[-1, 1]` and are (linearly) scaled
to `[-2, 2]` within the environment's `step()` call.

| Num | Action | Control Min | Control Max | Name (in corresponding config) | Joint | Unit |
|-----|-----------------------------------------------|-------------|-------------|--------------------------------|-------|--------------|
Expand Down Expand Up @@ -193,6 +196,9 @@ def reset(self, rng: jax.Array) -> State:
return State(pipeline_state, obs, reward, done, metrics)

def step(self, state: State, action: jax.Array) -> State:
action = self.scale_and_clip_actions(action)
pipeline_state = self.pipeline_step(state.pipeline_state, action)

assert state.pipeline_state is not None
x_i = state.pipeline_state.x.vmap().do(
base.Transform.create(pos=self.sys.link.inertia.transform.pos)
Expand All @@ -205,8 +211,6 @@ def step(self, state: State, action: jax.Array) -> State:
reward_ctrl = -jp.square(action).sum()
reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near

pipeline_state = self.pipeline_step(state.pipeline_state, action)

obs = self._get_obs(pipeline_state)
state.metrics.update(
reward_near=reward_near,
Expand Down