Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support specifying coordinate reference system (CRS) of geometries #1339

Merged
merged 11 commits into from
Apr 4, 2024
1 change: 1 addition & 0 deletions core/src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ end
@option @addnodetypes struct Toml <: TableOption
starttime::DateTime
endtime::DateTime
crs::String
Hofer-Julian marked this conversation as resolved.
Show resolved Hide resolved
ribasim_version::String
input_dir::String
results_dir::String
Expand Down
1 change: 1 addition & 0 deletions core/test/data/config_test.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
starttime = 2019-01-01
endtime = 2019-12-31
crs = "EPSG:28992"
input_dir = "../../generated_testmodels/lhm"
results_dir = "../../generated_testmodels/lhm"
ribasim_version = "2024.6.1"
Expand Down
1 change: 1 addition & 0 deletions core/test/data/logging_test_loglevel_debug.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
starttime = 2019-01-01
endtime = 2019-12-31
crs = "EPSG:28992"
input_dir = "."
results_dir = "results"
ribasim_version = "2024.6.1"
Expand Down
1 change: 1 addition & 0 deletions core/test/data/logging_test_no_loglevel.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
starttime = 2019-01-01
endtime = 2019-12-31
crs = "EPSG:28992"
input_dir = "."
results_dir = "results"
ribasim_version = "2024.6.1"
5 changes: 5 additions & 0 deletions core/test/docs.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
starttime = 2019-01-01 # required
endtime = 2021-01-01 # required

# Coordinate Reference System
# The accepted strings are documented here:
# https://proj.org/en/9.4/development/reference/functions.html#c.proj_create
crs = "EPSG:4326" # required
Hofer-Julian marked this conversation as resolved.
Show resolved Hide resolved

