This repository has been archived by the owner on Jan 13, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Ab initio training of models See merge request algorithm/taiyaki!64
- Loading branch information
Showing
8 changed files
with
350 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
#!/usr/bin/env python3 | ||
import argparse | ||
from Bio import SeqIO | ||
import h5py | ||
import numpy as np | ||
import os | ||
import pickle | ||
from shutil import copyfile | ||
import sys | ||
import time | ||
|
||
import torch | ||
from torch.optim.lr_scheduler import CosineAnnealingLR | ||
|
||
from taiyaki import ctc, flipflopfings, helpers | ||
from taiyaki import __version__ | ||
from taiyaki.cmdargs import FileExists, Positive | ||
from taiyaki.common_cmdargs import add_common_command_args | ||
from taiyaki.constants import DEFAULT_ALPHABET | ||
|
||
|
||
# This is here, not in main to allow documentation to be built | ||
parser = argparse.ArgumentParser( | ||
description='Train a flip-flop neural network', | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
|
||
add_common_command_args(parser, """adam alphabet device limit niteration | ||
overwrite quiet save_every version""".split()) | ||
|
||
parser.add_argument('--batch_size', default=128, metavar='chunks', | ||
type=Positive(int), help='Number of chunks to run in parallel') | ||
parser.add_argument( '--lr_max', default=4.0e-3, metavar='rate', | ||
type=Positive(float), help='Initial learning rate') | ||
parser.add_argument('--outdir', default='training', | ||
help='Output directory, created when run.') | ||
parser.add_argument('--size', default=96, metavar='neurons', | ||
type=Positive(int), help='Base layer size for model') | ||
parser.add_argument('--seed', default=None, metavar='integer', type=Positive(int), | ||
help='Set random number seed') | ||
parser.add_argument('--stride', default=2, metavar='samples', type=Positive(int), | ||
help='Stride for model') | ||
parser.add_argument('--winlen', default=19, type=Positive(int), | ||
help='Length of window over data') | ||
|
||
parser.add_argument('model', action=FileExists, | ||
help='File to read python model description from') | ||
parser.add_argument('chunks', action=FileExists, | ||
help='file containing chunks') | ||
parser.add_argument('reference', action=FileExists, | ||
help='file containing fasta reference') | ||
|
||
|
||
def convert_seq(s, alphabet): | ||
buf = np.array(list(s)) | ||
assert np.all(buf >= len(alphabet), "Alphabet violates assumption in convert_seq" | ||
for i, b in enumerate(alphabet): | ||
buf[buf == b] = i | ||
return flipflopfings.flipflop_code(buf.astype('i4'), len(alphabet)) | ||
|
||
|
||
def save_model(network, outdir, index=None): | ||
if index is None: | ||
basename = 'model_final' | ||
else: | ||
basename = 'model_checkpoint_{:05d}'.format(index) | ||
|
||
model_file = os.path.join(outdir, basename + '.checkpoint') | ||
torch.save(network, model_file) | ||
params_file = os.path.join(outdir, basename + '.params') | ||
torch.save(network.state_dict(), params_file) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parser.parse_args() | ||
|
||
np.random.seed(args.seed) | ||
|
||
device = torch.device(args.device) | ||
if device.type == 'cuda': | ||
torch.cuda.set_device(device) | ||
|
||
if not os.path.exists(args.outdir): | ||
os.mkdir(args.outdir) | ||
elif not args.overwrite: | ||
sys.stderr.write('Error: Output directory {} exists but --overwrite is false\n'.format(args.outdir)) | ||
exit(1) | ||
if not os.path.isdir(args.outdir): | ||
sys.stderr.write('Error: Output location {} is not directory\n'.format(args.outdir)) | ||
exit(1) | ||
|
||
copyfile(args.model, os.path.join(args.outdir, 'model.py')) | ||
|
||
log = helpers.Logger(os.path.join(args.outdir, 'model.log'), args.quiet) | ||
log.write('* Taiyaki version {}\n'.format(__version__)) | ||
log.write('* Command line\n') | ||
log.write(' '.join(sys.argv) + '\n') | ||
log.write('* Loading data from {}\n'.format(args.chunks)) | ||
log.write('* Per read file MD5 {}\n'.format(helpers.file_md5(args.chunks))) | ||
|
||
if args.limit is not None: | ||
log.write('* Limiting number of strands to {}\n'.format(args.limit)) | ||
|
||
|
||
with h5py.File(args.chunks, 'r') as h5: | ||
chunks = h5['chunks'][:args.limit] | ||
log.write('* Loaded {} reads from {}.\n'.format(len(chunks), args.chunks)) | ||
|
||
if os.path.splitext(args.reference)[1] == '.pkl': | ||
# Read preprocessed sequences from pickle | ||
with open(args.reference, 'rb') as fh: | ||
seq_dict = pickle.load(fh) | ||
log.write('* Loaded preprocessed references from {}.\n'.format(args.reference)) | ||
else: | ||
# Read sequences from .fa / .fasta file | ||
seq_dict = {int(seq.id) : convert_seq(str(seq.seq), args.alphabet) | ||
for seq in SeqIO.parse(args.reference, "fasta")} | ||
log.write('* Loaded references from {}.\n'.format(args.reference)) | ||
# Write pickle for future | ||
pickle_name = os.path.splitext(args.reference)[0] + '.pkl' | ||
with open(pickle_name, 'wb') as fh: | ||
pickle.dump(seq_dict, fh) | ||
log.write('* Written pickle of processed references to {} for future use.\n'.format(pickle_name)) | ||
|
||
|
||
log.write('* Reading network from {}\n'.format(args.model)) | ||
nbase = len(args.alphabet) | ||
model_kwargs = { | ||
'size' : args.size, | ||
'stride': args.stride, | ||
'winlen': args.winlen, | ||
'insize': 1, # Number of input features to model e.g. was >1 for event-based models (level, std, dwell) | ||
'outsize': flipflopfings.nstate_flipflop(nbase) | ||
} | ||
network = helpers.load_model(args.model, **model_kwargs).to(device) | ||
log.write('* Network has {} parameters.\n'.format(sum([p.nelement() | ||
for p in network.parameters()]))) | ||
|
||
optimizer = torch.optim.Adam(network.parameters(), lr=args.lr_max, | ||
betas=args.adam, eps=1e-6) | ||
lr_scheduler = CosineAnnealingLR(optimizer, args.niteration) | ||
|
||
score_smoothed = helpers.WindowedExpSmoother() | ||
|
||
log.write('* Dumping initial model\n') | ||
save_model(network, args.outdir, 0) | ||
|
||
total_bases = 0 | ||
total_samples = 0 | ||
total_chunks = 0 | ||
|
||
t0 = time.time() | ||
log.write('* Training\n') | ||
|
||
|
||
for i in range(args.niteration): | ||
lr_scheduler.step() | ||
|
||
idx = np.random.randint(len(chunks), size=args.batch_size) | ||
indata = chunks[idx].transpose(1, 0) | ||
indata = torch.tensor(indata[...,np.newaxis], device=device, dtype=torch.float32) | ||
seqs = [seq_dict[i] for i in idx] | ||
|
||
seqlens = torch.tensor([len(seq) for seq in seqs], dtype=torch.long, device=device) | ||
seqs = torch.tensor(np.concatenate(seqs), device=device, dtype=torch.long) | ||
|
||
optimizer.zero_grad() | ||
outputs = network(indata) | ||
lossvector = ctc.crf_flipflop_loss(outputs, seqs, seqlens, 1.0) | ||
loss = lossvector.sum() / (seqlens > 0.0).float().sum() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
fval = float(loss) | ||
score_smoothed.update(fval) | ||
|
||
total_bases += int(seqlens.sum()) | ||
total_samples += int(indata.nelement()) | ||
|
||
# Doing this deletion leads to less CUDA memory usage. | ||
del indata, seqs, seqlens, outputs, loss, lossvector | ||
if device.type == 'cuda': | ||
torch.cuda.empty_cache() | ||
|
||
if (i + 1) % args.save_every == 0: | ||
save_model(network, args.outdir, (i + 1) // args.save_every) | ||
log.write('C') | ||
else: | ||
log.write('.') | ||
|
||
if (i + 1) % 50 == 0: | ||
# In case of super batching, additional functionality must be | ||
# added here | ||
learning_rate = lr_scheduler.get_lr()[0] | ||
tn = time.time() | ||
dt = tn - t0 | ||
t = ' {:5d} {:5.3f} {:5.2f}s ({:.2f} ksample/s {:.2f} kbase/s) lr={:.2e}\n' | ||
log.write(t.format((i + 1) // 50, score_smoothed.value, | ||
dt, total_samples / 1000.0 / dt, | ||
total_bases / 1000.0 / dt, learning_rate)) | ||
total_bases = 0 | ||
total_samples = 0 | ||
t0 = tn | ||
|
||
save_model(network, args.outdir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
Ab Initio training | ||
================== | ||
.. _`walk through`: walkthrough.rst | ||
This walk-through describes an alternative entry point for training models with lighter input requirements that for the full `walk through`_. | ||
The models obtained will not achieve the same accuracy as the full training process, but is a useful starting point for basecalling and mapping reads as preparation for more rigorous training. | ||
|
||
The input for ab initio training is a set of signal-sequence pairs: | ||
|
||
- Fixed length chunks from reads | ||
- A reference sequence trimmed for each chunk. | ||
|
||
Three sets of input data are provided by way of example: | ||
|
||
- R9.4.1 DNA, 1497098 chunks of 2000 samples | ||
|
||
+ r941_dna/chunks.hdf5 | ||
+ r941_dna/chunks.fa | ||
|
||
- R9.4.1 RNA, 44735 chunks of 10000 samples | ||
|
||
+ r941_rna/chunks.hdf5 | ||
+ r941_rna/chunks.fa | ||
|
||
- R10 DNA, 498224 chunks of 2000 samples | ||
|
||
+ r10_dna/chunks.hdf5 | ||
+ r10_dna/chunks.fa | ||
|
||
Training | ||
-------- | ||
|
||
Training is as simple as: | ||
|
||
.. code-block:: bash | ||
train_abinitio.py --device 0 mGru_flipflop.py signal_chunks.hdf5 references.fa | ||
+----------------------+------------------------------------------------------------------+ | ||
| --device | Run training on GPU 0 | | ||
+----------------------+------------------------------------------------------------------+ | ||
| mGru_flipflop.py | Model description file, see ``taiyaki/models`` | | ||
+----------------------+------------------------------------------------------------------+ | ||
| signal_chunks.hdf5 | Signal chunk file, formatted as described in `Chunk format`_. | | ||
+----------------------+------------------------------------------------------------------+ | ||
| references.fa | Per-chunk reference sequence | | ||
+----------------------+------------------------------------------------------------------+ | ||
|
||
A ``Makefile`` is provided to demonstrate training for the example data sets provided. | ||
|
||
.. code-block:: bash | ||
# Run all examples | ||
make all | ||
# Run single example. Possible examples r941_dna, r941_rna, or r10_dna | ||
make r941_dna/training | ||
Chunk format | ||
------------ | ||
.. _HDF5: https://www.hdfgroup.org | ||
|
||
Chunks are stored in a HDF5_ file as a 2D array, *chunks x samples*. The TODO | ||
|
||
Creating this file and the corresponding reads TODO | ||
|
||
For example, the training file for the R941 DNA consists of 1497098 chunks of 2000 samples. | ||
|
||
.. code-block:: bash | ||
h5ls -r r941_dna/chunks.hdf5 | ||
/ Group | ||
/chunks Dataset {1497098, 2000} | ||
Scaling issues | ||
.............. | ||
.. _`file formats`: FILE_FORMATS.md#per-read-parameter-files | ||
.. _MAD: https://en.wikipedia.org/wiki/Median_absolute_deviation | ||
|
||
For compatibilty with ONT's basecallers and the default tool-chain, it is recommended that each read (not chunk) is scaled as follows: | ||
|
||
.. code-block:: bash | ||
signal_scaled = signal - median(signal) | ||
----------------------- | ||
1.4826 mad(signal) | ||
where the 'MAD_' (median absolute deviation) has additional multiplicative factor of 1.4826 to scale it consistently with standard deviation. | ||
|
||
|
||
Other scaling methods could be used if the user is will to create a pre-read parameter file for future training (see `file formats`_). | ||
|
||
|
||
Reference format | ||
---------------- | ||
The references are stored in a *fasta* format, one reference for each **chunk** trimmed to that chunk. | ||
The name of each reference should be the index of its respective chunk. | ||
|
||
|
||
For example, the training file for the R941 DNA consists of 1497098 chunks of 2000 samples. | ||
|
||
.. code-block:: | ||
>0 | ||
AGACAGCGAGGTTTATCCAATATTTTACAAGACACAAGAACTTCATGTCCATGCTTCAGG | ||
AACAGGACGTCAGATAGCAAACAATGGGAAGTATATTTTTATAACCGAGCAACATCTCTA | ||
CGGAACAGCGTTATCGGTATACAAGTACTCTATATCTTTCAAACGGTGGCTGTTCGTGGG | ||
CTACTCAGACATTAGGGCCAAATACGGTATA | ||
>1 | ||
GTATAAGGAGTGTCAAAGATCTCTTTGTTGGTAACTGTCCCTCTGTAAATAGCCCAGTGC | ||
TGACAATTCTTACTGATGACAATAACATTCAAACAATTCTTCTTAAATAAAGGTTAAGGA | ||
AATGTAAATAAAAAAATAACAGTGACATTAATTTGTATATATCTCAACTTCTTCACTTTA | ||
ACCTGTCTGAGCTGTTTGGTTTTGAACTG | ||
Modified bases | ||
-------------- | ||
.. _modbase: modbase.rst | ||
Ab initio training does not yet support our modified base models. | ||
While a model could be trained treating each modified base as an additional canonical base, the recommended proceedure is to train a canonical model using the ab initio process and then use this as the 'pre-trained' model in the modbase_ walk through. |
Oops, something went wrong.