Skip to content

Commit

Permalink
Pass ropt events to status queue
Browse files Browse the repository at this point in the history
  • Loading branch information
yngve-sk committed Feb 7, 2025
1 parent 856cc02 commit 58d95fb
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
19 changes: 18 additions & 1 deletion src/ert/run_models/event.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Annotated, Literal
from typing import Annotated, Any, Literal
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, TypeAdapter
Expand Down Expand Up @@ -29,6 +29,23 @@ class RunModelStatusEvent(RunModelEvent):
msg: str


class EverestStatusEvent(BaseModel):
iteration: int # Should maybe be batch? But for now we denote batch as iteration
event_type: Literal["EverestStatusEvent"]

# Reflects what is currently in ROPT,
# changes in ROPT should appear here accordingly
everest_event: Literal[
"START_EVALUATION",
"FINISHED_EVALUATION",
"START_OPTIMIZER_STEP",
"FINISHED_OPTIMIZER_STEP",
"START_EVALUATOR_STEP",
"FINISHED_EVALUATOR_STEP",
]
data_json: dict[str, Any]


class RunModelTimeEvent(RunModelEvent):
event_type: Literal["RunModelTimeEvent"] = "RunModelTimeEvent"
remaining_time: float
Expand Down
50 changes: 50 additions & 0 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ropt.evaluator import EvaluatorContext, EvaluatorResult
from ropt.plan import BasicOptimizer
from ropt.plan import Event as OptimizerEvent
from ropt.results import FunctionResults, GradientResults
from ropt.transforms import OptModelTransforms
from typing_extensions import TypedDict

Expand All @@ -42,6 +43,7 @@

from ..run_arg import RunArg, create_run_arguments
from .base_run_model import BaseRunModel, StatusEvents
from .event import EverestStatusEvent

if TYPE_CHECKING:
from ert.storage import Ensemble, Experiment
Expand Down Expand Up @@ -257,6 +259,54 @@ def _create_optimizer(self) -> BasicOptimizer:
),
)

def _forward_ropt_event(everest_event: OptimizerEvent):

Check failure on line 262 in src/ert/run_models/everest_run_model.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
has_results = bool(everest_event.data and everest_event.data.get("results"))

# The batch these results pertain to
# If the event has results, they usually pertain to the
# batch before self._batch_id, i.e., self._batch_id - 1
batch_id = (
everest_event.data["results"][0].batch_id
if has_results
else self._batch_id
)

data_json = {}
if has_results:
results = everest_event.data["results"]
exit_code = everest_event.data["exit_code"]

result_types = [
"FunctionResults"
if isinstance(r, FunctionResults)
else "GradientResults"
for r in results
]

data_json = {
"exit_code": exit_code,
"results": {},
}

if any(isinstance(r, GradientResults) for r in result_types):
data_json["results"]["gradient"] = True

if any(isinstance(r, FunctionResults) for r in result_types):
data_json["results"]["gradient"] = True

self.send_event(
EverestStatusEvent(

Check failure on line 298 in src/ert/run_models/everest_run_model.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Argument 1 to "send_event" of "BaseRunModel" has incompatible type "EverestStatusEvent"; expected "AnalysisStatusEvent | AnalysisTimeEvent | EndEvent | FullSnapshotEvent | SnapshotUpdateEvent | <6 more items>"
event_type="EverestStatusEvent",
iteration=batch_id,
everest_event=everest_event.event_type.name,
data_json=data_json,
)
)

# Forward ROPT events to queue
for event_type in EventType:
optimizer.add_observer(event_type, _forward_ropt_event)

return optimizer

def _on_before_forward_model_evaluation(
Expand Down

0 comments on commit 58d95fb

Please sign in to comment.