Skip to content
This repository has been archived by the owner on Jan 13, 2022. It is now read-only.

Commit

Permalink
Ab initio training of models
Browse files Browse the repository at this point in the history
  • Loading branch information
tmassingham-ont committed May 3, 2019
1 parent be9b02f commit 3b924f6
Show file tree
Hide file tree
Showing 8 changed files with 350 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Version numbers: major.minor.patch
* Minor version bump indicates a change in functionality that may affect users.
* Patch version bump indicates bug-fixes or minor improvements not expected to affect users.

## v4.1.0
* Ab initio ("bootstrap") training of models

## v4.0.0
* Modified base training and basecalling
* Minor changes to input format to trainer, use `misc/upgrade_mapped_signal.py` to upgrade old data
Expand Down
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Tests can be run as follows:

If Taiyaki has install in a virtual environment, it will have to activated before running tests: `source venv/bin/activate`. To deactivate, run `deactivate`.

# Walk through
# Walk throughs and further documentation
For a walk-through of Taiyaki model training, including how to obtain sample training data, see [docs/walkthrough.rst](docs/walkthrough.rst).

For an example of training a modifed base model, see [docs/modbase.rst](docs/modbase.rst).
Expand Down Expand Up @@ -255,6 +255,20 @@ When training a model from scratch it is generally recommended to set this facto

Modified base models can be used in megalodon (release imminent) to call modified bases anchored to a reference.

## Abinitio training

'Ab initio' is an alternative entry point for Taiyaki that obtains acceptable models with fewer input requirements,
particularly it does not require a previously trained model.

The input for ab initio training is a set of signal-sequence pairs:

- Fixed length chunks from reads
- A reference sequence trimmed for each chunk.

The models produced are not as accurate as normal training process but can be used to bootstrap it.


The process is described in the [abinitio](docs/abinito.rst) walk-through.

# Guppy compatibility

Expand Down
4 changes: 1 addition & 3 deletions bin/basecall.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@
description="Basecall reads using a taiyaki model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

add_common_command_args(parser, 'device input_folder input_strand_list limit output quiet recursive version'.split())
add_common_command_args(parser, 'alphabet device input_folder input_strand_list limit output quiet recursive version'.split())

parser.add_argument("--alphabet", default=DEFAULT_ALPHABET,
help="Alphabet used by basecaller")
parser.add_argument("--chunk_size", type=Positive(int),
default=basecall_helpers._DEFAULT_CHUNK_SIZE,
help="Size of signal chunks sent to GPU")
Expand Down
5 changes: 1 addition & 4 deletions bin/prepare_mapped_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
import sys
from taiyaki.cmdargs import FileExists
from taiyaki.common_cmdargs import add_common_command_args
from taiyaki.constants import DEFAULT_ALPHABET
from taiyaki import alphabet, fast5utils, helpers, prepare_mapping_funcs


program_description = "Prepare data for model training and save to hdf5 file by remapping with flip-flop model"
parser = argparse.ArgumentParser(description=program_description,
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

add_common_command_args(parser, 'device input_folder input_strand_list jobs limit overwrite recursive version'.split())
add_common_command_args(parser, 'alphabet device input_folder input_strand_list jobs limit overwrite recursive version'.split())

parser.add_argument('--alphabet', default=DEFAULT_ALPHABET,
help='Canonical base alphabet')
parser.add_argument('--mod', nargs=3, metavar=('base', 'canonical', 'name'),
default=[], action='append',
help='Modified base description')
Expand Down
204 changes: 204 additions & 0 deletions bin/train_abinitio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#!/usr/bin/env python3
import argparse
from Bio import SeqIO
import h5py
import numpy as np
import os
import pickle
from shutil import copyfile
import sys
import time

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR

from taiyaki import ctc, flipflopfings, helpers
from taiyaki import __version__
from taiyaki.cmdargs import FileExists, Positive
from taiyaki.common_cmdargs import add_common_command_args
from taiyaki.constants import DEFAULT_ALPHABET


# This is here, not in main to allow documentation to be built
parser = argparse.ArgumentParser(
description='Train a flip-flop neural network',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

add_common_command_args(parser, """adam alphabet device limit niteration
overwrite quiet save_every version""".split())

parser.add_argument('--batch_size', default=128, metavar='chunks',
type=Positive(int), help='Number of chunks to run in parallel')
parser.add_argument( '--lr_max', default=4.0e-3, metavar='rate',
type=Positive(float), help='Initial learning rate')
parser.add_argument('--outdir', default='training',
help='Output directory, created when run.')
parser.add_argument('--size', default=96, metavar='neurons',
type=Positive(int), help='Base layer size for model')
parser.add_argument('--seed', default=None, metavar='integer', type=Positive(int),
help='Set random number seed')
parser.add_argument('--stride', default=2, metavar='samples', type=Positive(int),
help='Stride for model')
parser.add_argument('--winlen', default=19, type=Positive(int),
help='Length of window over data')

parser.add_argument('model', action=FileExists,
help='File to read python model description from')
parser.add_argument('chunks', action=FileExists,
help='file containing chunks')
parser.add_argument('reference', action=FileExists,
help='file containing fasta reference')


def convert_seq(s, alphabet):
buf = np.array(list(s))
assert np.all(buf >= len(alphabet), "Alphabet violates assumption in convert_seq"
for i, b in enumerate(alphabet):
buf[buf == b] = i
return flipflopfings.flipflop_code(buf.astype('i4'), len(alphabet))


def save_model(network, outdir, index=None):
if index is None:
basename = 'model_final'
else:
basename = 'model_checkpoint_{:05d}'.format(index)

model_file = os.path.join(outdir, basename + '.checkpoint')
torch.save(network, model_file)
params_file = os.path.join(outdir, basename + '.params')
torch.save(network.state_dict(), params_file)


if __name__ == '__main__':
args = parser.parse_args()

np.random.seed(args.seed)

device = torch.device(args.device)
if device.type == 'cuda':
torch.cuda.set_device(device)

if not os.path.exists(args.outdir):
os.mkdir(args.outdir)
elif not args.overwrite:
sys.stderr.write('Error: Output directory {} exists but --overwrite is false\n'.format(args.outdir))
exit(1)
if not os.path.isdir(args.outdir):
sys.stderr.write('Error: Output location {} is not directory\n'.format(args.outdir))
exit(1)

copyfile(args.model, os.path.join(args.outdir, 'model.py'))

log = helpers.Logger(os.path.join(args.outdir, 'model.log'), args.quiet)
log.write('* Taiyaki version {}\n'.format(__version__))
log.write('* Command line\n')
log.write(' '.join(sys.argv) + '\n')
log.write('* Loading data from {}\n'.format(args.chunks))
log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.chunks)))

