Skip to content

Commit

Permalink
Adding forgotten gravity term function
Browse files Browse the repository at this point in the history
  • Loading branch information
Giulero committed Jun 27, 2024
1 parent 003f99f commit 46f0265
Showing 1 changed file with 29 additions and 11 deletions.
40 changes: 29 additions & 11 deletions src/adam/pytorch/computation_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,25 +365,43 @@ def fun(base_transform, joint_positions, base_velocity, joint_velocities):
return self.funcs["coriolis_term"]

def gravity_term(
self, base_transform: jnp.array, joint_positions: jnp.array
) -> jnp.array:
self, base_transform: torch.Tensor, joint_positions: torch.Tensor
) -> torch.Tensor:
"""Returns the gravity term of the floating-base dynamics equation,
using a reduced RNEA (no acceleration and external forces)
Args:
base_transform (jnp.array): The homogenous transform from base to world frame
joint_positions (jnp.array): The joints position
base_transform (torch.Tensor): The homogenous transform from base to world frame
joint_positions (torch.Tensor): The joints position
Returns:
G (jnp.array): the gravity term
"""
return self.rbdalgos.rnea(
base_transform,
joint_positions,
np.zeros(6).reshape(6, 1),
np.zeros(self.NDoF),
self.g,
).array.squeeze()
return self.gravity_term_fun()(base_transform, joint_positions)

def gravity_term_fun(self):
"""Returns the gravity term of the floating-base dynamics equation as a pytorch function
Returns:
G (pytorch function): the gravity term
"""
if self.funcs.get("gravity_term") is not None:
return self.funcs["gravity_term"]
print("[INFO] Compiling gravity term function")

def fun(base_transform, joint_positions):
return self.rbdalgos.rnea(
base_transform,
joint_positions,
np.zeros(6).reshape(6, 1),
np.zeros(self.NDoF),
self.g,
).array.squeeze()

vmapped_fun = jax.vmap(fun)
jit_vmapped_fun = jax.jit(vmapped_fun)
self.funcs["gravity_term"] = jax2torch(jit_vmapped_fun)
return self.funcs["gravity_term"]

def CoM_position(
self, base_transform: torch.Tensor, joint_positions: torch.Tensor
Expand Down

0 comments on commit 46f0265

Please sign in to comment.