# input files
input_dir = "." # required
results_dir = "results" # required
Expand Down
3 changes: 3 additions & 0 deletions core/test/io_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
database = "path/to/file",
input_dir = ".",
results_dir = "results",
crs = "EPSG:28992",
ribasim_version = string(Ribasim.pkgversion(Ribasim)),
)
config = Ribasim.Config(toml, "model")
Expand All @@ -21,6 +22,7 @@
database = "path/to/file",
input_dir = "input",
results_dir = "results",
crs = "EPSG:28992",
ribasim_version = string(Ribasim.pkgversion(Ribasim)),
)
config = Ribasim.Config(toml, "model")
Expand All @@ -34,6 +36,7 @@
database = "/path/to/file",
input_dir = ".",
results_dir = "results",
crs = "EPSG:28992",
ribasim_version = string(Ribasim.pkgversion(Ribasim)),
)
config = Ribasim.Config(toml)
Expand Down
19 changes: 5 additions & 14 deletions docs/python/examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = Model(starttime=\"2020-01-01\", endtime=\"2021-01-01\")"
"model = Model(starttime=\"2020-01-01\", endtime=\"2021-01-01\", crs=\"EPSG:4326\")"
]
},
{
Expand Down Expand Up @@ -457,7 +457,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = Model(starttime=\"2020-01-01\", endtime=\"2021-01-01\")"
"model = Model(starttime=\"2020-01-01\", endtime=\"2021-01-01\", crs=\"EPSG:4326\")"
]
},
{
Expand Down Expand Up @@ -771,10 +771,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = Model(\n",
" starttime=\"2020-01-01\",\n",
" endtime=\"2020-12-01\",\n",
")"
"model = Model(starttime=\"2020-01-01\", endtime=\"2020-12-01\", crs=\"EPSG:4326\")"
]
},
{
Expand Down Expand Up @@ -1037,10 +1034,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = Model(\n",
" starttime=\"2020-01-01\",\n",
" endtime=\"2020-01-20\",\n",
")"
"model = Model(starttime=\"2020-01-01\", endtime=\"2020-01-20\", crs=\"EPSG:4326\")"
]
},
{
Expand Down Expand Up @@ -1430,10 +1424,7 @@
"metadata": {},
"outputs": [],
"source": [
"model = Model(\n",
" starttime=\"2020-01-01\",\n",
" endtime=\"2020-02-01\",\n",
")"
"model = Model(starttime=\"2020-01-01\", endtime=\"2020-02-01\", crs=\"EPSG:4326\")"
]
},
{
Expand Down
16 changes: 4 additions & 12 deletions python/ribasim/ribasim/geometry/edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pandera.dtypes import Int32
from pandera.typing import Series
from pandera.typing.geopandas import GeoDataFrame, GeoSeries
from pydantic import model_validator
from shapely.geometry import LineString, MultiLineString, Point

from ribasim.input_base import SpatialTableModel
Expand Down Expand Up @@ -45,13 +44,6 @@ class Config:
class EdgeTable(SpatialTableModel[EdgeSchema]):
"""Defines the connections between nodes."""

@model_validator(mode="after")
def empty_table(self) -> "EdgeTable":
if self.df is None:
self.df = GeoDataFrame[EdgeSchema]()
self.df.set_geometry("geometry", inplace=True)
return self

def add(
self,
from_node: NodeData,
Expand All @@ -68,6 +60,8 @@ def add(
edge_type = (
"control" if from_node.node_type in SPATIALCONTROLNODETYPES else "flow"
)
assert self.df is not None

table_to_append = GeoDataFrame[EdgeSchema](
data={
"from_node_type": pd.Series([from_node.node_type], dtype=str),
Expand All @@ -79,12 +73,10 @@ def add(
"subnetwork_id": pd.Series([subnetwork_id], dtype=pd.Int32Dtype()),
},
geometry=geometry_to_append,
crs=self.df.crs,
)

if self.df is None:
self.df = table_to_append
else:
self.df = GeoDataFrame[EdgeSchema](pd.concat([self.df, table_to_append]))
self.df = GeoDataFrame[EdgeSchema](pd.concat([self.df, table_to_append]))

def get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]:
assert self.df is not None
Expand Down
29 changes: 7 additions & 22 deletions python/ribasim/ribasim/input_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from contextlib import closing
Expand Down Expand Up @@ -39,7 +38,6 @@
delimiter = " / "

gpd.options.io_engine = "pyogrio"
warnings.filterwarnings("ignore", category=UserWarning, module="pyogrio")

context_file_loading: ContextVar[dict[str, Any]] = ContextVar(
"file_loading", default={}
Expand Down Expand Up @@ -134,15 +132,6 @@ def set_filepath(self, filepath: Path) -> None:
self.filepath = filepath
self.model_config["validate_assignment"] = True

@abstractmethod
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath) -> None:
"""Save this instance to disk.

This method needs to be implemented by any class deriving from
FileModel.
"""
raise NotImplementedError()

@classmethod
@abstractmethod
def _load(cls, filepath: Path | None) -> dict[str, Any]:
Expand Down Expand Up @@ -227,21 +216,17 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]:
else:
return {}

def _save(
self,
directory: DirectoryPath,
input_dir: DirectoryPath,
) -> None:
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath) -> None:
# TODO directory could be used to save an arrow file
db_path = context_file_loading.get().get("database")
if self.filepath is not None:
self.sort()
self._write_arrow(self.filepath, directory, input_dir)
elif db_path is not None:
self.sort()
self._write_table(db_path)
self._write_geopackage(db_path)

def _write_table(self, temp_path: Path) -> None:
def _write_geopackage(self, temp_path: Path) -> None:
"""
Write the contents of the input to a database.

Expand Down Expand Up @@ -366,16 +351,16 @@ def _from_db(cls, path: Path, table: str):

return df

def _write_table(self, path: Path) -> None:
def _write_geopackage(self, path: Path) -> None:
"""
Write the contents of the input to a database.
Write the contents of the input to the GeoPackage.

Parameters
----------
path : Path
"""
assert self.df is not None
self.df.to_file(path, layer=self.tablename(), driver="GPKG", mode="a")
self.df.to_file(path, layer=self.tablename(), driver="GPKG")


class ChildModel(BaseModel):
Expand Down Expand Up @@ -436,7 +421,7 @@ def node_ids(self) -> set[int]:
node_ids.update(table.node_ids())
return node_ids

def _save(self, directory: DirectoryPath, input_dir: DirectoryPath, **kwargs):
def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
for table in self._tables():
table._save(directory, input_dir)

Expand Down
33 changes: 24 additions & 9 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tomli
import tomli_w
from matplotlib import pyplot as plt
from pandera.typing.geopandas import GeoDataFrame
from pydantic import (
DirectoryPath,
Field,
Expand Down Expand Up @@ -39,11 +40,12 @@
Terminal,
UserDemand,
)
from ribasim.geometry.edge import EdgeTable
from ribasim.geometry.edge import EdgeSchema, EdgeTable
from ribasim.geometry.node import NodeTable
from ribasim.input_base import (
ChildModel,
FileModel,
SpatialTableModel,
context_file_loading,
)
from ribasim.utils import MissingOptionalModule
Expand All @@ -57,6 +59,7 @@
class Model(FileModel):
starttime: datetime.datetime
endtime: datetime.datetime
crs: str

input_dir: Path = Field(default=Path("."))
results_dir: Path = Field(default=Path("results"))
Expand Down Expand Up @@ -97,6 +100,13 @@ def set_node_parent(self) -> "Model":
setattr(v, "_parent_field", k)
return self

@model_validator(mode="after")
def ensure_edge_table_is_present(self) -> "Model":
if self.edge.df is None:
self.edge.df = GeoDataFrame[EdgeSchema]()
self.edge.df.set_geometry("geometry", inplace=True, crs=self.crs)
return self

@field_serializer("input_dir", "results_dir")
def serialize_path(self, path: Path) -> str:
return str(path)
Expand Down Expand Up @@ -151,6 +161,7 @@ def _write_toml(self, fn: Path) -> Path:
return fn

def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
self.set_crs(self.crs)
db_path = directory / input_dir / "database.gpkg"
db_path.parent.mkdir(parents=True, exist_ok=True)
db_path.unlink(missing_ok=True)
Expand All @@ -170,9 +181,19 @@ def _save(self, directory: DirectoryPath, input_dir: DirectoryPath):
for sub in self._nodes():
sub._save(directory, input_dir)

def set_crs(self, crs: str) -> None:
Hofer-Julian marked this conversation as resolved.
Show resolved Hide resolved
self.crs = crs
self.edge.df = self.edge.df.set_crs(crs)
for sub in self._nodes():
if sub.node.df is not None:
sub.node.df = sub.node.df.set_crs(crs)
for table in sub._tables():
if isinstance(table, SpatialTableModel) and table.df is not None:
Hofer-Julian marked this conversation as resolved.
Show resolved Hide resolved
table.df = table.df.set_crs(crs)

def node_table(self) -> NodeTable:
"""Compute the full NodeTable from all node types."""
df_chunks = [node.node.df for node in self._nodes()]
df_chunks = [node.node.df.set_crs(self.crs) for node in self._nodes()]
Hofer-Julian marked this conversation as resolved.
Show resolved Hide resolved
df = pd.concat(df_chunks, ignore_index=True)
node_table = NodeTable(df=df)
node_table.sort()
Expand Down Expand Up @@ -392,12 +413,6 @@ def to_xugrid(self):
name="node_index",
)

if node_df.crs is None:
# TODO: can be removed when CRS is required, #1254
projected = False
else:
projected = node_df.crs.is_projected

grid = xugrid.Ugrid1d(
node_x=node_df.geometry.x,
node_y=node_df.geometry.y,
Expand All @@ -409,7 +424,7 @@ def to_xugrid(self):
)
),
name="ribasim",
projected=projected,
projected=node_df.crs.is_projected,
crs=node_df.crs,
)

Expand Down
6 changes: 4 additions & 2 deletions python/ribasim/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_sort(level_setpoint_with_minmax, tmp_path):
model.write(tmp_path / "basic/ribasim.toml")
# write sorts the model in place
assert table.df.iloc[0]["greater_than"] == 5.0
model_loaded = ribasim.Model(filepath=tmp_path / "basic/ribasim.toml")
model_loaded = ribasim.Model.read(filepath=tmp_path / "basic/ribasim.toml")
table_loaded = model_loaded.discrete_control.condition
edge_loaded = model_loaded.edge
assert table_loaded.df.iloc[0]["greater_than"] == 5.0
Expand Down Expand Up @@ -153,7 +153,9 @@ def test_roundtrip(trivial, tmp_path):
def test_datetime_timezone():
# Due to a pydantic issue, a time zone was added.
# https://github.com/Deltares/Ribasim/issues/1282
model = ribasim.Model(starttime="2000-01-01", endtime="2001-01-01 00:00:00")
model = ribasim.Model(
starttime="2000-01-01", endtime="2001-01-01 00:00:00", crs="EPSG:28992"
)
assert isinstance(model.starttime, datetime)
assert isinstance(model.endtime, datetime)
assert model.starttime.tzinfo is None
Expand Down
Loading
Loading