@@ -326,9 +326,12 @@ def _init_batch_data(
326
326
control_values : NDArray [np .float64 ],
327
327
evaluator_context : EvaluatorContext ,
328
328
cached_results : dict [int , Any ],
329
+ prefix : str = "" ,
329
330
) -> dict [int , dict [str , Any ]]:
330
331
def _add_controls (
331
- controls_config : list [ControlConfig ], values : NDArray [np .float64 ]
332
+ controls_config : list [ControlConfig ],
333
+ values : NDArray [np .float64 ],
334
+ prefix : str = "" ,
332
335
) -> dict [str , Any ]:
333
336
batch_data_item : dict [str , Any ] = {}
334
337
value_list = values .tolist ()
@@ -344,13 +347,27 @@ def _add_controls(
344
347
else :
345
348
variable_value = value_list .pop (0 )
346
349
control_dict [variable .name ] = variable_value
347
- batch_data_item [control .name ] = control_dict
350
+ batch_data_item [prefix + control .name ] = control_dict
351
+ return batch_data_item
352
+
353
+ def _add_controls_with_rescaling (
354
+ controls_config : list [ControlConfig ], values : NDArray [np .float64 ]
355
+ ) -> dict [str , Any ]:
356
+ batch_data_item = _add_controls (controls_config , values )
357
+ if self ._ropt_transforms .variables is not None :
358
+ rescaled_values = self ._ropt_transforms .variables .backward (values )
359
+ rescaled_item = _add_controls (
360
+ controls_config , rescaled_values , prefix = "rescaled-"
361
+ )
362
+ batch_data_item .update (rescaled_item )
348
363
return batch_data_item
349
364
350
365
active = evaluator_context .active
351
366
realizations = evaluator_context .realizations
352
367
return {
353
- idx : _add_controls (self ._everest_config .controls , control_values [idx , :])
368
+ idx : _add_controls_with_rescaling (
369
+ self ._everest_config .controls , control_values [idx , :]
370
+ )
354
371
for idx in range (control_values .shape [0 ])
355
372
if (
356
373
idx not in cached_results
@@ -392,7 +409,12 @@ def _check_suffix(
392
409
f"Key { key } has suffixes, a suffix must be specified"
393
410
)
394
411
395
- if set (controls .keys ()) != set (self ._everest_config .control_names ):
412
+ control_names = set (self ._everest_config .control_names )
413
+ if any (control .auto_scale for control in self ._everest_config .controls ):
414
+ control_names |= {
415
+ "rescaled-" + name for name in self ._everest_config .control_names
416
+ }
417
+ if set (controls .keys ()) != control_names :
396
418
err_msg = "Mismatch between initialized and provided control names."
397
419
raise KeyError (err_msg )
398
420
0 commit comments