Skip to content

Commit

Permalink
Propogate ens_path to plot_api StorageService.session calls
Browse files Browse the repository at this point in the history
  • Loading branch information
larsevj committed Feb 12, 2025
1 parent 91fe12a commit 9e51947
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 30 deletions.
9 changes: 7 additions & 2 deletions src/ert/gui/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import functools
import webbrowser
from pathlib import Path

from qtpy.QtCore import QCoreApplication, QEvent, QSize, Qt, Signal, Slot
from qtpy.QtGui import QCloseEvent, QCursor, QIcon, QMouseEvent
Expand Down Expand Up @@ -147,7 +148,9 @@ def is_dark_mode(self) -> bool:
def right_clicked(self) -> None:
actor = self.sender()
if actor and actor.property("index") == "Create plot":
pw = PlotWindow(self.config_file, None)
pw = PlotWindow(
self.config_file, Path(self.ert_config.ens_path).absolute(), None
)
pw.show()
self._external_plot_windows.append(pw)

Expand Down Expand Up @@ -180,7 +183,9 @@ def select_central_widget(self) -> None:
if index_name == "Create plot":
if self._plot_window:
self._plot_window.close()
self._plot_window = PlotWindow(self.config_file, self)
self._plot_window = PlotWindow(
self.config_file, Path(self.ert_config.ens_path).absolute(), self
)
self.central_layout.addWidget(self._plot_window)
self.central_panels_map["Create plot"] = self._plot_window

Expand Down
20 changes: 13 additions & 7 deletions src/ert/gui/tools/plot/plot_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import io
import logging
from dataclasses import dataclass
from itertools import combinations as combi
from json.decoder import JSONDecodeError
from typing import Any, NamedTuple
from typing import TYPE_CHECKING, Any, NamedTuple
from urllib.parse import quote

import httpx
Expand All @@ -16,6 +18,9 @@

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from pathlib import Path


@dataclass(frozen=True, eq=True)
class EnsembleObject:
Expand All @@ -35,7 +40,8 @@ class PlotApiKeyDefinition(NamedTuple):


class PlotApi:
def __init__(self) -> None:
def __init__(self, ens_path: Path) -> None:
self.ens_path = ens_path
self._all_ensembles: list[EnsembleObject] | None = None
self._timeout = 120

Expand All @@ -54,7 +60,7 @@ def get_all_ensembles(self) -> list[EnsembleObject]:
return self._all_ensembles

self._all_ensembles = []
with StorageService.session() as client:
with StorageService.session(project=self.ens_path) as client:
try:
response = client.get("/experiments", timeout=self._timeout)
self._check_response(response)
Expand Down Expand Up @@ -103,7 +109,7 @@ def all_data_type_keys(self) -> list[PlotApiKeyDefinition]:

all_keys: dict[str, PlotApiKeyDefinition] = {}

with StorageService.session() as client:
with StorageService.session(project=self.ens_path) as client:
response = client.get("/experiments", timeout=self._timeout)
self._check_response(response)

Expand Down Expand Up @@ -164,7 +170,7 @@ def data_for_key(self, ensemble_id: str, key: str) -> pd.DataFrame:
if not ensemble:
return pd.DataFrame()

with StorageService.session() as client:
with StorageService.session(project=self.ens_path) as client:
response = client.get(
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}",
headers={"accept": "application/x-parquet"},
Expand Down Expand Up @@ -197,7 +203,7 @@ def observations_for_key(self, ensemble_ids: list[str], key: str) -> pd.DataFram
if not ensemble:
continue

