|
1 | 1 | from functools import lru_cache
|
2 | 2 | from typing import Optional, Sequence
|
3 | 3 |
|
| 4 | +import copy |
4 | 5 | import numpy as np
|
5 | 6 | import torch
|
6 | 7 |
|
@@ -430,28 +431,29 @@ class SerialChain(Chain):
|
430 | 431 | """
|
431 | 432 |
|
432 | 433 | def __init__(self, chain, end_frame_name, root_frame_name="", **kwargs):
|
433 |
| - if root_frame_name == "": |
434 |
| - super().__init__(chain._root, **kwargs) |
435 |
| - else: |
436 |
| - super().__init__(chain.find_frame(root_frame_name), **kwargs) |
437 |
| - if self._root is None: |
438 |
| - raise ValueError("Invalid root frame name %s." % root_frame_name) |
439 |
| - self._serial_frames = [self._root] + self._generate_serial_chain_recurse(self._root, end_frame_name) |
440 |
| - if self._serial_frames is None: |
441 |
| - raise ValueError("Invalid end frame name %s." % end_frame_name) |
442 |
| - |
443 |
| - @staticmethod |
444 |
| - def _generate_serial_chain_recurse(root_frame, end_frame_name): |
445 |
| - for child in root_frame.children: |
446 |
| - if child.name == end_frame_name: |
447 |
| - # chop off any remaining tree after end frame |
448 |
| - child.children = [] |
449 |
| - return [child] |
450 |
| - else: |
451 |
| - frames = SerialChain._generate_serial_chain_recurse(child, end_frame_name) |
452 |
| - if not frames is None: |
453 |
| - return [child] + frames |
454 |
| - return None |
| 434 | + root_frame = chain._root if root_frame_name == "" else chain.find_frame(root_frame_name) |
| 435 | + if root_frame is None: |
| 436 | + raise ValueError("Invalid root frame name %s." % root_frame_name) |
| 437 | + chain = Chain(root_frame, **kwargs) |
| 438 | + |
| 439 | + # make a copy of those frames that includes only the chain up to the end effector |
| 440 | + end_frame_idx = chain.get_frame_indices(end_frame_name) |
| 441 | + ancestors = chain.parents_indices[end_frame_idx] |
| 442 | + |
| 443 | + frames = [] |
| 444 | + # first pass create copies of the ancestor nodes |
| 445 | + for idx in ancestors: |
| 446 | + this_frame_name = chain.idx_to_frame[idx.item()] |
| 447 | + this_frame = copy.deepcopy(chain.find_frame(this_frame_name)) |
| 448 | + if idx == end_frame_idx: |
| 449 | + this_frame.children = [] |
| 450 | + frames.append(this_frame) |
| 451 | + # second pass assign correct children (only the next one in the frame list) |
| 452 | + for i in range(len(ancestors) - 1): |
| 453 | + frames[i].children = [frames[i + 1]] |
| 454 | + |
| 455 | + self._serial_frames = frames |
| 456 | + super().__init__(frames[0], **kwargs) |
455 | 457 |
|
456 | 458 | def jacobian(self, th, locations=None, **kwargs):
|
457 | 459 | if locations is not None:
|
|
0 commit comments