Skip to content

Commit

Permalink
Feature: add SwnMf6.gather_reaches() (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
mwtoews authored Dec 29, 2023
1 parent b0231e9 commit f1799fc
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 112 deletions.
1 change: 1 addition & 0 deletions docs/source/swnmf6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,5 @@ Utilities
:toctree: ref/

SwnMf6.get_location_frame_reach_info
SwnMf6.gather_reaches
SwnMf6.route_reaches
2 changes: 1 addition & 1 deletion swn/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def gdf_to_shapefile(gdf, shp_fname, **kwargs):
rename[col] = colname10[col]
if rename:
gdf.rename(columns=rename, inplace=True)
gdf.to_file(str(shp_fname), **kwargs)
gdf.to_file(str(shp_fname), index=True, **kwargs)


def read_formatted_frame(fname):
Expand Down
116 changes: 116 additions & 0 deletions swn/modflow/_swnmf6.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
]

import os
from copy import deepcopy
from itertools import zip_longest

import numpy as np
Expand Down Expand Up @@ -1967,6 +1968,10 @@ def route_reaches(self, start, end, *, allow_indirect=False):
ConnecionError
If start and end reach numbers do not connect.
See Also
--------
gather_reaches : Query multiple reaches up and downstream.
Examples
--------
>>> import flopy
Expand Down Expand Up @@ -2054,6 +2059,117 @@ def go_downstream(ridx):
idx2 = con2.index(ridx)
return con1[:idx1] + list(reversed(con2[:idx2]))

def gather_reaches(
self, *, upstream=[], downstream=[], barrier=[], gather_upstream=False
):
"""Return reaches upstream (inclusive) and downstream (exclusive).
Parameters
----------
upstream, downstream : int or list, default []
Reach number(s) (rno or ifno) from reaches.index to search from.
barriers : int or list, default []
Reach number(s) that cannot be traversed past.
gather_upstream : bool, default False
Gather upstream from all other downstream reaches.
Returns
-------
list
See Also
--------
route_reaches :
Return a list of reaches that connect a pair of reaches.
Examples
--------
>>> import flopy
>>> import geopandas
>>> import swn
>>> lines = geopandas.GeoSeries.from_wkt([
... "LINESTRING (60 100, 60 80)",
... "LINESTRING (40 130, 60 100)",
... "LINESTRING (70 130, 60 100)"])
>>> lines.index += 100
>>> n = swn.SurfaceWaterNetwork.from_lines(lines)
>>> sim = flopy.mf6.MFSimulation()
>>> _ = flopy.mf6.ModflowTdis(sim, nper=1, time_units="days")
>>> gwf = flopy.mf6.ModflowGwf(sim)
>>> _ = flopy.mf6.ModflowGwfdis(
... gwf, nrow=3, ncol=2, delr=20.0, delc=20.0, idomain=1,
... length_units="meters", xorigin=30.0, yorigin=70.0)
>>> nm = swn.SwnMf6.from_swn_flopy(n, gwf)
>>> nm.reaches[["iseg", "ireach", "to_ifno", "from_ifnos", "segnum"]]
iseg ireach to_ifno from_ifnos segnum
ifno
1 1 1 2 {} 101
2 1 2 3 {1} 101
3 1 3 6 {2} 101
4 2 1 5 {} 102
5 2 2 6 {4} 102
6 3 1 7 {3, 5} 100
7 3 2 0 {6} 100
>>> nm.gather_reaches(upstream=6)
[6, 3, 2, 1, 5, 4]
>>> nm.gather_reaches(downstream=4)
[5, 6, 7]
"""
reaches_set = set(self.reaches.index)

def check_and_return_list(var, name):
if isinstance(var, list):
if not reaches_set.issuperset(var):
diff = list(sorted(set(var).difference(reaches_set)))
raise IndexError(
f"{len(diff)} {name} "
f"reach{'' if len(diff) == 1 else 'es'} "
f"not found in reaches.index: {abbr_str(diff)}"
)
return var
else:
if var not in reaches_set:
raise IndexError(
f"{name} {self.reach_index_name} {var} not found in reaches.index"
)
return [var]

def go_upstream(ridx):
yield ridx
for from_ridx in from_ridxs.get(ridx, []):
yield from go_upstream(from_ridx)

def go_downstream(ridx):
yield ridx
if ridx in to_ridxs:
yield from go_downstream(to_ridxs[ridx])

to_ridx_name = f"to_{self.reach_index_name}"
to_ridxs = dict(self.reaches.loc[self.reaches[to_ridx_name] != 0, to_ridx_name])
from_ridxs = self.reaches[f"from_{self.reach_index_name}s"]
# Note that `.copy(deep=True)` does not work; use deepcopy
from_ridxs = from_ridxs[from_ridxs.apply(len) > 0].apply(deepcopy)
for barrier in check_and_return_list(barrier, "barrier"):
for ridx in from_ridxs.get(barrier, []):
del to_ridxs[ridx]
from_ridxs[to_ridxs[barrier]].remove(barrier)
del to_ridxs[barrier]

ridxs = []
for ridx in check_and_return_list(upstream, "upstream"):
upridxs = list(go_upstream(ridx))
ridxs += upridxs # ridx inclusive
for ridx in check_and_return_list(downstream, "downstream"):
downidxs = list(go_downstream(ridx))
ridxs += downidxs[1:] # ridx exclusive
if gather_upstream:
for ridx in downidxs[1:]:
for from_ridx in from_ridxs.get(ridx, []):
if from_ridx not in downidxs:
upridxs = list(go_upstream(from_ridx))
ridxs += upridxs
return ridxs


def get_flopy_mf6_package(name: str):
"""Returns a flopy.mf6 package.
Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import geopandas
import pandas as pd
import pytest
from shapely import wkt

try:
import matplotlib
Expand All @@ -29,6 +30,28 @@
datadir = Path("tests") / "data"


# https://commons.wikimedia.org/wiki/File:Flussordnung_(Strahler).svg
fluss_gs = geopandas.GeoSeries(
wkt.loads(
"""\
MULTILINESTRING(
(380 490, 370 420), (300 460, 370 420), (370 420, 420 330),
(190 250, 280 270), (225 180, 280 270), (280 270, 420 330),
(420 330, 584 250), (520 220, 584 250), (584 250, 710 160),
(740 270, 710 160), (735 350, 740 270), (880 320, 740 270),
(925 370, 880 320), (974 300, 880 320), (760 460, 735 350),
(650 430, 735 350), (710 160, 770 100), (700 90, 770 100),
(770 100, 820 40))
"""
).geoms
)


@pytest.fixture
def fluss_n():
return swn.SurfaceWaterNetwork.from_lines(fluss_gs)


@pytest.fixture(scope="session", autouse=True)
def coastal_lines_gdf():
gdf = geopandas.read_file(datadir / "DN2_Coastal_strahler1z_stream_vf.shp")
Expand Down
Loading

0 comments on commit f1799fc

Please sign in to comment.