Skip to content

Commit

Permalink
Fix assemble_tree() for non-disassembled Tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jul 25, 2024
1 parent 92a872a commit 92701b5
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,8 @@ def assemble_tree(obj: PhiTreeNodeType, values: List[Tensor], attr_type=variable
return tuple([assemble_tree(item, values, attr_type) for item in obj])
elif isinstance(obj, dict):
return {name: assemble_tree(val, values, attr_type) for name, val in obj.items()}
elif isinstance(obj, Tensor):
return obj
elif isinstance(obj, PhiTreeNode):
attributes = attr_type(obj)
values = {a: assemble_tree(getattr(obj, a), values, attr_type) for a in attributes}
Expand Down

0 comments on commit 92701b5

Please sign in to comment.