Skip to content

Commit

Permalink
refactor: change order of args in check_point
Browse files Browse the repository at this point in the history
  • Loading branch information
hellkite500 committed Jul 24, 2024
1 parent 5a25197 commit 61c9672
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/ngen_cal/src/ngen/cal/calibratable.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def check_point_file(self) -> 'Path':
"""
return Path('{}_parameter_df_state.parquet'.format(self.id))

def check_point(self, info: JobMeta, iteration: int) -> None:
def check_point(self, iteration: int, info: JobMeta) -> None:
"""
Save calibration information
"""
Expand Down
2 changes: 1 addition & 1 deletion python/ngen_cal/src/ngen/cal/calibration_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def observed(self) -> 'DataFrame':
def observed(self, df):
self._observed = df

def check_point(self, info: JobMeta, iteration: int) -> None:
def check_point(self, iteration: int, info: JobMeta) -> None:
"""
Save calibration information
"""
Expand Down
12 changes: 6 additions & 6 deletions python/ngen_cal/src/ngen/cal/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def dds(start_iteration: int, iterations: int, calibration_object: 'Evaluatable
_execute(agent)
with pushd(agent.job.workdir):
_evaluate(0, calibration_object, info=True)
calibration_object.check_point(agent.job, 0)
calibration_object.check_point(0, agent.job)
start_iteration += 1

for i in range(start_iteration, iterations+1):
Expand All @@ -137,7 +137,7 @@ def dds(start_iteration: int, iterations: int, calibration_object: 'Evaluatable
_execute(agent)
with pushd(agent.job.workdir):
_evaluate(i, calibration_object, info=True)
calibration_object.check_point(agent.job, i)
calibration_object.check_point(i, agent.job)

def dds_set(start_iteration: int, iterations: int, agent: 'Agent'):
"""
Expand Down Expand Up @@ -175,7 +175,7 @@ def dds_set(start_iteration: int, iterations: int, agent: 'Agent'):
_execute(agent)
with pushd(agent.job.workdir):
_evaluate(0, calibration_set, info=True)
calibration_set.check_point(agent.job, 0)
calibration_set.check_point(0, agent.job)
start_iteration += 1

for i in range(start_iteration, iterations+1):
Expand All @@ -188,7 +188,7 @@ def dds_set(start_iteration: int, iterations: int, agent: 'Agent'):
_execute(agent)
with pushd(agent.job.workdir):
_evaluate(i, calibration_set, info=True)
calibration_set.check_point(agent.job, i)
calibration_set.check_point(i, agent.job)

def compute(calibration_object, iteration, input) -> float:
params = input[0]
Expand All @@ -202,7 +202,7 @@ def compute(calibration_object, iteration, input) -> float:
agent.update_config(iteration, calibration_object.df[[str(iteration), 'param', 'model']], calibration_object.id)
_execute(agent)
cost = _evaluate(iteration, calibration_object)
calibration_object.check_point(agent.job, iteration)
calibration_object.check_point(iteration, agent.job)
#cost = _objective_func(calibration_object.output, calibration_object.observed, calibration_object.objective, calibration_object.evaluation_range)
return cost

Expand Down Expand Up @@ -272,7 +272,7 @@ def pso_search(start_iteration: int, iterations: int, agent):
#For pyswarm, DO NOT use the embedded multi-processing -- it is impossible to track the mapping of an agent to the params
cost, pos = optimizer.optimize(cf, iters=iterations, n_processes=None)
calibration_object.df.loc[:,'global_best'] = pos
calibration_object.check_point(agent.job, iterations)
calibration_object.check_point(iterations, agent.job)
print("Best params with cost {}:".format(cost))
print(calibration_object.df[['param','global_best']].set_index('param'))

2 changes: 1 addition & 1 deletion python/ngen_cal/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def test_restart_1(ngen_config: 'Ngen', eval: 'EvaluationOptions', workdir: 'Dir
eval.write_param_log_file(2)
info = JobMeta(ngen_config.type, workdir, workdir = workdir)
#make sure the catchment param df is saved before trying to restart
ngen_config.adjustables[0].check_point(info, 1)
ngen_config.adjustables[0].check_point(1, info)

iteration = ngen_config.restart()
assert iteration == 3
Expand Down

0 comments on commit 61c9672

Please sign in to comment.