Skip to content

Commit

Permalink
Merge pull request #310 from jmccreight/feat_numpy_2.0
Browse files Browse the repository at this point in the history
Feat numpy 2.0
  • Loading branch information
jmccreight authored Oct 26, 2024
2 parents e566244 + abd4b79 commit 2920c8f
Show file tree
Hide file tree
Showing 13 changed files with 50 additions and 29 deletions.
10 changes: 7 additions & 3 deletions autotest/test_mmr_to_mf6_dfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
}


@pytest.mark.xfail
@pytest.mark.skipif(mf6_bin_unavailable, reason="mf6 binary not available")
@pytest.mark.domainless
@pytest.mark.parametrize("binary_flw", [True, False])
Expand Down Expand Up @@ -191,9 +190,14 @@ def test_mmr_to_mf6_chf_dfw(tmp_path, binary_flw):
else:
flw_vol = 0.0

# <
time_coord_data = np.array([control.start_time]).astype(
"datetime64[ns]"
)

_ = xr.Dataset(
coords=dict(
time=np.array([control.start_time]),
time=time_coord_data,
nsegment=params.parameters["seg_id"],
),
data_vars={
Expand Down Expand Up @@ -293,7 +297,6 @@ def test_mmr_to_mf6_chf_dfw(tmp_path, binary_flw):
}


@pytest.mark.xfail
@pytest.mark.skipif(mf6_bin_unavailable, reason="mf6 binary not available")
@pytest.mark.domain
def test_mmr_to_mf6_dfw_regression(simulation, tmp_path):
Expand Down Expand Up @@ -410,6 +413,7 @@ def test_mmr_to_mf6_dfw_regression(simulation, tmp_path):
inflow_dir=inflow_dir,
)

dfw.write()
success, buff = dfw.run(silent=False, report=True)
assert success

Expand Down
2 changes: 1 addition & 1 deletion autotest/test_starfit_flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def add_disconnected_node_data(ds: xr.Dataset) -> xr.Dataset:

# the first in the list is for the disconnected node
check_names = ["prms_channel"] + new_nodes_maker_names
check_indices = [dis_ds.dims["nsegment"] - 1] + new_nodes_maker_indices
check_indices = [dis_ds.sizes["nsegment"] - 1] + new_nodes_maker_indices
check_ids = [dis_ds.nhm_seg[-1].values.tolist()] + new_nodes_maker_ids

