Skip to content

Commit

Permalink
Write full Node table once (#1312)
Browse files Browse the repository at this point in the history
Co-authored-by: Hofer-Julian <[email protected]>
Co-authored-by: Hofer-Julian <[email protected]>
  • Loading branch information
3 people authored Mar 28, 2024
1 parent 1b784c3 commit af6b1ad
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
7 changes: 6 additions & 1 deletion python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
validate_call,
)

import ribasim
from ribasim.types import FilePath

__all__ = ("TableModel",)
Expand Down Expand Up @@ -424,7 +425,11 @@ def _layername(cls, field: str) -> str:
def _tables(self) -> Generator[TableModel[Any], Any, None]:
for key in self.fields():
attr = getattr(self, key)
if isinstance(attr, TableModel) and attr.df is not None:
if (
isinstance(attr, TableModel)
and (attr.df is not None)
and not (isinstance(attr, ribasim.geometry.node.NodeTable))
):
yield attr

def node_ids(self) -> set[int]:
Expand Down
18 changes: 9 additions & 9 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from pathlib import Path
from typing import Any

import geopandas as gpd
import numpy as np
import pandas as pd
import tomli
Expand Down Expand Up @@ -143,18 +142,19 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
db_path.unlink(missing_ok=True)
context_file_loading.get()["database"] = db_path
self.edge._save(directory, input_dir)
for sub in self._nodes():
sub._save(directory, input_dir)

node = self.node_table()
# Temporarily require unique node_id for #1262
# and copy them to the fid for #1306.
df = gpd.read_file(db_path, layer="Node")
if not df["node_id"].is_unique:
if not node.df["node_id"].is_unique:
raise ValueError("node_id must be unique")
df.set_index("node_id", drop=False, inplace=True)
df.sort_index(inplace=True)
df.index.name = "fid"
df.to_file(db_path, layer="Node", driver="GPKG", index=True)
node.df.set_index("node_id", drop=False, inplace=True)
node.df.sort_index(inplace=True)
node.df.index.name = "fid"
node._save(directory, input_dir)

for sub in self._nodes():
sub._save(directory, input_dir)

def node_table(self) -> NodeTable:
"""Compute the full NodeTable from all node types."""
Expand Down

0 comments on commit af6b1ad

Please sign in to comment.