Skip to content

Commit

Permalink
Add style and fix last tests
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Jan 4, 2024
1 parent 126b5b7 commit 28d63c2
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 22 deletions.
4 changes: 3 additions & 1 deletion src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ def _get_obs_and_measure_data(
ds = source_fs.load_responses(group, tuple(iens_active_index))

if "time" in observation.coords:
observation.coords["time"]= [t[:-3] for t in observation.coords["time"].values.astype(str)]
observation.coords["time"] = [
t[:-3] for t in observation.coords["time"].values.astype(str)
]

try:
filtered_ds = observation.merge(ds, join="left")
Expand Down
2 changes: 1 addition & 1 deletion src/ert/config/summary_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def read_from_file(self, run_path: str, iens: int) -> xr.Dataset:
summary_data.sort(key=lambda x: x[0])
data = [d for _, d in summary_data]
keys = [k for k, _ in summary_data]
time_map = [datetime.isoformat(t, timespec="microseconds") for t in time_map]
time_map = [datetime.isoformat(t, timespec="microseconds") for t in time_map]
ds = xr.Dataset(
{"values": (["name", "time"], data)},
coords={"time": time_map, "name": keys},
Expand Down
27 changes: 16 additions & 11 deletions src/ert/data/_measured_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,11 @@ def _get_data(
raise ResponseError(_msg)
except KeyError as e:
raise ResponseError(_msg) from e

if "time" in obs.coords:
obs.coords["time"]= [t[:-3] for t in obs.coords["time"].values.astype(str)]
obs.coords["time"] = [
t[:-3] for t in obs.coords["time"].values.astype(str)
]

ds = obs.merge(
response,
Expand All @@ -137,12 +139,10 @@ def _get_data(
ds = ds.rename(time="key_index")
ds = ds.assign_coords({"name": [key]})

new_index = pd.DatetimeIndex(response.indexes["time"].values.astype('datetime64[ns]'))
data_index = [
new_index.get_loc(date) for date in obs.time.values
]
#data_index = [response.indexes["time"].get_loc(date) for date in obs.time.values ]

new_index = pd.DatetimeIndex(
response.indexes["time"].values.astype("datetime64[ns]")
)
data_index = [new_index.get_loc(date) for date in obs.time.values]
index_vals = ds.observations.coords.to_index(
["name", "key_index"]
).values
Expand Down Expand Up @@ -212,9 +212,14 @@ def _create_condition(
conditions = []
for obs_key, index_list in zip(obs_keys, index_lists):
if index_list is not None:
if isinstance(index_list[0], datetime):
index_list= [datetime.isoformat(t, timespec="microseconds") for t in index_list]
index_cond = [data_index == index for index in index_list]
if all(isinstance(e, datetime) for e in index_list):
index_list_str = [
datetime.isoformat(t, timespec="microseconds") # type: ignore[arg-type]
for t in index_list
]
index_cond = [data_index == index for index in index_list_str]
else:
index_cond = [data_index == index for index in index_list]
index_cond = np.logical_or.reduce(index_cond)
conditions.append(np.logical_and(index_cond, (names == obs_key)))
return np.logical_or.reduce(conditions)
Expand Down
5 changes: 4 additions & 1 deletion src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,12 @@ def load_all_summary_data(
try:
df = self.load_responses(
"summary", tuple(self._filter_active_realizations(realization_index))
).to_dataframe()
).to_dataframe(["time", "name", "realization"])
except (ValueError, KeyError):
return pd.DataFrame()

# Remove the time part of the 'time' index
df.index = df.index.set_levels([t[:-16] for t in df.index.levels[0]], level=0)
df = df.unstack(level="name")
df.columns = [col[1] for col in df.columns.values]
df.index = df.index.rename(
Expand Down
7 changes: 4 additions & 3 deletions tests/unit_tests/data/test_integration_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def test_summary_obs(create_measured_data):
summary_obs.remove_inactive_observations()
assert all(summary_obs.data.columns.get_level_values("data_index").values == [71])
# Only one observation, we check the key_index is what we expect:
assert summary_obs.data.columns.get_level_values("key_index").values[
0
] == "2011-12-21T00:00:00.000000"
assert (
summary_obs.data.columns.get_level_values("key_index").values[0]
== "2011-12-21T00:00:00.000000"
)


@pytest.mark.filterwarnings("ignore::ert.config.ConfigWarning")
Expand Down
9 changes: 4 additions & 5 deletions tests/unit_tests/test_load_forward_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@

import xarray as xr
import fileinput
import logging
import os
Expand All @@ -9,6 +7,7 @@

import numpy as np
import pytest
import xarray as xr
from resdata.summary import Summary

from ert.config import ErtConfig
Expand Down Expand Up @@ -142,9 +141,9 @@ def test_datetime_2500():
realizations = [False] * facade.get_ensemble_size()
realizations[realisation_number] = True
facade.load_from_forward_model(ensemble, realizations, 0)
dataset= ensemble.load_responses("summary", tuple([0]))
assert dataset.coords["time"].data.dtype == np.dtype('object')

dataset = ensemble.load_responses("summary", tuple([0]))
assert dataset.coords["time"].data.dtype == np.dtype("object")


@pytest.mark.usefixtures("copy_snake_oil_case_storage")
Expand Down

0 comments on commit 28d63c2

Please sign in to comment.