diff --git a/scripts/crossbackend-processing-poc.py b/scripts/crossbackend-processing-poc.py index e8b83e5..e8a138c 100644 --- a/scripts/crossbackend-processing-poc.py +++ b/scripts/crossbackend-processing-poc.py @@ -8,6 +8,7 @@ from openeo_aggregator.partitionedjobs import PartitionedJob from openeo_aggregator.partitionedjobs.crossbackend import ( CrossBackendSplitter, + LoadCollectionGraphSplitter, run_partitioned_job, ) @@ -62,7 +63,9 @@ def backend_for_collection(collection_id) -> str: metadata = connection.describe_collection(collection_id) return metadata["summaries"][STAC_PROPERTY_FEDERATION_BACKENDS][0] - splitter = CrossBackendSplitter(backend_for_collection=backend_for_collection, always_split=True) + splitter = CrossBackendSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=backend_for_collection, always_split=True) + ) pjob: PartitionedJob = splitter.split({"process_graph": process_graph}) _log.info(f"Partitioned job: {pjob!r}") diff --git a/src/openeo_aggregator/backend.py b/src/openeo_aggregator/backend.py index b568ce3..541ab93 100644 --- a/src/openeo_aggregator/backend.py +++ b/src/openeo_aggregator/backend.py @@ -100,7 +100,10 @@ single_backend_collection_post_processing, ) from openeo_aggregator.partitionedjobs import PartitionedJob -from openeo_aggregator.partitionedjobs.crossbackend import CrossBackendSplitter +from openeo_aggregator.partitionedjobs.crossbackend import ( + CrossBackendSplitter, + LoadCollectionGraphSplitter, +) from openeo_aggregator.partitionedjobs.splitting import FlimsySplitter, TileGridSplitter from openeo_aggregator.partitionedjobs.tracking import ( PartitionedJobConnection, @@ -940,9 +943,11 @@ def backend_for_collection(collection_id) -> str: return self._catalog.get_backends_for_collection(cid=collection_id)[0] splitter = CrossBackendSplitter( - backend_for_collection=backend_for_collection, - # TODO: job option for `always_split` feature? - always_split=True, + graph_splitter=LoadCollectionGraphSplitter( + backend_for_collection=backend_for_collection, + # TODO: job option for `always_split` feature? + always_split=True, + ) ) pjob_id = self.partitioned_job_tracker.create_crossbackend_pjob( diff --git a/src/openeo_aggregator/partitionedjobs/crossbackend.py b/src/openeo_aggregator/partitionedjobs/crossbackend.py index 124b2be..3163e79 100644 --- a/src/openeo_aggregator/partitionedjobs/crossbackend.py +++ b/src/openeo_aggregator/partitionedjobs/crossbackend.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import collections import copy import dataclasses @@ -18,6 +19,7 @@ Iterator, List, Mapping, + NamedTuple, Optional, Protocol, Sequence, @@ -45,6 +47,7 @@ _LOAD_RESULT_PLACEHOLDER = "_placeholder:" # Some type annotation aliases to make things more self-documenting +CollectionId = str SubGraphId = str NodeId = str BackendId = str @@ -87,6 +90,75 @@ def _default_get_replacement(node_id: str, node: dict, subgraph_id: SubGraphId) } +class _SubGraphData(NamedTuple): + split_node: NodeId + node_ids: Set[NodeId] + backend_id: BackendId + + +class _PGSplitResult(NamedTuple): + primary_node_ids: Set[NodeId] + primary_backend_id: BackendId + secondary_graphs: List[_SubGraphData] + + +class ProcessGraphSplitterInterface(metaclass=abc.ABCMeta): + @abc.abstractmethod + def split(self, process_graph: FlatPG) -> _PGSplitResult: + """ + Split given process graph (flat graph representation) into sub graphs + + Returns primary graph data (node ids and backend id) + and secondary graphs data (list of tuples: split node id, subgraph node ids,backend id) + """ + ... + + +class LoadCollectionGraphSplitter(ProcessGraphSplitterInterface): + """Simple process graph splitter that just splits off load_collection nodes""" + + def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False): + # TODO: also support not not having a backend_for_collection map? + self._backend_for_collection = backend_for_collection + self._always_split = always_split + + def split(self, process_graph: FlatPG) -> _PGSplitResult: + # Extract necessary back-ends from `load_collection` usage + backend_per_collection: Dict[str, str] = { + cid: self._backend_for_collection(cid) + for cid in ( + node["arguments"]["id"] for node in process_graph.values() if node["process_id"] == "load_collection" + ) + } + backend_usage = collections.Counter(backend_per_collection.values()) + _log.info(f"Extracted backend usage from `load_collection` nodes: {backend_usage=} {backend_per_collection=}") + + # TODO: more options to determine primary backend? + primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None + secondary_backends = {b for b in backend_usage if b != primary_backend} + _log.info(f"Backend split: {primary_backend=} {secondary_backends=}") + + primary_has_load_collection = False + primary_graph_node_ids = set() + secondary_graphs: List[_SubGraphData] = [] + for node_id, node in process_graph.items(): + if node["process_id"] == "load_collection": + bid = backend_per_collection[node["arguments"]["id"]] + if bid == primary_backend and (not self._always_split or not primary_has_load_collection): + primary_graph_node_ids.add(node_id) + primary_has_load_collection = True + else: + secondary_graphs.append(_SubGraphData(split_node=node_id, node_ids={node_id}, backend_id=bid)) + else: + primary_graph_node_ids.add(node_id) + + return _PGSplitResult( + primary_node_ids=primary_graph_node_ids, + primary_backend_id=primary_backend, + secondary_graphs=secondary_graphs, + ) + + class CrossBackendSplitter(AbstractJobSplitter): """ Split a process graph, to be executed across multiple back-ends, @@ -97,14 +169,12 @@ class CrossBackendSplitter(AbstractJobSplitter): """ - def __init__(self, backend_for_collection: Callable[[str], str], always_split: bool = False): + def __init__(self, graph_splitter: ProcessGraphSplitterInterface): """ :param backend_for_collection: callable that determines backend id for given collection id :param always_split: split all load_collections, also when on same backend """ - # TODO: just handle this `backend_for_collection` callback with a regular method? - self.backend_for_collection = backend_for_collection - self._always_split = always_split + self._graph_splitter = graph_splitter def split_streaming( self, @@ -127,36 +197,12 @@ def split_streaming( - dependencies as list of subgraph ids """ - # Extract necessary back-ends from `load_collection` usage - backend_per_collection: Dict[str, str] = { - cid: self.backend_for_collection(cid) - for cid in ( - node["arguments"]["id"] for node in process_graph.values() if node["process_id"] == "load_collection" - ) - } - backend_usage = collections.Counter(backend_per_collection.values()) - _log.info(f"Extracted backend usage from `load_collection` nodes: {backend_usage=} {backend_per_collection=}") - - # TODO: more options to determine primary backend? - primary_backend = backend_usage.most_common(1)[0][0] if backend_usage else None - secondary_backends = {b for b in backend_usage if b != primary_backend} - _log.info(f"Backend split: {primary_backend=} {secondary_backends=}") + graph_split_result = self._graph_splitter.split(process_graph=process_graph) - primary_has_load_collection = False - sub_graphs: List[Tuple[NodeId, Set[NodeId], BackendId]] = [] - for node_id, node in process_graph.items(): - if node["process_id"] == "load_collection": - bid = backend_per_collection[node["arguments"]["id"]] - if bid == primary_backend and (not self._always_split or not primary_has_load_collection): - primary_has_load_collection = True - else: - sub_graphs.append((node_id, {node_id}, bid)) - - primary_graph_node_ids = set(process_graph.keys()).difference(n for _, ns, _ in sub_graphs for n in ns) - primary_pg = {k: process_graph[k] for k in primary_graph_node_ids} + primary_pg = {k: process_graph[k] for k in graph_split_result.primary_node_ids} primary_dependencies = [] - for node_id, subgraph_node_ids, backend_id in sub_graphs: + for node_id, subgraph_node_ids, backend_id in graph_split_result.secondary_graphs: # New secondary pg sub_id = f"{backend_id}:{node_id}" sub_pg = {k: v for k, v in process_graph.items() if k in subgraph_node_ids} @@ -178,8 +224,11 @@ def split_streaming( primary_pg.update(get_replacement(node_id=node_id, node=process_graph[node_id], subgraph_id=sub_id)) primary_dependencies.append(sub_id) - primary_id = main_subgraph_id - yield (primary_id, SubJob(process_graph=primary_pg, backend_id=primary_backend), primary_dependencies) + yield ( + main_subgraph_id, + SubJob(process_graph=primary_pg, backend_id=graph_split_result.primary_backend_id), + primary_dependencies, + ) def split(self, process: PGWithMetadata, metadata: dict = None, job_options: dict = None) -> PartitionedJob: """Split given process graph into a `PartitionedJob`""" diff --git a/tests/partitionedjobs/test_crossbackend.py b/tests/partitionedjobs/test_crossbackend.py index d513068..014ea4b 100644 --- a/tests/partitionedjobs/test_crossbackend.py +++ b/tests/partitionedjobs/test_crossbackend.py @@ -16,6 +16,7 @@ from openeo_aggregator.partitionedjobs.crossbackend import ( CrossBackendSplitter, GraphSplitException, + LoadCollectionGraphSplitter, SubGraphId, _FrozenGraph, _FrozenNode, @@ -26,7 +27,9 @@ class TestCrossBackendSplitter: def test_split_simple(self): process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo") + splitter = CrossBackendSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo") + ) res = splitter.split({"process_graph": process_graph}) assert res.subjobs == {"main": SubJob(process_graph, backend_id=None)} @@ -34,7 +37,9 @@ def test_split_simple(self): def test_split_streaming_simple(self): process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}} - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo") + splitter = CrossBackendSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo") + ) res = splitter.split_streaming(process_graph) assert isinstance(res, types.GeneratorType) assert list(res) == [("main", SubJob(process_graph, backend_id=None), [])] @@ -56,7 +61,9 @@ def test_split_basic(self): "result": True, }, } - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + splitter = CrossBackendSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + ) res = splitter.split({"process_graph": process_graph}) assert res.subjobs == { @@ -119,7 +126,9 @@ def test_split_streaming_basic(self): "result": True, }, } - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + splitter = CrossBackendSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + ) result = splitter.split_streaming(process_graph) assert isinstance(result, types.GeneratorType) @@ -179,7 +188,9 @@ def test_split_streaming_get_replacement(self): "result": True, }, } - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + splitter = CrossBackendSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + ) batch_jobs = {} @@ -375,7 +386,9 @@ def test_basic(self, aggregator: _FakeAggregator): "result": True, }, } - splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + splitter = CrossBackendSplitter( + graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0]) + ) pjob: PartitionedJob = splitter.split({"process_graph": process_graph}) connection = openeo.Connection(aggregator.url)