diff --git a/README.md b/README.md index 81af4253..81be897b 100644 --- a/README.md +++ b/README.md @@ -215,10 +215,15 @@ jitted_vmapped_frame_fk = jit(vmapped_frame_fk) joints_batch = jnp.tile(joints, (1024, 1)) w_H_b_batch = jnp.tile(w_H_b, (1024, 1, 1)) w_H_f_batch = jitted_vmapped_frame_fk(w_H_b_batch, joints_batch) -# Note that the first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast. + ``` +> [!NOTE] +> The first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast! + +```python + ### CasADi interface ```python @@ -260,7 +265,6 @@ joints = cs.MX.sym('joints', len(joints_name_list)) M = kinDyn.mass_matrix_fun() print(M(w_H_b, joints)) - ``` ### PyTorch interface @@ -295,6 +299,9 @@ print(M) ### PyTorch Batched interface +> [!NOTE] +> When using this interface, note that the first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast! + ```python import adam from adam.pytorch import KinDynComputationsBatch @@ -325,7 +332,6 @@ joints_batch = torch.tensor(np.tile(joints, (num_samples, 1)), dtype=torch.float M = kinDyn.mass_matrix(w_H_b_batch, joints_batch) w_H_f = kinDyn.forward_kinematics('frame_name', w_H_b_batch, joints_batch) -# Note that the first call of the jitted function can be slow, since JAX needs to compile the function. Than it will be fast. ``` ## 🦸‍♂️ Contributing