Skip to content

Commit

Permalink
Fix off-by-one error in split_by_batchsize
Browse files Browse the repository at this point in the history
If the batch size is equal to the number of parameters,
we want _split_by_batchsize to return a single batch and
not two.

Update batch_size calculation to make sure the new
_split_by_batchsize works when the number of parameters
is less than the hard-coded batch_size of 1000.
  • Loading branch information
dafeda committed Dec 18, 2023
1 parent e4230ff commit 30f53af
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,31 @@ def _load_observations_and_responses(
def _split_by_batchsize(
arr: npt.NDArray[np.int_], batch_size: int
) -> List[npt.NDArray[np.int_]]:
return np.array_split(arr, int((arr.shape[0] / batch_size)) + 1)
"""
Splits an array into sub-arrays of a specified batch size.
Examples
--------
>>> num_params = 10
>>> batch_size = 3
>>> s = np.arange(0, num_params)
>>> _split_by_batchsize(s, batch_size)
[array([0, 1, 2, 3]), array([4, 5, 6]), array([7, 8, 9])]
>>> num_params = 10
>>> batch_size = 10
>>> s = np.arange(0, num_params)
>>> _split_by_batchsize(s, batch_size)
[array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]
>>> num_params = 10
>>> batch_size = 20
>>> s = np.arange(0, num_params)
>>> _split_by_batchsize(s, batch_size)
[array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]
"""
sections = 1 if batch_size > len(arr) else len(arr) // batch_size
return np.array_split(arr, sections)


def _update_with_row_scaling(
Expand Down Expand Up @@ -535,7 +559,6 @@ def analysis_ES(
truncation = module.enkf_truncation

if module.localization:
batch_size: int = 1000
smoother_adaptive_es = AdaptiveESMDA(
covariance=observation_errors**2,
observations=observation_values,
Expand Down Expand Up @@ -570,6 +593,7 @@ def analysis_ES(
)
if module.localization:
num_params = temp_storage[param_group.name].shape[0]
batch_size = min(1000, num_params)
batches = _split_by_batchsize(np.arange(0, num_params), batch_size)

progress_callback(
Expand Down

0 comments on commit 30f53af

Please sign in to comment.