Skip to content

Commit d3fda4e

Browse files
committed
Insert names in everest results using polars
1 parent fac55ca commit d3fda4e

File tree

1 file changed

+75
-31
lines changed

1 file changed

+75
-31
lines changed

src/everest/everest_storage.py

+75-31
Original file line numberDiff line numberDiff line change
@@ -408,24 +408,12 @@ def init(self, everest_config: EverestConfig) -> None:
408408
)
409409

410410
def _store_function_results(self, results: FunctionResults) -> _EvaluationResults:
411-
names = {
412-
"variable": self.data.controls["control_name"],
413-
"objective": self.data.objective_functions["objective_name"],
414-
"nonlinear_constraint": (
415-
self.data.nonlinear_constraints["constraint_name"]
416-
if self.data.nonlinear_constraints is not None
417-
else None
418-
),
419-
"realization": self.data.realization_weights["realization"],
420-
}
421-
422411
# We could select only objective values,
423412
# but we select all to also get the constraint values (if they exist)
424413
realization_objectives = polars.from_pandas(
425414
results.to_dataframe(
426415
"evaluations",
427416
select=["objectives", "evaluation_ids"],
428-
names=names,
429417
).reset_index(),
430418
).select(
431419
"batch_id",
@@ -434,13 +422,20 @@ def _store_function_results(self, results: FunctionResults) -> _EvaluationResult
434422
"objectives",
435423
"evaluation_ids",
436424
)
425+
realization_objectives = realization_objectives.with_columns(
426+
polars.col("realization").replace_strict(
427+
dict(enumerate(self.data.realization_weights["realization"]))
428+
),
429+
polars.col("objective").replace_strict(
430+
dict(enumerate(self.data.objective_functions["objective_name"]))
431+
),
432+
)
437433

438434
if results.functions is not None and results.functions.constraints is not None:
439435
realization_constraints = polars.from_pandas(
440436
results.to_dataframe(
441437
"evaluations",
442438
select=["constraints", "evaluation_ids"],
443-
names=names,
444439
).reset_index(),
445440
).select(
446441
"batch_id",
@@ -449,16 +444,27 @@ def _store_function_results(self, results: FunctionResults) -> _EvaluationResult
449444
"nonlinear_constraint",
450445
"constraints",
451446
)
447+
realization_constraints = realization_constraints.with_columns(
448+
polars.col("realization").replace_strict(
449+
dict(enumerate(self.data.realization_weights["realization"]))
450+
),
451+
polars.col("nonlinear_constraint").replace_strict(
452+
dict(enumerate(self.data.nonlinear_constraints["constraint_name"]))
453+
),
454+
)
452455

453456
realization_constraints = self._rename_ropt_df_columns(
454457
realization_constraints
455458
)
456459

457460
batch_constraints = polars.from_pandas(
458-
results.to_dataframe(
459-
"functions", select=["constraints"], names=names
460-
).reset_index()
461+
results.to_dataframe("functions", select=["constraints"]).reset_index()
461462
).select("batch_id", "nonlinear_constraint", "constraints")
463+
batch_constraints = batch_constraints.with_columns(
464+
polars.col("nonlinear_constraint").replace_strict(
465+
dict(enumerate(self.data.nonlinear_constraints["constraint_name"]))
466+
),
467+
)
462468

