Skip to content

Commit

Permalink
Add set_crs method
Browse files Browse the repository at this point in the history
  • Loading branch information
Hofer-Julian committed Apr 4, 2024
1 parent 521e1c5 commit b6669a3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
16 changes: 7 additions & 9 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,17 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]:
else:
return {}

def _save(
self, directory: DirectoryPath, input_dir: DirectoryPath, crs: str
) -> None:
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath) -> None:
# TODO directory could be used to save an arrow file
db_path = context_file_loading.get().get("database")
if self.filepath is not None:
self.sort()
self._write_arrow(self.filepath, directory, input_dir)
elif db_path is not None:
self.sort()
self._write_geopackage(db_path, crs)
self._write_geopackage(db_path)

def _write_geopackage(self, temp_path: Path, crs: str) -> None:
def _write_geopackage(self, temp_path: Path) -> None:
"""
Write the contents of the input to a database.
Expand Down Expand Up @@ -353,7 +351,7 @@ def _from_db(cls, path: Path, table: str):

return df

def _write_geopackage(self, path: Path, crs: str) -> None:
def _write_geopackage(self, path: Path) -> None:
"""
Write the contents of the input to the GeoPackage.
Expand All @@ -362,7 +360,7 @@ def _write_geopackage(self, path: Path, crs: str) -> None:
path : Path
"""
assert self.df is not None
self.df.set_crs(crs).to_file(path, layer=self.tablename(), driver="GPKG")
self.df.to_file(path, layer=self.tablename(), driver="GPKG")


class ChildModel(BaseModel):
Expand Down Expand Up @@ -423,9 +421,9 @@ def node_ids(self) -> set[int]:
node_ids.update(table.node_ids())
return node_ids

def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, crs: str):
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
for table in self._tables():
table._save(directory, input_dir, crs)
table._save(directory, input_dir)

def _repr_content(self) -> str:
"""Generate a succinct overview of the content.
Expand Down
16 changes: 13 additions & 3 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ribasim.input_base import (
ChildModel,
FileModel,
SpatialTableModel,
context_file_loading,
)
from ribasim.utils import MissingOptionalModule
Expand Down Expand Up @@ -160,11 +161,12 @@ def _write_toml(self, fn: Path) -> Path:
return fn

def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
self.set_crs(self.crs)
db_path = directory / input_dir / "database.gpkg"
db_path.parent.mkdir(parents=True, exist_ok=True)
db_path.unlink(missing_ok=True)
context_file_loading.get()["database"] = db_path
self.edge._save(directory, input_dir, self.crs)
self.edge._save(directory, input_dir)

node = self.node_table()
# Temporarily require unique node_id for #1262
Expand All @@ -174,10 +176,18 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
node.df.set_index("node_id", drop=False, inplace=True)
node.df.sort_index(inplace=True)
node.df.index.name = "fid"
node._save(directory, input_dir, self.crs)
node._save(directory, input_dir)

for sub in self._nodes():
sub._save(directory, input_dir, self.crs)
sub._save(directory, input_dir)

def set_crs(self, crs: str) -> None:
self.edge.df.set_crs(crs)
for sub in self._nodes():
for table in sub._tables():
if isinstance(table, SpatialTableModel) and table.df is not None:
# TODO: that is missing node tables now
table.df.set_crs(crs)

def node_table(self) -> NodeTable:
"""Compute the full NodeTable from all node types."""
Expand Down

0 comments on commit b6669a3

Please sign in to comment.