Skip to content

Commit

Permalink
Add GEN_KW to heat equation and turn localization on
Browse files Browse the repository at this point in the history
Makes test more comprehensive and realistic.
  • Loading branch information
dafeda committed Jan 6, 2025
1 parent 839610b commit 8e6e54a
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 29 deletions.
6 changes: 6 additions & 0 deletions test-data/ert/heat_equation/config.ert
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ QUEUE_OPTION LOCAL MAX_RUNNING 100

RANDOM_SEED 11223344

ANALYSIS_SET_VAR STD_ENKF LOCALIZATION True
ANALYSIS_SET_VAR STD_ENKF LOCALIZATION_CORRELATION_THRESHOLD 0.1

NUM_REALIZATIONS 100
GRID CASE.EGRID

OBS_CONFIG observations

FIELD COND PARAMETER cond.bgrdecl INIT_FILES:cond.bgrdecl FORWARD_INIT:True

GEN_KW INIT_TEMP_SCALE init_temp_prior.txt
GEN_KW CORR_LENGTH corr_length_prior.txt

GEN_DATA MY_RESPONSE RESULT_FILE:gen_data_%d.out REPORT_STEPS:10,71,132,193,255,316,377,438 INPUT_FORMAT:ASCII

INSTALL_JOB heat_equation HEAT_EQUATION
Expand Down
1 change: 1 addition & 0 deletions test-data/ert/heat_equation/corr_length_prior.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
NORMAL 0.8 0.1
23 changes: 19 additions & 4 deletions test-data/ert/heat_equation/heat_equation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
"""Partial Differential Equations to use as forward models."""

import json
import sys

import geostat
Expand Down Expand Up @@ -51,16 +52,28 @@ def heat_equation(
return u_


def sample_prior_conductivity(ensemble_size, nx, rng):
def sample_prior_conductivity(ensemble_size, nx, rng, corr_length):
mesh = np.meshgrid(np.linspace(0, 1, nx), np.linspace(0, 1, nx))
return np.exp(geostat.gaussian_fields(mesh, rng, ensemble_size, r=0.8))
return np.exp(geostat.gaussian_fields(mesh, rng, ensemble_size, r=corr_length))


def load_parameters(filename):
with open(filename, encoding="utf-8") as f:
return json.load(f)


if __name__ == "__main__":
iens = int(sys.argv[1])
iteration = int(sys.argv[2])
rng = np.random.default_rng(iens)
cond = sample_prior_conductivity(ensemble_size=1, nx=nx, rng=rng).reshape(nx, nx)

parameters = load_parameters("parameters.json")
init_temp_scale = parameters["INIT_TEMP_SCALE"]
corr_length = parameters["CORR_LENGTH"]

cond = sample_prior_conductivity(
ensemble_size=1, nx=nx, rng=rng, corr_length=float(corr_length["x"])
).reshape(nx, nx)

if iteration == 0:
resfo.write(
Expand All @@ -78,7 +91,9 @@ def sample_prior_conductivity(ensemble_size, nx, rng):
# Note that this could be avoided if we used an implicit solver.
dt = dx**2 / (4 * max(np.max(cond), np.max(cond)))

response = heat_equation(u_init, cond, dx, dt, k_start, k_end, rng)
scaled_u_init = u_init * float(init_temp_scale["x"])

response = heat_equation(scaled_u_init, cond, dx, dt, k_start, k_end, rng)

index = sorted((obs.x, obs.y) for obs in obs_coordinates)
for time_step in obs_times:
Expand Down
1 change: 1 addition & 0 deletions test-data/ert/heat_equation/init_temp_prior.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UNIFORM 0 1
75 changes: 50 additions & 25 deletions tests/ert/ui_tests/cli/test_field_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,36 +23,61 @@
from .run_cli import run_cli


def test_field_param_update_using_heat_equation(heat_equation_storage):
config = ErtConfig.from_file("config.ert")
with open_storage(config.ens_path, mode="w") as storage:
def test_shared_heat_equation_storage(heat_equation_storage):
"""The fixture heat_equation_storage runs the heat equation test case.
This test verifies that results are as expected.
"""
config = heat_equation_storage
with open_storage(config.ens_path) as storage:
experiment = storage.get_experiment_by_name("es-mda")
prior = experiment.get_ensemble_by_name("default_0")
posterior = experiment.get_ensemble_by_name("default_1")

prior_result = prior.load_parameters("COND")["values"]
ensembles = [experiment.get_ensemble_by_name(f"default_{i}") for i in range(4)]

param_config = config.ensemble_config.parameter_configs["COND"]
assert len(prior_result.x) == param_config.nx
assert len(prior_result.y) == param_config.ny
assert len(prior_result.z) == param_config.nz

posterior_result = posterior.load_parameters("COND")["values"]
prior_covariance = np.cov(
prior_result.values.reshape(
prior.ensemble_size, param_config.nx * param_config.ny * param_config.nz
),
rowvar=False,
)
posterior_covariance = np.cov(
posterior_result.values.reshape(
posterior.ensemble_size,
# Check that generalized variance decreases across consecutive ensembles
covariances = []
for ensemble in ensembles:
results = ensemble.load_parameters("COND")["values"]
reshaped_values = results.values.reshape(
ensemble.ensemble_size,
param_config.nx * param_config.ny * param_config.nz,
),
rowvar=False,
)
# Check that generalized variance is reduced by update step.
assert np.trace(prior_covariance) > np.trace(posterior_covariance)
)
covariances.append(np.cov(reshaped_values, rowvar=False))
for i in range(len(covariances) - 1):
assert np.trace(covariances[i]) > np.trace(
covariances[i + 1]
), f"Generalized variance did not decrease from iteration {i} to {i + 1}"

# Check that the saved cross-correlations are as expected.
for i in range(3):
ensemble = ensembles[i]
corr_XY = ensemble.load_cross_correlations()

assert sorted(corr_XY["param_group"].unique().to_list()) == [
"CORR_LENGTH",
"INIT_TEMP_SCALE",
]
assert corr_XY["param_name"].unique().to_list() == ["x"]

# Make sure correlations are between -1 and 1.
is_valid = (corr_XY["value"] >= -1) & (corr_XY["value"] <= 1)
assert is_valid.all()

# Check obs names and obs groups
expected_obs_groups = [obs[0] for obs in config.observation_config]
obs_groups = corr_XY["obs_group"].unique().to_list()
assert sorted(obs_groups) == sorted(expected_obs_groups)
# Check that obs names are created using obs groups
obs_name_starts_with_group = (
corr_XY.with_columns(
pl.col("obs_name")
.str.starts_with(pl.col("obs_group"))
.alias("starts_with_check")
)
.get_column("starts_with_check")
.all()
)
assert obs_name_starts_with_group

# Check that fields in the runpath are different between iterations
cond_iter0 = resfo.read("simulations/realization-0/iter-0/cond.bgrdecl")[0][1]
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 8e6e54a

Please sign in to comment.