From 61c9672559832d5bb2edf7471f8d1c1cf959f9c6 Mon Sep 17 00:00:00 2001 From: hellkite500 Date: Wed, 24 Jul 2024 10:42:05 -0600 Subject: [PATCH] refactor: change order of args in check_point --- python/ngen_cal/src/ngen/cal/calibratable.py | 2 +- python/ngen_cal/src/ngen/cal/calibration_set.py | 2 +- python/ngen_cal/src/ngen/cal/search.py | 12 ++++++------ python/ngen_cal/tests/test_model.py | 2 +- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python/ngen_cal/src/ngen/cal/calibratable.py b/python/ngen_cal/src/ngen/cal/calibratable.py index ea72c26b..ce3f2f2b 100644 --- a/python/ngen_cal/src/ngen/cal/calibratable.py +++ b/python/ngen_cal/src/ngen/cal/calibratable.py @@ -83,7 +83,7 @@ def check_point_file(self) -> 'Path': """ return Path('{}_parameter_df_state.parquet'.format(self.id)) - def check_point(self, info: JobMeta, iteration: int) -> None: + def check_point(self, iteration: int, info: JobMeta) -> None: """ Save calibration information """ diff --git a/python/ngen_cal/src/ngen/cal/calibration_set.py b/python/ngen_cal/src/ngen/cal/calibration_set.py index 4e5a30ab..9e7883a0 100644 --- a/python/ngen_cal/src/ngen/cal/calibration_set.py +++ b/python/ngen_cal/src/ngen/cal/calibration_set.py @@ -94,7 +94,7 @@ def observed(self) -> 'DataFrame': def observed(self, df): self._observed = df - def check_point(self, info: JobMeta, iteration: int) -> None: + def check_point(self, iteration: int, info: JobMeta) -> None: """ Save calibration information """ diff --git a/python/ngen_cal/src/ngen/cal/search.py b/python/ngen_cal/src/ngen/cal/search.py index c8a81952..63fbce2b 100644 --- a/python/ngen_cal/src/ngen/cal/search.py +++ b/python/ngen_cal/src/ngen/cal/search.py @@ -125,7 +125,7 @@ def dds(start_iteration: int, iterations: int, calibration_object: 'Evaluatable _execute(agent) with pushd(agent.job.workdir): _evaluate(0, calibration_object, info=True) - calibration_object.check_point(agent.job, 0) + calibration_object.check_point(0, agent.job) start_iteration += 1 for i in range(start_iteration, iterations+1): @@ -137,7 +137,7 @@ def dds(start_iteration: int, iterations: int, calibration_object: 'Evaluatable _execute(agent) with pushd(agent.job.workdir): _evaluate(i, calibration_object, info=True) - calibration_object.check_point(agent.job, i) + calibration_object.check_point(i, agent.job) def dds_set(start_iteration: int, iterations: int, agent: 'Agent'): """ @@ -175,7 +175,7 @@ def dds_set(start_iteration: int, iterations: int, agent: 'Agent'): _execute(agent) with pushd(agent.job.workdir): _evaluate(0, calibration_set, info=True) - calibration_set.check_point(agent.job, 0) + calibration_set.check_point(0, agent.job) start_iteration += 1 for i in range(start_iteration, iterations+1): @@ -188,7 +188,7 @@ def dds_set(start_iteration: int, iterations: int, agent: 'Agent'): _execute(agent) with pushd(agent.job.workdir): _evaluate(i, calibration_set, info=True) - calibration_set.check_point(agent.job, i) + calibration_set.check_point(i, agent.job) def compute(calibration_object, iteration, input) -> float: params = input[0] @@ -202,7 +202,7 @@ def compute(calibration_object, iteration, input) -> float: agent.update_config(iteration, calibration_object.df[[str(iteration), 'param', 'model']], calibration_object.id) _execute(agent) cost = _evaluate(iteration, calibration_object) - calibration_object.check_point(agent.job, iteration) + calibration_object.check_point(iteration, agent.job) #cost = _objective_func(calibration_object.output, calibration_object.observed, calibration_object.objective, calibration_object.evaluation_range) return cost @@ -272,7 +272,7 @@ def pso_search(start_iteration: int, iterations: int, agent): #For pyswarm, DO NOT use the embedded multi-processing -- it is impossible to track the mapping of an agent to the params cost, pos = optimizer.optimize(cf, iters=iterations, n_processes=None) calibration_object.df.loc[:,'global_best'] = pos - calibration_object.check_point(agent.job, iterations) + calibration_object.check_point(iterations, agent.job) print("Best params with cost {}:".format(cost)) print(calibration_object.df[['param','global_best']].set_index('param')) \ No newline at end of file diff --git a/python/ngen_cal/tests/test_model.py b/python/ngen_cal/tests/test_model.py index 5b6243ee..54d9ba16 100644 --- a/python/ngen_cal/tests/test_model.py +++ b/python/ngen_cal/tests/test_model.py @@ -61,7 +61,7 @@ def test_restart_1(ngen_config: 'Ngen', eval: 'EvaluationOptions', workdir: 'Dir eval.write_param_log_file(2) info = JobMeta(ngen_config.type, workdir, workdir = workdir) #make sure the catchment param df is saved before trying to restart - ngen_config.adjustables[0].check_point(info, 1) + ngen_config.adjustables[0].check_point(1, info) iteration = ngen_config.restart() assert iteration == 3