Skip to content

Commit

Permalink
Scaling control actions for BRAX environments (#473)
Browse files Browse the repository at this point in the history
* Updated docstrings with action sclaing info

* Added action scaling to relevant environments

* Removed edits to docstrings, moved action scaling to each environment

* Removed trailing white space

* Undo changes to .gitignore

* Undo changes to .gitignore
  • Loading branch information
nic-barbara authored May 20, 2024
1 parent 0d513cd commit 2635079
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 4 deletions.
6 changes: 6 additions & 0 deletions brax/envs/humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ def reset(self, rng: jax.Array) -> State:

def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""

# Scale action from [-1,1] to actuator limits
action_min = self.sys.actuator.ctrl_range[:, 0]
action_max = self.sys.actuator.ctrl_range[:, 1]
action = (action + 1) * (action_max - action_min) * 0.5 + action_min

pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)

Expand Down
6 changes: 6 additions & 0 deletions brax/envs/humanoidstandup.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ def reset(self, rng: jax.Array) -> State:

def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""

# Scale action from [-1,1] to actuator limits
action_min = self.sys.actuator.ctrl_range[:, 0]
action_max = self.sys.actuator.ctrl_range[:, 1]
action = (action + 1) * (action_max - action_min) * 0.5 + action_min

pipeline_state = self.pipeline_step(state.pipeline_state, action)

pos_after = pipeline_state.x.pos[0, 2] # z coordinate of torso
Expand Down
4 changes: 2 additions & 2 deletions brax/envs/inverted_double_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ class InvertedDoublePendulum(PipelineEnv):
The agent take a 1-element vector for actions.
The action space is a continuous `(action)` in `[-3, 3]`, where `action`
The action space is a continuous `(action)` in `[-1, 1]`, where `action`
represents the numerical force applied to the cart (with magnitude
representing the amount of force and sign representing the direction)
| Num | Action | Control Min | Control Max | Name (in
corresponding config) | Joint | Unit |
|-----|---------------------------|-------------|-------------|--------------------------------|-------|-----------|
| 0 | Force applied on the cart | -3 | 3 | slider
| 0 | Force applied on the cart | -1 | 1 | slider
| slide | Force (N) |
### Observation Space
Expand Down
6 changes: 6 additions & 0 deletions brax/envs/inverted_pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ def reset(self, rng: jax.Array) -> State:

def step(self, state: State, action: jax.Array) -> State:
"""Run one timestep of the environment's dynamics."""

# Scale action from [-1,1] to actuator limits
action_min = self.sys.actuator.ctrl_range[:, 0]
action_max = self.sys.actuator.ctrl_range[:, 1]
action = (action + 1) * (action_max - action_min) * 0.5 + action_min

pipeline_state = self.pipeline_step(state.pipeline_state, action)
obs = self._get_obs(pipeline_state)
reward = 1.0
Expand Down
10 changes: 8 additions & 2 deletions brax/envs/pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ def reset(self, rng: jax.Array) -> State:
return State(pipeline_state, obs, reward, done, metrics)

def step(self, state: State, action: jax.Array) -> State:

# Scale action from [-1,1] to actuator limits
action_min = self.sys.actuator.ctrl_range[:, 0]
action_max = self.sys.actuator.ctrl_range[:, 1]
action = (action + 1) * (action_max - action_min) * 0.5 + action_min

pipeline_state = self.pipeline_step(state.pipeline_state, action)

assert state.pipeline_state is not None
x_i = state.pipeline_state.x.vmap().do(
base.Transform.create(pos=self.sys.link.inertia.transform.pos)
Expand All @@ -205,8 +213,6 @@ def step(self, state: State, action: jax.Array) -> State:
reward_ctrl = -jp.square(action).sum()
reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near

pipeline_state = self.pipeline_step(state.pipeline_state, action)

obs = self._get_obs(pipeline_state)
state.metrics.update(
reward_near=reward_near,
Expand Down

0 comments on commit 2635079

Please sign in to comment.