Skip to content

Commit 01533b6

Browse files
committed
Update model reduction logic
1 parent 874c16c commit 01533b6

File tree

3 files changed

+42
-68
lines changed

3 files changed

+42
-68
lines changed

src/adam/core/rbd_algorithms.py

+18-36
Original file line numberDiff line numberDiff line change
@@ -481,15 +481,7 @@ def aba(
481481
joint_accelerations (T): The joints acceleration
482482
"""
483483
model = self.model.reduce(self.model.actuated_joints)
484-
485-
joints = list(
486-
filter(
487-
lambda joint: joint.name in self.model.actuated_joints,
488-
self.model.joints.values(),
489-
)
490-
)
491-
492-
joints.sort(key=lambda joint: joint.idx)
484+
joints = list(model.joints.values())
493485

494486
NB = model.N
495487

@@ -506,10 +498,8 @@ def aba(
506498
sdd = self.math.factory.zeros(NB, 1, 1)
507499
B_X_W = self.math.adjoint_mixed(base_transform)
508500

509-
if self.model.floating_base:
510-
IA[0] = self.model.tree.get_node_from_name(
511-
self.root_link
512-
).link.spatial_inertia()
501+
if model.floating_base:
502+
IA[0] = model.tree.get_node_from_name(self.root_link).link.spatial_inertia()
513503
v[0] = B_X_W @ base_velocity
514504
pA[0] = (
515505
self.math.spatial_skew_star(v[0]) @ IA[0] @ v[0]
@@ -522,7 +512,7 @@ def get_tree_transform(self, joints) -> "Array":
522512
Array: the tree transform
523513
"""
524514
relative_transform = lambda j: self.math.inv(
525-
self.model.tree.graph[j.parent].parent_arc.spatial_transform(0)
515+
model.tree.graph[j.child].parent_arc.spatial_transform(0)
526516
) @ j.spatial_transform(0)
527517

528518
return self.math.vertcat(
@@ -538,19 +528,7 @@ def get_tree_transform(self, joints) -> "Array":
538528
)
539529

540530
tree_transform = get_tree_transform(self, joints)
541-
542-
find_parent = (
543-
lambda j: find_parent(model.tree.get_node_from_name(j.parent).parent_arc)
544-
if model.tree.get_node_from_name(j.parent).parent_arc.idx is None
545-
else model.tree.get_node_from_name(j.parent).parent_arc.idx
546-
)
547-
548-
p = [-1] + [
549-
model.tree.get_idx_from_name(i.parent)
550-
if model.tree.get_idx_from_name(i.parent) < NB
551-
else find_parent(i)
552-
for i in joints
553-
]
531+
p = lambda i: list(model.tree.graph).index(joints[i].parent)
554532

555533
# Pass 1
556534
for i, joint in enumerate(joints[1:], start=1):
@@ -561,8 +539,8 @@ def get_tree_transform(self, joints) -> "Array":
561539
i_X_pi[i] = joint.spatial_transform(q) @ tree_transform[i]
562540
v_J = joint.motion_subspace() * q_dot
563541

564-
v[i] = i_X_pi[i] @ v[p[i]] + v_J
565-
c[i] = i_X_pi[i] @ c[p[i]] + self.math.spatial_skew(v[i]) @ v_J
542+
v[i] = i_X_pi[i] @ v[p(i)] + v_J
543+
c[i] = i_X_pi[i] @ c[p(i)] + self.math.spatial_skew(v[i]) @ v_J
566544

567545
IA[i] = model.tree.get_node_from_name(joint.parent).link.spatial_inertia()
568546

@@ -579,17 +557,21 @@ def get_tree_transform(self, joints) -> "Array":
579557
):
580558
U[i] = IA[i] @ joint.motion_subspace()
581559
D[i] = joint.motion_subspace().T @ U[i]
582-
u[i] = self.math.vertcat(tau[joint.idx]) - joint.motion_subspace().T @ pA[i]
560+
u[i] = (
561+
self.math.vertcat(tau[joint.idx]) - joint.motion_subspace().T @ pA[i]
562+
if joint.idx is not None
563+
else 0.0
564+
)
583565

584566
Ia = IA[i] - U[i] / D[i] @ U[i].T
585567
pa = pA[i] + Ia @ c[i] + U[i] * u[i] / D[i]
586568

587-
if joint.parent != self.root_link or not self.model.floating_base:
588-
IA[p[i]] += i_X_pi[i].T @ Ia @ i_X_pi[i]
589-
pA[p[i]] += i_X_pi[i].T @ pa
569+
if joint.parent != self.root_link or not model.floating_base:
570+
IA[p(i)] += i_X_pi[i].T @ Ia @ i_X_pi[i]
571+
pA[p(i)] += i_X_pi[i].T @ pa
590572
continue
591573

