diff --git a/brax/envs/humanoid.py b/brax/envs/humanoid.py index e1ba0ed5..0e2600c1 100644 --- a/brax/envs/humanoid.py +++ b/brax/envs/humanoid.py @@ -255,6 +255,12 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Runs one timestep of the environment's dynamics.""" + + # Scale action from [-1,1] to actuator limits + action_min = self.sys.actuator.ctrl_range[:, 0] + action_max = self.sys.actuator.ctrl_range[:, 1] + action = (action + 1) * (action_max - action_min) * 0.5 + action_min + pipeline_state0 = state.pipeline_state pipeline_state = self.pipeline_step(pipeline_state0, action) diff --git a/brax/envs/humanoidstandup.py b/brax/envs/humanoidstandup.py index 777e863e..5fd28dc9 100644 --- a/brax/envs/humanoidstandup.py +++ b/brax/envs/humanoidstandup.py @@ -219,6 +219,12 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Runs one timestep of the environment's dynamics.""" + + # Scale action from [-1,1] to actuator limits + action_min = self.sys.actuator.ctrl_range[:, 0] + action_max = self.sys.actuator.ctrl_range[:, 1] + action = (action + 1) * (action_max - action_min) * 0.5 + action_min + pipeline_state = self.pipeline_step(state.pipeline_state, action) pos_after = pipeline_state.x.pos[0, 2] # z coordinate of torso diff --git a/brax/envs/inverted_double_pendulum.py b/brax/envs/inverted_double_pendulum.py index a546512f..8916fd21 100644 --- a/brax/envs/inverted_double_pendulum.py +++ b/brax/envs/inverted_double_pendulum.py @@ -46,14 +46,14 @@ class InvertedDoublePendulum(PipelineEnv): The agent take a 1-element vector for actions. - The action space is a continuous `(action)` in `[-3, 3]`, where `action` + The action space is a continuous `(action)` in `[-1, 1]`, where `action` represents the numerical force applied to the cart (with magnitude representing the amount of force and sign representing the direction) | Num | Action | Control Min | Control Max | Name (in corresponding config) | Joint | Unit | |-----|---------------------------|-------------|-------------|--------------------------------|-------|-----------| - | 0 | Force applied on the cart | -3 | 3 | slider + | 0 | Force applied on the cart | -1 | 1 | slider | slide | Force (N) | ### Observation Space diff --git a/brax/envs/inverted_pendulum.py b/brax/envs/inverted_pendulum.py index 30692177..58c396cf 100644 --- a/brax/envs/inverted_pendulum.py +++ b/brax/envs/inverted_pendulum.py @@ -129,6 +129,12 @@ def reset(self, rng: jax.Array) -> State: def step(self, state: State, action: jax.Array) -> State: """Run one timestep of the environment's dynamics.""" + + # Scale action from [-1,1] to actuator limits + action_min = self.sys.actuator.ctrl_range[:, 0] + action_max = self.sys.actuator.ctrl_range[:, 1] + action = (action + 1) * (action_max - action_min) * 0.5 + action_min + pipeline_state = self.pipeline_step(state.pipeline_state, action) obs = self._get_obs(pipeline_state) reward = 1.0 diff --git a/brax/envs/pusher.py b/brax/envs/pusher.py index d30e3511..5d3db9dd 100644 --- a/brax/envs/pusher.py +++ b/brax/envs/pusher.py @@ -193,6 +193,14 @@ def reset(self, rng: jax.Array) -> State: return State(pipeline_state, obs, reward, done, metrics) def step(self, state: State, action: jax.Array) -> State: + + # Scale action from [-1,1] to actuator limits + action_min = self.sys.actuator.ctrl_range[:, 0] + action_max = self.sys.actuator.ctrl_range[:, 1] + action = (action + 1) * (action_max - action_min) * 0.5 + action_min + + pipeline_state = self.pipeline_step(state.pipeline_state, action) + assert state.pipeline_state is not None x_i = state.pipeline_state.x.vmap().do( base.Transform.create(pos=self.sys.link.inertia.transform.pos) @@ -205,8 +213,6 @@ def step(self, state: State, action: jax.Array) -> State: reward_ctrl = -jp.square(action).sum() reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near - pipeline_state = self.pipeline_step(state.pipeline_state, action) - obs = self._get_obs(pipeline_state) state.metrics.update( reward_near=reward_near,