Skip to content

Commit

Permalink
Fix off-by-one error in split_by_batchsize
Browse files Browse the repository at this point in the history
If the batch size is equal to the number of parameters,
we want _split_by_batchsize to return a single batch and
not two.

Update batch_size calculation to make sure the new
_split_by_batchsize works when the number of parameters
is less than the hard-coded batch_size of 1000.

Commitin
  • Loading branch information
dafeda committed Dec 18, 2023
1 parent e4230ff commit f6b0b50
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,24 @@ def _load_observations_and_responses(
def _split_by_batchsize(
arr: npt.NDArray[np.int_], batch_size: int
) -> List[npt.NDArray[np.int_]]:
return np.array_split(arr, int((arr.shape[0] / batch_size)) + 1)
"""
Splits an array into sub-arrays of a specified batch size.
Examples
--------
>>> num_params = 10
>>> batch_size = 3
>>> s = np.arange(0, num_params)
>>> _split_by_batchsize(s, batch_size)
[array([0, 1, 2, 3]), array([4, 5, 6]), array([7, 8, 9])]
>>> num_params = 10
>>> batch_size = 10
>>> s = np.arange(0, num_params)
>>> _split_by_batchsize(s, batch_size)
[array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]
"""
return np.array_split(arr, len(arr) // batch_size)


def _update_with_row_scaling(
Expand Down Expand Up @@ -535,7 +552,6 @@ def analysis_ES(
truncation = module.enkf_truncation

if module.localization:
batch_size: int = 1000
smoother_adaptive_es = AdaptiveESMDA(
covariance=observation_errors**2,
observations=observation_values,
Expand Down Expand Up @@ -570,6 +586,7 @@ def analysis_ES(
)
if module.localization:
num_params = temp_storage[param_group.name].shape[0]
batch_size = min(1000, num_params)
batches = _split_by_batchsize(np.arange(0, num_params), batch_size)

progress_callback(
Expand Down

0 comments on commit f6b0b50

Please sign in to comment.