Skip to content

Commit

Permalink
refactor(modpathfile): toward unified particle tracking api (#2127)
Browse files Browse the repository at this point in the history
* introduce base particle track file module and class
* rename _ModpathSeries -> ModpathFile
* deduplicate shared logic in ModpathFile
* deprecate write_shapefile() params
* prep to add support for MF6 PRT
* clarify canonical (minimal) fields
* add dtypes as class attributes
* misc cleanup in plotutil.py
  • Loading branch information
wpbonelli authored Mar 28, 2024
1 parent f75853f commit 77e5e1d
Show file tree
Hide file tree
Showing 7 changed files with 832 additions and 1,155 deletions.
12 changes: 2 additions & 10 deletions autotest/test_mp5.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os

import numpy as np
import pandas as pd
from autotest.test_mp6 import eval_timeseries
from matplotlib import pyplot as plt
from modflow_devtools.markers import requires_pkg

from flopy.modflow import Modflow
from flopy.plot import PlotMapView
Expand Down Expand Up @@ -50,14 +48,8 @@ def test_mp5_load(function_tmpdir, example_data_path):
for n in pthobj.nid:
p = pthobj.get_data(partid=n)
e = endobj.get_data(partid=n)
try:
mm.plot_pathline(p, colors=colors[n], layer="all")
except:
assert False, f'could not plot pathline {n + 1} with layer="all"'
try:
mm.plot_endpoint(e)
except:
assert False, f'could not plot endpoint {n + 1} with layer="all"'
mm.plot_pathline(p, colors=colors[n], layer="all")
mm.plot_endpoint(e)

# plot the grid and ibound array
mm.plot_grid(lw=0.5)
Expand Down
5 changes: 2 additions & 3 deletions autotest/test_plot_particle_tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import pandas as pd
import pytest
from matplotlib.collections import LineCollection, PathCollection
from modflow_devtools.markers import requires_exe, requires_pkg
from modflow_devtools.markers import requires_exe

from flopy.modflow import Modflow
from flopy.modpath import Modpath6, Modpath6Bas
from flopy.plot import PlotCrossSection, PlotMapView
from flopy.utils import CellBudgetFile, EndpointFile, HeadFile, PathlineFile
from flopy.utils import EndpointFile, PathlineFile


@pytest.fixture
Expand Down Expand Up @@ -58,7 +58,6 @@ def test_plot(pl):
mx.plot_grid()
mx.plot_bc("WEL", kper=2, color="blue")
pth = mx.plot_pathline(pl, colors="red")
# plt.show()
assert isinstance(pth, LineCollection)
assert len(pth._paths) == 114

Expand Down
20 changes: 10 additions & 10 deletions autotest/test_plotutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
"PRP000000001", # name
],
],
columns=PRT_PATHLINE_DTYPE.fields.keys(),
columns=PRT_PATHLINE_DTYPE.names,
)
MP7_TEST_PATHLINES = pd.DataFrame.from_records(
[
Expand Down Expand Up @@ -233,7 +233,7 @@
1, # timestep
],
],
columns=MP7_PATHLINE_DTYPE.fields.keys(),
columns=MP7_PATHLINE_DTYPE.names,
)
MP7_TEST_ENDPOINTS = pd.DataFrame.from_records(
[
Expand Down Expand Up @@ -322,7 +322,7 @@
2, # cellface
],
],
columns=MP7_ENDPOINT_DTYPE.fields.keys(),
columns=MP7_ENDPOINT_DTYPE.names,
)


Expand All @@ -342,13 +342,13 @@ def test_to_mp7_pathlines(dataframe):
assert len(mp7_pls) == 10
assert set(
dict(mp7_pls.dtypes).keys() if dataframe else mp7_pls.dtype.names
) == set(MP7_PATHLINE_DTYPE.fields.keys())
) == set(MP7_PATHLINE_DTYPE.names)


@pytest.mark.parametrize("dataframe", [True, False])
def test_to_mp7_pathlines_empty(dataframe):
mp7_pls = to_mp7_pathlines(
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.fields.keys())
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.names)
if dataframe
else np.recarray((0,), dtype=PRT_PATHLINE_DTYPE)
)
Expand All @@ -374,7 +374,7 @@ def test_to_mp7_pathlines_noop(dataframe):
assert len(mp7_pls) == 2
assert set(
dict(mp7_pls.dtypes).keys() if dataframe else mp7_pls.dtype.names
) == set(MP7_PATHLINE_DTYPE.fields.keys())
) == set(MP7_PATHLINE_DTYPE.names)
assert np.array_equal(
mp7_pls if dataframe else pd.DataFrame(mp7_pls), MP7_TEST_PATHLINES
)
Expand All @@ -391,13 +391,13 @@ def test_to_mp7_endpoints(dataframe):
assert np.isclose(mp7_eps.time[0], PRT_TEST_PATHLINES.t.max())
assert set(
dict(mp7_eps.dtypes).keys() if dataframe else mp7_eps.dtype.names
) == set(MP7_ENDPOINT_DTYPE.fields.keys())
) == set(MP7_ENDPOINT_DTYPE.names)


