-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Param mapping #182
base: master
Are you sure you want to change the base?
Param mapping #182
Changes from all commits
819f17c
a9e2f50
3c3428b
933b465
223161b
3ed98b1
ebd1b6e
ba07a38
9875096
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,7 +1,7 @@ | ||||||
from __future__ import annotations | ||||||
|
||||||
from pydantic import BaseModel, Field | ||||||
from typing import Sequence | ||||||
from pydantic import BaseModel, Field, root_validator | ||||||
from typing import Sequence, Mapping, Optional | ||||||
|
||||||
class Parameter(BaseModel, allow_population_by_field_name = True): | ||||||
""" | ||||||
|
@@ -11,5 +11,13 @@ class Parameter(BaseModel, allow_population_by_field_name = True): | |||||
min: float | ||||||
max: float | ||||||
init: float | ||||||
alias: Optional[str] | ||||||
aaraney marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
@root_validator | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor improvement that without could result in a misleading error.
Suggested change
|
||||||
def _set_alias(cls, values: dict) -> dict: | ||||||
alias = values.get('alias', None) | ||||||
if alias is None: | ||||||
values['alias'] = values['name'] | ||||||
return values | ||||||
|
||||||
Parameters = Sequence[Parameter] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,7 @@ | ||
from __future__ import annotations | ||
|
||
import pytest | ||
from typing import Generator, List | ||
from typing import Generator, List, Mapping | ||
from pathlib import Path | ||
from copy import deepcopy | ||
import json | ||
|
@@ -11,6 +13,7 @@ | |
from ngen.cal.calibration_cathment import CalibrationCatchment | ||
from ngen.cal.model import EvaluationOptions | ||
from ngen.cal.agent import Agent | ||
from ngen.cal.parameter import Parameter | ||
from hypy import Nexus | ||
|
||
from .utils import * | ||
|
@@ -219,3 +222,32 @@ def explicit_catchments(nexus, fabric, workdir) -> Generator[ List[ CalibrationC | |
cat = CalibrationCatchment(workdir, id, nexus, start, end, fabric, 'Q_Out', eval_options, data) | ||
catchments.append(cat) | ||
yield catchments | ||
|
||
@pytest.fixture | ||
def multi_model_shared_params() -> Mapping[str, list[Parameter]]: | ||
p1 = Parameter(name='a', min=0, max=1, init=0) | ||
p2 = Parameter(name='d', min=2, max=3, init=0) | ||
p3 = Parameter(name='c', min=0, max=1, init=0) | ||
p4 = Parameter(name='a', min=0, max=1, init=0) | ||
params = {'A':[p1], 'B':[p2, p4], 'C':[p3, p4]} | ||
|
||
return params | ||
|
||
@pytest.fixture | ||
def multi_model_shared_params2() -> Mapping[str, list[Parameter]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At first glance, this overlaps with |
||
p1 = Parameter(name='a', min=0, max=1, init=0) | ||
p2 = Parameter(name='a', min=0, max=1, init=0) | ||
p3 = Parameter(name='c', min=0, max=1, init=0) | ||
params = {'A':[p1, p3], 'B':[p2]} | ||
|
||
return params | ||
|
||
@pytest.fixture | ||
def multi_model_alias_params() -> Mapping[str, list[Parameter]]: | ||
p1 = Parameter(name='a', alias='c', min=0, max=1, init=0) | ||
p2 = Parameter(name='b', alias='c', min=0, max=1, init=0) | ||
p3 = Parameter(name='d', min=0, max=1, init=0) | ||
|
||
params = {'A':[p1,p3], 'B':[p2, p3], 'C':[p3]} | ||
|
||
return params |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from __future__ import annotations | ||
|
||
from ngen.cal.ngen import _params_as_df | ||
import pandas as pd | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from typing import Mapping | ||
from ngen.cal.parameter import Parameter | ||
|
||
def test_multi_params(multi_model_shared_params: Mapping[str, list[Parameter]]): | ||
# This is essentially the path the params go through from | ||
# creation in model.py to update in search.py | ||
params = _params_as_df(multi_model_shared_params) | ||
params = pd.DataFrame(params).rename(columns={'init':'0'}) | ||
# create new iteration from old | ||
params['1'] = params['0'] | ||
#update the parameters by index | ||
idx1 = 'a' | ||
idx2 = 'c' | ||
params.loc[idx1, '1'] = 0.5 | ||
pa = params[ params['model'] == 'A' ].loc[idx1] | ||
pb = params[ params['model'] == 'B' ].loc[idx1] | ||
pc = params[ params['model'] == 'C' ].loc[idx2] | ||
|
||
assert pa.drop('model').equals( pb.drop('model') ) | ||
# ensure unique params/alias are not modifed by selection | ||
assert pa.loc['1'] != pc.loc['1'] | ||
|
||
def test_multi_params2(multi_model_shared_params2: Mapping[str, list[Parameter]]): | ||
# This is essentially the path the params go through from | ||
# creation in model.py to update in search.py | ||
params = _params_as_df(multi_model_shared_params2) | ||
params = pd.DataFrame(params).rename(columns={'init':'0'}) | ||
# create new iteration from old | ||
params['1'] = params['0'] | ||
#update the parameters by index | ||
params.loc['a', '1'] = 0.5 | ||
pa = params[ params['model'] == 'A' ].drop('model', axis=1).loc['a'] | ||
pb = params[ params['model'] == 'B' ].drop('model', axis=1).loc['a'] | ||
assert pa.equals( pb ) | ||
|
||
def test_alias_params(multi_model_alias_params: Mapping[str, list[Parameter]]): | ||
# This is essentially the path the params go through from | ||
# creation in model.py to update in search.py | ||
assert multi_model_alias_params['C'][0].alias == multi_model_alias_params['C'][0].name | ||
params = _params_as_df(multi_model_alias_params) | ||
params = pd.DataFrame(params).rename(columns={'init':'0'}) | ||
# create new iteration from old | ||
params['1'] = params['0'] | ||
#update the parameters by index | ||
|
||
params.loc['c', '1'] = 0.5 | ||
pa = params[ params['model'] == 'A' ] | ||
pb = params[ params['model'] == 'B' ] | ||
|
||
assert pa.drop(['model', 'param'], axis=1).equals( pb.drop(['model', 'param'], axis=1) ) | ||
|
||
# TODO test/document case where params have same name but different min/max/init values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In looking at the tests, what are your thoughts on validating that:
init
,min
, andmax
are notnan
(on the fence about this one)init
is in the bounds ofmin
andmax
min
>=max
?These seem like sane invariants that we should uphold.