Skip to content

Commit

Permalink
module const for canonical particle tracking dtype, cleanup plotutil.py
Browse files Browse the repository at this point in the history
  • Loading branch information
wpbonelli committed Mar 21, 2024
1 parent 27f80f6 commit 9149045
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 34 deletions.
20 changes: 10 additions & 10 deletions autotest/test_plotutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
"PRP000000001", # name
],
],
columns=PRT_PATHLINE_DTYPE.fields.keys(),
columns=PRT_PATHLINE_DTYPE.names,
)
MP7_TEST_PATHLINES = pd.DataFrame.from_records(
[
Expand Down Expand Up @@ -233,7 +233,7 @@
1, # timestep
],
],
columns=MP7_PATHLINE_DTYPE.fields.keys(),
columns=MP7_PATHLINE_DTYPE.names,
)
MP7_TEST_ENDPOINTS = pd.DataFrame.from_records(
[
Expand Down Expand Up @@ -322,7 +322,7 @@
2, # cellface
],
],
columns=MP7_ENDPOINT_DTYPE.fields.keys(),
columns=MP7_ENDPOINT_DTYPE.names,
)


Expand All @@ -342,13 +342,13 @@ def test_to_mp7_pathlines(dataframe):
assert len(mp7_pls) == 10
assert set(
dict(mp7_pls.dtypes).keys() if dataframe else mp7_pls.dtype.names
) == set(MP7_PATHLINE_DTYPE.fields.keys())
) == set(MP7_PATHLINE_DTYPE.names)


@pytest.mark.parametrize("dataframe", [True, False])
def test_to_mp7_pathlines_empty(dataframe):
mp7_pls = to_mp7_pathlines(
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.fields.keys())
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.names)
if dataframe
else np.recarray((0,), dtype=PRT_PATHLINE_DTYPE)
)
Expand All @@ -374,7 +374,7 @@ def test_to_mp7_pathlines_noop(dataframe):
assert len(mp7_pls) == 2
assert set(
dict(mp7_pls.dtypes).keys() if dataframe else mp7_pls.dtype.names
) == set(MP7_PATHLINE_DTYPE.fields.keys())
) == set(MP7_PATHLINE_DTYPE.names)
assert np.array_equal(
mp7_pls if dataframe else pd.DataFrame(mp7_pls), MP7_TEST_PATHLINES
)
Expand All @@ -391,13 +391,13 @@ def test_to_mp7_endpoints(dataframe):
assert np.isclose(mp7_eps.time[0], PRT_TEST_PATHLINES.t.max())
assert set(
dict(mp7_eps.dtypes).keys() if dataframe else mp7_eps.dtype.names
) == set(MP7_ENDPOINT_DTYPE.fields.keys())
) == set(MP7_ENDPOINT_DTYPE.names)


@pytest.mark.parametrize("dataframe", [True, False])
def test_to_mp7_endpoints_empty(dataframe):
mp7_eps = to_mp7_endpoints(
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.fields.keys())
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.names)
if dataframe
else np.recarray((0,), dtype=PRT_PATHLINE_DTYPE)
)
Expand Down Expand Up @@ -445,7 +445,7 @@ def test_to_prt_pathlines_roundtrip(dataframe):
@pytest.mark.parametrize("dataframe", [True, False])
def test_to_prt_pathlines_roundtrip_empty(dataframe):
mp7_pls = to_mp7_pathlines(
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.fields.keys())
pd.DataFrame.from_records([], columns=PRT_PATHLINE_DTYPE.names)
if dataframe
else np.recarray((0,), dtype=PRT_PATHLINE_DTYPE)
)
Expand All @@ -454,4 +454,4 @@ def test_to_prt_pathlines_roundtrip_empty(dataframe):
assert prt_pls.empty if dataframe else mp7_pls.size == 0
assert set(
dict(mp7_pls.dtypes).keys() if dataframe else mp7_pls.dtype.names
) == set(MP7_PATHLINE_DTYPE.fields.keys())
) == set(MP7_PATHLINE_DTYPE.names)
31 changes: 17 additions & 14 deletions flopy/plot/plotutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -2679,7 +2679,6 @@ def parse_modpath_selection_options(
("cellface", np.int32),
]
)
MP_MIN_PLOT_FIELDS = ["x", "y", "z", "time", "k", "particleid"]