@pytest.mark.parametrize("dataframe", [True, False])
def test_to_mp7_endpoints_empty(dataframe):
mp7_eps = to_mp7_endpoints(
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.fields.keys())
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.names)
if dataframe
else np.recarray((0,), dtype=PRT_PATHLINE_DTYPE)
)
Expand Down Expand Up @@ -445,7 +445,7 @@ def test_to_prt_pathlines_roundtrip(dataframe):
@pytest.mark.parametrize("dataframe", [True, False])
def test_to_prt_pathlines_roundtrip_empty(dataframe):
mp7_pls = to_mp7_pathlines(
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.fields.keys())
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.names)
if dataframe
else np.recarray((0,), dtype=PRT_PATHLINE_DTYPE)
)
Expand All @@ -454,4 +454,4 @@ def test_to_prt_pathlines_roundtrip_empty(dataframe):
assert prt_pls.empty if dataframe else mp7_pls.size == 0
assert set(
dict(mp7_pls.dtypes).keys() if dataframe else mp7_pls.dtype.names
) == set(MP7_PATHLINE_DTYPE.fields.keys())
) == set(MP7_PATHLINE_DTYPE.names)
31 changes: 17 additions & 14 deletions flopy/plot/plotutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -2679,7 +2679,6 @@ def parse_modpath_selection_options(
("cellface", np.int32),
]
)
MP_MIN_PLOT_FIELDS = ["x", "y", "z", "time", "k", "particleid"]


def to_mp7_pathlines(
Expand All @@ -2698,6 +2697,8 @@ def to_mp7_pathlines(
np.recarray or pd.DataFrame (consistent with input type)
"""

from flopy.utils.particletrackfile import MIN_PARTICLE_TRACK_DTYPE

# determine return type
ret_type = type(data)

Expand All @@ -2708,13 +2709,13 @@ def to_mp7_pathlines(
# check format
dt = data.dtypes
if not (
all(n in dt for n in MP_MIN_PLOT_FIELDS)
or all(n in dt for n in PRT_PATHLINE_DTYPE.fields.keys())
all(n in dt for n in MIN_PARTICLE_TRACK_DTYPE.names)
or all(n in dt for n in PRT_PATHLINE_DTYPE.names)
):
raise ValueError(
"Pathline data must contain the following fields: "
f"{MP_MIN_PLOT_FIELDS} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.fields.keys()} for MODFLOW 6 PRT"
f"{MIN_PARTICLE_TRACK_DTYPE.names} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.names} for MODFLOW 6 PRT"
)

# return early if already in MP7 format
Expand Down Expand Up @@ -2780,6 +2781,8 @@ def to_mp7_endpoints(
np.recarray or pd.DataFrame (consistent with input type)
"""

from flopy.utils.particletrackfile import MIN_PARTICLE_TRACK_DTYPE

# determine return type
ret_type = type(data)

Expand All @@ -2789,18 +2792,18 @@ def to_mp7_endpoints(

# check format
dt = data.dtypes
if all(n in dt for n in MP7_ENDPOINT_DTYPE.fields.keys()):
if all(n in dt for n in MP7_ENDPOINT_DTYPE.names):
return (
data if ret_type == pd.DataFrame else data.to_records(index=False)
)
if not (
all(n in dt for n in MP_MIN_PLOT_FIELDS)
or all(n in dt for n in PRT_PATHLINE_DTYPE.fields.keys())
all(n in dt for n in MIN_PARTICLE_TRACK_DTYPE.names)
or all(n in dt for n in PRT_PATHLINE_DTYPE.names)
):
raise ValueError(
"Pathline data must contain the following fields: "
f"{MP_MIN_PLOT_FIELDS} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.fields.keys()} for MODFLOW 6 PRT"
f"{MIN_PARTICLE_TRACK_DTYPE.names} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.names} for MODFLOW 6 PRT"
)

# return early if empty
Expand Down Expand Up @@ -2909,13 +2912,13 @@ def to_prt_pathlines(
# check format
dt = data.dtypes
if not (
all(n in dt for n in MP7_PATHLINE_DTYPE.fields.keys())
or all(n in dt for n in PRT_PATHLINE_DTYPE.fields.keys())
all(n in dt for n in MP7_PATHLINE_DTYPE.names)
or all(n in dt for n in PRT_PATHLINE_DTYPE.names)
):
raise ValueError(
"Pathline data must contain the following fields: "
f"{MP7_PATHLINE_DTYPE.fields.keys()} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.fields.keys()} for MODFLOW 6 PRT"
f"{MP7_PATHLINE_DTYPE.names} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.names} for MODFLOW 6 PRT"
)

# return early if already in PRT format
Expand Down
1 change: 0 additions & 1 deletion flopy/utils/flopy_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ def loadtxt(
ra : np.recarray
Numpy record array of file contents.
"""
from ..utils import import_optional_dependency

if use_pandas:
if delimiter.isspace():
Expand Down
Loading

0 comments on commit 77e5e1d

Please sign in to comment.