Skip to content

Commit

Permalink
Add merging of models in ribasim Python
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Dec 20, 2023
1 parent 836dc71 commit cd6f532
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 0 deletions.
20 changes: 20 additions & 0 deletions python/ribasim/ribasim/geometry/edge.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import Any

import matplotlib.pyplot as plt
Expand All @@ -9,6 +10,7 @@
from numpy.typing import NDArray
from pandera.typing import Series
from pandera.typing.geopandas import GeoSeries
from shapely.geometry import LineString

from ribasim.input_base import SpatialTableModel

Expand Down Expand Up @@ -38,6 +40,24 @@ class Edge(SpatialTableModel[EdgeSchema]):
Table describing the flow connections.
"""

def translate_spacially(
self, offset_spacial: tuple[float, float], inplace: bool = True
) -> "Edge":
if inplace:
edge = self
else:
edge = deepcopy(self)

edge.df.geometry = edge.df.geometry.apply(
lambda linestring: LineString(
[
(point[0] + offset_spacial[0], point[1] + offset_spacial[1])
for point in linestring.coords
]
)
)
return edge

def get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]:
return (self.df.edge_type == edge_type).to_numpy()

Expand Down
17 changes: 17 additions & 0 deletions python/ribasim/ribasim/geometry/node.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Sequence
from copy import deepcopy
from typing import Any

import matplotlib.pyplot as plt
Expand All @@ -9,6 +10,7 @@
from numpy.typing import NDArray
from pandera.typing import Series
from pandera.typing.geopandas import GeoSeries
from shapely.geometry import Point

from ribasim.input_base import SpatialTableModel

Expand Down Expand Up @@ -59,6 +61,21 @@ def node_ids_and_types(*nodes):

return node_id, node_type

def translate_spacially(
self, offset_spacial: tuple[float, float], inplace: bool = True
) -> "Node":
if inplace:
node = self
else:
node = deepcopy(self)

node.df.geometry = node.df.geometry.apply(
lambda point: Point(
point.x + offset_spacial[0], point.y + offset_spacial[1]
)
)
return node

def geometry_from_connectivity(
self, from_id: Sequence[int], to_id: Sequence[int]
) -> NDArray[Any]:
Expand Down
70 changes: 70 additions & 0 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Callable, Generator
from contextlib import closing
from contextvars import ContextVar
from copy import deepcopy
from pathlib import Path
from sqlite3 import Connection, connect
from typing import (
Expand Down Expand Up @@ -202,6 +203,42 @@ def node_ids(self) -> set[int]:

return node_ids

def offset_node_ids(self, offset_node_id: int) -> "TableModel":
copy = deepcopy(self)
df = copy.df
if copy.df is not None:
df.index += offset_node_id
for name_column in [
"node_id",
"from_node_id",
"to_node_id",
"listen_node_id",
]:
if hasattr(df, name_column):
df[name_column] += offset_node_id
return copy

def merge_table(
self, table_added: "TableModel", inplace: bool = True
) -> "TableModel":
assert type(self) == type(
table_added
), "Can only merge tables of the same type."

if inplace:
table = self
else:
table = deepcopy(self)

table.df = pd.concat(
[
table.df,
table_added.df,
]
)

return table

@classmethod
def _load(cls, filepath: Path | None) -> dict[str, Any]:
db = context_file_loading.get().get("database")
Expand Down Expand Up @@ -421,6 +458,39 @@ def node_ids_and_types(self) -> tuple[list[int], list[str]]:
ids = self.node_ids()
return list(ids), len(ids) * [self.get_input_type()]

def offset_node_ids(self, offset_node_id: int) -> "NodeModel":
node_copy = deepcopy(self)
for field in node_copy.fields():
attr = getattr(node_copy, field)
if isinstance(attr, TableModel):
table = attr
setattr(
node_copy,
field,
table.offset_node_ids(offset_node_id),
)
return node_copy

def merge_node(self, node_added: "NodeModel", inplace: bool = True) -> "NodeModel":
assert type(self) == type(node_added), "Can only merge nodes of the same type."

if inplace:
node = self
else:
node = deepcopy(self)

for field in node_added.fields():
attr = getattr(node_added, field)
if isinstance(attr, TableModel):
table_added = attr
table_node = getattr(node, field)
if table_added.df is not None:
if table_node.df is not None:
table_added = table_node.merge_table(table_added, inplace=False)

setattr(node, field, table_added)
return node

def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, **kwargs):
for field in self.fields():
getattr(self, field)._save(
Expand Down
51 changes: 51 additions & 0 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import shutil
from copy import deepcopy
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -58,6 +59,18 @@ def n_nodes(self):

return n

def translate_spacially(
self, offset_spacial: tuple[float, float], inplace: bool = True
) -> "Network":
if inplace:
network = self
else:
network = deepcopy(self)

network.node.translate_spacially(offset_spacial)
network.edge.translate_spacially(offset_spacial)
return network

@classmethod
def _load(cls, filepath: Path | None) -> dict[str, Any]:
directory = context_file_loading.get().get("directory", None)
Expand Down Expand Up @@ -376,6 +389,44 @@ def reset_contextvar(self) -> "Model":
context_file_loading.set({})
return self

def max_node_id(self) -> int:
return self.network.node.df.index.max()

def merge_model(
self,
model_added: "Model",
offset_node_id: int | None = None,
offset_spacial: tuple[float, float] = (0.0, 0.0),
inplace: bool = True,
):
if inplace:
model = self
else:
model = deepcopy(self)

nodes_model = model.nodes()
nodes_added = model_added.nodes()
nodes_added["network"] = nodes_added["network"].translate_spacially(
offset_spacial, inplace=False
)
min_offset_node_id = model.max_node_id()

if offset_node_id is None:
offset_node_id = min_offset_node_id
else:
assert (
offset_node_id >= min_offset_node_id
), f"The node id offset must be at least the maximum node ID of the main model ({min_offset_node_id}) to avoid conflicts."

for node_type, node_added in nodes_added.items():
node_added = node_added.offset_node_ids(offset_node_id)
if node_type in nodes_model:
node_added = nodes_model[node_type].merge_node(
node_added, inplace=False
)

setattr(model, node_type, node_added)

def plot_control_listen(self, ax):
x_start, x_end = [], []
y_start, y_end = [], []
Expand Down

0 comments on commit cd6f532

Please sign in to comment.