Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 718976152
Change-Id: I8a4aaaa7e489916964c769ef957a87b82108b622
  • Loading branch information
Brax Team authored and btaba committed Jan 23, 2025
1 parent d48b0b3 commit 1bc5c00
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 91 deletions.
21 changes: 21 additions & 0 deletions brax/envs/wrappers/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ def reset(self, rng: jax.Array) -> State:
state = self.env.reset(rng)
state.info['steps'] = jp.zeros(rng.shape[:-1])
state.info['truncation'] = jp.zeros(rng.shape[:-1])
# Keep separate record of episode done as state.info['done'] can be erased
# by AutoResetWrapper
state.info['episode_done'] = jp.zeros(rng.shape[:-1])
episode_metrics = dict()
episode_metrics['sum_reward'] = jp.zeros(rng.shape[:-1])
episode_metrics['length'] = jp.zeros(rng.shape[:-1])
for metric_name in state.metrics.keys():
episode_metrics[metric_name] = jp.zeros(rng.shape[:-1])
state.info['episode_metrics'] = episode_metrics
return state

def step(self, state: State, action: jax.Array) -> State:
Expand All @@ -101,6 +110,18 @@ def f(state, _):
steps >= episode_length, 1 - state.done, zero
)
state.info['steps'] = steps

# Aggregate state metrics into episode metrics
prev_done = state.info['episode_done']
state.info['episode_metrics']['sum_reward'] += jp.sum(rewards, axis=0)
state.info['episode_metrics']['sum_reward'] *= (1 - prev_done)
state.info['episode_metrics']['length'] += self.action_repeat
state.info['episode_metrics']['length'] *= (1 - prev_done)
for metric_name in state.metrics.keys():
if metric_name != 'reward':
state.info['episode_metrics'][metric_name] += state.metrics[metric_name]
state.info['episode_metrics'][metric_name] *= (1 - prev_done)
state.info['episode_done'] = done
return state.replace(done=done)


Expand Down
7 changes: 3 additions & 4 deletions brax/generalized/constraint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ def test_force(self, xml_file):
efc_jt = np.reshape(mj_next.efc_J, (-1, sys.qd_size())).T
# recover con_frc by backing it out from qf_constraint
con_frc = np.linalg.lstsq(efc_jt, state.qf_constraint, None)[0]
err += np.sum((mj_next.efc_AR @ con_frc + mj_next.efc_b) ** 2)
mj_err += np.sum(
(mj_next.efc_AR @ mj_next.efc_force + mj_next.efc_b) ** 2
)
ar = mj_next.efc_AR.reshape((con_frc.shape[0], con_frc.shape[0]))
err += np.sum((ar @ con_frc + mj_next.efc_b) ** 2)
mj_err += np.sum((ar @ mj_next.efc_force + mj_next.efc_b) ** 2)

self.assertLessEqual(err, mj_err + 0.01)

Expand Down
Loading

0 comments on commit 1bc5c00

Please sign in to comment.