def to_mp7_pathlines(
Expand All @@ -2698,6 +2697,8 @@ def to_mp7_pathlines(
np.recarray or pd.DataFrame (consistent with input type)
"""

from flopy.utils.particletracking import MIN_PARTICLE_TRACK_DTYPE

# determine return type
ret_type = type(data)

Expand All @@ -2708,13 +2709,13 @@ def to_mp7_pathlines(
# check format
dt = data.dtypes
if not (
all(n in dt for n in MP_MIN_PLOT_FIELDS)
or all(n in dt for n in PRT_PATHLINE_DTYPE.fields.keys())
all(n in dt for n in MIN_PARTICLE_TRACK_DTYPE.names)
or all(n in dt for n in PRT_PATHLINE_DTYPE.names)
):
raise ValueError(
"Pathline data must contain the following fields: "
f"{MP_MIN_PLOT_FIELDS} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.fields.keys()} for MODFLOW 6 PRT"
f"{MIN_PARTICLE_TRACK_DTYPE.names} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.names} for MODFLOW 6 PRT"
)

# return early if already in MP7 format
Expand Down Expand Up @@ -2780,6 +2781,8 @@ def to_mp7_endpoints(
np.recarray or pd.DataFrame (consistent with input type)
"""

from flopy.utils.particletracking import MIN_PARTICLE_TRACK_DTYPE

# determine return type
ret_type = type(data)

Expand All @@ -2789,18 +2792,18 @@ def to_mp7_endpoints(

# check format
dt = data.dtypes
if all(n in dt for n in MP7_ENDPOINT_DTYPE.fields.keys()):
if all(n in dt for n in MP7_ENDPOINT_DTYPE.names):
return (
data if ret_type == pd.DataFrame else data.to_records(index=False)
)
if not (
all(n in dt for n in MP_MIN_PLOT_FIELDS)
or all(n in dt for n in PRT_PATHLINE_DTYPE.fields.keys())
all(n in dt for n in MIN_PARTICLE_TRACK_DTYPE.names)
or all(n in dt for n in PRT_PATHLINE_DTYPE.names)
):
raise ValueError(
"Pathline data must contain the following fields: "
f"{MP_MIN_PLOT_FIELDS} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.fields.keys()} for MODFLOW 6 PRT"
f"{MIN_PARTICLE_TRACK_DTYPE.names} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.names} for MODFLOW 6 PRT"
)

# return early if empty
Expand Down Expand Up @@ -2909,13 +2912,13 @@ def to_prt_pathlines(
# check format
dt = data.dtypes
if not (
all(n in dt for n in MP7_PATHLINE_DTYPE.fields.keys())
or all(n in dt for n in PRT_PATHLINE_DTYPE.fields.keys())
all(n in dt for n in MP7_PATHLINE_DTYPE.names)
or all(n in dt for n in PRT_PATHLINE_DTYPE.names)
):
raise ValueError(
"Pathline data must contain the following fields: "
f"{MP7_PATHLINE_DTYPE.fields.keys()} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.fields.keys()} for MODFLOW 6 PRT"
f"{MP7_PATHLINE_DTYPE.names} for MODPATH 7, or "
f"{PRT_PATHLINE_DTYPE.names} for MODFLOW 6 PRT"
)

# return early if already in PRT format
Expand Down
23 changes: 13 additions & 10 deletions flopy/utils/particletracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,18 @@
from pathlib import Path


MIN_PARTICLE_TRACK_DTYPE = np.dtype(
[
("x", np.float32),
("y", np.float32),
("z", np.float32),
("time", np.float32),
("k", np.int32),
("particleid", np.int32),
]
)


class ParticleTrackFile(ABC):
"""
Abstract base class for particle track output files. Exposes a unified API
Expand All @@ -30,16 +42,7 @@ class ParticleTrackFile(ABC):
"""

outdtype = np.dtype(
[
("x", np.float32),
("y", np.float32),
("z", np.float32),
("time", np.float32),
("k", np.int32),
("particleid", np.int32),
]
)
outdtype = MIN_PARTICLE_TRACK_DTYPE
"""
Minimal information shared by all particle track file formats.
Track data are converted to this dtype for internal storage
Expand Down

0 comments on commit 9149045

Please sign in to comment.