diff --git a/CHANGELOG.md b/CHANGELOG.md index a19cb15..470aea3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,9 @@ Version numbers: major.minor.patch * Minor version bump indicates a change in functionality that may affect users. * Patch version bump indicates bug-fixes or minor improvements not expected to affect users. +## v4.1.0 +* Ab initio ("bootstrap") training of models + ## v4.0.0 * Modified base training and basecalling * Minor changes to input format to trainer, use `misc/upgrade_mapped_signal.py` to upgrade old data diff --git a/README.md b/README.md index abd5949..4462b3c 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,7 @@ Tests can be run as follows: If Taiyaki has install in a virtual environment, it will have to activated before running tests: `source venv/bin/activate`. To deactivate, run `deactivate`. -# Walk through +# Walk throughs and further documentation For a walk-through of Taiyaki model training, including how to obtain sample training data, see [docs/walkthrough.rst](docs/walkthrough.rst). For an example of training a modifed base model, see [docs/modbase.rst](docs/modbase.rst). @@ -255,6 +255,20 @@ When training a model from scratch it is generally recommended to set this facto Modified base models can be used in megalodon (release imminent) to call modified bases anchored to a reference. +## Abinitio training + +'Ab initio' is an alternative entry point for Taiyaki that obtains acceptable models with fewer input requirements, +particularly it does not require a previously trained model. + +The input for ab initio training is a set of signal-sequence pairs: + +- Fixed length chunks from reads +- A reference sequence trimmed for each chunk. + +The models produced are not as accurate as normal training process but can be used to bootstrap it. + + +The process is described in the [abinitio](docs/abinito.rst) walk-through. # Guppy compatibility diff --git a/bin/basecall.py b/bin/basecall.py index 9292f44..ebce5cd 100755 --- a/bin/basecall.py +++ b/bin/basecall.py @@ -28,10 +28,8 @@ description="Basecall reads using a taiyaki model", formatter_class=argparse.ArgumentDefaultsHelpFormatter) -add_common_command_args(parser, 'device input_folder input_strand_list limit output quiet recursive version'.split()) +add_common_command_args(parser, 'alphabet device input_folder input_strand_list limit output quiet recursive version'.split()) -parser.add_argument("--alphabet", default=DEFAULT_ALPHABET, - help="Alphabet used by basecaller") parser.add_argument("--chunk_size", type=Positive(int), default=basecall_helpers._DEFAULT_CHUNK_SIZE, help="Size of signal chunks sent to GPU") diff --git a/bin/prepare_mapped_reads.py b/bin/prepare_mapped_reads.py index a78eca4..bc34223 100755 --- a/bin/prepare_mapped_reads.py +++ b/bin/prepare_mapped_reads.py @@ -5,7 +5,6 @@ import sys from taiyaki.cmdargs import FileExists from taiyaki.common_cmdargs import add_common_command_args -from taiyaki.constants import DEFAULT_ALPHABET from taiyaki import alphabet, fast5utils, helpers, prepare_mapping_funcs @@ -13,10 +12,8 @@ parser = argparse.ArgumentParser(description=program_description, formatter_class=argparse.ArgumentDefaultsHelpFormatter) -add_common_command_args(parser, 'device input_folder input_strand_list jobs limit overwrite recursive version'.split()) +add_common_command_args(parser, 'alphabet device input_folder input_strand_list jobs limit overwrite recursive version'.split()) -parser.add_argument('--alphabet', default=DEFAULT_ALPHABET, - help='Canonical base alphabet') parser.add_argument('--mod', nargs=3, metavar=('base', 'canonical', 'name'), default=[], action='append', help='Modified base description') diff --git a/bin/train_abinitio.py b/bin/train_abinitio.py new file mode 100755 index 0000000..b4779cb --- /dev/null +++ b/bin/train_abinitio.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +import argparse +from Bio import SeqIO +import h5py +import numpy as np +import os +import pickle +from shutil import copyfile +import sys +import time + +import torch +from torch.optim.lr_scheduler import CosineAnnealingLR + +from taiyaki import ctc, flipflopfings, helpers +from taiyaki import __version__ +from taiyaki.cmdargs import FileExists, Positive +from taiyaki.common_cmdargs import add_common_command_args +from taiyaki.constants import DEFAULT_ALPHABET + + +# This is here, not in main to allow documentation to be built +parser = argparse.ArgumentParser( + description='Train a flip-flop neural network', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + +add_common_command_args(parser, """adam alphabet device limit niteration + overwrite quiet save_every version""".split()) + +parser.add_argument('--batch_size', default=128, metavar='chunks', + type=Positive(int), help='Number of chunks to run in parallel') +parser.add_argument( '--lr_max', default=4.0e-3, metavar='rate', + type=Positive(float), help='Initial learning rate') +parser.add_argument('--outdir', default='training', + help='Output directory, created when run.') +parser.add_argument('--size', default=96, metavar='neurons', + type=Positive(int), help='Base layer size for model') +parser.add_argument('--seed', default=None, metavar='integer', type=Positive(int), + help='Set random number seed') +parser.add_argument('--stride', default=2, metavar='samples', type=Positive(int), + help='Stride for model') +parser.add_argument('--winlen', default=19, type=Positive(int), + help='Length of window over data') + +parser.add_argument('model', action=FileExists, + help='File to read python model description from') +parser.add_argument('chunks', action=FileExists, + help='file containing chunks') +parser.add_argument('reference', action=FileExists, + help='file containing fasta reference') + + +def convert_seq(s, alphabet): + buf = np.array(list(s)) + assert np.all(buf >= len(alphabet), "Alphabet violates assumption in convert_seq" + for i, b in enumerate(alphabet): + buf[buf == b] = i + return flipflopfings.flipflop_code(buf.astype('i4'), len(alphabet)) + + +def save_model(network, outdir, index=None): + if index is None: + basename = 'model_final' + else: + basename = 'model_checkpoint_{:05d}'.format(index) + + model_file = os.path.join(outdir, basename + '.checkpoint') + torch.save(network, model_file) + params_file = os.path.join(outdir, basename + '.params') + torch.save(network.state_dict(), params_file) + + +if __name__ == '__main__': + args = parser.parse_args() + + np.random.seed(args.seed) + + device = torch.device(args.device) + if device.type == 'cuda': + torch.cuda.set_device(device) + + if not os.path.exists(args.outdir): + os.mkdir(args.outdir) + elif not args.overwrite: + sys.stderr.write('Error: Output directory {} exists but --overwrite is false\n'.format(args.outdir)) + exit(1) + if not os.path.isdir(args.outdir): + sys.stderr.write('Error: Output location {} is not directory\n'.format(args.outdir)) + exit(1) + + copyfile(args.model, os.path.join(args.outdir, 'model.py')) + + log = helpers.Logger(os.path.join(args.outdir, 'model.log'), args.quiet) + log.write('* Taiyaki version {}\n'.format(__version__)) + log.write('* Command line\n') + log.write(' '.join(sys.argv) + '\n') + log.write('* Loading data from {}\n'.format(args.chunks)) + log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.chunks))) + + if args.limit is not None: + log.write('* Limiting number of strands to {}\n'.format(args.limit)) + + + with h5py.File(args.chunks, 'r') as h5: + chunks = h5['chunks'][:args.limit] + log.write('* Loaded {} reads from {}.\n'.format(len(chunks), args.chunks)) + + if os.path.splitext(args.reference)[1] == '.pkl': + # Read preprocessed sequences from pickle + with open(args.reference, 'rb') as fh: + seq_dict = pickle.load(fh) + log.write('* Loaded preprocessed references from {}.\n'.format(args.reference)) + else: + # Read sequences from .fa / .fasta file + seq_dict = {int(seq.id) : convert_seq(str(seq.seq), args.alphabet) + for seq in SeqIO.parse(args.reference, "fasta")} + log.write('* Loaded references from {}.\n'.format(args.reference)) + # Write pickle for future + pickle_name = os.path.splitext(args.reference)[0] + '.pkl' + with open(pickle_name, 'wb') as fh: + pickle.dump(seq_dict, fh) + log.write('* Written pickle of processed references to {} for future use.\n'.format(pickle_name)) + + + log.write('* Reading network from {}\n'.format(args.model)) + nbase = len(args.alphabet) + model_kwargs = { + 'size' : args.size, + 'stride': args.stride, + 'winlen': args.winlen, + 'insize': 1, # Number of input features to model e.g. was >1 for event-based models (level, std, dwell) + 'outsize': flipflopfings.nstate_flipflop(nbase) + } + network = helpers.load_model(args.model, **model_kwargs).to(device) + log.write('* Network has {} parameters.\n'.format(sum([p.nelement() + for p in network.parameters()]))) + + optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, + betas=args.adam, eps=1e-6) + lr_scheduler = CosineAnnealingLR(optimizer, args.niteration) + + score_smoothed = helpers.WindowedExpSmoother() + + log.write('* Dumping initial model\n') + save_model(network, args.outdir, 0) + + total_bases = 0 + total_samples = 0 + total_chunks = 0 + + t0 = time.time() + log.write('* Training\n') + + + for i in range(args.niteration): + lr_scheduler.step() + + idx = np.random.randint(len(chunks), size=args.batch_size) + indata = chunks[idx].transpose(1, 0) + indata = torch.tensor(indata[...,np.newaxis], device=device, dtype=torch.float32) + seqs = [seq_dict[i] for i in idx] + + seqlens = torch.tensor([len(seq) for seq in seqs], dtype=torch.long, device=device) + seqs = torch.tensor(np.concatenate(seqs), device=device, dtype=torch.long) + + optimizer.zero_grad() + outputs = network(indata) + lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens, 1.0) + loss = lossvector.sum() / (seqlens > 0.0).float().sum() + loss.backward() + optimizer.step() + + fval = float(loss) + score_smoothed.update(fval) + + total_bases += int(seqlens.sum()) + total_samples += int(indata.nelement()) + + # Doing this deletion leads to less CUDA memory usage. + del indata, seqs, seqlens, outputs, loss, lossvector + if device.type == 'cuda': + torch.cuda.empty_cache() + + if (i + 1) % args.save_every == 0: + save_model(network, args.outdir, (i + 1) // args.save_every) + log.write('C') + else: + log.write('.') + + if (i + 1) % 50 == 0: + # In case of super batching, additional functionality must be + # added here + learning_rate = lr_scheduler.get_lr()[0] + tn = time.time() + dt = tn - t0 + t = ' {:5d} {:5.3f} {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) lr={:.2e}\n' + log.write(t.format((i + 1) // 50, score_smoothed.value, + dt, total_samples / 1000.0 / dt, + total_bases / 1000.0 / dt, learning_rate)) + total_bases = 0 + total_samples = 0 + t0 = tn + + save_model(network, args.outdir) diff --git a/docs/abinitio.rst b/docs/abinitio.rst new file mode 100755 index 0000000..8af9d4e --- /dev/null +++ b/docs/abinitio.rst @@ -0,0 +1,120 @@ +Ab Initio training +================== +.. _`walk through`: walkthrough.rst +This walk-through describes an alternative entry point for training models with lighter input requirements that for the full `walk through`_. +The models obtained will not achieve the same accuracy as the full training process, but is a useful starting point for basecalling and mapping reads as preparation for more rigorous training. + +The input for ab initio training is a set of signal-sequence pairs: + +- Fixed length chunks from reads +- A reference sequence trimmed for each chunk. + +Three sets of input data are provided by way of example: + +- R9.4.1 DNA, 1497098 chunks of 2000 samples + + + r941_dna/chunks.hdf5 + + r941_dna/chunks.fa + +- R9.4.1 RNA, 44735 chunks of 10000 samples + + + r941_rna/chunks.hdf5 + + r941_rna/chunks.fa + +- R10 DNA, 498224 chunks of 2000 samples + + + r10_dna/chunks.hdf5 + + r10_dna/chunks.fa + +Training +-------- + +Training is as simple as: + +.. code-block:: bash + + train_abinitio.py --device 0 mGru_flipflop.py signal_chunks.hdf5 references.fa + ++----------------------+------------------------------------------------------------------+ +| --device | Run training on GPU 0 | ++----------------------+------------------------------------------------------------------+ +| mGru_flipflop.py | Model description file, see ``taiyaki/models`` | ++----------------------+------------------------------------------------------------------+ +| signal_chunks.hdf5 | Signal chunk file, formatted as described in `Chunk format`_. | ++----------------------+------------------------------------------------------------------+ +| references.fa | Per-chunk reference sequence | ++----------------------+------------------------------------------------------------------+ + +A ``Makefile`` is provided to demonstrate training for the example data sets provided. + +.. code-block:: bash + + # Run all examples + make all + # Run single example. Possible examples r941_dna, r941_rna, or r10_dna + make r941_dna/training + + +Chunk format +------------ +.. _HDF5: https://www.hdfgroup.org + +Chunks are stored in a HDF5_ file as a 2D array, *chunks x samples*. The TODO + +Creating this file and the corresponding reads TODO + +For example, the training file for the R941 DNA consists of 1497098 chunks of 2000 samples. + +.. code-block:: bash + + h5ls -r r941_dna/chunks.hdf5 + / Group + /chunks Dataset {1497098, 2000} + + +Scaling issues +.............. +.. _`file formats`: FILE_FORMATS.md#per-read-parameter-files +.. _MAD: https://en.wikipedia.org/wiki/Median_absolute_deviation + +For compatibilty with ONT's basecallers and the default tool-chain, it is recommended that each read (not chunk) is scaled as follows: + +.. code-block:: bash + + signal_scaled = signal - median(signal) + ----------------------- + 1.4826 mad(signal) + +where the 'MAD_' (median absolute deviation) has additional multiplicative factor of 1.4826 to scale it consistently with standard deviation. + + +Other scaling methods could be used if the user is will to create a pre-read parameter file for future training (see `file formats`_). + + +Reference format +---------------- +The references are stored in a *fasta* format, one reference for each **chunk** trimmed to that chunk. +The name of each reference should be the index of its respective chunk. + + +For example, the training file for the R941 DNA consists of 1497098 chunks of 2000 samples. + +.. code-block:: + + >0 + AGACAGCGAGGTTTATCCAATATTTTACAAGACACAAGAACTTCATGTCCATGCTTCAGG + AACAGGACGTCAGATAGCAAACAATGGGAAGTATATTTTTATAACCGAGCAACATCTCTA + CGGAACAGCGTTATCGGTATACAAGTACTCTATATCTTTCAAACGGTGGCTGTTCGTGGG + CTACTCAGACATTAGGGCCAAATACGGTATA + >1 + GTATAAGGAGTGTCAAAGATCTCTTTGTTGGTAACTGTCCCTCTGTAAATAGCCCAGTGC + TGACAATTCTTACTGATGACAATAACATTCAAACAATTCTTCTTAAATAAAGGTTAAGGA + AATGTAAATAAAAAAATAACAGTGACATTAATTTGTATATATCTCAACTTCTTCACTTTA + ACCTGTCTGAGCTGTTTGGTTTTGAACTG + + +Modified bases +-------------- +.. _modbase: modbase.rst +Ab initio training does not yet support our modified base models. +While a model could be trained treating each modified base as an additional canonical base, the recommended proceedure is to train a canonical model using the ab initio process and then use this as the 'pre-trained' model in the modbase_ walk through. diff --git a/taiyaki/__init__.py b/taiyaki/__init__.py index 1864dde..8eb964f 100644 --- a/taiyaki/__init__.py +++ b/taiyaki/__init__.py @@ -1,7 +1,7 @@ """Custard owns my heart!""" __version_info__ = { 'major': 4, - 'minor': 0, + 'minor': 1, 'revision': 0, } __version__ = "{major}.{minor}.{revision}".format(**__version_info__) diff --git a/taiyaki/common_cmdargs.py b/taiyaki/common_cmdargs.py index 8aad75c..524fc35 100644 --- a/taiyaki/common_cmdargs.py +++ b/taiyaki/common_cmdargs.py @@ -3,6 +3,7 @@ from taiyaki.cmdargs import (AutoBool, DeviceAction, FileAbsent, FileExists, Maybe, NonNegative, ParseToNamedTuple, Positive, display_version_and_exit) +from taiyaki.constants import DEFAULT_ALPHABET from taiyaki import __version__ @@ -30,6 +31,10 @@ def add_common_command_args(parser, arglist): NonNegative(float)), action=ParseToNamedTuple, help='Parameters beta1, beta2 for Exponential Decay Adaptive Momentum') + if 'alphabet' in arglist: + parser.add_argument('--alphabet', default=DEFAULT_ALPHABET, + help='Canonical base alphabet') + if 'chunk_logging_threshold' in arglist: parser.add_argument('--chunk_logging_threshold', default=10.0, metavar='multiple', type=NonNegative(float),