Skip to content

Commit 3c2de13

Browse files
committed
Fix typos and finalize ABA
1 parent e148da1 commit 3c2de13

File tree

3 files changed

+55
-17
lines changed

3 files changed

+55
-17
lines changed

src/adam/casadi/casadi_like.py

+12
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,18 @@ def horzcat(*x) -> "CasadiLike":
212212
y = [xi.array if isinstance(xi, CasadiLike) else xi for xi in x]
213213
return CasadiLike(cs.horzcat(*y))
214214

215+
@staticmethod
216+
def solve(A: "CasadiLike", b: "CasadiLike") -> "CasadiLike":
217+
"""
218+
Args:
219+
A (CasadiLike): matrix
220+
b (CasadiLike): vector
221+
222+
Returns:
223+
CasadiLike: solution of A*x=b
224+
"""
225+
return CasadiLike(cs.solve(A.array, b.array))
226+
215227

216228
if __name__ == "__main__":
217229
math = SpatialMath()

src/adam/core/rbd_algorithms.py

+39-15
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def aba(
462462
base_transform: npt.ArrayLike,
463463
joint_positions: npt.ArrayLike,
464464
joint_velocities: npt.ArrayLike,
465-
joint_torques: npt.ArrayLike,
465+
tau: npt.ArrayLike,
466466
g: npt.ArrayLike,
467467
) -> npt.ArrayLike:
468468
"""Implementation of Articulated Body Algorithm
@@ -483,24 +483,27 @@ def aba(
483483
c = self.math.factory.zeros(self.model.N, 6, 1)
484484
pA = self.math.factory.zeros(self.model.N, 6, 1)
485485
IA = self.math.factory.zeros(self.model.N, 6, 6)
486-
U = self.math.factory.zeros(self.model.N, 6, 6)
487-
D = self.math.factory.zeros(self.model.N, 6, 6)
488-
u = self.math.factory.zeros(self.model.N, 6, 1)
486+
U = self.math.factory.zeros(self.model.N, 6, 1)
487+
D = self.math.factory.zeros(self.model.N, 1, 1)
488+
u = self.math.factory.zeros(self.model.N, 1, 1)
489489
a = self.math.factory.zeros(self.model.N, 6, 1)
490-
f = self.math.factory.zeros(self.model.N, 6, 1)
490+
sdd = self.math.factory.zeros(self.model.N, 1, 1)
491+
B_X_W = self.math.adjoint_mixed_inverse(base_transform)
491492

492493
# Pass 1
493494
for i, node in enumerate(self.model.tree):
494495
link_i, joint_i, link_pi = node.get_elements()
495496

496497
if link_i.name == self.root_link:
497498
continue
499+
q = joint_positions[joint_i.idx] if joint_i.idx is not None else 0.0
500+
q_dot = joint_velocities[joint_i.idx] if joint_i.idx is not None else 0.0
498501

499502
pi = self.model.tree.get_idx_from_name(link_pi.name)
500503

501504
# Parent-child transform
502-
i_X_pi[i] = joint_i.spatial_transform(joint_positions[i])
503-
v_J = joint_i.motion_subspace() * joint_velocities[i]
505+
i_X_pi[i] = joint_i.spatial_transform(q)
506+
v_J = joint_i.motion_subspace() * q_dot
504507

505508
v[i] = i_X_pi[i] @ v[pi] + v_J
506509
c[i] = i_X_pi[i] @ c[pi] + self.math.spatial_skew(v[i]) @ v_J
@@ -519,26 +522,47 @@ def aba(
519522
continue
520523

521524
pi = self.model.tree.get_idx_from_name(link_pi.name)
525+
tau_i = tau[joint_i.idx] if joint_i.idx is not None else 0.0
522526

523-
U[i] = IA[i] @ node.joint.motion_subspace()
524-
D[i] = node.joint.motion_subspace().T @ U[i]
525-
u[i] = tau[i] - node.joint.motion_subspace().T @ pA[i]
527+
U[i] = IA[i] @ joint_i.motion_subspace()
528+
D[i] = joint_i.motion_subspace().T @ U[i]
529+
u[i] = self.math.vertcat(tau_i) - joint_i.motion_subspace().T @ pA[i]
526530

527531
Ia = IA[i] - U[i] / D[i] @ U[i].T
528532
pa = pA[i] + Ia @ c[i] + U[i] * u[i] / D[i]
529533

534+
a[0] = B_X_W @ g if self.model.floating_base else self.math.solve(-IA[0], pA[0])
535+
530536
# Pass 3
531537
for i, node in enumerate(self.model.tree):
532538
link_i, joint_i, link_pi = node.get_elements()
533539

534540
if link_i.name == self.root_link:
535-
IA[pi] += i_X_pi[i].T @ Ia @ i_X_pi[i]
536-
pA[pi] += i_X_pi[i].T @ pa
541+
continue
537542

538543
pi = self.model.tree.get_idx_from_name(link_pi.name)
539544

540-
sdd = (u[i] - U[i].T @ a[i]) / D[i]
545+
sdd[i - 1] = (u[i] - U[i].T @ a[i]) / D[i]
546+
547+
a[i] += i_X_pi[i].T @ a[pi] + joint_i.motion_subspace() * sdd[i - 1] + c[i]
548+
549+
# Filter sdd to remove NaNs generate with lumped joints
550+
s_ddot = self.math.vertcat(
551+
*[sdd[i] for i in range(self.model.N) if sdd[i] == sdd[i]]
552+
)
541553

542-
a[i] = i_X_pi[i].T @ a[pi] + node.joint.motion_subspace() * sdd + c[i]
554+
if (
555+
self.frame_velocity_representation
556+
== Representations.BODY_FIXED_REPRESENTATION
557+
):
558+
return a[0], s_ddot
543559

544-
return a, sdd
560+
elif self.frame_velocity_representation == Representations.MIXED_REPRESENTATION:
561+
return (
562+
self.math.vertcat(
563+
self.math.solve(B_X_W, a[0]) + g
564+
if self.model.floating_base
565+
else self.math.zeros(6, 1),
566+
),
567+
s_ddot,
568+
)

src/adam/jax/computations.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -259,10 +259,12 @@ def forward_dynamics(
259259
base_acceleration (jnp.array): The base acceleration in mixed representation
260260
joint_accelerations (jnp.array): The joints acceleration
261261
"""
262-
return self.rbdalgos.aba(
262+
base_acceleration, joint_accelerations = self.rbdalgos.aba(
263263
base_transform,
264264
joint_positions,
265265
joint_velocities,
266266
joint_torques,
267267
self.g,
268-
).array.squeeze()
268+
)
269+
270+
return base_acceleration.array.squeeze(), joint_accelerations.array.squeeze()

0 commit comments

Comments
 (0)