Skip to content

Commit

Permalink
Finish it
Browse files Browse the repository at this point in the history
  • Loading branch information
visr committed Apr 9, 2024
1 parent 1029996 commit 697d87f
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 12 deletions.
7 changes: 5 additions & 2 deletions python/ribasim/ribasim/geometry/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,17 @@ def add(
self.df = GeoDataFrame[EdgeSchema](
pd.concat([self.df, table_to_append], ignore_index=True)
)
self.df.index.name = "fid"

def get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]:
assert self.df is not None
return (self.df.edge_type == edge_type).to_numpy()

def sort(self):
# To keep the fid / edge_id as stable as possible, don't sort the edges.
return None
# Only sort the index (fid / edge_id) since this needs to be sorted in a GeoPackage.
# Under most circumstances, this retains the input order,
# making the edge_id as stable as possible; useful for post-processing.
self.df.sort_index(inplace=True)

def plot(self, **kwargs) -> Axes:
assert self.df is not None
Expand Down
11 changes: 4 additions & 7 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,6 @@ def check_extra_columns(cls, v: DataFrame[TableT]):
for colname in v.columns:
if colname == "fid":
# Autogenerated on writing, don't carry them
# TODO: read in the fid as index for EdgeTable
v = v.drop(columns=["fid"]) # type: ignore
elif colname not in cls.columns() and not colname.startswith("meta_"):
raise ValueError(
Expand Down Expand Up @@ -242,11 +241,7 @@ def _write_geopackage(self, temp_path: Path) -> None:
# Add `fid` to all tables as primary key
# Enables editing values manually in QGIS
df = self.df.copy()
if table == "Edge":
df.index.name = "fid"
df.reset_index(inplace=True)
else:
df["fid"] = range(1, len(df) + 1)
df["fid"] = range(1, len(df) + 1)

with closing(connect(temp_path)) as connection:
df.to_sql(
Expand Down Expand Up @@ -365,7 +360,9 @@ def _write_geopackage(self, path: Path) -> None:
path : Path
"""
assert self.df is not None
self.df.to_file(path, layer=self.tablename(), driver="GPKG")
# the index name must be fid otherwise it will generate a separate fid column
self.df.index.name = "fid"
self.df.to_file(path, layer=self.tablename(), index=True, driver="GPKG")


class ChildModel(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
if not node.df["node_id"].is_unique:
raise ValueError("node_id must be unique")
node.df.set_index("node_id", drop=False, inplace=True)
node.df.sort_index(inplace=True)
node.df.index.name = "fid"
node.df.sort_index(inplace=True)
node._save(directory, input_dir)

for sub in self._nodes():
Expand All @@ -197,6 +197,7 @@ def node_table(self) -> NodeTable:
df = pd.concat(df_chunks, ignore_index=True)
node_table = NodeTable(df=df)
node_table.sort()
node_table.df.index.name = "fid"
return node_table

def _nodes(self) -> Generator[MultiNodeModel, Any, None]:
Expand Down
8 changes: 7 additions & 1 deletion python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_repr():
assert isinstance(pump_static._repr_html_(), str)


def test_extra_columns(basic_transient):
def test_extra_columns():
terminal_static = terminal.Static(meta_id=[-1, -2, -3])
assert "meta_id" in terminal_static.df.columns
assert (terminal_static.df.meta_id == [-1, -2, -3]).all()
Expand Down Expand Up @@ -127,6 +127,8 @@ def test_sort(level_setpoint_with_minmax, tmp_path):

def test_roundtrip(trivial, tmp_path):
model1 = trivial
# set custom Edge index
model1.edge.df.index = [15, 12]
model1dir = tmp_path / "model1"
model2dir = tmp_path / "model2"
# read a model and then write it to a different path
Expand All @@ -141,6 +143,10 @@ def test_roundtrip(trivial, tmp_path):
model2dir / "ribasim.toml"
).read_text()

# check if custom Edge indexes are retained (sorted)
assert (model1.edge.df.index == [12, 15]).all()
assert (model2.edge.df.index == [12, 15]).all()

# check if all tables are the same
__assert_equal(model1.node_table().df, model2.node_table().df)
__assert_equal(model1.edge.df, model2.edge.df)
Expand Down
2 changes: 1 addition & 1 deletion python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_write_adds_fid_in_tables(basic, tmp_path):

# for edge no index was provided, but it still needs to write it to file
nrow = len(model_orig.edge.df)
assert model_orig.edge.df.index.name is None
assert model_orig.edge.df.index.name == "fid"
assert model_orig.edge.df.index.equals(pd.RangeIndex(nrow))

model_orig.write(tmp_path / "basic/ribasim.toml")
Expand Down

0 comments on commit 697d87f

Please sign in to comment.