Skip to content

Commit

Permalink
Reduce nesting in analysis_ES
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Dec 19, 2023
1 parent 603daad commit 603ff8e
Showing 1 changed file with 75 additions and 58 deletions.
133 changes: 75 additions & 58 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,20 @@ def _update_with_row_scaling(
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)


def _determine_inversion_type(ies_inversion: int, update_step_name: str) -> str:
inversion_types = {0: "exact", 1: "subspace", 2: "subspace", 3: "subspace"}

try:
inversion_type = inversion_types[ies_inversion]
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched inversion type for update step: {update_step_name}. "
f"Specified: {ies_inversion}, with possible: {list(inversion_types.keys())}"
) from e

return inversion_type


def analysis_ES(
updatestep: UpdateConfiguration,
rng: np.random.Generator,
Expand All @@ -506,11 +520,14 @@ def analysis_ES(
misfit_process: bool,
) -> None:
iens_active_index = np.flatnonzero(ens_mask)

ensemble_size = ens_mask.sum()
updated_parameter_groups = []

for update_step in updatestep:
updated_parameter_groups.extend(
[param_group.name for param_group in update_step.parameters]
)

progress_callback(
AnalysisStatusEvent(msg="Loading observations and responses..")
)
Expand All @@ -533,6 +550,7 @@ def analysis_ES(
)
except IndexError as e:
raise ErtAnalysisError(e) from e

smoother_snapshot.update_step_snapshots[update_step.name] = update_snapshot

num_obs = len(observation_values)
Expand All @@ -541,25 +559,52 @@ def analysis_ES(
f"No active observations for update step: {update_step.name}."
)

inversion_types = {0: "exact", 1: "subspace", 2: "subspace", 3: "subspace"}
try:
inversion_type = inversion_types[module.ies_inversion]
except KeyError as e:
raise ErtAnalysisError(
f"Mismatched inversion type for: "
f"Specified: {module.ies_inversion}, with possible: {inversion_types}"
) from e

smoother_es = ies.ESMDA(
covariance=observation_errors**2,
observations=observation_values,
alpha=1, # The user is responsible for scaling observation covariance (esmda usage)
seed=rng,
inversion=inversion_type,
inversion_type = _determine_inversion_type(
module.ies_inversion, update_step.name
)

truncation = module.enkf_truncation

if module.localization:
# If doing global update, i.e., udpating all parameters using all observations.
if not module.localization:
smoother_es = ies.ESMDA(
covariance=observation_errors**2,
observations=observation_values,
alpha=1, # The user is responsible for scaling observation covariance (esmda usage)
seed=rng,
inversion=inversion_type,
)
# Compute transition matrix so that
# X_posterior = X_prior @ (I + T)
T = smoother_es.compute_transition_matrix(
Y=S, alpha=1.0, truncation=truncation
)
# Add identity in place for fast computation
np.fill_diagonal(T, T.diagonal() + 1)

# One parameter group is updated at a time to save memory.
# We call this a "Streaming Algorithm".
for param_group in update_step.parameters:
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
source = target_fs
else:
source = source_fs
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)

# Update manually using global transition matrix T
if active_indices := param_group.index_list:
temp_storage[param_group.name][active_indices, :] @= T
else:
temp_storage[param_group.name] @= T

progress_callback(
AnalysisStatusEvent(msg=f"Storing data for {param_group.name}..")
)
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
else: # Adaptive Localization
smoother_adaptive_es = AdaptiveESMDA(
covariance=observation_errors**2,
observations=observation_values,
Expand All @@ -573,26 +618,15 @@ def analysis_ES(
ensemble_size=ensemble_size, alpha=1.0
)

else:
# Compute transition matrix so that
# X_posterior = X_prior @ T
T = smoother_es.compute_transition_matrix(
Y=S, alpha=1.0, truncation=truncation
)
# Add identity in place for fast computation
np.fill_diagonal(T, T.diagonal() + 1)

for param_group in update_step.parameters:
updated_parameter_groups.append(param_group.name)
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
source = target_fs
else:
source = source_fs
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
if module.localization:
for param_group in update_step.parameters:
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
source = target_fs
else:
source = source_fs
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
num_params = temp_storage[param_group.name].shape[0]

# Calculate adaptive batch size.
Expand Down Expand Up @@ -646,27 +680,10 @@ def analysis_ES(
f"Adaptive Localization of {param_group} completed in {(time.time() - start) / 60} minutes"
)

else:
# Use low-level ies API to allow looping over parameters
if active_indices := param_group.index_list:
# The batch of parameters
X_local = temp_storage[param_group.name][active_indices, :]

# Update manually using global transition matrix T
temp_storage[param_group.name][active_indices, :] = X_local @ T

else:
# Update manually using global transition matrix T
temp_storage[param_group.name] @= T

log_msg = f"Storing data for {param_group.name}.."
_logger.info(log_msg)
progress_callback(AnalysisStatusEvent(msg=log_msg))
start = time.time()
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
_logger.info(
f"Storing data for {param_group.name} completed in {(time.time() - start) / 60} minutes"
)
progress_callback(
AnalysisStatusEvent(msg=f"Storing data for {param_group.name}..")
)
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)

# Finally, if some parameter groups have not been updated we need to copy the parameters
# from the parent ensemble.
Expand Down

0 comments on commit 603ff8e

Please sign in to comment.