diff --git a/python/ngen_cal/src/ngen/cal/ngen.py b/python/ngen_cal/src/ngen/cal/ngen.py index 9d2cd498..ebbe25e9 100644 --- a/python/ngen_cal/src/ngen/cal/ngen.py +++ b/python/ngen_cal/src/ngen/cal/ngen.py @@ -46,13 +46,13 @@ def _params_as_df(params: Mapping[str, Parameters], name: str = None): df['model'] = k df.rename(columns={'name':'param'}, inplace=True) dfs.append(df) - return pd.concat(dfs) + return pd.concat(dfs).set_index('param') else: p = params.get(name, []) df = pd.DataFrame([s.__dict__ for s in p]) df['model'] = name df.rename(columns={'name':'param'}, inplace=True) - return df + return df.set_index('param') def _map_params_to_realization(params: Mapping[str, Parameters], realization: Realization): # don't even think about calibration multiple formulations at once just yet.. diff --git a/python/ngen_cal/tests/conftest.py b/python/ngen_cal/tests/conftest.py index 32c06a48..e3af5da5 100644 --- a/python/ngen_cal/tests/conftest.py +++ b/python/ngen_cal/tests/conftest.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import pytest -from typing import Generator, List +from typing import Generator, List, Mapping from pathlib import Path from copy import deepcopy import json @@ -11,6 +13,7 @@ from ngen.cal.calibration_cathment import CalibrationCatchment from ngen.cal.model import EvaluationOptions from ngen.cal.agent import Agent +from ngen.cal.parameter import Parameter from hypy import Nexus from .utils import * @@ -219,3 +222,22 @@ def explicit_catchments(nexus, fabric, workdir) -> Generator[ List[ CalibrationC cat = CalibrationCatchment(workdir, id, nexus, start, end, fabric, 'Q_Out', eval_options, data) catchments.append(cat) yield catchments + +@pytest.fixture +def multi_model_shared_params() -> Mapping[str, list[Parameter]]: + p1 = Parameter(name='a', min=0, max=1, init=0) + p2 = Parameter(name='d', min=2, max=3, init=0) + p3 = Parameter(name='c', min=0, max=1, init=0) + p4 = Parameter(name='a', min=0, max=1, init=0) + params = {'A':[p1], 'B':[p2, p4], 'C':[p3, p4]} + + return params + +@pytest.fixture +def multi_model_shared_params2() -> Mapping[str, list[Parameter]]: + p1 = Parameter(name='a', min=0, max=1, init=0) + p2 = Parameter(name='a', min=0, max=1, init=0) + p3 = Parameter(name='c', min=0, max=1, init=0) + params = {'A':[p1, p3], 'B':[p2]} + + return params diff --git a/python/ngen_cal/tests/test_params.py b/python/ngen_cal/tests/test_params.py new file mode 100644 index 00000000..bb483c9d --- /dev/null +++ b/python/ngen_cal/tests/test_params.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from ngen.cal.ngen import _params_as_df +import pandas as pd + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Mapping + from ngen.cal.parameter import Parameter + +def test_multi_params(multi_model_shared_params: Mapping[str, list[Parameter]]): + # This is essentially the path the params go through from + # creation in model.py to update in search.py + params = _params_as_df(multi_model_shared_params) + params = pd.DataFrame(params).rename(columns={'init':'0'}) + # create new iteration from old + params['1'] = params['0'] + #update the parameters by index + idx1 = 'a' + idx2 = 'c' + params.loc[idx1, '1'] = 0.5 + pa = params[ params['model'] == 'A' ].loc[idx1] + pb = params[ params['model'] == 'B' ].loc[idx1] + pc = params[ params['model'] == 'C' ].loc[idx2] + + assert pa.drop('model').equals( pb.drop('model') ) + # ensure unique params/alias are not modifed by selection + assert pa.loc['1'] != pc.loc['1'] + +def test_multi_params2(multi_model_shared_params2: Mapping[str, list[Parameter]]): + # This is essentially the path the params go through from + # creation in model.py to update in search.py + params = _params_as_df(multi_model_shared_params2) + params = pd.DataFrame(params).rename(columns={'init':'0'}) + # create new iteration from old + params['1'] = params['0'] + #update the parameters by index + params.loc['a', '1'] = 0.5 + pa = params[ params['model'] == 'A' ].drop('model', axis=1).loc['a'] + pb = params[ params['model'] == 'B' ].drop('model', axis=1).loc['a'] + assert pa.equals( pb )