Skip to content

Commit

Permalink
Reduce nesting in analysis_ES
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Dec 16, 2023
1 parent 1c0a627 commit 349704a
Showing 1 changed file with 54 additions and 38 deletions.
92 changes: 54 additions & 38 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,46 @@ def analysis_ES(
)
truncation = module.enkf_truncation

if module.localization:
# If doing global update, i.e., udpating all parameters using all observations.
if not module.localization:
# Compute transition matrix so that
# X_posterior = X_prior @ (I + T)
T = smoother_es.compute_transition_matrix(
Y=S, alpha=1.0, truncation=truncation
)
# Add identity in place for fast computation
np.fill_diagonal(T, T.diagonal() + 1)

# One parameter group is updated at a time to save memory.
# We call this a "Streaming Algorithm".
for param_group in update_step.parameters:
updated_parameter_groups.append(param_group.name)
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
source = target_fs
else:
source = source_fs
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)

# Use low-level ies API to allow looping over parameters
if active_indices := param_group.index_list:
# The batch of parameters
X_local = temp_storage[param_group.name][active_indices, :]

# Update manually using global transition matrix T
temp_storage[param_group.name][active_indices, :] = X_local @ T

else:
# Update manually using global transition matrix T
temp_storage[param_group.name] @= T

progress_callback(
AnalysisStatusEvent(msg=f"Storing data for {param_group.name}..")
)
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
else:
batch_size: int = 1000
smoother_adaptive_es = AdaptiveESMDA(
covariance=observation_errors**2,
Expand All @@ -549,26 +588,16 @@ def analysis_ES(
ensemble_size=ensemble_size, alpha=1.0
)

else:
# Compute transition matrix so that
# X_posterior = X_prior @ T
T = smoother_es.compute_transition_matrix(
Y=S, alpha=1.0, truncation=truncation
)
# Add identity in place for fast computation
np.fill_diagonal(T, T.diagonal() + 1)

for param_group in update_step.parameters:
updated_parameter_groups.append(param_group.name)
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
source = target_fs
else:
source = source_fs
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
if module.localization:
for param_group in update_step.parameters:
updated_parameter_groups.append(param_group.name)
source: Union[EnsembleReader, EnsembleAccessor]
if target_fs.has_parameter_group(param_group.name):
source = target_fs
else:
source = source_fs
temp_storage = _create_temporary_parameter_storage(
source, iens_active_index, param_group.name
)
num_params = temp_storage[param_group.name].shape[0]
batches = _split_by_batchsize(np.arange(0, num_params), batch_size)

Expand All @@ -592,23 +621,10 @@ def analysis_ES(
verbose=False,
)

else:
# Use low-level ies API to allow looping over parameters
if active_indices := param_group.index_list:
# The batch of parameters
X_local = temp_storage[param_group.name][active_indices, :]

# Update manually using global transition matrix T
temp_storage[param_group.name][active_indices, :] = X_local @ T

else:
# Update manually using global transition matrix T
temp_storage[param_group.name] @= T

progress_callback(
AnalysisStatusEvent(msg=f"Storing data for {param_group.name}..")
)
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)
progress_callback(
AnalysisStatusEvent(msg=f"Storing data for {param_group.name}..")
)
_save_temp_storage_to_disk(target_fs, temp_storage, iens_active_index)

# Finally, if some parameter groups have not been updated we need to copy the parameters
# from the parent ensemble.
Expand Down

0 comments on commit 349704a

Please sign in to comment.