From 8b360d7e5ba5f526e6646450e4a193dca6cc2d70 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Sep 2024 12:00:38 -0500 Subject: [PATCH 1/4] feat(ngen.cal): optionally run validation simulation w/ best params NOTE: this needs to be refactored... and tested. This was pushed through just to get thing working will not affect the code around it if this feature is not used. This feature enables running ngen with the best calibration parameters after calibration has completed. A validation interval is configurable under `model.val_params`. See ngen.cal.model.ValidationOptions for specifics. --- python/ngen_cal/src/ngen/cal/__main__.py | 94 ++++++++++++++++++++++++ python/ngen_cal/src/ngen/cal/model.py | 47 ++++++++++++ 2 files changed, 141 insertions(+) diff --git a/python/ngen_cal/src/ngen/cal/__main__.py b/python/ngen_cal/src/ngen/cal/__main__.py index 14dd04df..7cfe5833 100644 --- a/python/ngen_cal/src/ngen/cal/__main__.py +++ b/python/ngen_cal/src/ngen/cal/__main__.py @@ -15,6 +15,7 @@ from types import ModuleType if TYPE_CHECKING: + from ngen.config.realization import NgenRealization from typing import Mapping, Any from pluggy import PluginManager @@ -35,6 +36,25 @@ def _loaded_plugins(pm: PluginManager) -> str: return f"Plugins Loaded: {', '.join(plugins)}" +def _update_troute_config( + realization: NgenRealization, + troute_config: dict[str, Any], +): + start = realization.time.start_time + end = realization.time.end_time + duration = (end - start).total_seconds() + + troute_config["compute_parameters"]["restart_parameters"]["start_datetime"] = ( + start.strftime("%Y-%m-%d %H:%M:%S") + ) + + forcing_parameters = troute_config["compute_parameters"]["forcing_parameters"] + dt = forcing_parameters["dt"] + nts, r = divmod(duration, dt) + assert r == 0, "routing timestep is not evenly divisible by ngen_timesteps" + forcing_parameters["nts"] = nts + + def main(general: General, model_conf: Mapping[str, Any]): #seed the random number generators if requested if general.random_seed is not None: @@ -108,6 +128,80 @@ def main(general: General, model_conf: Mapping[str, Any]): #for catchment_set in agent.model.adjustables: # func(start_iteration, general.iterations, catchment_set, agent) func(start_iteration, general.iterations, agent) + + if (validation_parms := model.model.unwrap().val_params) is not None: + print("configuring calibration") + # NOTE: importing here so its easier to refactor in the future + from ngen.cal.calibration_set import CalibrationSet + import pandas as pd + from typing import TYPE_CHECKING + if TYPE_CHECKING: + from typing import Sequence + import pandas as pd + from ngen.cal.calibration_cathment import CalibrationCatchment + + adjustables: Sequence[CalibrationCatchment] = agent.model.adjustables + + realization: NgenRealization = agent.model.unwrap().ngen_realization + assert realization is not None + + sim_start, sim_end = validation_parms.sim_interval() + eval_start, eval_end = validation_parms.evaluation_interval() + print(f"validation {sim_start=} {sim_end=}") + + # NOTE: do this before `update_config` is called so the right path is written to disk + realization.time.start_time = sim_start + realization.time.end_time = sim_end + + assert realization.routing is not None + + troute_config_path = realization.routing.config + + with troute_config_path.open() as fp: + troute_config = yaml.safe_load(fp) + + _update_troute_config(realization, troute_config) + + troute_config_path_validation = troute_config_path.with_name("troute_validation.yaml") + with troute_config_path_validation.open("w") as fp: + yaml.dump(troute_config, fp) + + # NOTE: do this before `update_config` is called so the right path is written to disk + realization.routing.config = troute_config_path_validation + + for calibration_object in adjustables: + best_df: pd.DataFrame = calibration_object.df[[str(agent.best_params), 'param', 'model']] + + agent.update_config(agent.best_params, best_df, calibration_object.id) + + # NOTE: importing here so its easier to refactor in the future + from ngen.cal.search import _execute, _objective_func + from ngen.cal.utils import pushd + + print("starting calibration") + # TODO: validation_parms.objective and target are not being correctly configured + _execute(agent) + with pushd(agent.job.workdir): + sim = calibration_object.output + + assert isinstance(calibration_object, CalibrationSet) + # TODO: get from realization config + simulation_interval = pd.Timedelta(3600, unit="s") + # TODO: need a way to get the nexus + nexus = calibration_object._eval_nexus + agent_pm = agent.model.unwrap()._plugin_manager + obs = agent_pm.hook.ngen_cal_model_observations( + nexus=nexus, + # NOTE: techinically start_time=`eval_start` + `simulation_interval` + start_time=eval_start, + end_time=eval_end, + simulation_interval=simulation_interval, + ) + print(f"{sim=}") + print(f"{obs=}") + score = _objective_func(sim, obs, validation_parms.objective, (sim_start, sim_end)) + print(f"validation run score: {score}") + # call `ngen_cal_finish` plugin hook functions except Exception as e: plugin_manager.hook.ngen_cal_finish(exception=e) diff --git a/python/ngen_cal/src/ngen/cal/model.py b/python/ngen_cal/src/ngen/cal/model.py index a88796d6..2d5da773 100644 --- a/python/ngen_cal/src/ngen/cal/model.py +++ b/python/ngen_cal/src/ngen/cal/model.py @@ -212,6 +212,51 @@ def restart(self) -> int: return start_iteration + +class ValidationOptions(BaseModel): + """A data class holding validation options""" + #Optional, but co-dependent, see @_validate_start_stop_both_or_neither_exist for validation logic + evaluation_start: datetime + evaluation_stop: datetime + sim_start: Optional[datetime] = None + sim_stop: Optional[datetime] = None + objective: Optional[Union[Objective, PyObject]] = Objective.custom + target: Union[Literal['min'], Literal['max'], float] = 'min' + + def sim_interval(self) -> tuple[datetime, datetime]: + """Returns a tuple of simulation start and stop datetimes""" + start = self.sim_start if self.sim_start is not None else self.evaluation_start + stop = self.sim_stop if self.sim_stop is not None else self.evaluation_stop + return start, stop + + def evaluation_interval(self) -> tuple[datetime, datetime]: + """Returns a tuple of evaluation start and stop datetimes""" + return self.evaluation_start, self.evaluation_stop + + @root_validator(skip_on_failure=True) + @classmethod + def _validate_periods(cls, values: dict[str, datetime | None]) -> dict[str, datetime | None]: + evaluation_start: datetime = values["evaluation_start"] # type: ignore + evaluation_stop: datetime = values["evaluation_stop"] # type: ignore + sim_start: datetime | None = values.get("sim_start") + sim_stop: datetime | None = values.get("sim_stop") + + errs: list[str] = [] + if sim_start is not None and sim_start > evaluation_start: + errs.append("`sim_start` must be <= `evaluation_start`") + + if sim_stop is not None and sim_stop < evaluation_stop: + errs.append("`evaluation_stop` must be <= `sim_stop`") + + if evaluation_stop < evaluation_start: + errs.append("`evaluation_start` must be <= `evaluation_stop`") + + if errs: + raise ValueError("\n".join(errs)) + + return values + + class ModelExec(BaseModel, Configurable): """ The data class for a given model, which must also be Configurable @@ -220,6 +265,8 @@ class ModelExec(BaseModel, Configurable): args: Optional[str] workdir: DirectoryPath = Path("./") #FIXME test the various workdirs eval_params: Optional[EvaluationOptions] = Field(default_factory=EvaluationOptions) + # TODO: likely want to move this into `NgenBase` instead of here + val_params: Optional[ValidationOptions] = None plugins: List[PyObjectOrModule] = Field(default_factory=list) plugin_settings: Dict[str, Dict[str, Any]] = Field(default_factory=dict) From 10277ecf0e356cb9cdfd8288b010b4b1aad8ffbf Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Sep 2024 12:07:15 -0500 Subject: [PATCH 2/4] todo: model files that may be desirable to remove. --- python/ngen_cal/src/ngen/cal/ngen.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/python/ngen_cal/src/ngen/cal/ngen.py b/python/ngen_cal/src/ngen/cal/ngen.py index 9d2cd498..86bea941 100644 --- a/python/ngen_cal/src/ngen/cal/ngen.py +++ b/python/ngen_cal/src/ngen/cal/ngen.py @@ -388,7 +388,16 @@ def update_config(self, i: int, params: pd.DataFrame, id: str = None, path=Path( # Cleanup any t-route parquet files between runs # TODO this may not be _the_ best place to do this, but for now, # it works, so here it be... - for file in Path(path).glob("*NEXOUT.parquet"): + import itertools + to_remove = ( + # Path(path).glob("troute_output_*.*"), + # Path(path).glob("flowveldepth_*.*"), + Path(path).glob("*NEXOUT.parquet"), + # ngen files + # Path(path).glob("cat-*.csv"), + # Path(path).glob("nex-*_output.csv"), + ) + for file in itertools.chain(*to_remove): file.unlink() class NgenExplicit(NgenBase): From 06fa1b88bd4c940026349b12edc700d17f8a7341 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Sep 2024 12:07:57 -0500 Subject: [PATCH 3/4] refactor: TrouteOutput to handle validation --- .../src/ngen/cal/ngen_hooks/ngen_output.py | 85 +++++++++++++++---- 1 file changed, 68 insertions(+), 17 deletions(-) diff --git a/python/ngen_cal/src/ngen/cal/ngen_hooks/ngen_output.py b/python/ngen_cal/src/ngen/cal/ngen_hooks/ngen_output.py index 40344d64..d6625f9a 100644 --- a/python/ngen_cal/src/ngen/cal/ngen_hooks/ngen_output.py +++ b/python/ngen_cal/src/ngen/cal/ngen_hooks/ngen_output.py @@ -7,22 +7,30 @@ import pandas as pd from ngen.cal import hookimpl +from pydantic import BaseModel if TYPE_CHECKING: from ngen.cal.meta import JobMeta - from ngen.cal.model import ModelExec + from ngen.cal.model import ModelExec, ValidationOptions, EvaluationOptions from ngen.config.realization import NgenRealization class _NgenCalModelOutputFn(typing.Protocol): def __call__(self, id: str) -> pd.Series: ... +class TrouteOutputSettings(BaseModel): + validation_routing_output: Path + @typing.final class TrouteOutput: def __init__(self, filepath: Path) -> None: self._output_file = filepath + self._settings: TrouteOutputSettings | None = None + self._ngen_realization: NgenRealization | None = None + self._validation_options: ValidationOptions | None = None + self._eval_options: EvaluationOptions | None = None @hookimpl def ngen_cal_model_configure(self, config: ModelExec) -> None: @@ -33,36 +41,79 @@ def ngen_cal_model_configure(self, config: ModelExec) -> None: assert config.ngen_realization is not None self._ngen_realization = config.ngen_realization - # Try external provided output hooks, if those fail, try this one - # this will only execute if all other hooks return None (or they don't exist) - @hookimpl(specname="ngen_cal_model_output", trylast=True) - def get_output(self, id: str) -> pd.Series | None: + if (eval_options := config.eval_params) is not None: + self._eval_options = eval_options + + if (validation_config := config.val_params) is not None: + self._validation_options = validation_config + + # maybe pull in plugin settings + if (plugin_settings := config.plugin_settings.get("ngen_cal_troute_output")) is not None: + self._settings = TrouteOutputSettings.parse_obj(plugin_settings) + + def _sim_eval_interval(self) -> tuple[datetime.datetime, datetime.datetime]: assert ( self._ngen_realization is not None ), "ngen realization required; ensure `ngen_cal_model_configure` was called and the plugin was properly configured" - if not self._output_file.exists(): - print( - f"{self._output_file} not found. Current working directory is {Path.cwd()!s}" - ) - print("Setting output to None") - return None + if self._eval_options is not None and self._eval_options.evaluation_start is not None: + assert self._eval_options.evaluation_stop is not None + return self._eval_options.evaluation_start, self._eval_options.evaluation_stop + + return self._ngen_realization.time.start_time, self._ngen_realization.time.end_time - filetype = self._output_file.suffix.lower() + def _validation_eval_interval(self) -> tuple[datetime.datetime, datetime.datetime]: + if self._validation_options is None: + print("validation options not provided, using sim evaluation interval") + return self._sim_eval_interval() + return self._validation_options.evaluation_interval() + + def _output_handler_factory(self, output_file: Path) -> _NgenCalModelOutputFn: + filetype = output_file.suffix.lower() if filetype == ".csv": - fn = self._factory_handler_csv(self._output_file) + fn = self._factory_handler_csv(output_file) # TODO: fix. dont know if this format still works # elif filetype == ".hdf5": # fn = _model_output_legacy_hdf5(self._output_file) elif filetype == ".nc": - fn = _stream_output_netcdf_v1(self._output_file) + fn = _stream_output_netcdf_v1(output_file) elif filetype == ".parquet": - fn = _stream_output_parquet_v1(self._output_file) + fn = _stream_output_parquet_v1(output_file) else: raise RuntimeError( - f"unsupported t-route output filetype: {self._output_file.suffix}" + f"unsupported t-route output filetype: {output_file.suffix}" ) + return fn + + # Try external provided output hooks, if those fail, try this one + # this will only execute if all other hooks return None (or they don't exist) + @hookimpl(specname="ngen_cal_model_output", trylast=True) + def get_output(self, id: str) -> pd.Series | None: + assert ( + self._ngen_realization is not None + ), "ngen realization required; ensure `ngen_cal_model_configure` was called and the plugin was properly configured" + + if self._settings is not None and self._settings.validation_routing_output.exists(): + output_file = self._settings.validation_routing_output + print(f"retrieving simulation data from validation output file: {output_file!s}") + + start, end = self._validation_eval_interval() + print(f"validation: {start=} {end=}") + elif self._output_file.exists(): + output_file = self._output_file + print(f"retrieving simulation data from output file: {output_file!s}") + + start, end = self._sim_eval_interval() + print(f"{start=} {end=}") + else: + print( + f"{self._output_file} not found. Current working directory is {Path.cwd()!s}" + ) + print("Setting output to None") + return None + # TODO: I dont think all output handlers can handle validation (csv comes to mind). circle back to this + fn = self._output_handler_factory(output_file) ds = fn(id) ds.name = "sim_flow" @@ -74,7 +125,7 @@ def get_output(self, id: str) -> pd.Series | None: seconds=self._ngen_realization.time.output_interval ) start = self._ngen_realization.time.start_time - ds = ds.loc[start + ngen_dt :] + ds = ds.loc[start + ngen_dt :end] ds = ds.resample("1h").first() return ds From 804bb5fe8820a31ee6feaaf42a877b4b46f4acc5 Mon Sep 17 00:00:00 2001 From: Austin Raney Date: Fri, 20 Sep 2024 12:08:40 -0500 Subject: [PATCH 4/4] chore: improve sim/obs df error messages --- python/ngen_cal/src/ngen/cal/search.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/ngen_cal/src/ngen/cal/search.py b/python/ngen_cal/src/ngen/cal/search.py index 7e29d920..5ca2b8b8 100644 --- a/python/ngen_cal/src/ngen/cal/search.py +++ b/python/ngen_cal/src/ngen/cal/search.py @@ -24,11 +24,14 @@ def _objective_func(simulated_hydrograph, observed_hydrograph, objective, eval_range: tuple[datetime, datetime] | None = None): df = pd.merge(simulated_hydrograph, observed_hydrograph, left_index=True, right_index=True) - if df.empty: - print("WARNING: Cannot compute objective function, do time indicies align?") if eval_range: df = df.loc[eval_range[0]:eval_range[1]] - #print( df ) + if df.empty: + print("WARNING: Cannot compute objective function, do time indicies align?") + if eval_range: + print(f"\teval range: [{eval_range[0]!s} : {eval_range[1]!s}]") + print(f"\tsim interval: [{simulated_hydrograph.index.min()!s} : {simulated_hydrograph.index.max()!s}]") + print(f"\tobs interval: [{observed_hydrograph.index.min()!s} : {observed_hydrograph.index.max()!s}]") #Evaluate custom objective function providing simulated, observed series return objective(df['obs_flow'], df['sim_flow'])