From a5bf3017c2dd047e62338ad0965c89ee6a29ffbd Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Fri, 2 Jun 2023 12:34:47 +0200 Subject: [PATCH] Handle too high batch steps more graciously Instead of erroring, when too many batchsteps is set such that the final batch size would exceed dataset length, simply don't truncate the batch steps instead of throwing an error. This change enables experimenting with more aggressive batch steps, and also comes in handy when working with long-read data. --- test/test_encode.py | 14 ++++++++++++-- vamb/encode.py | 42 ++++++++++++++++++++++++++++-------------- 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/test/test_encode.py b/test/test_encode.py index 12afdc07..b579657e 100644 --- a/test/test_encode.py +++ b/test/test_encode.py @@ -174,7 +174,7 @@ def test_loss_falls(self): vae = vamb.encode.VAE(self.rpkm.shape[1]) rpkm_copy = self.rpkm.copy() tnfs_copy = self.tnfs.copy() - dl, mask = vamb.encode.make_dataloader( + dl, _ = vamb.encode.make_dataloader( rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True ) di = torch.Tensor(rpkm_copy) @@ -202,10 +202,20 @@ def test_loss_falls(self): after_encoding = vae_2.encode(dl) self.assertTrue(np.all(np.abs(before_encoding - after_encoding) < 1e-6)) + def test_warn_too_many_batch_steps(self): + vae = vamb.encode.VAE(self.rpkm.shape[1]) + rpkm_copy = self.rpkm.copy() + tnfs_copy = self.tnfs.copy() + dl, _ = vamb.encode.make_dataloader( + rpkm_copy, tnfs_copy, self.lens, batchsize=16, destroy=True + ) + with self.assertWarns(Warning): + vae.trainmodel(dl, nepochs=4, batchsteps=[1, 2, 3]) + def test_encoding(self): nlatent = 15 vae = vamb.encode.VAE(self.rpkm.shape[1], nlatent=nlatent) - dl, mask = vamb.encode.make_dataloader( + dl, _ = vamb.encode.make_dataloader( self.rpkm, self.tnfs, self.lens, batchsize=32 ) encoding = vae.encode(dl) diff --git a/vamb/encode.py b/vamb/encode.py index 5e380b34..221a34cd 100644 --- a/vamb/encode.py +++ b/vamb/encode.py @@ -9,6 +9,7 @@ from torch import nn as _nn from math import log as _log from time import time +import warnings __doc__ = """Encode a depths matrix and a tnf matrix to latent representation. @@ -379,7 +380,7 @@ def trainepoch( epoch_celoss = 0.0 if epoch in batchsteps: - data_loader = set_batchsize(data_loader, data_loader.batch_size * 2) + data_loader = set_batchsize(data_loader, data_loader.batch_size * 2) # type: ignore for depths_in, tnf_in, weights in data_loader: depths_in.requires_grad = True @@ -450,7 +451,7 @@ def encode(self, data_loader) -> _np.ndarray: row = 0 with _torch.no_grad(): - for depths, tnf, weights in new_data_loader: + for depths, tnf, _ in new_data_loader: # Move input to GPU if requested if self.usecuda: depths = depths.cuda() @@ -551,28 +552,41 @@ def trainmodel( if nepochs < 1: raise ValueError("Minimum 1 epoch, not {nepochs}") - if batchsteps is None: - batchsteps_set: set[int] = set() + if batchsteps is None or len(batchsteps) == 0: + sorted_batch_steps: list[int] = [] else: # First collect to list in order to allow all element types, then check that # they are integers - batchsteps = list(batchsteps) if not all(isinstance(i, int) for i in batchsteps): raise ValueError("All elements of batchsteps must be integers") - if max(batchsteps, default=0) >= nepochs: + sorted_batch_steps = sorted(set(batchsteps)) + if sorted_batch_steps[0] < 1: + raise ValueError( + f"Minimum of batchsteps must be 1, not {sorted_batch_steps[0]}" + ) + if sorted_batch_steps[-1] >= nepochs: raise ValueError("Max batchsteps must not equal or exceed nepochs") - last_batchsize = dataloader.batch_size * 2 ** len(batchsteps) - if len(dataloader.dataset) < last_batchsize: # type: ignore + + n_contigs = len(dataloader.dataset) # type: ignore + starting_batch_size: int = dataloader.batch_size # type: ignore + if n_contigs < starting_batch_size: raise ValueError( - f"Last batch size of {last_batchsize} exceeds dataset length " - f"of {len(dataloader.dataset)}. " # type: ignore + f"Starting batch size of {starting_batch_size} exceeds dataset length " + f"of {n_contigs}. " "This means you have too few contigs left after filtering to train. " "It is not adviced to run Vamb with fewer than 10,000 sequences " "after filtering. " "Please check the Vamb log file to see where the sequences were " "filtered away, and verify BAM files has sensible content." ) - batchsteps_set = set(batchsteps) + maximum_batch_steps = (n_contigs // starting_batch_size).bit_length() - 1 + if maximum_batch_steps < len(sorted_batch_steps): + warnings.warn( + f"Requested {len(sorted_batch_steps)} batch steps, but with a starting " + f"batch size of {starting_batch_size} and {n_contigs} contigs, " + f"only the first {maximum_batch_steps} batch steps can be used." + ) + sorted_batch_steps = sorted_batch_steps[:maximum_batch_steps] # Get number of features # Following line is un-inferrable due to typing problems with DataLoader @@ -591,8 +605,8 @@ def trainmodel( print("\tN epochs:", nepochs, file=logfile) print("\tStarting batch size:", dataloader.batch_size, file=logfile) batchsteps_string = ( - ", ".join(map(str, sorted(batchsteps_set))) - if batchsteps_set + ", ".join(map(str, sorted_batch_steps)) + if len(sorted_batch_steps) > 0 else "None" ) print("\tBatchsteps:", batchsteps_string, file=logfile) @@ -603,7 +617,7 @@ def trainmodel( # Train for epoch in range(nepochs): dataloader = self.trainepoch( - dataloader, epoch, optimizer, sorted(batchsteps_set), time(), logfile + dataloader, epoch, optimizer, sorted_batch_steps, time(), logfile ) # Save weights - Lord forgive me, for I have sinned when catching all exceptions