Skip to content

Commit 2baf77c

Browse files
committed
Generate copies of frames with only direct ancestors in SerialChain
1 parent c7cda5e commit 2baf77c

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

src/pytorch_kinematics/chain.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import lru_cache
22
from typing import Optional, Sequence
33

4+
import copy
45
import numpy as np
56
import torch
67

@@ -430,28 +431,29 @@ class SerialChain(Chain):
430431
"""
431432

432433
def __init__(self, chain, end_frame_name, root_frame_name="", **kwargs):
433-
if root_frame_name == "":
434-
super().__init__(chain._root, **kwargs)
435-
else:
436-
super().__init__(chain.find_frame(root_frame_name), **kwargs)
437-
if self._root is None:
438-
raise ValueError("Invalid root frame name %s." % root_frame_name)
439-
self._serial_frames = [self._root] + self._generate_serial_chain_recurse(self._root, end_frame_name)
440-
if self._serial_frames is None:
441-
raise ValueError("Invalid end frame name %s." % end_frame_name)
442-
443-
@staticmethod
444-
def _generate_serial_chain_recurse(root_frame, end_frame_name):
445-
for child in root_frame.children:
446-
if child.name == end_frame_name:
447-
# chop off any remaining tree after end frame
448-
child.children = []
449-
return [child]
450-
else:
451-
frames = SerialChain._generate_serial_chain_recurse(child, end_frame_name)
452-
if not frames is None:
453-
return [child] + frames
454-
return None
434+
root_frame = chain._root if root_frame_name == "" else chain.find_frame(root_frame_name)
435+
if root_frame is None:
436+
raise ValueError("Invalid root frame name %s." % root_frame_name)
437+
chain = Chain(root_frame, **kwargs)
438+
439+
# make a copy of those frames that includes only the chain up to the end effector
440+
end_frame_idx = chain.get_frame_indices(end_frame_name)
441+
ancestors = chain.parents_indices[end_frame_idx]
442+
443+
frames = []
444+
# first pass create copies of the ancestor nodes
445+
for idx in ancestors:
446+
this_frame_name = chain.idx_to_frame[idx.item()]
447+
this_frame = copy.deepcopy(chain.find_frame(this_frame_name))
448+
if idx == end_frame_idx:
449+
this_frame.children = []
450+
frames.append(this_frame)
451+
# second pass assign correct children (only the next one in the frame list)
452+
for i in range(len(ancestors) - 1):
453+
frames[i].children = [frames[i + 1]]
454+
455+
self._serial_frames = frames
456+
super().__init__(frames[0], **kwargs)
455457

456458
def jacobian(self, th, locations=None, **kwargs):
457459
if locations is not None:

src/pytorch_kinematics/urdf.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,4 @@ def build_serial_chain_from_urdf(data, end_link_name, root_link_name=""):
133133
SerialChain object created from URDF.
134134
"""
135135
urdf_chain = build_chain_from_urdf(data)
136-
return chain.SerialChain(urdf_chain, end_link_name,
137-
"" if root_link_name == "" else root_link_name)
136+
return chain.SerialChain(urdf_chain, end_link_name, root_link_name or '')

0 commit comments

Comments
 (0)