Skip to content

Commit

Permalink
fixup! Issue #150 CrossBackendSplitter: decouple graph splitting from…
Browse files Browse the repository at this point in the history
… SubJob yielding
  • Loading branch information
soxofaan committed Sep 16, 2024
1 parent 5302f74 commit 1ab0736
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 16 deletions.
5 changes: 4 additions & 1 deletion scripts/crossbackend-processing-poc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from openeo_aggregator.partitionedjobs import PartitionedJob
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
LoadCollectionGraphSplitter,
run_partitioned_job,
)

Expand Down Expand Up @@ -62,7 +63,9 @@ def backend_for_collection(collection_id) -> str:
metadata = connection.describe_collection(collection_id)
return metadata["summaries"][STAC_PROPERTY_FEDERATION_BACKENDS][0]

splitter = CrossBackendSplitter(backend_for_collection=backend_for_collection, always_split=True)
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=backend_for_collection, always_split=True)
)
pjob: PartitionedJob = splitter.split({"process_graph": process_graph})
_log.info(f"Partitioned job: {pjob!r}")

Expand Down
13 changes: 9 additions & 4 deletions src/openeo_aggregator/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@
single_backend_collection_post_processing,
)
from openeo_aggregator.partitionedjobs import PartitionedJob
from openeo_aggregator.partitionedjobs.crossbackend import CrossBackendSplitter
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
LoadCollectionGraphSplitter,
)
from openeo_aggregator.partitionedjobs.splitting import FlimsySplitter, TileGridSplitter
from openeo_aggregator.partitionedjobs.tracking import (
PartitionedJobConnection,
Expand Down Expand Up @@ -940,9 +943,11 @@ def backend_for_collection(collection_id) -> str:
return self._catalog.get_backends_for_collection(cid=collection_id)[0]

splitter = CrossBackendSplitter(
backend_for_collection=backend_for_collection,
# TODO: job option for `always_split` feature?
always_split=True,
graph_splitter=LoadCollectionGraphSplitter(
backend_for_collection=backend_for_collection,
# TODO: job option for `always_split` feature?
always_split=True,
)
)

pjob_id = self.partitioned_job_tracker.create_crossbackend_pjob(
Expand Down
7 changes: 2 additions & 5 deletions src/openeo_aggregator/partitionedjobs/crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,12 @@ class CrossBackendSplitter(AbstractJobSplitter):
"""

def __init__(self, backend_for_collection: Callable[[CollectionId], BackendId], always_split: bool = False):
def __init__(self, graph_splitter: ProcessGraphSplitterInterface):
"""
:param backend_for_collection: callable that determines backend id for given collection id
:param always_split: split all load_collections, also when on same backend
"""
# TODO: inject splitter instead of building it here
self._graph_splitter = LoadCollectionGraphSplitter(
backend_for_collection=backend_for_collection, always_split=always_split
)
self._graph_splitter = graph_splitter

def split_streaming(
self,
Expand Down
25 changes: 19 additions & 6 deletions tests/partitionedjobs/test_crossbackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from openeo_aggregator.partitionedjobs.crossbackend import (
CrossBackendSplitter,
GraphSplitException,
LoadCollectionGraphSplitter,
SubGraphId,
_FrozenGraph,
_FrozenNode,
Expand All @@ -26,15 +27,19 @@
class TestCrossBackendSplitter:
def test_split_simple(self):
process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo")
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo")
)
res = splitter.split({"process_graph": process_graph})

assert res.subjobs == {"main": SubJob(process_graph, backend_id=None)}
assert res.dependencies == {}

def test_split_streaming_simple(self):
process_graph = {"add": {"process_id": "add", "arguments": {"x": 3, "y": 5}, "result": True}}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: "foo")
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: "foo")
)
res = splitter.split_streaming(process_graph)
assert isinstance(res, types.GeneratorType)
assert list(res) == [("main", SubJob(process_graph, backend_id=None), [])]
Expand All @@ -56,7 +61,9 @@ def test_split_basic(self):
"result": True,
},
}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
)
res = splitter.split({"process_graph": process_graph})

assert res.subjobs == {
Expand Down Expand Up @@ -119,7 +126,9 @@ def test_split_streaming_basic(self):
"result": True,
},
}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
)
result = splitter.split_streaming(process_graph)
assert isinstance(result, types.GeneratorType)

Expand Down Expand Up @@ -179,7 +188,9 @@ def test_split_streaming_get_replacement(self):
"result": True,
},
}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
)

batch_jobs = {}

Expand Down Expand Up @@ -375,7 +386,9 @@ def test_basic(self, aggregator: _FakeAggregator):
"result": True,
},
}
splitter = CrossBackendSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
splitter = CrossBackendSplitter(
graph_splitter=LoadCollectionGraphSplitter(backend_for_collection=lambda cid: cid.split("_")[0])
)
pjob: PartitionedJob = splitter.split({"process_graph": process_graph})

connection = openeo.Connection(aggregator.url)
Expand Down

0 comments on commit 1ab0736

Please sign in to comment.