Skip to content

Commit

Permalink
Sort only Edge fid so edge_id becomes more stable (#1363)
Browse files Browse the repository at this point in the history
Before we sorted the Edge table like this:

```python
        sort_keys = [
            "from_node_type",
            "from_node_id",
            "to_node_type",
            "to_node_id",
        ]
```

This made it appear a bit more neat, though it served no other purpose.

The `fid` index was mostly an implementation detail that the user did
not specify, and it went from 1:n following the sorting above. However
`fid` becomes `edge_id` in the flow.arrow. Therefore when a users added
a new edge, usually half of all `edge_id`s changed, making
post-processing unnecessarily difficult. Therefore this PR removes this
sorting, such that the input order is retained, keeping the old
`edge_id`s stable.

With this I think we can close #1310. It is useful to have a single
identifier value for an Edge, even though it is somewhat superfluous.
With this PR it becomes stable, unless users start modifying `fid`s
themselves.
  • Loading branch information
visr authored Apr 9, 2024
1 parent 5cf4492 commit 35dfc2c
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 23 deletions.
2 changes: 1 addition & 1 deletion core/test/run_models_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@

@testset "Results values" begin
@test flow.time[1] == DateTime(2020)
@test coalesce.(flow.edge_id[1:2], -1) == [1, 2]
@test coalesce.(flow.edge_id[1:2], -1) == [0, 1]
@test flow.from_node_id[1:2] == [6, 0]
@test flow.to_node_id[1:2] == [0, 922]

Expand Down
4 changes: 2 additions & 2 deletions core/test/validation_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,10 @@ end
@test length(logger.logs) == 2
@test logger.logs[1].level == Error
@test logger.logs[1].message ==
"Invalid edge type 'foo' for edge #1 from node #1 to node #2."
"Invalid edge type 'foo' for edge #0 from node #1 to node #2."
@test logger.logs[2].level == Error
@test logger.logs[2].message ==
"Invalid edge type 'bar' for edge #2 from node #2 to node #3."
"Invalid edge type 'bar' for edge #1 from node #2 to node #3."
end

@testitem "Subgrid validation" begin
Expand Down
17 changes: 8 additions & 9 deletions python/ribasim/ribasim/geometry/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,21 +76,20 @@ def add(
crs=self.df.crs,
)

self.df = GeoDataFrame[EdgeSchema](pd.concat([self.df, table_to_append]))
self.df = GeoDataFrame[EdgeSchema](
pd.concat([self.df, table_to_append], ignore_index=True)
)
self.df.index.name = "fid"

def get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]:
assert self.df is not None
return (self.df.edge_type == edge_type).to_numpy()

def sort(self):
assert self.df is not None
sort_keys = [
"from_node_type",
"from_node_id",
"to_node_type",
"to_node_id",
]
self.df.sort_values(sort_keys, ignore_index=True, inplace=True)
# Only sort the index (fid / edge_id) since this needs to be sorted in a GeoPackage.
# Under most circumstances, this retains the input order,
# making the edge_id as stable as possible; useful for post-processing.
self.df.sort_index(inplace=True)

