diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 9c97300..40ce637 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -1,11 +1,30 @@ +from __future__ import annotations + import collections import copy +import dataclasses import datetime +import fractions +import functools import itertools import logging import time +import types from contextlib import nullcontext -from typing import Callable, Dict, Iterator, List, Optional, Protocol, Sequence, Tuple +from typing import ( + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Protocol, + Sequence, + Set, + Tuple, + Union, +) import openeo from openeo import BatchJob @@ -14,7 +33,12 @@ from openeo_aggregator.constants import JOB_OPTION_FORCE_BACKEND from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob from openeo_aggregator.partitionedjobs.splitting import AbstractJobSplitter -from openeo_aggregator.utils import FlatPG, PGWithMetadata, SkipIntermittentFailures +from openeo_aggregator.utils import ( + _UNSET, + FlatPG, + PGWithMetadata, + SkipIntermittentFailures, +) _log = logging.getLogger(__name__) @@ -24,6 +48,10 @@ SubGraphId = str +class GraphSplitException(Exception): + pass + + class GetReplacementCallable(Protocol): """ Type annotation for callback functions that produce a node replacement @@ -88,11 +116,11 @@ def split_streaming( (e.g. main "primary" graph comes last). The iterator approach allows working with a dynamic `get_replacement` implementation - that adapting to on previously produced subgraphs + that can be adaptive to previously produced subgraphs (e.g. creating openEO batch jobs on the fly and injecting the corresponding batch job ids appropriately). - :return: tuple containing: - - subgraph id, recommended to handle it as opaque id (but usually format '{backend_id}:{node_id}') + :return: Iterator of tuples containing: + - subgraph id, it's recommended to handle it as opaque id (but usually format '{backend_id}:{node_id}') - SubJob - dependencies as list of subgraph ids """ @@ -351,3 +379,311 @@ def run_partitioned_job(pjob: PartitionedJob, connection: openeo.Connection, fai } for sid in subjobs.keys() } + + +# Type aliases to make things more self-documenting +NodeId = str +BackendId = str + + +@dataclasses.dataclass(frozen=True) +class _FrozenNode: + """ + Node in a _FrozenGraph, with pointers to other nodes it depends on (needs data/input from) + and nodes to which it is input to. + + This is as immutable as possible (as far as Python allows) to + be used and reused in iterative/recursive graph handling algorithms, + without having to worry about accidentally changing state. + """ + + # TODO: instead of frozen dataclass: have __init__ with some type casting/validation. Or use attrs? + # TODO: better name for this class? + + # Node ids of other nodes this node depends on (aka parents) + depends_on: frozenset[NodeId] + # Node ids of other nodes that depend on this node (aka children) + flows_to: frozenset[NodeId] + + # Backend ids this node is marked to be supported on + # value None means it is unknown/unconstrained for this node + # TODO: Move this to _FrozenGraph as responsibility? + backend_candidates: Union[frozenset[BackendId], None] + + def __repr__(self): + return "".join( + [ + f"Node ", + f"@({','.join(self.backend_candidates) if self.backend_candidates else None})", + ] + + [f"<{d}" for d in self.depends_on] + + [f">{f}" for f in self.flows_to] + ) + + +class _FrozenGraph: + """ + Graph of _FrozenNode objects. + """ + + # TODO: find better class name: e.g. SplitGraphView, GraphSplitUtility, GraphSplitter, ...? + + def __init__(self, graph: dict[NodeId, _FrozenNode]): + # Work with a read-only proxy to prevent accidental changes + # TODO: check consistency of references? + self._graph: Mapping[NodeId, _FrozenNode] = types.MappingProxyType(graph) + + def __repr__(self): + return f"<{type(self).__name__}({self._graph})>" + + @classmethod + def from_flat_graph(cls, flat_graph: FlatPG, backend_candidates_map: Dict[NodeId, Iterable[BackendId]]): + """ + Build _FrozenGraph from a flat process graph representation + """ + # Extract dependency links between nodes + depends_on = collections.defaultdict(list) + flows_to = collections.defaultdict(list) + for node_id, node in flat_graph.items(): + for arg_value in node.get("arguments", {}).values(): + if isinstance(arg_value, dict) and list(arg_value.keys()) == ["from_node"]: + from_node = arg_value["from_node"] + depends_on[node_id].append(from_node) + flows_to[from_node].append(node_id) + graph = { + node_id: _FrozenNode( + depends_on=frozenset(depends_on.get(node_id, [])), + flows_to=frozenset(flows_to.get(node_id, [])), + backend_candidates=( + # TODO move this logic to _FrozenNode.__init__ + frozenset(backend_candidates_map.get(node_id)) + if node_id in backend_candidates_map + else None + ), + ) + for node_id, node in flat_graph.items() + } + return cls(graph=graph) + + @classmethod + def from_edges( + cls, + edges: Iterable[Tuple[NodeId, NodeId]], + backend_candidates_map: Optional[Dict[NodeId, Iterable[BackendId]]] = None, + ): + """ + Simple factory to build graph from parent-child tuples for testing purposes + """ + depends_on = collections.defaultdict(list) + flows_to = collections.defaultdict(list) + for parent, child in edges: + depends_on[child].append(parent) + flows_to[parent].append(child) + + graph = { + node_id: _FrozenNode( + # Note that we just use node id as process id. Do we have better options here? + depends_on=frozenset(depends_on.get(node_id, [])), + flows_to=frozenset(flows_to.get(node_id, [])), + backend_candidates=( + frozenset(backend_candidates_map.get(node_id)) + if backend_candidates_map and node_id in backend_candidates_map + else None + ), + ) + for node_id in set(depends_on.keys()).union(flows_to.keys()) + } + return cls(graph=graph) + + def node(self, node_id: NodeId) -> _FrozenNode: + return self._graph[node_id] + + def iter_nodes(self) -> Iterator[Tuple[NodeId, _FrozenNode]]: + """Iterate through node_id-node pairs""" + yield from self._graph.items() + + def _walk( + self, seeds: Iterable[NodeId], next_nodes: Callable[[NodeId], Iterable[NodeId]], include_seeds: bool = True + ) -> Iterator[NodeId]: + """ + Walk the graph nodes starting from given seed nodes, taking steps as defined by `next_nodes` function. + Optionally include seeds or not, and walk breadth first. + """ + if include_seeds: + visited = set() + to_visit = list(seeds) + else: + visited = set(seeds) + to_visit = [n for s in seeds for n in next_nodes(s)] + + while to_visit: + node_id = to_visit.pop(0) + if node_id in visited: + continue + yield node_id + visited.add(node_id) + to_visit.extend(set(next_nodes(node_id)).difference(visited)) + + def walk_upstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = True) -> Iterator[NodeId]: + """ + Walk upstream nodes (along `depends_on` link) starting from given seed nodes. + Optionally include seeds or not, and walk breadth first. + """ + return self._walk(seeds=seeds, next_nodes=lambda n: self.node(n).depends_on, include_seeds=include_seeds) + + def walk_downstream_nodes(self, seeds: Iterable[NodeId], include_seeds: bool = True) -> Iterator[NodeId]: + """ + Walk downstream nodes (along `flows_to` link) starting from given seed nodes. + Optionally include seeds or not, and walk breadth first. + """ + return self._walk(seeds=seeds, next_nodes=lambda n: self.node(n).flows_to, include_seeds=include_seeds) + + def get_backend_candidates(self, node_id: NodeId) -> Union[frozenset[BackendId], None]: + """Determine backend candidates for given node id""" + if self.node(node_id).backend_candidates is not None: + # Node has explicit backend candidates listed + return self.node(node_id).backend_candidates + elif self.node(node_id).depends_on: + # Backend support is unset: determine it (as intersection) from upstream nodes + # TODO: cache intermediate sets? (Only when caching is safe: e.g. wrapped graph is immutable/not manipulated) + upstream_candidates = (self.get_backend_candidates(n) for n in self.node(node_id).depends_on) + upstream_candidates = [c for c in upstream_candidates if c is not None] + if upstream_candidates: + return functools.reduce(lambda a, b: a.intersection(b), upstream_candidates) + else: + return None + else: + return None + + def find_forsaken_nodes(self) -> Set[NodeId]: + """ + Find nodes that have no backend candidates to process them + """ + return set(node_id for (node_id, _) in self.iter_nodes() if self.get_backend_candidates(node_id) == set()) + + def find_articulation_points(self) -> Set[NodeId]: + """ + Find articulation points (cut vertices) in the directed graph: + nodes that when removed would split the graph into multiple sub-graphs. + + Note that, unlike in traditional graph theory, the search also includes leaf nodes + (e.g. nodes with no parents), as in this context of openEO graph splitting, + when we "cut" a node, we replace it with two disconnected new nodes + (one connecting to the original parents and one connecting to the original children). + """ + # Approach: label the start nodes (e.g. load_collection) with their id and weight 1. + # Propagate these labels along the depends-on links, but split/sum the weight according + # to the number of children/parents. + # At the end: the articulation points are the nodes where all flows have weight 1. + + # Mapping: node_id -> start_node_id -> flow_weight + flow_weights: Dict[NodeId, Dict[NodeId, fractions.Fraction]] = {} + + # Initialize at the pure input nodes (nodes with no upstream dependencies) + for node_id, node in self.iter_nodes(): + if not node.depends_on: + flow_weights[node_id] = {node_id: fractions.Fraction(1, 1)} + + # Propagate flow weights using recursion + caching + def get_flow_weights(node_id: NodeId) -> Dict[NodeId, fractions.Fraction]: + nonlocal flow_weights + if node_id not in flow_weights: + flow_weights[node_id] = {} + # Calculate from upstream nodes + for upstream in self.node(node_id).depends_on: + for start_node_id, weight in get_flow_weights(upstream).items(): + flow_weights[node_id].setdefault(start_node_id, fractions.Fraction(0, 1)) + flow_weights[node_id][start_node_id] += weight / len(self.node(upstream).flows_to) + return flow_weights[node_id] + + for node_id, node in self.iter_nodes(): + get_flow_weights(node_id) + + # Select articulation points: nodes where all flows have weight 1 + return set(node_id for node_id, flows in flow_weights.items() if all(w == 1 for w in flows.values())) + + def split_at(self, split_node_id: NodeId) -> Tuple[_FrozenGraph, _FrozenGraph]: + """ + Split graph at given node id (must be articulation point), + creating two new graphs, containing original nodes and adaptation of the split node. + """ + split_node = self.node(split_node_id) + + # TODO: first verify that node_id is a valid articulation point? + # Or let this fail, e.g. in validation of _FrozenGraph.__init__? + + # Walk the graph, upstream from the split node + def next_nodes(node_id: NodeId) -> Iterable[NodeId]: + node = self.node(node_id) + if node_id == split_node_id: + return node.depends_on + else: + return node.depends_on.union(node.flows_to) + + up_node_ids = set(self._walk(seeds=[split_node_id], next_nodes=next_nodes)) + + if split_node.flows_to.intersection(up_node_ids): + raise GraphSplitException(f"Graph can not be split at {split_node_id}: not an articulation point.") + + up_graph = {n: self.node(n) for n in up_node_ids} + up_graph[split_node_id] = _FrozenNode( + depends_on=split_node.depends_on, + flows_to=frozenset(), + backend_candidates=split_node.backend_candidates, + ) + up = _FrozenGraph(graph=up_graph) + + down_graph = {n: node for n, node in self.iter_nodes() if n not in up_node_ids} + down_graph[split_node_id] = _FrozenNode( + depends_on=frozenset(), + flows_to=split_node.flows_to, + backend_candidates=None, + ) + down = _FrozenGraph(graph=down_graph) + + return down, up + + def produce_split_locations(self, limit: int = 2) -> Iterator[List[NodeId]]: + """ + Produce disjoint subgraphs that can be processed independently + """ + # Find nodes that have empty set of backend_candidates + forsaken_nodes = self.find_forsaken_nodes() + + if forsaken_nodes: + # Sort forsaken nodes (based on forsaken parent count), to start higher up the graph + # TODO: avoid need for this sort, and just use a better scoring metric higher up? + forsaken_nodes = sorted( + forsaken_nodes, reverse=True, key=lambda n: sum(p in forsaken_nodes for p in self.node(n).depends_on) + ) + # Collect nodes where we could split the graph in disjoint subgraphs + articulation_points: Set[NodeId] = set(self.find_articulation_points()) + + # Walk upstream from forsaken nodes to find articulation points, where we can cut + split_options = [ + n + for n in self.walk_upstream_nodes(seeds=forsaken_nodes, include_seeds=False) + if n in articulation_points + ] + if not split_options: + raise GraphSplitException("No split options found.") + # TODO: how to handle limit? will it scale feasibly to iterate over all possibilities at this point? + # TODO: smarter picking of split node (e.g. one with most upstream nodes) + for split_node_id in split_options[:limit]: + # Split graph at this articulation point + down, up = self.split_at(split_node_id) + if down.find_forsaken_nodes(): + down_splits = list(down.produce_split_locations(limit=limit - 1)) + else: + down_splits = [[]] + if up.find_forsaken_nodes(): + up_splits = list(up.produce_split_locations(limit=limit - 1)) + else: + up_splits = [[]] + + for down_split, up_split in itertools.product(down_splits, up_splits): + yield [split_node_id] + down_split + up_split + + else: + # All nodes can be handled as is, no need to split + yield [] diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index b4b6239..f2c8c85 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional from unittest import mock +import dirty_equals import openeo import pytest import requests @@ -14,7 +15,10 @@ from openeo_aggregator.partitionedjobs import PartitionedJob, SubJob from openeo_aggregator.partitionedjobs.crossbackend import ( CrossBackendSplitter, + GraphSplitException, SubGraphId, + _FrozenGraph, + _FrozenNode, run_partitioned_job, ) @@ -417,3 +421,332 @@ def test_basic(self, aggregator: _FakeAggregator): "result": True, }, } + + +class TestFrozenGraph: + def test_empty(self): + graph = _FrozenGraph(graph={}) + assert list(graph.iter_nodes()) == [] + + def test_from_flat_graph_basic(self): + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "B1_NDVI"}}, + "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, + } + graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"]}) + assert sorted(graph.iter_nodes()) == [ + ( + "lc1", + _FrozenNode(frozenset(), frozenset(["ndvi1"]), backend_candidates=frozenset(["b1"])), + ), + ("ndvi1", _FrozenNode(frozenset(["lc1"]), frozenset([]), backend_candidates=None)), + ] + + # TODO: test from_flat_graph with more complex graphs + + def test_from_edges(self): + graph = _FrozenGraph.from_edges([("a", "c"), ("b", "d"), ("c", "e"), ("d", "e"), ("e", "f")]) + assert sorted(graph.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["c"]), backend_candidates=None)), + ("b", _FrozenNode(frozenset(), frozenset(["d"]), backend_candidates=None)), + ("c", _FrozenNode(frozenset(["a"]), frozenset(["e"]), backend_candidates=None)), + ("d", _FrozenNode(frozenset(["b"]), frozenset(["e"]), backend_candidates=None)), + ("e", _FrozenNode(frozenset(["c", "d"]), frozenset("f"), backend_candidates=None)), + ("f", _FrozenNode(frozenset(["e"]), frozenset(), backend_candidates=None)), + ] + + @pytest.mark.parametrize( + ["seed", "include_seeds", "expected"], + [ + (["a"], True, ["a"]), + (["a"], False, []), + (["c"], True, ["c", "a"]), + (["c"], False, ["a"]), + (["a", "c"], True, ["a", "c"]), + (["a", "c"], False, []), + (["c", "a"], True, ["c", "a"]), + (["c", "a"], False, []), + ( + ["e"], + True, + dirty_equals.IsOneOf( + ["e", "c", "d", "a", "b"], + ["e", "d", "c", "b", "a"], + ), + ), + ( + ["e"], + False, + dirty_equals.IsOneOf( + ["c", "d", "a", "b"], + ["d", "c", "b", "a"], + ), + ), + (["e", "d"], True, ["e", "d", "c", "b", "a"]), + (["e", "d"], False, ["c", "b", "a"]), + (["d", "e"], True, ["d", "e", "b", "c", "a"]), + (["d", "e"], False, ["b", "c", "a"]), + (["f", "c"], True, ["f", "c", "e", "a", "d", "b"]), + (["f", "c"], False, ["e", "a", "d", "b"]), + ], + ) + def test_walk_upstream_nodes(self, seed, include_seeds, expected): + graph = _FrozenGraph.from_edges([("a", "c"), ("b", "d"), ("c", "e"), ("d", "e"), ("e", "f")]) + assert list(graph.walk_upstream_nodes(seed, include_seeds)) == expected + + def test_get_backend_candidates_basic(self): + graph = _FrozenGraph.from_edges( + [("a", "b"), ("b", "d"), ("c", "d")], + backend_candidates_map={"a": ["b1"], "c": ["b2"]}, + ) + assert graph.get_backend_candidates("a") == {"b1"} + assert graph.get_backend_candidates("b") == {"b1"} + assert graph.get_backend_candidates("c") == {"b2"} + assert graph.get_backend_candidates("d") == set() + + def test_get_backend_candidates_none(self): + graph = _FrozenGraph.from_edges( + [("a", "b"), ("b", "d"), ("c", "d")], + backend_candidates_map={}, + ) + assert graph.get_backend_candidates("a") is None + assert graph.get_backend_candidates("b") is None + assert graph.get_backend_candidates("c") is None + assert graph.get_backend_candidates("d") is None + + def test_get_backend_candidates_intersection(self): + graph = _FrozenGraph.from_edges( + [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f")], + backend_candidates_map={"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]}, + ) + assert graph.get_backend_candidates("a") == {"b1", "b2"} + assert graph.get_backend_candidates("b") == {"b2", "b3"} + assert graph.get_backend_candidates("c") == {"b4"} + assert graph.get_backend_candidates("d") == {"b2"} + assert graph.get_backend_candidates("e") == set() + assert graph.get_backend_candidates("f") == set() + + def test_find_forsaken_nodes(self): + graph = _FrozenGraph.from_edges( + [("a", "d"), ("b", "d"), ("b", "e"), ("c", "e"), ("d", "f"), ("e", "f"), ("f", "g"), ("f", "h")], + backend_candidates_map={"a": ["b1", "b2"], "b": ["b2", "b3"], "c": ["b4"]}, + ) + assert graph.find_forsaken_nodes() == {"e", "f", "g", "h"} + + def test_find_articulation_points_basic(self): + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, + } + graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={}) + assert graph.find_articulation_points() == {"lc1", "ndvi1"} + + @pytest.mark.parametrize( + ["flat", "expected"], + [ + ( + { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}}, + }, + {"lc1", "ndvi1"}, + ), + ( + { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands1": { + "process_id": "filter_bands", + "arguments": {"data": {"from_node": "lc1"}, "bands": ["b1"]}, + }, + "bands2": { + "process_id": "filter_bands", + "arguments": {"data": {"from_node": "lc1"}, "bands": ["b2"]}, + }, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}}, + }, + "save1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "merge1"}, "format": "GTiff"}, + }, + }, + {"lc1", "merge1", "save1"}, + ), + ( + { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + }, + "save1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "merge1"}, "format": "GTiff"}, + }, + }, + {"lc1", "lc2", "merge1", "save1"}, + ), + ( + { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands1": { + "process_id": "filter_bands", + "arguments": {"data": {"from_node": "lc1"}, "bands": ["b1"]}, + }, + "bbox1": { + "process_id": "filter_spatial", + "arguments": {"data": {"from_node": "bands1"}}, + }, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "bbox1"}}, + }, + "save1": { + "process_id": "save_result", + "arguments": {"data": {"from_node": "merge1"}}, + }, + }, + {"lc1", "merge1", "save1"}, + ), + ], + ) + def test_find_articulation_points(self, flat, expected): + graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={}) + assert graph.find_articulation_points() == expected + + def test_split_at_minimal(self): + graph = _FrozenGraph.from_edges([("a", "b")], backend_candidates_map={"a": "A"}) + # Split at a + down, up = graph.split_at("a") + assert sorted(up.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(), backend_candidates=frozenset(["A"]))), + ] + assert sorted(down.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=None)), + ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), + ] + # Split at b + down, up = graph.split_at("b") + assert sorted(up.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=frozenset(["A"]))), + ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), + ] + assert sorted(down.iter_nodes()) == [ + ("b", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)), + ] + + def test_split_at_basic(self): + graph = _FrozenGraph.from_edges([("a", "b"), ("b", "c")], backend_candidates_map={"a": "A"}) + down, up = graph.split_at("b") + assert sorted(up.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["b"]), backend_candidates=frozenset(["A"]))), + ("b", _FrozenNode(frozenset(["a"]), frozenset([]), backend_candidates=None)), + ] + assert sorted(down.iter_nodes()) == [ + ("b", _FrozenNode(frozenset(), frozenset(["c"]), backend_candidates=None)), + ("c", _FrozenNode(frozenset(["b"]), frozenset([]), backend_candidates=None)), + ] + + def test_split_at_complex(self): + graph = _FrozenGraph.from_edges( + [("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e"), ("e", "g"), ("f", "g"), ("X", "Y")] + ) + down, up = graph.split_at("e") + assert sorted(up.iter_nodes()) == sorted( + _FrozenGraph.from_edges([("a", "b"), ("a", "c"), ("b", "d"), ("c", "d"), ("c", "e")]).iter_nodes() + ) + assert sorted(down.iter_nodes()) == sorted( + _FrozenGraph.from_edges([("e", "g"), ("f", "g"), ("X", "Y")]).iter_nodes() + ) + + def test_split_at_non_articulation_point(self): + graph = _FrozenGraph.from_edges([("a", "b"), ("b", "c"), ("a", "c")]) + with pytest.raises(GraphSplitException, match="not an articulation point"): + _ = graph.split_at("b") + + # These should still work + down, up = graph.split_at("a") + assert sorted(up.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)), + ] + assert sorted(down.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["b", "c"]), backend_candidates=None)), + ("b", _FrozenNode(frozenset(["a"]), frozenset(["c"]), backend_candidates=None)), + ("c", _FrozenNode(frozenset(["a", "b"]), frozenset(), backend_candidates=None)), + ] + + down, up = graph.split_at("c") + assert sorted(up.iter_nodes()) == [ + ("a", _FrozenNode(frozenset(), frozenset(["b", "c"]), backend_candidates=None)), + ("b", _FrozenNode(frozenset(["a"]), frozenset(["c"]), backend_candidates=None)), + ("c", _FrozenNode(frozenset(["a", "b"]), frozenset(), backend_candidates=None)), + ] + assert sorted(down.iter_nodes()) == [ + ("c", _FrozenNode(frozenset(), frozenset(), backend_candidates=None)), + ] + + def test_produce_split_locations_simple(self): + """Simple produce_split_locations use case: no need for splits""" + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "ndvi1": {"process_id": "ndvi", "arguments": {"data": {"from_node": "lc1"}}, "result": True}, + } + graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": "b1"}) + assert list(graph.produce_split_locations()) == [[]] + + def test_produce_split_locations_merge_basic(self): + """ + Basic produce_split_locations use case: + two load collections on different backends and a merge + """ + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "lc1"}, "cube2": {"from_node": "lc2"}}, + }, + } + graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + assert sorted(graph.produce_split_locations()) == [["lc1"], ["lc2"]] + + def test_produce_split_locations_merge_longer(self): + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "bands1"}, "cube2": {"from_node": "bands2"}}, + }, + } + graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + assert sorted(graph.produce_split_locations(limit=2)) == [["bands1"], ["bands2"]] + assert list(graph.produce_split_locations(limit=4)) == dirty_equals.IsOneOf( + [["bands1"], ["bands2"], ["lc1"], ["lc2"]], + [["bands2"], ["bands1"], ["lc2"], ["lc1"]], + ) + + def test_produce_split_locations_merge_longer_triangle(self): + flat = { + "lc1": {"process_id": "load_collection", "arguments": {"id": "S1"}}, + "bands1": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc1"}, "bands": ["B01"]}}, + "mask1": { + "process_id": "mask", + "arguments": {"data": {"from_node": "bands1"}, "mask": {"from_node": "lc1"}}, + }, + "lc2": {"process_id": "load_collection", "arguments": {"id": "S2"}}, + "bands2": {"process_id": "filter_bands", "arguments": {"data": {"from_node": "lc2"}, "bands": ["B02"]}}, + "merge1": { + "process_id": "merge_cubes", + "arguments": {"cube1": {"from_node": "mask1"}, "cube2": {"from_node": "bands2"}}, + }, + } + graph = _FrozenGraph.from_flat_graph(flat, backend_candidates_map={"lc1": ["b1"], "lc2": ["b2"]}) + assert list(graph.produce_split_locations(limit=4)) == dirty_equals.IsOneOf( + [["mask1"], ["bands2"], ["lc1"], ["lc2"]], + [["bands2"], ["mask1"], ["lc2"], ["lc1"]], + )