Skip to content

Commit

Permalink
changed under_construction so it investigates the call stack instead …
Browse files Browse the repository at this point in the history
…of using a class variable
  • Loading branch information
tclose committed Jan 29, 2025
1 parent 4ff7303 commit 0e0a02d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 51 deletions.
2 changes: 1 addition & 1 deletion pydra/design/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def this() -> "Workflow":
"""
from pydra.engine.core import Workflow

return Workflow.under_construction
return Workflow.under_construction()


OutputsType = ty.TypeVar("OutputsType", bound="TaskOutputs")
Expand Down
103 changes: 53 additions & 50 deletions pydra/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def _combined_output(self, return_inputs=False):
return None
if return_inputs is True or return_inputs == "val":
result = (self.state.states_val[ind], result)
elif return_inputs == "ind":
elif return_inputs is True or return_inputs == "ind":
result = (self.state.states_ind[ind], result)
combined_results_gr.append(result)
combined_results.append(combined_results_gr)
Expand Down Expand Up @@ -637,7 +637,7 @@ def construct(

# Initialise the lzin fields
lazy_spec = copy(definition)
wf = cls.under_construction = Workflow(
workflow = Workflow(
name=type(definition).__name__,
inputs=lazy_spec,
outputs=outputs,
Expand All @@ -647,51 +647,66 @@ def construct(
lazy_spec,
lzy_inpt.name,
LazyInField(
workflow=wf,
workflow=workflow,
field=lzy_inpt.name,
type=lzy_inpt.type,
),
)

input_values = attrs_values(lazy_spec)
constructor = input_values.pop("constructor")
cls._under_construction = wf
try:
# Call the user defined constructor to set the outputs
output_lazy_fields = constructor(**input_values)
# Check to see whether any mandatory inputs are not set
for node in wf.nodes:
node._definition._check_rules()
# Check that the outputs are set correctly, either directly by the constructor
# or via returned values that can be zipped with the output names
if output_lazy_fields:
if not isinstance(output_lazy_fields, (list, tuple)):
output_lazy_fields = [output_lazy_fields]
output_fields = list_fields(definition.Outputs)
if len(output_lazy_fields) != len(output_fields):
raise ValueError(
f"Expected {len(output_fields)} outputs, got "
f"{len(output_lazy_fields)} ({output_lazy_fields})"
)
for outpt, outpt_lf in zip(output_fields, output_lazy_fields):
# Automatically combine any uncombined state arrays into lists
if TypeParser.get_origin(outpt_lf.type) is StateArray:
outpt_lf.type = list[TypeParser.strip_splits(outpt_lf.type)[0]]
setattr(outputs, outpt.name, outpt_lf)
else:
if unset_outputs := [
a for a, v in attrs_values(outputs).items() if v is attrs.NOTHING
]:
raise ValueError(
f"Expected outputs {unset_outputs} to be set by the "
f"constructor of {wf!r}"
)
finally:
cls._under_construction = None
# Call the user defined constructor to set the outputs
output_lazy_fields = constructor(**input_values)
# Check to see whether any mandatory inputs are not set
for node in workflow.nodes:
node._definition._check_rules()
# Check that the outputs are set correctly, either directly by the constructor
# or via returned values that can be zipped with the output names
if output_lazy_fields:
if not isinstance(output_lazy_fields, (list, tuple)):
output_lazy_fields = [output_lazy_fields]
output_fields = list_fields(definition.Outputs)
if len(output_lazy_fields) != len(output_fields):
raise ValueError(
f"Expected {len(output_fields)} outputs, got "
f"{len(output_lazy_fields)} ({output_lazy_fields})"
)
for outpt, outpt_lf in zip(output_fields, output_lazy_fields):
# Automatically combine any uncombined state arrays into lists
if TypeParser.get_origin(outpt_lf.type) is StateArray:
outpt_lf.type = list[TypeParser.strip_splits(outpt_lf.type)[0]]
setattr(outputs, outpt.name, outpt_lf)
else:
if unset_outputs := [
a for a, v in attrs_values(outputs).items() if v is attrs.NOTHING
]:
raise ValueError(
f"Expected outputs {unset_outputs} to be set by the "
f"constructor of {workflow!r}"
)

cls._constructed[hash_key] = wf
cls._constructed[hash_key] = workflow

return wf
return workflow

@classmethod
def under_construction(cls) -> "Workflow[ty.Any]":
"""Access the under_construction variable by iterating up through the call stack."""
frame = inspect.currentframe()
while frame:
# Find the frame where the construct method was called
if (
frame.f_code.co_name == "construct"
and "cls" in frame.f_locals
and frame.f_locals["cls"] is cls
and "workflow" in frame.f_locals
):
return frame.f_locals["workflow"] # local var "workflow" in construct
frame = frame.f_back
raise RuntimeError(
"No workflow is currently under construction (i.e. did not find a "
"`Workflow.construct` in the current call stack"
)

@classmethod
def clear_cache(cls):
Expand Down Expand Up @@ -733,18 +748,6 @@ def nodes(self) -> ty.Iterable[Node]:
def node_names(self) -> list[str]:
return list(self._nodes)

@property
@classmethod
def under_construction(cls) -> "Workflow[ty.Any]":
if cls._under_construction is None:
raise ValueError(
"pydra.design.workflow.this() can only be called from within a workflow "
"constructor function (see 'pydra.design.workflow.define')"
)
return cls._under_construction

# Used to store the workflow that is currently being constructed
_under_construction: "Workflow[ty.Any]" = None
# Used to cache the constructed workflows by their hashed input values
_constructed: dict[int, "Workflow[ty.Any]"] = {}

Expand Down

0 comments on commit 0e0a02d

Please sign in to comment.