with StorageService.session() as client:
with StorageService.session(project=self.ens_path) as client:
response = client.get(
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}/observations",
timeout=self._timeout,
Expand Down Expand Up @@ -271,7 +277,7 @@ def std_dev_for_parameter(
if not ensemble:
return np.array([])

with StorageService.session() as client:
with StorageService.session(project=self.ens_path) as client:
response = client.get(
f"/ensembles/{ensemble.id}/records/{PlotApi.escape(key)}/std_dev",
params={"z": z},
Expand Down
8 changes: 6 additions & 2 deletions src/ert/gui/tools/plot/plot_window.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
import time
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -53,6 +55,8 @@
from ert.gui.ertwidgets import CopyButton

if TYPE_CHECKING:
from pathlib import Path

import numpy.typing as npt


Expand Down Expand Up @@ -96,7 +100,7 @@ def open_error_dialog(title: str, content: str) -> None:


class PlotWindow(QMainWindow):
def __init__(self, config_file: str, parent: QWidget | None):
def __init__(self, config_file: str, ens_path: Path, parent: QWidget | None):
QMainWindow.__init__(self, parent)
t = time.perf_counter()

Expand All @@ -108,7 +112,7 @@ def __init__(self, config_file: str, parent: QWidget | None):
self._preferred_ensemble_x_axis_format = PlotContext.INDEX_AXIS
QApplication.setOverrideCursor(Qt.CursorShape.WaitCursor)
try:
self._api = PlotApi()
self._api = PlotApi(ens_path)
self._key_definitions = self._api.all_data_type_keys()
except (RequestError, TimeoutError) as e:
logger.exception(f"plot api request failed: {e}")
Expand Down
5 changes: 2 additions & 3 deletions src/ert/services/_base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,16 @@ def start_server(cls: type[T], *args: Any, **kwargs: Any) -> _Context[T]:
def connect(
cls,
*,
project: os.PathLike[str] | None = None,
project: os.PathLike[str],
timeout: int | None = None,
) -> Self:
if cls._instance is not None:
cls._instance.wait_until_ready()
assert isinstance(cls._instance, cls)
return cls._instance

path = Path(project) if project is not None else Path.cwd()
path = Path(project)
name = f"{cls.service_name}_server.json"

# Note: If the caller actually pass None, we override that here...
if timeout is None:
timeout = 240
Expand Down
9 changes: 6 additions & 3 deletions src/ert/services/storage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
HTTPXClientInstrumentor().instrument()


import os


class StorageService(BaseService):
service_name = "storage"

Expand Down Expand Up @@ -56,7 +59,7 @@ def fetch_auth(self) -> tuple[str, Any]:
@classmethod
def init_service(cls, *args: Any, **kwargs: Any) -> _Context[StorageService]:
try:
service = cls.connect(timeout=0, project=kwargs.get("project"))
service = cls.connect(timeout=0, project=kwargs.get("project", os.getcwd()))
# Check the server is up and running
_ = service.fetch_url()
except TimeoutError:
Expand Down Expand Up @@ -87,11 +90,11 @@ def fetch_url(self) -> str:
)

@classmethod
def session(cls, timeout: int | None = None) -> Client:
def session(cls, project: os.PathLike[str], timeout: int | None = None) -> Client:
"""
Start a HTTP transaction with the server
"""
inst = cls.connect(timeout=timeout)
inst = cls.connect(timeout=timeout, project=project)
return Client(
conn_info=ConnInfo(
base_url=inst.fetch_url(), auth_token=inst.fetch_auth()[1]
Expand Down
2 changes: 1 addition & 1 deletion src/ert/shared/storage/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def get_info(
project_id: os.PathLike[str] | None = None,
project_id: os.PathLike[str],
) -> dict[str, str | tuple[str, Any]]:
client = StorageService.connect(project=project_id)
return {
Expand Down
9 changes: 5 additions & 4 deletions tests/ert/performance_tests/test_dark_storage_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
@pytest.fixture(autouse=True)
def use_testclient(monkeypatch):
client = TestClient(app)
monkeypatch.setattr(StorageService, "session", lambda: client)
monkeypatch.setattr(StorageService, "session", lambda project: client)

def test_escape(s: str) -> str:
"""
Expand Down Expand Up @@ -216,10 +216,11 @@ def test_direct_dark_performance_with_storage(

@pytest.fixture
def api_and_storage(monkeypatch, tmp_path):
with open_storage(tmp_path / "storage", mode="w") as storage:
ens_path = tmp_path / "storage"
with open_storage(ens_path, mode="w") as storage:
monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup")
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", str(storage.path))
api = PlotApi()
api = PlotApi(ens_path)
yield api, storage
if enkf._storage is not None:
enkf._storage.close()
Expand All @@ -233,7 +234,7 @@ def api_and_snake_oil_storage(snake_oil_case_storage, monkeypatch):
monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup")
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", str(storage.path))

api = PlotApi()
api = PlotApi(snake_oil_case_storage.ens_path)
yield api, storage

if enkf._storage is not None:
Expand Down
4 changes: 2 additions & 2 deletions tests/ert/unit_tests/gui/tools/plot/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def is_success(self):
@pytest.fixture
def api(tmpdir, source_root, monkeypatch):
@contextmanager
def session():
def session(project: str):
yield MagicMock(get=mocked_requests_get)

monkeypatch.setattr(StorageService, "session", session)
Expand All @@ -42,7 +42,7 @@ def session():
test_data_dir = os.path.join(test_data_root, "snake_oil")
shutil.copytree(test_data_dir, "test_data")
os.chdir("test_data")
api = PlotApi()
api = PlotApi(test_data_dir)
yield api


Expand Down
7 changes: 4 additions & 3 deletions tests/ert/unit_tests/gui/tools/plot/test_plot_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
@pytest.fixture(autouse=True)
def use_testclient(monkeypatch):
client = TestClient(app)
monkeypatch.setattr(StorageService, "session", lambda: client)
monkeypatch.setattr(StorageService, "session", lambda project: client)

def test_escape(s: str) -> str:
"""
Expand Down Expand Up @@ -184,10 +184,11 @@ def test_plot_api_request_errors(api):

@pytest.fixture
def api_and_storage(monkeypatch, tmp_path):
with open_storage(tmp_path / "storage", mode="w") as storage:
ens_path = tmp_path / "storage"
with open_storage(ens_path, mode="w") as storage:
monkeypatch.setenv("ERT_STORAGE_NO_TOKEN", "yup")
monkeypatch.setenv("ERT_STORAGE_ENS_PATH", str(storage.path))
api = PlotApi()
api = PlotApi(ens_path)
yield api, storage
if enkf._storage is not None:
enkf._storage.close()
Expand Down
6 changes: 3 additions & 3 deletions tests/ert/unit_tests/services/test_base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def test_singleton_start(server_script, tmp_path):
os.close(fd)
"""
)
def test_singleton_connect(server_script):
def test_singleton_connect(tmp_path, server_script):
with _DummyService.start_server(exec_args=[str(server_script)]) as server:
client = _DummyService.connect(timeout=30)
client = _DummyService.connect(project=tmp_path, timeout=30)
assert server is client


Expand All @@ -231,7 +231,7 @@ class ClientThread(threading.Thread):
def run(self):
start_event.set()
try:
self.client = _DummyService.connect(timeout=30)
self.client = _DummyService.connect(project=tmp_path, timeout=30)
except Exception as ex:
self.exception = ex
ready_event.set()
Expand Down

0 comments on commit 9e51947

Please sign in to comment.