# This warning should say: TODO
Expand Down
5 changes: 3 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies:
- nbconvert
- netCDF4
- networkx
- numpy<2.0.0
- numpy>=2.0.0
- numba
- pandas>=1.4.0
- pint
Expand All @@ -43,10 +43,11 @@ dependencies:
- sphinx-autosummary-accessors
- sphinx-copybutton
- tqdm
- xarray>=2023.05.0
- xarray>=2024.06.0
- pip:
- asv
- click != 8.1.0
- filelock
- git+https://github.com/modflowpy/flopy.git
- jupyter_black
- modflow-devtools
Expand Down
4 changes: 2 additions & 2 deletions environment_w_jupyter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies:
- nbconvert
- netCDF4
- networkx
- numpy<2.0.0
- numpy>=2.0.0
- numba
- pandas>=1.4.0
- pint
Expand All @@ -45,7 +45,7 @@ dependencies:
- sphinx-autosummary-accessors
- sphinx-copybutton
- tqdm
- xarray>=2023.05.0
- xarray>=2024.06.0
- pip:
- asv
- click != 8.1.0
Expand Down
12 changes: 8 additions & 4 deletions examples/07_mmr_to_mf6_chf_dfw.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@
" inflow_dir=inflow_dir,\n",
" )\n",
"\n",
" dfw.write()\n",
" success, buff = dfw.run(silent=False, report=True)\n",
" assert success"
]
Expand Down Expand Up @@ -354,7 +355,9 @@
"n_substeps = int(ndays_run * 24 * 60 * 60 / tdis_perlen * tdis_nstp)\n",
"substep_len = np.timedelta64(int(tdis_perlen / tdis_nstp), \"s\")\n",
"sim_end_time = sim_start_time + n_substeps * substep_len\n",
"sim_times = np.arange(sim_start_time, sim_end_time, substep_len)\n",
"sim_times = np.arange(sim_start_time, sim_end_time, substep_len).astype(\n",
" \"datetime64[ns]\"\n",
") # ns to avoid xarray warnings\n",
"perioddata = tdis.perioddata.get_data()\n",
"assert len(sim_times) == len(perioddata) * perioddata[0][1]"
]
Expand Down Expand Up @@ -567,10 +570,11 @@
"outputs": [],
"source": [
"# Subset to points of interest (poi), known flow gages\n",
"poi_id = np.chararray(\n",
" prms_mf6_ds[\"prms streamflow\"].nhm_seg.shape, unicode=True, itemsize=15\n",
")\n",
"empty_str = \" \" * 15\n",
"poi_id = np.full(\n",
" prms_mf6_ds[\"prms streamflow\"].nhm_seg.shape, empty_str, dtype=\"<U15\"\n",
")\n",
"\n",
"poi_id[:] = empty_str\n",
"for ii, jj in enumerate(params.parameters[\"poi_gage_segment\"].tolist()):\n",
" poi_id[jj] = params.parameters[\"poi_gage_id\"][ii]\n",
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[build-system]
requires = [
"setuptools >=61, <=72.2.0",
"numpy >=1.15.0,<2.0.0",
"setuptools>64",
"numpy>=2.0",
]
build-backend = "setuptools.build_meta"

Expand Down Expand Up @@ -29,7 +29,7 @@ classifiers = [
requires-python = ">=3.10,<3.12"
dependencies = [
"contextily",
"numpy >=1.15.0,<2.0.0",
"numpy>=2.0.0",
"matplotlib >=1.4.0",
"epiweeks",
"flopy",
Expand Down
2 changes: 1 addition & 1 deletion pywatershed/atmosphere/prms_atmosphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def get_init_values() -> dict:
"tmaxc": nan,
"tavgc": nan,
"tminc": nan,
"pptmix": nan,
"pptmix": -9999,
"orad_hru": nan,
}

Expand Down
5 changes: 5 additions & 0 deletions pywatershed/base/data_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,13 +763,18 @@ def xr_ds_to_dd(file_or_ds, schema_only=False, encoding=True) -> dict:

dd = xr_ds.to_dict(data=data_arg, encoding=encoding)

# before = xr_ds.time.values.dtype
# after = dd["coords"]["time"]["data"].dtype
# assert before == after

dd = xr_dd_to_dd(dd)

return dd


def xr_dd_to_dd(xr_dd: dict) -> dict:
dd = deepcopy(xr_dd)
# asdf

# Move the global encoding to a global key of itself
dd["encoding"] = {"global": dd.get("encoding", {})}
Expand Down
2 changes: 1 addition & 1 deletion pywatershed/hydrology/prms_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def _calculate_numpy(
# Keep the f90 call signature consistent with the args in
# python/numba.

intcp_form = np.full_like(hru_rain, np.nan, dtype="int32")
intcp_form = np.full_like(hru_rain, -9999, dtype="int32")
for i in prange(nhru):
netrain = hru_rain[i]
netsnow = hru_snow[i]
Expand Down
17 changes: 10 additions & 7 deletions pywatershed/parameters/prms_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ def default(self, obj):


def _json_load(json_filename):
pars = json.load(open(json_filename))
with open(json_filename) as ff:
pars = json.load(ff)

# need to convert lists to numpy arrays
for k, v in pars.items():
if isinstance(v, list):
Expand Down Expand Up @@ -101,12 +103,13 @@ def __init__(

def parameters_to_json(self, json_filename) -> None:
"""write the parameters dictionary out to a json file"""
json.dump(
{**self.dims, **self.parameters},
open(json_filename, "w"),
indent=4,
cls=JSONParameterEncoder,
)
with open(json_filename, "w") as ff:
json.dump(
{**self.dims, **self.parameters},
ff,
indent=4,
cls=JSONParameterEncoder,
)
return None

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion pywatershed/utils/csv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _get_data(self) -> None:
"""

def str2date(x):
return dt.datetime.strptime(x.decode("utf-8"), "%Y-%m-%d")
return dt.datetime.strptime(x, "%Y-%m-%d")

all_data = []
ntimes = 0
Expand Down
2 changes: 1 addition & 1 deletion pywatershed/utils/mmr_to_mf6_dfw.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
# time_units="seconds",
save_flows: bool = True,
time_zone: str = "UTC",
write_on_init: bool = True,
write_on_init: bool = False,
chd_options: dict = None,
cxs_options: dict = None,
disv1d_options: dict = None,
Expand Down
10 changes: 7 additions & 3 deletions pywatershed/utils/prms5_file_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def _get_file_object(
raise TypeError("file_path must be a file path")
return

def _close_file_object(self) -> None:
self.file_object.close()

def _get_control_variables(
self,
) -> dict:
Expand Down Expand Up @@ -156,7 +159,7 @@ def _get_control_variables(
elif key in ("initial_deltat",):
value = np.timedelta64(int(value[0]), "h")
variable_dict[key] = value
self.file_object.close()
self._close_file_object()
return variable_dict

def _get_dimensions_parameters(self):
Expand Down Expand Up @@ -220,6 +223,7 @@ def _get_dimensions_parameters(self):
for kk, vv in parameter_dimensions_full_dict.items()
}

self._close_file_object()
return parameters_full_dict, parameter_dimensions_full_dict

def _get_parameters(self):
Expand Down Expand Up @@ -307,7 +311,7 @@ def _parse_variable(
for idx in range(num_values):
arr[idx] = float(self._get_line().split()[0])
elif data_type == PrmsDataType.CHARACTER.value:
arr = np.zeros(num_values, dtype=np.chararray)
arr = np.zeros(num_values, dtype="O")
for idx in range(num_values):
arr[idx] = self._get_line().split()[0]
else:
Expand Down Expand Up @@ -377,7 +381,7 @@ def _parse_parameter(
for idx in range(len_array):
arr[idx] = float(self._get_line().split()[0])
elif data_type == PrmsDataType.CHARACTER.value:
arr = np.zeros(len_array, dtype=np.chararray)
arr = np.zeros(len_array, dtype="O")
for idx in range(len_array):
arr[idx] = self._get_line().split()[0]
else:
Expand Down

0 comments on commit 2920c8f

Please sign in to comment.