592-
a[0] = B_X_W @ g if self.model.floating_base else self.math.solve(-IA[0], pA[0])
574+
a[0] = B_X_W @ g if model.floating_base else self.math.solve(-IA[0], pA[0])
593575

594576
# Pass 3
595577
for i, joint in enumerate(joints[1:], start=1):
@@ -598,7 +580,7 @@ def get_tree_transform(self, joints) -> "Array":
598580

599581
sdd[i - 1] = (u[i] - U[i].T @ a[i]) / D[i]
600582

601-
a[i] += i_X_pi[i].T @ a[p[i]] + joint.motion_subspace() * sdd[i - 1] + c[i]
583+
a[i] += i_X_pi[i].T @ a[p(i)] + joint.motion_subspace() * sdd[i - 1] + c[i]
602584

603585
# Squeeze sdd
604586
s_ddot = self.math.vertcat(*[sdd[i] for i in range(sdd.shape[0])])
@@ -613,7 +595,7 @@ def get_tree_transform(self, joints) -> "Array":
613595
return self.math.horzcat(
614596
self.math.vertcat(
615597
self.math.solve(B_X_W, a[0]) + g
616-
if self.model.floating_base
598+
if model.floating_base
617599
else self.math.zeros(6, 1),
618600
),
619601
s_ddot,

src/adam/model/model.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,12 @@ def reduce(self, joints_name_list: List[str]) -> "Model":
9898
)
9999

100100
tree = self.tree.reduce(joints_name_list)
101-
joints_list = list(
102-
filter(
103-
lambda joint: joint.name in self.actuated_joints,
104-
self.joints.values(),
105-
)
106-
)
101+
102+
joints_list = [
103+
node.parent_arc for node in tree.graph.values() if node.name != tree.root
104+
]
107105
joints_list.sort(key=lambda joint: joint.idx)
108-
# update nodes dict
106+
109107
links = {node.name: node.link for node in tree.graph.values()}
110108
joints = {joint.name: joint for joint in joints_list}
111109
frames = {

src/adam/model/tree.py

+19-25
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,6 @@ def reduce(self, considered_joint_names: List[str]) -> "Tree":
8282
Returns:
8383
Tree: the reduced tree
8484
"""
85-
# find the nodes between two fixed joints
86-
nodes_to_lump = list(
87-
{
88-
joint.child
89-
for node in self.graph.values()
90-
for joint in node.arcs
91-
if joint.name not in considered_joint_names
92-
}
93-
)
9485

9586
relative_transform = (
9687
lambda node: node.link.math.inv(
@@ -101,21 +92,17 @@ def reduce(self, considered_joint_names: List[str]) -> "Tree":
10192
else node.parent_arc.spatial_transform(0)
10293
)
10394

104-
last = []
105-
leaves = [node for node in self.graph.values() if node.children == last]
95+
# find the tree leaves and proceed until the root
96+
leaves = [node for node in self.graph.values() if node.children == []]
10697

10798
while all(leaf.name != self.root for leaf in leaves):
10899
for leaf in leaves:
109-
if leaf is self.graph[self.root]:
110-
continue
111-
112-
if leaf.parent_arc.name not in considered_joint_names:
113-
# create the new node
100+
if leaf.parent_arc.name not in considered_joint_names + [self.root]:
114101
new_node = Node(
115102
name=leaf.parent.name,
116103
link=None,
117104
arcs=[],
118-
children=None,
105+
children=[],
119106
parent=None,
120107
parent_arc=None,
121108
)
@@ -129,22 +116,29 @@ def reduce(self, considered_joint_names: List[str]) -> "Tree":
129116
# update the parents
130117
new_node.parent = self.graph[leaf.parent.name].parent
131118
new_node.parent_arc = self.graph[new_node.name].parent_arc
132-
new_node.parent_arc.parent = (
133-
leaf.children[0].parent_arc.name if leaf.children != [] else []
134-
)
135119

136120
# update the children
137-
new_node.children = leaf.children
121+
new_node.children = [
122+
child for child in leaf.children if child.name in self.graph
123+
]
124+
125+
for child in new_node.children:
126+
child.parent = new_node.link
127+
child.parent_arc = new_node.parent_arc
138128

139129
# update the arcs
140-
if leaf.arcs != []:
141-
for arc in leaf.arcs:
142-
if arc.name in considered_joint_names:
143-
new_node.arcs.append(arc)
130+
new_node.arcs = (
131+
[arc for arc in leaf.arcs if arc.name in considered_joint_names]
132+
if leaf.arcs != []
133+
else []
134+
)
135+
for j in new_node.arcs:
136+
j.parent = new_node.link.name
144137

145138
logging.debug(f"Removing {leaf.name}")
146139
self.graph.pop(leaf.name)
147140
self.graph[new_node.name] = new_node
141+
self.ordered_nodes_list.remove(leaf.name)
148142
leaves = [
149143
self.get_node_from_name((leaf.parent.name))
150144
for leaf in leaves

0 commit comments

Comments
 (0)