Skip to content

Commit

Permalink
Add sample batching to avoid memory issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed Oct 30, 2024
1 parent e924067 commit 89325d5
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions harmonic/evidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_masks(self, chain_start_ixs: jnp.ndarray) -> jnp.ndarray:

return masks_arr

def add_chains(self, chains):
def add_chains(self, chains, num_slices=None):
"""Add new chains and calculate an estimate of the inverse evidence, its
variance, and the variance of the variance.
Expand All @@ -228,6 +228,10 @@ def add_chains(self, chains):
chains (Chains): An instance of the chains class containing the chains to
be used in the calculation.
num_slices (int): Number of slices into which the samples are divided row-wise
when using flow models to avoid memory issues. If None, the samples are
considered all-together. Defaults to None.
Raises:
ValueError: Raised if the input number of chains to not match the
Expand All @@ -247,10 +251,36 @@ def add_chains(self, chains):
Y = chains.ln_posterior
nchains = self.nchains

if not num_slices is None:
if num_slices > X.shape[0]:
raise ValueError(
"Can't split chains into more blocks than there are samples."
)

if self.batch_calculation:
lnpred = self.model.predict(x=X)
lnargs = lnpred - Y
lnargs = lnargs.at[jnp.isinf(lnargs)].set(jnp.nan)
if num_slices:
# Number of rows in each slice
slice_size = X.shape[0] // num_slices
lnpred_list = []

# Calculate lnpred in row-wise slices
for i in range(num_slices):
start_row = i * slice_size
end_row = (i + 1) * slice_size if i < num_slices - 1 else X.shape[0]
X_slice = X[start_row:end_row]

# Predict for each row slice and append result
lnpred_slice = self.model.predict(x=X_slice)
lnpred_list.append(lnpred_slice)

# Concatenate all row slice predictions
lnpred = jnp.concatenate(lnpred_list, axis=0)
lnargs = lnpred - Y
lnargs = lnargs.at[jnp.isinf(lnargs)].set(jnp.nan)
else:
lnpred = self.model.predict(x=X)
lnargs = lnpred - Y
lnargs = lnargs.at[jnp.isinf(lnargs)].set(jnp.nan)

else:
lnpred = np.zeros_like(Y)
Expand Down

0 comments on commit 89325d5

Please sign in to comment.