From 142150ba8acd6c0489fde0b87a5691bb03911c97 Mon Sep 17 00:00:00 2001 From: Bart de Koning <74617371+SouthEndMusic@users.noreply.github.com> Date: Thu, 8 Aug 2024 15:58:30 +0200 Subject: [PATCH] Make more python functions private (#1702) Fixes https://github.com/Deltares/Ribasim/issues/1651 --- python/ribasim/ribasim/geometry/edge.py | 14 +++++++++++--- python/ribasim/ribasim/input_base.py | 22 +++++++++++----------- python/ribasim/ribasim/model.py | 10 +++++----- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index 0b7162ab6..ce1f09a91 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -107,7 +107,7 @@ def add( ) self.df.index.name = "fid" - def get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]: + def _get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]: assert self.df is not None return (self.df.edge_type == edge_type).to_numpy() @@ -118,6 +118,14 @@ def sort(self): self.df.sort_index(inplace=True) def plot(self, **kwargs) -> Axes: + """Plot the edges of the model. + + Parameters + ---------- + **kwargs : Dict + Supported: 'ax', 'color_flow', 'color_control' + """ + assert self.df is not None kwargs = kwargs.copy() # Avoid side-effects ax = kwargs.get("ax", None) @@ -141,8 +149,8 @@ def plot(self, **kwargs) -> Axes: kwargs_control["color"] = color_control kwargs_control["label"] = "Control edge" - where_flow = self.get_where_edge_type("flow") - where_control = self.get_where_edge_type("control") + where_flow = self._get_where_edge_type("flow") + where_control = self._get_where_edge_type("control") if not self.df[where_flow].empty: self.df[where_flow].plot(**kwargs_flow) diff --git a/python/ribasim/ribasim/input_base.py b/python/ribasim/ribasim/input_base.py index b5c8ee89d..bae678a09 100644 --- a/python/ribasim/ribasim/input_base.py +++ b/python/ribasim/ribasim/input_base.py @@ -92,7 +92,7 @@ class BaseModel(PydanticBaseModel): ) @classmethod - def fields(cls) -> list[str]: + def _fields(cls) -> list[str]: """Return the names of the fields contained in the Model.""" return list(cls.model_fields.keys()) @@ -117,7 +117,7 @@ class FileModel(BaseModel, ABC): @model_validator(mode="before") @classmethod - def check_filepath(cls, value: Any) -> Any: + def _check_filepath(cls, value: Any) -> Any: # Enable initialization with a Path. if isinstance(value, dict): # Pydantic Model init requires a dict @@ -173,7 +173,7 @@ class TableModel(FileModel, Generic[TableT]): @field_validator("df") @classmethod - def check_extra_columns(cls, v: DataFrame[TableT]): + def _check_extra_columns(cls, v: DataFrame[TableT]): """Allow only extra columns with `meta_` prefix.""" if isinstance(v, (pd.DataFrame, gpd.GeoDataFrame)): for colname in v.columns: @@ -187,7 +187,7 @@ def check_extra_columns(cls, v: DataFrame[TableT]): return v @model_serializer - def set_model(self) -> str | None: + def _set_model(self) -> str | None: return str(self.filepath.name) if self.filepath is not None else None @classmethod @@ -213,14 +213,14 @@ def tablename(cls) -> str: @model_validator(mode="before") @classmethod - def check_dataframe(cls, value: Any) -> Any: + def _check_dataframe(cls, value: Any) -> Any: # Enable initialization with a DataFrame. if isinstance(value, pd.DataFrame | gpd.GeoDataFrame): value = {"df": value} return value - def node_ids(self) -> set[int]: + def _node_ids(self) -> set[int]: node_ids: set[int] = set() if self.df is not None and "node_id" in self.df.columns: node_ids.update(self.df["node_id"]) @@ -394,7 +394,7 @@ class ChildModel(BaseModel): _parent_field: str | None = None @model_validator(mode="after") - def check_parent(self) -> "ChildModel": + def _check_parent(self) -> "ChildModel": if self._parent is not None: self._parent.model_fields_set.update({self._parent_field}) return self @@ -432,7 +432,7 @@ def _layername(cls, field: str) -> str: return f"{cls.get_input_type()}{delimiter}{field}" def _tables(self) -> Generator[TableModel[Any], Any, None]: - for key in self.fields(): + for key in self._fields(): attr = getattr(self, key) if ( isinstance(attr, TableModel) @@ -441,10 +441,10 @@ def _tables(self) -> Generator[TableModel[Any], Any, None]: ): yield attr - def node_ids(self) -> set[int]: + def _node_ids(self) -> set[int]: node_ids: set[int] = set() for table in self._tables(): - node_ids.update(table.node_ids()) + node_ids.update(table._node_ids()) return node_ids def _save(self, directory: DirectoryPath, input_dir: DirectoryPath): @@ -457,7 +457,7 @@ def _repr_content(self) -> str: Skip "empty" attributes: when the dataframe of a TableModel is None. """ content = [] - for field in self.fields(): + for field in self._fields(): attr = getattr(self, field) if isinstance(attr, TableModel): if attr.df is not None: diff --git a/python/ribasim/ribasim/model.py b/python/ribasim/ribasim/model.py index 3aac4c4be..4cca8dbbc 100644 --- a/python/ribasim/ribasim/model.py +++ b/python/ribasim/ribasim/model.py @@ -98,7 +98,7 @@ class Model(FileModel): edge: EdgeTable = Field(default_factory=EdgeTable) @model_validator(mode="after") - def set_node_parent(self) -> "Model": + def _set_node_parent(self) -> "Model": for ( k, v, @@ -108,14 +108,14 @@ def set_node_parent(self) -> "Model": return self @model_validator(mode="after") - def ensure_edge_table_is_present(self) -> "Model": + def _ensure_edge_table_is_present(self) -> "Model": if self.edge.df is None: self.edge.df = GeoDataFrame[EdgeSchema]() self.edge.df.set_geometry("geometry", inplace=True, crs=self.crs) return self @field_serializer("input_dir", "results_dir") - def serialize_path(self, path: Path) -> str: + def _serialize_path(self, path: Path) -> str: return str(path) def model_post_init(self, __context: Any) -> None: @@ -132,7 +132,7 @@ def __repr__(self) -> str: """ content = ["ribasim.Model("] INDENT = " " - for field in self.fields(): + for field in self._fields(): attr = getattr(self, field) if isinstance(attr, EdgeTable): content.append(f"{INDENT}{field}=Edge(...),") @@ -289,7 +289,7 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]: return {} @model_validator(mode="after") - def reset_contextvar(self) -> "Model": + def _reset_contextvar(self) -> "Model": # Drop database info context_file_loading.set({}) return self