if args.limit is not None:
log.write('* Limiting number of strands to {}\n'.format(args.limit))


with h5py.File(args.chunks, 'r') as h5:
chunks = h5['chunks'][:args.limit]
log.write('* Loaded {} reads from {}.\n'.format(len(chunks), args.chunks))

if os.path.splitext(args.reference)[1] == '.pkl':
# Read preprocessed sequences from pickle
with open(args.reference, 'rb') as fh:
seq_dict = pickle.load(fh)
log.write('* Loaded preprocessed references from {}.\n'.format(args.reference))
else:
# Read sequences from .fa / .fasta file
seq_dict = {int(seq.id) : convert_seq(str(seq.seq), args.alphabet)
for seq in SeqIO.parse(args.reference, "fasta")}
log.write('* Loaded references from {}.\n'.format(args.reference))
# Write pickle for future
pickle_name = os.path.splitext(args.reference)[0] + '.pkl'
with open(pickle_name, 'wb') as fh:
pickle.dump(seq_dict, fh)
log.write('* Written pickle of processed references to {} for future use.\n'.format(pickle_name))


log.write('* Reading network from {}\n'.format(args.model))
nbase = len(args.alphabet)
model_kwargs = {
'size' : args.size,
'stride': args.stride,
'winlen': args.winlen,
'insize': 1, # Number of input features to model e.g. was >1 for event-based models (level, std, dwell)
'outsize': flipflopfings.nstate_flipflop(nbase)
}
network = helpers.load_model(args.model, **model_kwargs).to(device)
log.write('* Network has {} parameters.\n'.format(sum([p.nelement()
for p in network.parameters()])))

optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max,
betas=args.adam, eps=1e-6)
lr_scheduler = CosineAnnealingLR(optimizer, args.niteration)

score_smoothed = helpers.WindowedExpSmoother()

log.write('* Dumping initial model\n')
save_model(network, args.outdir, 0)

total_bases = 0
total_samples = 0
total_chunks = 0

t0 = time.time()
log.write('* Training\n')


for i in range(args.niteration):
lr_scheduler.step()

idx = np.random.randint(len(chunks), size=args.batch_size)
indata = chunks[idx].transpose(1, 0)
indata = torch.tensor(indata[...,np.newaxis], device=device, dtype=torch.float32)
seqs = [seq_dict[i] for i in idx]

seqlens = torch.tensor([len(seq) for seq in seqs], dtype=torch.long, device=device)
seqs = torch.tensor(np.concatenate(seqs), device=device, dtype=torch.long)

optimizer.zero_grad()
outputs = network(indata)
lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens, 1.0)
loss = lossvector.sum() / (seqlens > 0.0).float().sum()
loss.backward()
optimizer.step()

fval = float(loss)
score_smoothed.update(fval)

total_bases += int(seqlens.sum())
total_samples += int(indata.nelement())

# Doing this deletion leads to less CUDA memory usage.
del indata, seqs, seqlens, outputs, loss, lossvector
if device.type == 'cuda':
torch.cuda.empty_cache()