463469
batch_constraints = batch_constraints.rename(
464470
{
@@ -485,13 +491,17 @@ def _store_function_results(self, results: FunctionResults) -> _EvaluationResult
485491
results.to_dataframe(
486492
"functions",
487493
select=["objectives", "weighted_objective"],
488-
names=names,
489494
).reset_index()
490495
).select("batch_id", "objective", "objectives", "weighted_objective")
496+
batch_objectives = batch_objectives.with_columns(
497+
polars.col("objective").replace_strict(
498+
dict(enumerate(self.data.objective_functions["objective_name"]))
499+
),
500+
)
491501

492502
realization_controls = polars.from_pandas(
493503
results.to_dataframe(
494-
"evaluations", select=["variables", "evaluation_ids"], names=names
504+
"evaluations", select=["variables", "evaluation_ids"]
495505
).reset_index()
496506
).select(
497507
"batch_id",
@@ -500,6 +510,14 @@ def _store_function_results(self, results: FunctionResults) -> _EvaluationResult
500510
"variables",
501511
"evaluation_ids",
502512
)
513+
realization_controls = realization_controls.with_columns(
514+
polars.col("variable").replace_strict(
515+
dict(enumerate(self.data.controls["control_name"]))
516+
),
517+
polars.col("realization").replace_strict(
518+
dict(enumerate(self.data.realization_weights["realization"]))
519+
),
520+
)
503521

504522
realization_controls = self._rename_ropt_df_columns(realization_controls)
505523
realization_controls = self._enforce_dtypes(realization_controls)
@@ -541,19 +559,8 @@ def _store_function_results(self, results: FunctionResults) -> _EvaluationResult
541559
}
542560

543561
def _store_gradient_results(self, results: GradientResults) -> _GradientResults:
544-
names = {
545-
"variable": self.data.controls["control_name"],
546-
"objective": self.data.objective_functions["objective_name"],
547-
"nonlinear_constraint": (
548-
self.data.nonlinear_constraints["constraint_name"]
549-
if self.data.nonlinear_constraints is not None
550-
else None
551-
),
552-
"realization": self.data.realization_weights["realization"],
553-
}
554-
555562
perturbation_objectives = polars.from_pandas(
556-
results.to_dataframe("evaluations", names=names).reset_index()
563+
results.to_dataframe("evaluations").reset_index()
557564
).select(
558565
[
559566
"batch_id",
@@ -572,12 +579,30 @@ def _store_gradient_results(self, results: GradientResults) -> _GradientResults:
572579
),
573580
]
574581
)
582+
perturbation_objectives = perturbation_objectives.with_columns(
583+
polars.col("variable").replace_strict(
584+
dict(enumerate(self.data.controls["control_name"]))
585+
),
586+
polars.col("realization").replace_strict(
587+
dict(enumerate(self.data.realization_weights["realization"]))
588+
),
589+
polars.col("objective").replace_strict(
590+
dict(enumerate(self.data.objective_functions["objective_name"]))
591+
),
592+
)
593+
if results.evaluations.perturbed_constraints is not None:
594+
perturbation_objectives = perturbation_objectives.with_columns(
595+
polars.col("nonlinear_constraint").replace_strict(
596+
dict(enumerate(self.data.nonlinear_constraints["constraint_name"]))
597+
),
598+
)
599+
575600
perturbation_objectives = self._rename_ropt_df_columns(perturbation_objectives)
576601
perturbation_objectives = self._enforce_dtypes(perturbation_objectives)
577602

578603
if results.gradients is not None:
579604
batch_objective_gradient = polars.from_pandas(
580-
results.to_dataframe("gradients", names=names).reset_index()
605+
results.to_dataframe("gradients").reset_index()
581606
).select(
582607
[
583608
"batch_id",
@@ -592,6 +617,25 @@ def _store_gradient_results(self, results: GradientResults) -> _GradientResults:
592617
),
593618
]
594619
)
620+
batch_objective_gradient = batch_objective_gradient.with_columns(
621+
polars.col("variable").replace_strict(
622+
dict(enumerate(self.data.controls["control_name"]))
623+
),
624+
polars.col("objective").replace_strict(
625+
dict(enumerate(self.data.objective_functions["objective_name"]))
626+
),
627+
)
628+
if results.gradients.constraints is not None:
629+
batch_objective_gradient = batch_objective_gradient.with_columns(
630+
polars.col("nonlinear_constraint").replace_strict(
631+
dict(
632+
enumerate(
633+
self.data.nonlinear_constraints["constraint_name"]
634+
)
635+
)
636+
),
637+
)
638+
595639
batch_objective_gradient = self._rename_ropt_df_columns(
596640
batch_objective_gradient
597641
)

0 commit comments

Comments
 (0)