Skip to content

Commit

Permalink
fixed up setting of state in workflow constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Jan 29, 2025
1 parent 0a32bfa commit 4ff7303
Showing 1 changed file with 28 additions and 39 deletions.
67 changes: 28 additions & 39 deletions pydra/engine/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,7 @@ class Node(ty.Generic[OutputType]):
) # QUESTION: should this be included in the state?

def __attrs_post_init__(self):
# Add node name to state's splitter, combiner and cont_dim loaded from the def
splitter = self._definition._splitter
combiner = self._definition._combiner
if splitter:
splitter = hlpst.add_name_splitter(splitter, self.name)
if combiner:
combiner = hlpst.add_name_combiner(combiner, self.name)
if self._definition._cont_dim:
self._cont_dim = {}
for key, val in self._definition._cont_dim.items():
self._cont_dim[f"{self.name}.{key}"] = val
self._set_state(splitter=splitter, combiner=combiner)
if combiner:
if not_split := [
c for c in combiner if not any(c in s for s in self.state.splitter_rpn)
]:
raise ValueError(
f"Combiner fields {not_split} for Node {self.name!r} are not in the "
f"splitter fields {self.state.splitter_rpn}"
)
self._set_state()

class Inputs:
"""A class to wrap the inputs of a node and control access to them so lazy fields
Expand Down Expand Up @@ -101,6 +82,7 @@ def __setattr__(self, name: str, value: ty.Any) -> None:
f"cannot set {name!r} input to {value} because it changes the "
f"state"
)
self._set_state()

@property
def inputs(self) -> Inputs:
Expand All @@ -115,9 +97,6 @@ def state(self):
"""Initialise the state of the node just after it has been created (i.e. before
it has been split or combined) based on the upstream connections
"""
if self._state is not NOT_SET:
return self._state
self._set_state(other_states=self._get_upstream_states())
return self._state

@property
Expand Down Expand Up @@ -248,29 +227,39 @@ def _wrap_lzout_types_in_state_arrays(self) -> None:
type_ = StateArray[type_]
outpt_lf.type = type_

def _set_state(
self,
splitter: list[str] | tuple[str, ...] | None = None,
combiner: list[str] | None = None,
other_states: dict[str, tuple["State", list[str]]] | None = None,
) -> None:
if self._state not in (NOT_SET, None):
if splitter is None:
splitter = self._state.current_splitter
if combiner is None:
combiner = self._state.current_combiner
if other_states is None:
other_states = self._state.other_states
if not (splitter or combiner or other_states):
self._state = None
else:
def _set_state(self) -> None:
# Add node name to state's splitter, combiner and cont_dim loaded from the def
splitter = self._definition._splitter
combiner = self._definition._combiner
if splitter:
splitter = hlpst.add_name_splitter(splitter, self.name)
if combiner:
combiner = hlpst.add_name_combiner(combiner, self.name)
if self._definition._cont_dim:
self._cont_dim = {}
for key, val in self._definition._cont_dim.items():
self._cont_dim[f"{self.name}.{key}"] = val
other_states = self._get_upstream_states()
if splitter or combiner or other_states:
self._state = State(
self.name,
self._definition,
splitter=splitter,
other_states=other_states,
combiner=combiner,
)
if combiner:
if not_split := [
c
for c in combiner
if not any(c in s for s in self.state.splitter_rpn)
]:
raise ValueError(
f"Combiner fields {not_split} for Node {self.name!r} are not in the "
f"splitter fields {self.state.splitter_rpn}"
)
else:
self._state = None

def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]:
"""Get the states of the upstream nodes that are connected to this node"""
Expand Down

0 comments on commit 4ff7303

Please sign in to comment.