def plot(self, **kwargs) -> Axes:
assert self.df is not None
Expand Down
4 changes: 3 additions & 1 deletion python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,9 @@ def _write_geopackage(self, path: Path) -> None:
path : Path
"""
assert self.df is not None
self.df.to_file(path, layer=self.tablename(), driver="GPKG")
# the index name must be fid otherwise it will generate a separate fid column
self.df.index.name = "fid"
self.df.to_file(path, layer=self.tablename(), index=True, driver="GPKG")


class ChildModel(BaseModel):
Expand Down
5 changes: 4 additions & 1 deletion python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
if not node.df["node_id"].is_unique:
raise ValueError("node_id must be unique")
node.df.set_index("node_id", drop=False, inplace=True)
node.df.sort_index(inplace=True)
node.df.index.name = "fid"
node.df.sort_index(inplace=True)
node._save(directory, input_dir)

for sub in self._nodes():
Expand Down Expand Up @@ -207,6 +207,7 @@ def node_table(self) -> NodeTable:
df = pd.concat(df_chunks, ignore_index=True)
node_table = NodeTable(df=df)
node_table.sort()
node_table.df.index.name = "fid"
return node_table

def _nodes(self) -> Generator[MultiNodeModel, Any, None]:
Expand Down Expand Up @@ -413,6 +414,7 @@ def to_xugrid(self):
edge_df = edge_df[edge_df.edge_type == "flow"]

node_id = node_df.node_id.to_numpy()
edge_id = edge_df.index.to_numpy()
from_node_id = edge_df.from_node_id.to_numpy()
to_node_id = edge_df.to_node_id.to_numpy()

Expand Down Expand Up @@ -443,6 +445,7 @@ def to_xugrid(self):

uds = xugrid.UgridDataset(None, grid)
uds = uds.assign_coords(node_id=(node_dim, node_id))
uds = uds.assign_coords(edge_id=(edge_dim, edge_id))
uds = uds.assign_coords(from_node_id=(edge_dim, from_node_id))
uds = uds.assign_coords(to_node_id=(edge_dim, to_node_id))

Expand Down
19 changes: 12 additions & 7 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_repr():
assert isinstance(pump_static._repr_html_(), str)


def test_extra_columns(basic_transient):
def test_extra_columns():
terminal_static = terminal.Static(meta_id=[-1, -2, -3])
assert "meta_id" in terminal_static.df.columns
assert (terminal_static.df.meta_id == [-1, -2, -3]).all()
Expand All @@ -106,28 +106,29 @@ def test_sort(level_setpoint_with_minmax, tmp_path):
table.sort()
assert table.df.iloc[0]["greater_than"] == 5.0

edge.df.sort_values("from_node_type", ascending=False, inplace=True)
assert edge.df.iloc[0]["from_node_type"] != "Basin"
edge.sort()
assert edge.df.iloc[0]["from_node_type"] == "Basin"
# The edge table is not sorted
assert edge.df.iloc[1]["from_node_type"] == "Pump"
assert edge.df.iloc[1]["from_node_id"] == 3

# re-apply wrong sort, then check if it gets sorted on write
table.df.sort_values("greater_than", ascending=False, inplace=True)
edge.df.sort_values("from_node_type", ascending=False, inplace=True)
model.write(tmp_path / "basic/ribasim.toml")
# write sorts the model in place
assert table.df.iloc[0]["greater_than"] == 5.0
model_loaded = ribasim.Model.read(filepath=tmp_path / "basic/ribasim.toml")
table_loaded = model_loaded.discrete_control.condition
edge_loaded = model_loaded.edge
assert table_loaded.df.iloc[0]["greater_than"] == 5.0
assert edge.df.iloc[0]["from_node_type"] == "Basin"
assert edge.df.iloc[1]["from_node_type"] == "Pump"
assert edge.df.iloc[1]["from_node_id"] == 3
__assert_equal(table.df, table_loaded.df)
__assert_equal(edge.df, edge_loaded.df)


def test_roundtrip(trivial, tmp_path):
model1 = trivial
# set custom Edge index
model1.edge.df.index = [15, 12]
model1dir = tmp_path / "model1"
model2dir = tmp_path / "model2"
# read a model and then write it to a different path
Expand All @@ -142,6 +143,10 @@ def test_roundtrip(trivial, tmp_path):
model2dir / "ribasim.toml"
).read_text()

# check if custom Edge indexes are retained (sorted)
assert (model1.edge.df.index == [12, 15]).all()
assert (model2.edge.df.index == [12, 15]).all()

# check if all tables are the same
__assert_equal(model1.node_table().df, model2.node_table().df)
__assert_equal(model1.edge.df, model2.edge.df)
Expand Down
4 changes: 2 additions & 2 deletions python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def test_write_adds_fid_in_tables(basic, tmp_path):

# for edge no index was provided, but it still needs to write it to file
nrow = len(model_orig.edge.df)
assert model_orig.edge.df.index.name is None
assert model_orig.edge.df.index.equals(pd.Index(np.full(nrow, 0)))
assert model_orig.edge.df.index.name == "fid"
assert model_orig.edge.df.index.equals(pd.RangeIndex(nrow))

model_orig.write(tmp_path / "basic/ribasim.toml")
with connect(tmp_path / "basic/database.gpkg") as connection:
Expand Down

0 comments on commit 35dfc2c

Please sign in to comment.