From 89325d539de2b284d930c334b4883836962be44c Mon Sep 17 00:00:00 2001 From: alicjapolanska Date: Wed, 30 Oct 2024 13:14:39 +0000 Subject: [PATCH] Add sample batching to avoid memory issues. --- harmonic/evidence.py | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/harmonic/evidence.py b/harmonic/evidence.py index 5ff73f8..825a47f 100644 --- a/harmonic/evidence.py +++ b/harmonic/evidence.py @@ -212,7 +212,7 @@ def get_masks(self, chain_start_ixs: jnp.ndarray) -> jnp.ndarray: return masks_arr - def add_chains(self, chains): + def add_chains(self, chains, num_slices=None): """Add new chains and calculate an estimate of the inverse evidence, its variance, and the variance of the variance. @@ -228,6 +228,10 @@ def add_chains(self, chains): chains (Chains): An instance of the chains class containing the chains to be used in the calculation. + num_slices (int): Number of slices into which the samples are divided row-wise + when using flow models to avoid memory issues. If None, the samples are + considered all-together. Defaults to None. + Raises: ValueError: Raised if the input number of chains to not match the @@ -247,10 +251,36 @@ def add_chains(self, chains): Y = chains.ln_posterior nchains = self.nchains + if not num_slices is None: + if num_slices > X.shape[0]: + raise ValueError( + "Can't split chains into more blocks than there are samples." + ) + if self.batch_calculation: - lnpred = self.model.predict(x=X) - lnargs = lnpred - Y - lnargs = lnargs.at[jnp.isinf(lnargs)].set(jnp.nan) + if num_slices: + # Number of rows in each slice + slice_size = X.shape[0] // num_slices + lnpred_list = [] + + # Calculate lnpred in row-wise slices + for i in range(num_slices): + start_row = i * slice_size + end_row = (i + 1) * slice_size if i < num_slices - 1 else X.shape[0] + X_slice = X[start_row:end_row] + + # Predict for each row slice and append result + lnpred_slice = self.model.predict(x=X_slice) + lnpred_list.append(lnpred_slice) + + # Concatenate all row slice predictions + lnpred = jnp.concatenate(lnpred_list, axis=0) + lnargs = lnpred - Y + lnargs = lnargs.at[jnp.isinf(lnargs)].set(jnp.nan) + else: + lnpred = self.model.predict(x=X) + lnargs = lnpred - Y + lnargs = lnargs.at[jnp.isinf(lnargs)].set(jnp.nan) else: lnpred = np.zeros_like(Y)