diff --git a/src/adam/pytorch/computation_batch.py b/src/adam/pytorch/computation_batch.py index 1e7753a9..4939cd7c 100644 --- a/src/adam/pytorch/computation_batch.py +++ b/src/adam/pytorch/computation_batch.py @@ -365,25 +365,43 @@ def fun(base_transform, joint_positions, base_velocity, joint_velocities): return self.funcs["coriolis_term"] def gravity_term( - self, base_transform: jnp.array, joint_positions: jnp.array - ) -> jnp.array: + self, base_transform: torch.Tensor, joint_positions: torch.Tensor + ) -> torch.Tensor: """Returns the gravity term of the floating-base dynamics equation, using a reduced RNEA (no acceleration and external forces) Args: - base_transform (jnp.array): The homogenous transform from base to world frame - joint_positions (jnp.array): The joints position + base_transform (torch.Tensor): The homogenous transform from base to world frame + joint_positions (torch.Tensor): The joints position Returns: G (jnp.array): the gravity term """ - return self.rbdalgos.rnea( - base_transform, - joint_positions, - np.zeros(6).reshape(6, 1), - np.zeros(self.NDoF), - self.g, - ).array.squeeze() + return self.gravity_term_fun()(base_transform, joint_positions) + + def gravity_term_fun(self): + """Returns the gravity term of the floating-base dynamics equation as a pytorch function + + Returns: + G (pytorch function): the gravity term + """ + if self.funcs.get("gravity_term") is not None: + return self.funcs["gravity_term"] + print("[INFO] Compiling gravity term function") + + def fun(base_transform, joint_positions): + return self.rbdalgos.rnea( + base_transform, + joint_positions, + np.zeros(6).reshape(6, 1), + np.zeros(self.NDoF), + self.g, + ).array.squeeze() + + vmapped_fun = jax.vmap(fun) + jit_vmapped_fun = jax.jit(vmapped_fun) + self.funcs["gravity_term"] = jax2torch(jit_vmapped_fun) + return self.funcs["gravity_term"] def CoM_position( self, base_transform: torch.Tensor, joint_positions: torch.Tensor