Skip to content

Commit

Permalink
Use Dadaptation for all DL models
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Jun 25, 2024
1 parent 0ebd8f5 commit bd69eac
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 30 deletions.
24 changes: 9 additions & 15 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,20 +433,20 @@ def __init__(
encoder_options: EncoderOptions,
vae_options: Optional[VAETrainingOptions],
aae_options: Optional[AAETrainingOptions],
lrate: float,
lrate: Optional[float],
):
assert isinstance(lrate, float)
assert isinstance(lrate, (type(None), float))

assert (encoder_options.vae_options is None) == (vae_options is None)
assert (encoder_options.aae_options is None) == (aae_options is None)

if lrate <= 0.0:
raise argparse.ArgumentTypeError("Learning rate must be positive")
self.lrate = lrate
if lrate is not None:
logger.warning(
"The --lrate argument is deprecated, and has no effect in Vamb 5 onwards"
)

self.vae_options = vae_options
self.aae_options = aae_options
self.lrate = lrate


class ClusterOptions:
Expand Down Expand Up @@ -660,7 +660,6 @@ def trainvae(
vae_options: VAEOptions,
training_options: VAETrainingOptions,
vamb_options: VambOptions,
lrate: float,
alpha: Optional[float],
data_loader: DataLoader,
) -> np.ndarray:
Expand All @@ -684,7 +683,6 @@ def trainvae(
vae.trainmodel(
vamb.encode.set_batchsize(data_loader, training_options.batchsize),
nepochs=training_options.nepochs,
lrate=lrate,
batchsteps=training_options.batchsteps,
modelfile=modelpath,
)
Expand All @@ -705,7 +703,6 @@ def trainaae(
aae_options: AAEOptions,
training_options: AAETrainingOptions,
vamb_options: VambOptions,
lrate: float,
alpha: Optional[float], # set automatically if None
contignames: Sequence[str],
) -> tuple[np.ndarray, dict[str, set[str]]]:
Expand All @@ -732,7 +729,6 @@ def trainaae(
training_options.nepochs,
training_options.batchsteps,
training_options.temp,
lrate,
modelpath,
)

Expand Down Expand Up @@ -933,7 +929,6 @@ def run(
vae_options=vae_options,
training_options=vae_training_options,
vamb_options=vamb_options,
lrate=training_options.lrate,
alpha=encoder_options.alpha,
data_loader=data_loader,
)
Expand All @@ -948,7 +943,6 @@ def run(
aae_options=aae_options,
vamb_options=vamb_options,
training_options=aae_training_options,
lrate=training_options.lrate,
alpha=encoder_options.alpha,
contignames=composition.metadata.identifiers, # type:ignore
)
Expand Down Expand Up @@ -1827,9 +1821,9 @@ def add_vae_arguments(subparser):
"-r",
dest="lrate",
metavar="",
type=float,
default=1e-3,
help="learning rate [0.001]",
type=Optional[float],
default=None,
help=argparse.SUPPRESS,
)
return subparser

Expand Down
10 changes: 5 additions & 5 deletions vamb/aamb_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Optional, IO, Union
from numpy.typing import NDArray
from loguru import logger
import dadaptation


############################################################################# MODEL ###########################################################
Expand Down Expand Up @@ -206,7 +207,6 @@ def trainmodel(
nepochs: int,
batchsteps: list[int],
T,
lr: float,
modelfile: Union[None, str, IO[bytes]] = None,
):
Tensor = torch.cuda.FloatTensor if self.usecuda else torch.FloatTensor
Expand Down Expand Up @@ -251,11 +251,11 @@ def trainmodel(
adversarial_loss.cuda()

#### Optimizers
optimizer_E = torch.optim.Adam(enc_params, lr=lr)
optimizer_D = torch.optim.Adam(dec_params, lr=lr)
optimizer_E = dadaptation.DAdaptAdam(enc_params, decouple=True)
optimizer_D = dadaptation.DAdaptAdam(dec_params, decouple=True)

optimizer_D_z = torch.optim.Adam(disc_z_params, lr=lr)
optimizer_D_y = torch.optim.Adam(disc_y_params, lr=lr)
optimizer_D_z = dadaptation.DAdaptAdam(disc_z_params, decouple=True)
optimizer_D_y = dadaptation.DAdaptAdam(disc_y_params, decouple=True)

for epoch_i in range(nepochs):
if epoch_i in batchsteps:
Expand Down
13 changes: 3 additions & 10 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from torch.utils.data.dataset import TensorDataset as _TensorDataset
from torch.utils.data import DataLoader as _DataLoader
from torch.nn.functional import softmax as _softmax
from torch.optim import Adam as _Adam
from torch import Tensor
import dadaptation
from torch import nn as _nn
from math import log as _log
from loguru import logger
Expand Down Expand Up @@ -155,7 +155,7 @@ class VAE(_nn.Module):
dropout: Probability of dropout on forward pass [0.2]
cuda: Use CUDA (GPU accelerated training) [False]
vae.trainmodel(dataloader, nepochs batchsteps, lrate, modelfile)
vae.trainmodel(dataloader, nepochs batchsteps, modelfile)
Trains the model, returning None
vae.encode(self, data_loader):
Expand Down Expand Up @@ -535,7 +535,6 @@ def trainmodel(
self,
dataloader: _DataLoader[tuple[Tensor, Tensor, Tensor]],
nepochs: int = 500,
lrate: float = 1e-3,
batchsteps: Optional[list[int]] = [25, 75, 150, 300],
modelfile: Union[None, str, Path, IO[bytes]] = None,
):
Expand All @@ -544,16 +543,11 @@ def trainmodel(
Inputs:
dataloader: DataLoader made by make_dataloader
nepochs: Train for this many epochs before encoding [500]
lrate: Starting learning rate for the optimizer [0.001]
batchsteps: None or double batchsize at these epochs [25, 75, 150, 300]
modelfile: Save models to this file if not None [None]
Output: None
"""

if lrate < 0:
raise ValueError(f"Learning rate must be positive, not {lrate}")

if nepochs < 1:
raise ValueError("Minimum 1 epoch, not {nepochs}")

Expand All @@ -572,7 +566,7 @@ def trainmodel(
# Get number of features
# Following line is un-inferrable due to typing problems with DataLoader
ncontigs, nsamples = dataloader.dataset.tensors[0].shape # type: ignore
optimizer = _Adam(self.parameters(), lr=lrate)
optimizer = dadaptation.DAdaptAdam(self.parameters(), decouple=True)

logger.info("\tNetwork properties:")
logger.info(f"\tCUDA: {self.usecuda}")
Expand All @@ -588,7 +582,6 @@ def trainmodel(
", ".join(map(str, sorted(batchsteps_set))) if batchsteps_set else "None"
)
logger.info(f"\tBatchsteps: {batchsteps_string}")
logger.info(f"\tLearning rate: {lrate}")
logger.info(f"\tN sequences: {ncontigs}")
logger.info(f"\tN samples: {nsamples}")

Expand Down

0 comments on commit bd69eac

Please sign in to comment.