diff --git a/cubed/core/plan.py b/cubed/core/plan.py index 69ae70e4..86d63c15 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -8,7 +8,7 @@ import networkx as nx import zarr -from cubed.core.optimization import simple_optimize_dag +from cubed.core.optimization import multiple_inputs_optimize_dag from cubed.primitive.blockwise import BlockwiseSpec from cubed.primitive.types import PrimitiveOperation from cubed.runtime.pipeline import visit_nodes @@ -135,7 +135,7 @@ def optimize( optimize_function: Optional[Callable[..., nx.MultiDiGraph]] = None, ): if optimize_function is None: - optimize_function = simple_optimize_dag + optimize_function = multiple_inputs_optimize_dag dag = optimize_function(self.dag) return Plan(dag)