Skip to content

Commit 290705f

Browse files
committed
Export re-scaled controls
1 parent 6e0574b commit 290705f

File tree

6 files changed

+33
-39
lines changed

6 files changed

+33
-39
lines changed

src/ert/run_models/everest_run_model.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,12 @@ def _init_batch_data(
326326
control_values: NDArray[np.float64],
327327
evaluator_context: EvaluatorContext,
328328
cached_results: dict[int, Any],
329+
prefix: str = "",
329330
) -> dict[int, dict[str, Any]]:
330331
def _add_controls(
331-
controls_config: list[ControlConfig], values: NDArray[np.float64]
332+
controls_config: list[ControlConfig],
333+
values: NDArray[np.float64],
334+
prefix: str = "",
332335
) -> dict[str, Any]:
333336
batch_data_item: dict[str, Any] = {}
334337
value_list = values.tolist()
@@ -344,13 +347,27 @@ def _add_controls(
344347
else:
345348
variable_value = value_list.pop(0)
346349
control_dict[variable.name] = variable_value
347-
batch_data_item[control.name] = control_dict
350+
batch_data_item[prefix + control.name] = control_dict
351+
return batch_data_item
352+
353+
def _add_controls_with_rescaling(
354+
controls_config: list[ControlConfig], values: NDArray[np.float64]
355+
) -> dict[str, Any]:
356+
batch_data_item = _add_controls(controls_config, values)
357+
if self._ropt_transforms.variables is not None:
358+
rescaled_values = self._ropt_transforms.variables.backward(values)
359+
rescaled_item = _add_controls(
360+
controls_config, rescaled_values, prefix="rescaled-"
361+
)
362+
batch_data_item.update(rescaled_item)
348363
return batch_data_item
349364

350365
active = evaluator_context.active
351366
realizations = evaluator_context.realizations
352367
return {
353-
idx: _add_controls(self._everest_config.controls, control_values[idx, :])
368+
idx: _add_controls_with_rescaling(
369+
self._everest_config.controls, control_values[idx, :]
370+
)
354371
for idx in range(control_values.shape[0])
355372
if (
356373
idx not in cached_results
@@ -392,7 +409,12 @@ def _check_suffix(
392409
f"Key {key} has suffixes, a suffix must be specified"
393410
)
394411

395-
if set(controls.keys()) != set(self._everest_config.control_names):
412+
control_names = set(self._everest_config.control_names)
413+
if any(control.auto_scale for control in self._everest_config.controls):
414+
control_names |= {
415+
"rescaled-" + name for name in self._everest_config.control_names
416+
}
417+
if set(controls.keys()) != control_names:
396418
err_msg = "Mismatch between initialized and provided control names."
397419
raise KeyError(err_msg)
398420

src/everest/config/control_variable_config.py

-18
Original file line numberDiff line numberDiff line change
@@ -34,24 +34,6 @@ class _ControlVariable(BaseModel):
3434
initial value.
3535
""",
3636
)
37-
auto_scale: bool | None = Field(
38-
default=None,
39-
description="""
40-
Can be set to true to re-scale variable from the range
41-
defined by [min, max] to the range defined by scaled_range (default [0, 1])
42-
""",
43-
)
44-
scaled_range: Annotated[tuple[float, float] | None, AfterValidator(valid_range)] = (
45-
Field(
46-
default=None,
47-
description="""
48-
Can be used to set the range of the variable values
49-
after scaling (default = [0, 1]).
50-
51-
This option has no effect if auto_scale is not set.
52-
""",
53-
)
54-
)
5537
min: float | None = Field(
5638
default=None,
5739
description="""

src/everest/config/utils.py

-2
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,6 @@ def _add_variable(
6666
for key in [
6767
"control_type",
6868
"enabled",
69-
"auto_scale",
70-
"scaled_range",
7169
"min",
7270
"max",
7371
"perturbation_magnitude",

src/everest/optimizer/everest2ropt.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def everest2ropt(
389389

390390
def get_ropt_transforms(ever_config: EverestConfig) -> Transforms:
391391
controls = FlattenedControls(ever_config.controls)
392-
if any(item is not None for item in controls.auto_scales):
392+
if any(controls.auto_scales):
393393
variable_scaler = ControlScaler(
394394
controls.lower_bounds,
395395
controls.upper_bounds,

src/everest/simulator/everest_to_ert.py

+6
Original file line numberDiff line numberDiff line change
@@ -526,5 +526,11 @@ def _get_variables(
526526
input_keys=_get_variables(control.variables),
527527
output_file=control.name + ".json",
528528
)
529+
if control.auto_scale:
530+
ens_config.parameter_configs["rescaled-" + control.name] = ExtParamConfig(
531+
name="rescaled-" + control.name,
532+
input_keys=_get_variables(control.variables),
533+
output_file="rescaled-" + control.name + ".json",
534+
)
529535

530536
return ert_config

tests/everest/test_ropt_initialization.py

-14
Original file line numberDiff line numberDiff line change
@@ -46,20 +46,6 @@ def test_everest2ropt_controls_auto_scale():
4646
assert numpy.allclose(ropt_config.variables.upper_bounds, 0.7)
4747

4848

49-
def test_everest2ropt_variables_auto_scale():
50-
config = EverestConfig.load_file(os.path.join(_CONFIG_DIR, _CONFIG_FILE))
51-
controls = config.controls
52-
controls[0].variables[1].auto_scale = True
53-
controls[0].variables[1].scaled_range = [0.3, 0.7]
54-
ropt_config = everest2ropt(config, transforms=get_ropt_transforms(config))
55-
assert ropt_config.variables.lower_bounds[0] == 0.0
56-
assert ropt_config.variables.upper_bounds[0] == 0.1
57-
assert ropt_config.variables.lower_bounds[1] == 0.3
58-
assert ropt_config.variables.upper_bounds[1] == 0.7
59-
assert numpy.allclose(ropt_config.variables.lower_bounds[2:], 0.0)
60-
assert numpy.allclose(ropt_config.variables.upper_bounds[2:], 0.1)
61-
62-
6349
def test_everest2ropt_controls_input_constraint():
6450
config = EverestConfig.load_file(
6551
os.path.join(_CONFIG_DIR, "config_input_constraints.yml")

0 commit comments

Comments
 (0)