if (i + 1) % args.save_every == 0:
save_model(network, args.outdir, (i + 1) // args.save_every)
log.write('C')
else:
log.write('.')

if (i + 1) % 50 == 0:
# In case of super batching, additional functionality must be
# added here
learning_rate = lr_scheduler.get_lr()[0]
tn = time.time()
dt = tn - t0
t = ' {:5d} {:5.3f} {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) lr={:.2e}\n'
log.write(t.format((i + 1) // 50, score_smoothed.value,
dt, total_samples / 1000.0 / dt,
total_bases / 1000.0 / dt, learning_rate))
total_bases = 0
total_samples = 0
t0 = tn

save_model(network, args.outdir)
120 changes: 120 additions & 0 deletions docs/abinitio.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
Ab Initio training
==================
.. _`walk through`: walkthrough.rst
This walk-through describes an alternative entry point for training models with lighter input requirements that for the full `walk through`_.
The models obtained will not achieve the same accuracy as the full training process, but is a useful starting point for basecalling and mapping reads as preparation for more rigorous training.

The input for ab initio training is a set of signal-sequence pairs:

- Fixed length chunks from reads
- A reference sequence trimmed for each chunk.

Three sets of input data are provided by way of example:

- R9.4.1 DNA, 1497098 chunks of 2000 samples

+ r941_dna/chunks.hdf5
+ r941_dna/chunks.fa

- R9.4.1 RNA, 44735 chunks of 10000 samples

+ r941_rna/chunks.hdf5
+ r941_rna/chunks.fa

- R10 DNA, 498224 chunks of 2000 samples

+ r10_dna/chunks.hdf5
+ r10_dna/chunks.fa

Training
--------

Training is as simple as:

.. code-block:: bash
train_abinitio.py --device 0 mGru_flipflop.py signal_chunks.hdf5 references.fa
+----------------------+------------------------------------------------------------------+
| --device | Run training on GPU 0 |
+----------------------+------------------------------------------------------------------+
| mGru_flipflop.py | Model description file, see ``taiyaki/models`` |
+----------------------+------------------------------------------------------------------+
| signal_chunks.hdf5 | Signal chunk file, formatted as described in `Chunk format`_. |
+----------------------+------------------------------------------------------------------+
| references.fa | Per-chunk reference sequence |
+----------------------+------------------------------------------------------------------+

A ``Makefile`` is provided to demonstrate training for the example data sets provided.

.. code-block:: bash
# Run all examples
make all
# Run single example. Possible examples r941_dna, r941_rna, or r10_dna
make r941_dna/training
Chunk format
------------
.. _HDF5: https://www.hdfgroup.org

Chunks are stored in a HDF5_ file as a 2D array, *chunks x samples*. The TODO

Creating this file and the corresponding reads TODO

For example, the training file for the R941 DNA consists of 1497098 chunks of 2000 samples.

.. code-block:: bash
h5ls -r r941_dna/chunks.hdf5
/ Group
/chunks Dataset {1497098, 2000}
Scaling issues
..............
.. _`file formats`: FILE_FORMATS.md#per-read-parameter-files
.. _MAD: https://en.wikipedia.org/wiki/Median_absolute_deviation

For compatibilty with ONT's basecallers and the default tool-chain, it is recommended that each read (not chunk) is scaled as follows:

.. code-block:: bash
signal_scaled = signal - median(signal)
-----------------------
1.4826 mad(signal)
where the 'MAD_' (median absolute deviation) has additional multiplicative factor of 1.4826 to scale it consistently with standard deviation.


Other scaling methods could be used if the user is will to create a pre-read parameter file for future training (see `file formats`_).


Reference format
----------------
The references are stored in a *fasta* format, one reference for each **chunk** trimmed to that chunk.
The name of each reference should be the index of its respective chunk.


For example, the training file for the R941 DNA consists of 1497098 chunks of 2000 samples.

.. code-block::
>0
AGACAGCGAGGTTTATCCAATATTTTACAAGACACAAGAACTTCATGTCCATGCTTCAGG
AACAGGACGTCAGATAGCAAACAATGGGAAGTATATTTTTATAACCGAGCAACATCTCTA
CGGAACAGCGTTATCGGTATACAAGTACTCTATATCTTTCAAACGGTGGCTGTTCGTGGG
CTACTCAGACATTAGGGCCAAATACGGTATA
>1
GTATAAGGAGTGTCAAAGATCTCTTTGTTGGTAACTGTCCCTCTGTAAATAGCCCAGTGC
TGACAATTCTTACTGATGACAATAACATTCAAACAATTCTTCTTAAATAAAGGTTAAGGA
AATGTAAATAAAAAAATAACAGTGACATTAATTTGTATATATCTCAACTTCTTCACTTTA
ACCTGTCTGAGCTGTTTGGTTTTGAACTG
Modified bases
--------------
.. _modbase: modbase.rst
Ab initio training does not yet support our modified base models.
While a model could be trained treating each modified base as an additional canonical base, the recommended proceedure is to train a canonical model using the ab initio process and then use this as the 'pre-trained' model in the modbase_ walk through.
Loading

0 comments on commit 3b924f6

Please sign in to comment.