From ddb720b1d30f55af6ccfeb5bf1f99122d2664a39 Mon Sep 17 00:00:00 2001 From: Guilherme Sena Date: Mon, 11 Dec 2023 18:38:41 +0100 Subject: [PATCH] Update to version 0.7.0 --- README.md | 11 + trgt/Cargo.toml | 2 +- trgt/src/cli.rs | 17 +- trgt/src/cluster/cluster.rs | 228 ++++---- trgt/src/cluster/consensus.rs | 2 + trgt/src/faidx.rs | 158 ++++++ trgt/src/genotype/consensus.rs | 2 + trgt/src/genotype/flank.rs | 141 +++-- trgt/src/genotype/genotype.rs | 8 +- trgt/src/karyotype.rs | 96 ++++ trgt/src/label/hmm.rs | 14 +- trgt/src/label/hmm_defs.rs | 12 - trgt/src/label/label_with_hmm.rs | 523 +++++++++++++++--- trgt/src/label/mod.rs | 4 +- trgt/src/label/refine_motif_counts.rs | 241 ++++++-- trgt/src/locate/locate.rs | 188 +++---- trgt/src/locate/mod.rs | 2 +- trgt/src/locus.rs | 224 +++----- trgt/src/main.rs | 157 +++--- trgt/src/reads/clip_bases.rs | 10 +- trgt/src/reads/clip_region.rs | 32 +- trgt/src/reads/mod.rs | 1 + trgt/src/reads/read.rs | 89 ++- trgt/src/{ => reads}/snp.rs | 3 +- trgt/src/workflows/tr.rs | 18 +- trgt/src/writers/mod.rs | 5 + .../{read_output.rs => writers/write_bam.rs} | 29 +- trgt/src/{vcf.rs => writers/write_vcf.rs} | 11 +- trvz/Cargo.toml | 2 +- trvz/src/cli.rs | 2 +- trvz/src/hmm.rs | 25 +- trvz/src/hmm_defs.rs | 12 - trvz/src/label_hmm.rs | 324 ++++++++--- trvz/src/main.rs | 32 +- trvz/src/pipe_plot.rs | 8 +- 35 files changed, 1760 insertions(+), 873 deletions(-) create mode 100644 trgt/src/faidx.rs create mode 100644 trgt/src/karyotype.rs delete mode 100644 trgt/src/label/hmm_defs.rs rename trgt/src/{ => reads}/snp.rs (99%) create mode 100644 trgt/src/writers/mod.rs rename trgt/src/{read_output.rs => writers/write_bam.rs} (79%) rename trgt/src/{vcf.rs => writers/write_vcf.rs} (96%) delete mode 100644 trvz/src/hmm_defs.rs diff --git a/README.md b/README.md index a24b9fc..5d1bfbe 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,17 @@ assessment of tandem repeats at scale. bioRxiv. 2023](https://doi.org/10.1101/20 - BAM files now contain read-to-allele assignments - Added support for gzip compressed repeat files - Improved error handling and error messages +- 0.6.0 + - Add alignment CIGARs to spanning.bam reads + - Increase read extraction region + - Cluster genotyper reports confidence intervals + - Improved error handling of invalid input files (genome, catalog + and reads) +- 0.7.0 + - Read phasing information can now be used during repeat genotyping (via `HP` tags) + - Users can now define complex repeats by specifying motif sequences in the MOTIFS field and setting STRUC to <`locus_name`> + - The original MAPQ values in the input reads are now reported in the BAM output + - BAMlet sample name can now be provided using the `--sample-name` flag; if it not provided, it is extracted from the input BAM or file stem (addressing issue #18) ### DISCLAIMER THIS WEBSITE AND CONTENT AND ALL SITE-RELATED SERVICES, INCLUDING ANY DATA, ARE diff --git a/trgt/Cargo.toml b/trgt/Cargo.toml index 8a35501..e64ac6d 100644 --- a/trgt/Cargo.toml +++ b/trgt/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "trgt" -version = "0.5.0" +version = "0.7.0" edition = "2021" build = "build.rs" diff --git a/trgt/src/cli.rs b/trgt/src/cli.rs index a0b0c60..2a9083b 100644 --- a/trgt/src/cli.rs +++ b/trgt/src/cli.rs @@ -70,6 +70,13 @@ pub struct CliParams { #[clap(default_value = "XX")] pub karyotype: String, + #[clap(long = "sample-name")] + #[clap(value_name = "SAMPLE_NAME")] + #[clap(help = "Sample name")] + #[clap(default_value = None)] + #[arg(value_parser = check_sample_name_nonempty)] + pub sample_name: Option, + #[clap(long = "max-depth")] #[clap(value_name = "MAX_DEPTH")] #[clap(help = "Maximum locus depth")] @@ -94,7 +101,7 @@ pub struct CliParams { pub aln_scoring: TrgtScoring, #[clap(help_heading("Advanced"))] - #[clap(long = "min-flank-id-perc")] + #[clap(long = "min-flank-id-frac")] #[clap(value_name = "PERC")] #[clap(help = "Minimum fraction of matches in a flank sequence to consider it 'found'")] #[clap(default_value = "0.7")] @@ -199,6 +206,14 @@ fn check_file_exists(s: &str) -> Result { } } +fn check_sample_name_nonempty(s: &str) -> Result { + if s.trim().is_empty() { + Err("Sample name cannot be an empty string".to_string()) + } else { + Ok(s.to_string()) + } +} + fn scoring_from_string(s: &str) -> Result { const NUM_EXPECTED_VALUES: usize = 6; let values: Vec = s.split(',').filter_map(|x| x.parse().ok()).collect(); diff --git a/trgt/src/cluster/cluster.rs b/trgt/src/cluster/cluster.rs index e36f0c8..4373724 100644 --- a/trgt/src/cluster/cluster.rs +++ b/trgt/src/cluster/cluster.rs @@ -2,152 +2,135 @@ use crate::cluster::consensus; use crate::cluster::math::median; use crate::genotype::{Gt, TrSize}; use arrayvec::ArrayVec; -use bio::alignment::distance::levenshtein; +use bio::alignment::distance::simd::bounded_levenshtein; use bio::alignment::{pairwise::*, Alignment}; use itertools::Itertools; use kodama::{linkage, Method}; +pub fn central_read(num_seqs: usize, group: &[usize], dists: &[f64]) -> usize { + let group_size = group.len(); + if group_size <= 2 { + return group[0]; + } + + let mut dist_sums = vec![0_f64; group_size]; + for i in 0..(group_size - 1) { + for j in (i + 1)..group_size { + let index1 = group[i]; + let index2 = group[j]; + + /* dist_sums has the condensed distance matrix, element 0 is + * dist(seq[0], seq[1]), element 1 is dist(seq[0], seq[2]), etc. + * This is the "inverse" that finds the matrix index based on the + * index of the two seqs being compared */ + let mat_index = num_seqs * index1 - index1 * (index1 + 3) / 2 + index2 - 1; + dist_sums[i] += dists[mat_index]; + dist_sums[j] += dists[mat_index]; + } + } + + dist_sums + .iter() + .enumerate() + .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(index, _)| group[index]) + .unwrap() +} + +pub fn make_consensus( + num_seqs: usize, + trs: &[&str], + dists: &[f64], + group: &[usize], +) -> (String, TrSize) { + let seqs = group.iter().map(|&i| trs[i]).collect_vec(); + let backbone = trs[central_read(num_seqs, group, dists)]; + let aligns = align(backbone, &seqs); + let allele = consensus::repair_consensus(backbone, &seqs, &aligns); + let size = TrSize::new(allele.len(), get_ci(&seqs)); + + (allele, size) +} + pub fn genotype(seqs: &Vec<&[u8]>, trs: &[&str]) -> (Gt, Vec, Vec) { - const MAX_SEQS: usize = 50; // max number of sequences to do all-pairs Levenshtein - let mut groups = cluster(seqs, MAX_SEQS); + let mut dists = get_dist_matrix(seqs); + let mut groups = cluster(seqs.len(), &mut dists); groups.sort_by_key(|a| a.len()); let group1 = groups.pop().unwrap(); - let seqs1 = group1.iter().map(|i| trs[*i]).collect_vec(); - let backbone = *seqs1.first().unwrap(); - let aligns1 = align(backbone, &seqs1); - let allele1 = consensus::repair_consensus(backbone, &seqs1, &aligns1); - - let size1 = TrSize { - size: allele1.len(), - ci: get_ci(&seqs1), + let (allele1, size1) = make_consensus(seqs.len(), trs, &dists, &group1); + + let group2 = groups.pop().unwrap_or_default(); + + const MAX_GROUP_RATIO: usize = 10; + let is_homozygous = group2.is_empty() || group2.len() * MAX_GROUP_RATIO <= group1.len() || { + // reject group2 if estimated consensus size difference is negligible + // and group 1 has significantly more reads. + let group1_len = + median(&group1.iter().map(|&i| trs[i].len() as i32).collect_vec()).unwrap(); + let group2_len = + median(&group2.iter().map(|&i| trs[i].len() as i32).collect_vec()).unwrap(); + let delta = (group1_len - group2_len).abs(); + let group1_frac = group1.len() as f64 / (group1.len() as f64 + group2.len() as f64); + + const MIN_GROUP_SIZE_DIFF: f32 = 100.0; + const MAX_SIMILAR_GROUP_FRAC: f64 = 0.80; + delta <= MIN_GROUP_SIZE_DIFF && group1_frac >= MAX_SIMILAR_GROUP_FRAC }; - if groups.is_empty() { + if is_homozygous { let gt = ArrayVec::from([size1.clone(), size1]); let alleles = vec![allele1.clone(), allele1]; - let classification = vec![0_i32; seqs.len()]; - return (gt, alleles, classification); - } - - let group2 = groups.pop().unwrap(); - let tiny_group2 = group1.len() >= 10 && group2.len() == 1; - let group1_len = median(&group1.iter().map(|i| trs[*i].len() as i32).collect_vec()).unwrap(); - let group2_len = median(&group2.iter().map(|i| trs[*i].len() as i32).collect_vec()).unwrap(); - let delta = (group1_len - group2_len).abs(); - let group1_frac = group1.len() as f64 / (group1.len() as f64 + group2.len() as f64); - let small_group2 = delta <= 100.0 && group1_frac >= 0.80; - - if tiny_group2 || small_group2 { - let gt = ArrayVec::from([size1.clone(), size1]); - let alleles = vec![allele1.clone(), allele1]; - let classification = vec![0_i32; seqs.len()]; + // distribute reads across alleles "randomly" + let classification = (0..seqs.len() as i32).map(|x| x % 2).collect::>(); return (gt, alleles, classification); } - let seqs2 = group2.iter().map(|i| trs[*i]).collect_vec(); - let backbone = *seqs2.first().unwrap(); - let aligns2 = align(backbone, &seqs2); - let allele2 = consensus::repair_consensus(backbone, &seqs2, &aligns2); - - let mut classifications = vec![2_i32; seqs.len()]; - - // for reads already clustered, no need to do edit distance + let (allele2, size2) = make_consensus(seqs.len(), trs, &dists, &group2); + let mut classifications = vec![2; seqs.len()]; for seq_index in group1 { - classifications[seq_index] = 0_i32; + classifications[seq_index] = 0; } for seq_index in group2 { - classifications[seq_index] = 1_i32; + classifications[seq_index] = 1; } - // assign reads that weren't used for clustering to the closest allele - if seqs.len() > MAX_SEQS { - let mut tie_breaker = 1; - let a1 = &allele1.as_bytes(); - let a2 = &allele2.as_bytes(); - for (tr, classification) in trs.iter().zip(classifications.iter_mut()) { - if *classification == 2 { - //no allele assigned - let tr = &tr.as_bytes(); - let dist1 = levenshtein(tr, a1); - let dist2 = levenshtein(tr, a2); - *classification = match dist1.cmp(&dist2) { - std::cmp::Ordering::Less => 0, - std::cmp::Ordering::Greater => 1, - std::cmp::Ordering::Equal => { - tie_breaker = (tie_breaker + 1) % 2; - tie_breaker - } - }; - } - } - } - let size2 = TrSize { - size: allele2.len(), - ci: get_ci(&seqs2), + let (gt, alleles) = if allele1.len() > allele2.len() { + classifications = classifications.iter().map(|x| 1 - x).collect(); + (ArrayVec::from([size2, size1]), vec![allele2, allele1]) + } else { + (ArrayVec::from([size1, size2]), vec![allele1, allele2]) }; - let mut gt = Gt::new(); - gt.push(size1); - gt.push(size2); - (gt, vec![allele1, allele2], classifications) + (gt, alleles, classifications) } -pub fn cluster(seqs: &Vec<&[u8]>, max_seqs: usize) -> Vec> { - if seqs.is_empty() { +pub fn cluster(num_seqs: usize, dists: &mut Vec) -> Vec> { + if num_seqs == 0 { return Vec::new(); } - if seqs.len() == 1 { + assert_eq!(num_seqs * (num_seqs - 1) / 2, dists.len()); + if num_seqs == 1 { return vec![vec![0]]; } - if seqs.len() == 2 { + if num_seqs == 2 { return vec![vec![0], vec![1]]; } - let (seqs, orig_index) = if seqs.len() <= max_seqs { - (seqs.clone(), (0..seqs.len()).collect::>()) - } else { - // uniformly pick sequences by length distribution - let num_reads = seqs.len(); - log::warn!( - "Subsampling {} / {} reads for sequence-based clustering", - max_seqs, - num_reads - ); - let mut ret: Vec<&[u8]> = Vec::new(); - let mut orig_index = vec![0; max_seqs]; - ret.reserve(max_seqs); - let mut fast: f64 = 0.0; - let step = (num_reads as f64) / (max_seqs as f64); - - for item in orig_index.iter_mut().take(max_seqs) { - let ind = fast.floor() as usize; - ret.push(seqs[ind]); - *item = ind; - fast += step; - } + let dendrogram = linkage(dists, num_seqs, Method::Ward); - (ret, orig_index) - }; - - let mut dists = get_dist_matrix(&seqs); - let dendrogram = linkage(&mut dists, seqs.len(), Method::Ward); - let cutoff = dendrogram - .steps() - .iter() - .map(|s| s.dissimilarity) - .max_by(|a, b| a.partial_cmp(b).unwrap()) - .unwrap() - * 0.7; + // last element is the last merge, which is the highest dissimilarity + let cutoff = dendrogram.steps().last().unwrap().dissimilarity; + let cutoff = cutoff * 0.7; let mut num_groups = 0; - let num_nodes = 2 * seqs.len() - 1; + let num_nodes = 2 * num_seqs - 1; let mut membership = vec![None; num_nodes]; for (cluster_index, step) in dendrogram.steps().iter().enumerate().rev() { - let cluster = cluster_index + seqs.len(); + let cluster = cluster_index + num_seqs; if step.dissimilarity <= cutoff { if membership[cluster].is_none() { membership[cluster] = Some(num_groups); @@ -159,10 +142,8 @@ pub fn cluster(seqs: &Vec<&[u8]>, max_seqs: usize) -> Vec> { } } - let mut groups = Vec::new(); - groups.reserve(seqs.len()); - - for group in membership.into_iter().take(seqs.len()) { + let mut groups = Vec::with_capacity(num_seqs); + for group in membership.into_iter().take(num_seqs) { if let Some(group) = group { groups.push(group); } else { @@ -173,7 +154,7 @@ pub fn cluster(seqs: &Vec<&[u8]>, max_seqs: usize) -> Vec> { let mut seqs_by_group = vec![Vec::new(); num_groups]; for (seq_index, group) in groups.iter().enumerate() { - seqs_by_group[*group].push(orig_index[seq_index]); + seqs_by_group[*group].push(seq_index); } seqs_by_group @@ -186,13 +167,21 @@ fn get_ci(seqs: &[&str]) -> (usize, usize) { } fn get_dist_matrix(seqs: &Vec<&[u8]>) -> Vec { - let dist_len = seqs.len() * (seqs.len() - 1); - let dist_len = (dist_len as f64 / 2.0) as usize; - let mut dists = Vec::new(); - dists.reserve(dist_len); - for index1 in 0..seqs.len() { - for index2 in index1 + 1..seqs.len() { - let dist = levenshtein(seqs[index1], seqs[index2]); + let dist_len = seqs.len() * (seqs.len() - 1) / 2; + let mut dists = Vec::with_capacity(dist_len); + for (index1, seq1) in seqs.iter().enumerate() { + for (_index2, seq2) in seqs.iter().enumerate().skip(index1 + 1) { + let max_len = std::cmp::max(seq1.len(), seq2.len()) as u32; + let min_len = std::cmp::min(seq1.len(), seq2.len()) as u32; + let length_diff = max_len - min_len; + + // we'll skip ED in cases we already know it will be too costly to do so. + const MAX_K: u32 = 500; + let dist = if length_diff <= MAX_K { + bounded_levenshtein(seq1, seq2, MAX_K).unwrap_or(MAX_K) + } else { + length_diff // lower bound on ED + }; dists.push(dist as f64); } } @@ -202,7 +191,6 @@ fn get_dist_matrix(seqs: &Vec<&[u8]>) -> Vec { fn align(backbone: &str, seqs: &[&str]) -> Vec { let mut aligner = Aligner::new(-5, -1, |a, b| if a == b { 1i32 } else { -1i32 }); - seqs.iter() .map(|seq| aligner.global(seq.as_bytes(), backbone.as_bytes())) .collect_vec() diff --git a/trgt/src/cluster/consensus.rs b/trgt/src/cluster/consensus.rs index eb1e3b9..a304d02 100644 --- a/trgt/src/cluster/consensus.rs +++ b/trgt/src/cluster/consensus.rs @@ -107,6 +107,7 @@ fn get_ins_consensus(ins_by_read: &mut Vec, num_reads: usize) -> &str { } } +/* #[cfg(test)] mod tests { @@ -152,3 +153,4 @@ mod tests { //assert_eq!(fixed, expected); } } + */ diff --git a/trgt/src/faidx.rs b/trgt/src/faidx.rs new file mode 100644 index 0000000..2dd153f --- /dev/null +++ b/trgt/src/faidx.rs @@ -0,0 +1,158 @@ +// https://github.com/rust-bio/rust-htslib/blob/master/src/faidx/mod.rs +// Copyright 2020 Manuel Landesfeind, Evotec International GmbH +// Licensed under the MIT license (http://opensource.org/licenses/MIT) +// This file may not be copied, modified, or distributed +// except according to those terms. + +// TODO: Temporary work-around until https://github.com/rust-bio/rust-htslib/pull/410 gets merged which adds fetch_seq_len + +//! +//! Module for working with faidx-indexed FASTA files. +//! + +use rust_htslib::errors::{Error, Result}; +use rust_htslib::htslib; +use rust_htslib::utils::path_as_bytes; +use std::collections::HashMap; +use std::ffi; +use std::path::Path; + +/// A Fasta reader. +#[derive(Debug)] +pub struct Reader { + inner: *mut htslib::faidx_t, +} + +impl Reader { + /// Create a new Reader from a path. + /// + /// # Arguments + /// + /// * `path` - the path to open. + pub fn from_path>(path: P) -> Result { + Self::new(&path_as_bytes(path, true)?) + } + + /// Internal function to create a Reader from some sort of path (could be file path but also URL). + /// The path or URL will be handled by the c-implementation transparently. + /// + /// # Arguments + /// + /// * `path` - the path or URL to open + fn new(path: &[u8]) -> Result { + let cpath = ffi::CString::new(path).unwrap(); + let inner = unsafe { htslib::fai_load(cpath.as_ptr()) }; + Ok(Self { inner }) + } + + /// Fetch the sequence as a byte array. + /// + /// # Arguments + /// + /// * `name` - the name of the template sequence (e.g., "chr1") + /// * `begin` - the offset within the template sequence (starting with 0) + /// * `end` - the end position to return (if smaller than `begin`, the behavior is undefined). + pub fn fetch_seq>(&self, name: N, begin: usize, end: usize) -> Result<&[u8]> { + if begin > std::i64::MAX as usize { + return Err(Error::FaidxPositionTooLarge); + } + if end > std::i64::MAX as usize { + return Err(Error::FaidxPositionTooLarge); + } + let cname = ffi::CString::new(name.as_ref().as_bytes()).unwrap(); + let len_out: i64 = 0; + let cseq = unsafe { + let ptr = htslib::faidx_fetch_seq64( + self.inner, //*const faidx_t, + cname.as_ptr(), // c_name + begin as htslib::hts_pos_t, // p_beg_i + end as htslib::hts_pos_t, // p_end_i + &mut (len_out as htslib::hts_pos_t), //len + ); + ffi::CStr::from_ptr(ptr) + }; + + Ok(cseq.to_bytes()) + } + + /// Fetches the sequence and returns it as string. + /// + /// # Arguments + /// + /// * `name` - the name of the template sequence (e.g., "chr1") + /// * `begin` - the offset within the template sequence (starting with 0) + /// * `end` - the end position to return (if smaller than `begin`, the behavior is undefined). + pub fn fetch_seq_string>( + &self, + name: N, + begin: usize, + end: usize, + ) -> Result { + let bytes = self.fetch_seq(name, begin, end)?; + Ok(std::str::from_utf8(bytes).unwrap().to_owned()) + } + + /// Fetches the number of sequences in the fai index + pub fn n_seqs(&self) -> u64 { + let n = unsafe { htslib::faidx_nseq(self.inner) }; + n as u64 + } + + /// Fetches the i-th sequence name + /// + /// # Arguments + /// + /// * `i` - index to query + pub fn seq_name(&self, i: i32) -> Result { + let cname = unsafe { + let ptr = htslib::faidx_iseq(self.inner, i); + ffi::CStr::from_ptr(ptr) + }; + + let out = match cname.to_str() { + Ok(s) => s.to_string(), + Err(_) => { + return Err(Error::FaidxBadSeqName); + } + }; + + Ok(out) + } + + pub fn fetch_seq_len>(&self, name: N) -> Option { + let cname = ffi::CString::new(name.as_ref().as_bytes()).unwrap(); + let seq_len = unsafe { htslib::faidx_seq_len(self.inner, cname.as_ptr()) }; + if seq_len >= 0 { + Some(seq_len) + } else { + None + } + } + + /// Create a HashMap mapping each sequence name to its length. + pub fn create_chrom_lookup(&self) -> Result, String> { + let num_seqs = self.n_seqs() as usize; + let mut map = HashMap::with_capacity(num_seqs); + for i in 0..num_seqs { + let name = self.seq_name(i as i32).map_err(|e| e.to_string())?; + if let Some(len) = self.fetch_seq_len(&name) { + let len_u32 = u32::try_from(len).map_err(|_| { + format!( + "Sequence length for '{}' is negative and cannot be converted to u32", + &name + ) + })?; + map.insert(name, len_u32); + } + } + Ok(map) + } +} + +impl Drop for Reader { + fn drop(&mut self) { + unsafe { + htslib::fai_destroy(self.inner); + } + } +} diff --git a/trgt/src/genotype/consensus.rs b/trgt/src/genotype/consensus.rs index ea913cf..abf5643 100644 --- a/trgt/src/genotype/consensus.rs +++ b/trgt/src/genotype/consensus.rs @@ -124,6 +124,7 @@ fn get_ins_consensus(ins_by_read: &mut Vec, num_reads: usize) -> &str { } } +/* #[cfg(test)] mod tests { use super::*; @@ -170,3 +171,4 @@ mod tests { //assert_eq!(fixed, expected); } } +*/ diff --git a/trgt/src/genotype/flank.rs b/trgt/src/genotype/flank.rs index 744fde4..800ead4 100644 --- a/trgt/src/genotype/flank.rs +++ b/trgt/src/genotype/flank.rs @@ -1,12 +1,87 @@ use super::Gt; +use crate::cluster::consensus::repair_consensus; +use crate::cluster::math::median; use crate::{genotype::TrSize, reads::HiFiRead}; +use bio::alignment::{pairwise::*, Alignment}; use itertools::Itertools; use std::cmp::Ordering; type Profile = Vec>; -pub fn genotype(reads: &Vec, tr_seqs: &[&str]) -> Option<(Gt, Vec, Vec)> { - if reads.is_empty() { +pub fn genotype(reads: &[HiFiRead], tr_seqs: &[&str]) -> Option<(Gt, Vec, Vec)> { + let (trs_by_allele, mut allele_assignment) = + get_trs_with_hp(reads, tr_seqs).or_else(|| get_trs_with_clustering(reads, tr_seqs))?; + let mut gt = Gt::new(); + let mut alleles = Vec::new(); + + for trs in trs_by_allele { + let (backbone, frequency) = simple_consensus(&trs)?; + const MIN_FREQ_TO_ALIGN: f64 = 0.5; + let allele = if frequency < MIN_FREQ_TO_ALIGN { + let aligns = align(&backbone, &trs); + repair_consensus(&backbone, &trs, &aligns) + } else { + backbone.to_string() + }; + + let min_tr_len = trs.iter().map(|tr| tr.len()).min().unwrap(); + let max_tr_len = trs.iter().map(|tr| tr.len()).max().unwrap(); + + let size = TrSize::new(allele.len(), (min_tr_len, max_tr_len)); + gt.push(size); + alleles.push(allele); + } + + // Smaller allele should always appear first + if alleles[0].len() > alleles[1].len() { + gt.swap(0, 1); + alleles.swap(0, 1); + allele_assignment = allele_assignment.into_iter().map(|a| (a + 1) % 2).collect(); + } + + Some((gt, alleles, allele_assignment)) +} + +fn get_trs_with_hp<'a>( + reads: &[HiFiRead], + tr_seqs: &[&'a str], +) -> Option<([Vec<&'a str>; 2], Vec)> { + let mut allele_assignment = Vec::new(); + let mut trs_by_allele = [Vec::new(), Vec::new()]; + let mut assignment_tie_breaker: usize = 1; + let mut num_unassigned = 0; + for (read, tr_seq) in reads.iter().zip(tr_seqs.iter()) { + match read.hp_tag { + Some(1) => { + allele_assignment.push(0_i32); + trs_by_allele[0].push(*tr_seq); + } + Some(2) => { + allele_assignment.push(1_i32); + trs_by_allele[1].push(*tr_seq); + } + _ => { + assignment_tie_breaker = (assignment_tie_breaker + 1) % 2; + allele_assignment.push(assignment_tie_breaker as i32); + trs_by_allele[assignment_tie_breaker].push(*tr_seq); + num_unassigned += 1; + } + } + } + + let prop_assigned = (reads.len() - num_unassigned) as f64 / reads.len() as f64; + if !trs_by_allele[0].is_empty() && !trs_by_allele[1].is_empty() && prop_assigned >= 0.7 { + Some((trs_by_allele, allele_assignment)) + } else { + None + } +} + +fn get_trs_with_clustering<'a>( + reads: &[HiFiRead], + tr_seqs: &[&'a str], +) -> Option<([Vec<&'a str>; 2], Vec)> { + if tr_seqs.is_empty() { return None; } @@ -58,32 +133,7 @@ pub fn genotype(reads: &Vec, tr_seqs: &[&str]) -> Option<(Gt, Vec alleles[1].len() { - gt.swap(0, 1); - alleles.swap(0, 1); - allele_assignment = allele_assignment.into_iter().map(|a| (a + 1) % 2).collect(); - } - - Some((gt, alleles, allele_assignment)) + Some((trs_by_allele, allele_assignment)) } fn get_dist(read: &[Option], allele: &[bool]) -> usize { @@ -93,13 +143,25 @@ fn get_dist(read: &[Option], allele: &[bool]) -> usize { .sum() } -fn simple_consensus(strs: &[&str]) -> String { - strs.iter() - .counts() +/// Determine consensus for the input sequences +/// +/// Return the most frequent sequence and its relative frequency. +/// If multiple sequences have the same frequency, +/// return the one whose length is closest to the median. +/// +fn simple_consensus(seqs: &[&str]) -> Option<(String, f64)> { + let median_len = median(&seqs.iter().map(|s| s.len() as i32).collect_vec())? as usize; + let seq_to_count = seqs.iter().counts(); + let top_group_size = *seq_to_count.values().max()?; + let consensus = seq_to_count .into_iter() - .max_by_key(|&(_, count)| count) - .map(|(s, _)| s.to_string()) - .unwrap_or_default() + .filter(|(_, c)| *c == top_group_size) + .map(|(s, _)| (s, s.len().abs_diff(median_len))) + .min_by_key(|(_, delta)| *delta) + .map(|(s, _)| s.to_string())?; + + let top_group_frequency = (top_group_size as f64) / (seqs.len() as f64); + Some((consensus, top_group_frequency)) } fn get_loglik(gt: &(Vec, Vec), profiles: &Vec) -> f64 { @@ -206,7 +268,7 @@ fn get_profiles(reads: &[HiFiRead], snvs: &[i32]) -> Vec { profiles } -fn call_snvs(region: (i32, i32), reads: &Vec, min_freq: f64) -> Vec { +fn call_snvs(region: (i32, i32), reads: &[HiFiRead], min_freq: f64) -> Vec { let offset_counts = reads .iter() .filter_map(|r| r.mismatch_offsets.as_ref()) @@ -223,6 +285,13 @@ fn call_snvs(region: (i32, i32), reads: &Vec, min_freq: f64) -> Vec Vec { + let mut aligner = Aligner::new(-5, -1, |a, b| if a == b { 1i32 } else { -1i32 }); + seqs.iter() + .map(|seq| aligner.global(seq.as_bytes(), backbone.as_bytes())) + .collect_vec() +} + #[cfg(test)] mod tests { use super::*; @@ -262,6 +331,8 @@ mod tests { start_offset, end_offset, cigar: None, + hp_tag: None, + mapq: 60, } } diff --git a/trgt/src/genotype/genotype.rs b/trgt/src/genotype/genotype.rs index 8a7f6a1..3cf9a99 100644 --- a/trgt/src/genotype/genotype.rs +++ b/trgt/src/genotype/genotype.rs @@ -30,6 +30,10 @@ pub fn genotype(ploidy: Ploidy, seqs: &Vec<&str>) -> (Gt, Vec, Vec) } alleles = fixed_alleles; + if ploidy == Ploidy::Two && alleles.len() == 1 { + alleles.push(alleles[0].clone()); + } + let mut classifications = vec![0_i32; seqs.len()]; let mut tie_breaker = 1; for (seq, classification) in seqs.iter().zip(classifications.iter_mut()) { @@ -46,10 +50,6 @@ pub fn genotype(ploidy: Ploidy, seqs: &Vec<&str>) -> (Gt, Vec, Vec) }; } } - // allele_seqs is expected to contain two elements - if alleles.len() == 1 { - alleles.push(alleles[0].clone()); - } (gt, alleles, classifications) } diff --git a/trgt/src/karyotype.rs b/trgt/src/karyotype.rs new file mode 100644 index 0000000..d15d6aa --- /dev/null +++ b/trgt/src/karyotype.rs @@ -0,0 +1,96 @@ +use crate::genotype::Ploidy; +use std::collections::HashMap; +use std::fs; + +#[derive(Debug, PartialEq, Clone)] +pub struct Karyotype { + ploidy: PloidyInfo, +} + +#[derive(Debug, PartialEq, Clone)] +enum PloidyInfo { + PresetXX, + PresetXY, + Custom(HashMap), +} + +impl Karyotype { + pub fn new(encoding: &str) -> Result { + let ploidy = match encoding { + "XX" => PloidyInfo::PresetXX, + "XY" => PloidyInfo::PresetXY, + _ => return Self::from_file(encoding), + }; + Ok(Self { ploidy }) + } + + #[cfg(test)] + pub fn new_for_test(ploidies: HashMap) -> Self { + Self { + ploidy: PloidyInfo::Custom(ploidies), + } + } + + fn from_file(path: &str) -> Result { + let contents = fs::read_to_string(path).map_err(|e| format!("File {}: {}", path, e))?; + + let ploidies = contents + .lines() + .enumerate() + .map(|(line_number, line)| { + let mut parts = line.split_whitespace(); + let chrom = parts + .next() + .ok_or(format!("Missing chromosome at line {}", line_number).to_string())?; + let ploidy_str = parts + .next() + .ok_or(format!("Missing ploidy at line {}", line_number).to_string())?; + let ploidy = ploidy_str.parse().map_err(|e: String| { + format!("Invalid ploidy at line {}, {}", line_number, e) + })?; + Ok((chrom.to_string(), ploidy)) + }) + .collect::, String>>()?; + + Ok(Self { + ploidy: PloidyInfo::Custom(ploidies), + }) + } + + pub fn get_ploidy(&self, chrom: &str) -> Result { + match &self.ploidy { + PloidyInfo::PresetXX => match chrom { + "Y" | "chrY" => Ok(Ploidy::Zero), + _ => Ok(Ploidy::Two), + }, + PloidyInfo::PresetXY => match chrom { + "X" | "chrX" | "Y" | "chrY" => Ok(Ploidy::One), + _ => Ok(Ploidy::Two), + }, + PloidyInfo::Custom(ploidies) => ploidies + .get(chrom) + .copied() + .ok_or_else(|| format!("Ploidy was not specified for chromosome: {}", chrom)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::genotype::Ploidy; + use std::collections::HashMap; + + #[test] + fn test_karyotype_custom() { + let mut ploidies = HashMap::new(); + ploidies.insert("chr1".to_string(), Ploidy::Two); + ploidies.insert("chr2".to_string(), Ploidy::One); + + let karyotype = Karyotype::new_for_test(ploidies); + + assert_eq!(karyotype.get_ploidy("chr1").unwrap(), Ploidy::Two); + assert_eq!(karyotype.get_ploidy("chr2").unwrap(), Ploidy::One); + assert!(karyotype.get_ploidy("chrX").is_err()); + } +} diff --git a/trgt/src/label/hmm.rs b/trgt/src/label/hmm.rs index e7bcbe9..0b8dfc6 100644 --- a/trgt/src/label/hmm.rs +++ b/trgt/src/label/hmm.rs @@ -1,8 +1,7 @@ +use super::spans::Span; use itertools::Itertools; use std::collections::HashMap; -use super::Span; - // List of abbreviations // lp = log probability // ems = emissions @@ -10,14 +9,16 @@ use super::Span; type MatF64 = Vec>; type MatInt = Vec>; +#[derive(Debug, PartialEq)] pub struct Hmm { - num_states: usize, - ems: MatF64, + pub num_states: usize, + pub ems: MatF64, in_states: MatInt, in_lps: MatF64, pub motifs: Vec, } +#[derive(Debug, PartialEq)] pub struct HmmMotif { pub start_state: usize, pub end_state: usize, @@ -145,6 +146,9 @@ impl Hmm { } pub fn label(&self, query: &str) -> Vec { + if query.is_empty() { + return Vec::new(); + } let query = "#" .bytes() .chain(query.bytes().chain("#".bytes())) @@ -204,7 +208,7 @@ impl Hmm { motif_spans } - fn emits_base(&self, state: usize) -> bool { + pub fn emits_base(&self, state: usize) -> bool { self.ems[state].iter().skip(1).any(|e| e.is_finite()) } diff --git a/trgt/src/label/hmm_defs.rs b/trgt/src/label/hmm_defs.rs deleted file mode 100644 index dbeff87..0000000 --- a/trgt/src/label/hmm_defs.rs +++ /dev/null @@ -1,12 +0,0 @@ -use lazy_static::lazy_static; -use std::collections::HashMap; - -lazy_static! { - pub static ref HMM_DEFS: HashMap<&'static str, &'static str> = { - let mut hmm_defs = HashMap::new(); - hmm_defs.insert("", "[[1.0,0.0,0.0,0.0,0.0],[0.0,0.05,0.45,0.45,0.05],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.45,0.45,0.05,0.05],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[1.0,0.0,0.0,0.0,0.0]]|[[[],[]],[[0,1,8],[1.0,0.7,0.2]],[[1],[0.2]],[[2,9,15],[0.8,0.5,0.5]],[[3,10,16],[0.8,0.5,0.5]],[[4,11,17],[0.8,0.5,0.5]],[[5,12,18],[0.8,0.5,0.5]],[[6,13,19],[0.8,0.5,0.5]],[[7,8,14,20],[0.8,0.7,0.5,0.1]],[[1,2,9],[0.05,0.1,0.5]],[[3,10],[0.1,0.5]],[[4,11],[0.1,0.5]],[[5,12],[0.1,0.5]],[[6,13],[0.1,0.5]],[[7,14],[0.2,0.5]],[[1],[0.05]],[[2,15],[0.1,0.5]],[[3,16],[0.1,0.5]],[[4,17],[0.1,0.5]],[[5,18],[0.1,0.5]],[[6,19],[0.1,0.5]],[[8],[0.1]]]|(1,8,0)"); - hmm_defs.insert("", "[[1.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.7,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.7,0.1],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[1.0,0.0,0.0,0.0,0.0]]|[[[],[]],[[0,133],[1.0,0.9]],[[1,17],[0.125,0.5]],[[2],[0.9]],[[2,3,8],[0.03387533875338753,0.8,0.2]],[[2,4,9,13],[0.02710027100271003,0.8,0.2,0.5]],[[2,5,10,14],[0.02168021680216803,0.8,0.2,0.5]],[[2,6,11,15],[0.01734417344173442,0.8,0.2,0.5]],[[8,3],[0.8,0.1]],[[9,4],[0.8,0.1]],[[10,5],[0.8,0.1]],[[11,6],[0.8,0.1]],[[12,7],[0.8,0.1]],[[3],[0.1]],[[13,4],[0.5,0.1]],[[14,5],[0.5,0.1]],[[15,6],[0.5,0.1]],[[7,12,16],[0.9,0.2,1.0]],[[1,33],[0.125,0.5]],[[18],[0.9]],[[18,19,24],[0.03387533875338753,0.8,0.2]],[[18,20,25,29],[0.02710027100271003,0.8,0.2,0.5]],[[18,21,26,30],[0.02168021680216803,0.8,0.2,0.5]],[[18,22,27,31],[0.01734417344173442,0.8,0.2,0.5]],[[24,19],[0.8,0.1]],[[25,20],[0.8,0.1]],[[26,21],[0.8,0.1]],[[27,22],[0.8,0.1]],[[28,23],[0.8,0.1]],[[19],[0.1]],[[29,20],[0.5,0.1]],[[30,21],[0.5,0.1]],[[31,22],[0.5,0.1]],[[23,28,32],[0.9,0.2,1.0]],[[1,49],[0.125,0.5]],[[34],[0.9]],[[34,35,40],[0.03387533875338753,0.8,0.2]],[[34,36,41,45],[0.02710027100271003,0.8,0.2,0.5]],[[34,37,42,46],[0.02168021680216803,0.8,0.2,0.5]],[[34,38,43,47],[0.01734417344173442,0.8,0.2,0.5]],[[40,35],[0.8,0.1]],[[41,36],[0.8,0.1]],[[42,37],[0.8,0.1]],[[43,38],[0.8,0.1]],[[44,39],[0.8,0.1]],[[35],[0.1]],[[45,36],[0.5,0.1]],[[46,37],[0.5,0.1]],[[47,38],[0.5,0.1]],[[39,44,48],[0.9,0.2,1.0]],[[1,65],[0.125,0.5]],[[50],[0.9]],[[50,51,56],[0.03387533875338753,0.8,0.2]],[[50,52,57,61],[0.02710027100271003,0.8,0.2,0.5]],[[50,53,58,62],[0.02168021680216803,0.8,0.2,0.5]],[[50,54,59,63],[0.01734417344173442,0.8,0.2,0.5]],[[56,51],[0.8,0.1]],[[57,52],[0.8,0.1]],[[58,53],[0.8,0.1]],[[59,54],[0.8,0.1]],[[60,55],[0.8,0.1]],[[51],[0.1]],[[61,52],[0.5,0.1]],[[62,53],[0.5,0.1]],[[63,54],[0.5,0.1]],[[55,60,64],[0.9,0.2,1.0]],[[1,81],[0.125,0.5]],[[66],[0.9]],[[66,67,72],[0.03387533875338753,0.8,0.2]],[[66,68,73,77],[0.02710027100271003,0.8,0.2,0.5]],[[66,69,74,78],[0.02168021680216803,0.8,0.2,0.5]],[[66,70,75,79],[0.01734417344173442,0.8,0.2,0.5]],[[72,67],[0.8,0.1]],[[73,68],[0.8,0.1]],[[74,69],[0.8,0.1]],[[75,70],[0.8,0.1]],[[76,71],[0.8,0.1]],[[67],[0.1]],[[77,68],[0.5,0.1]],[[78,69],[0.5,0.1]],[[79,70],[0.5,0.1]],[[71,76,80],[0.9,0.2,1.0]],[[1,97],[0.125,0.5]],[[82],[0.9]],[[82,83,88],[0.03387533875338753,0.8,0.2]],[[82,84,89,93],[0.02710027100271003,0.8,0.2,0.5]],[[82,85,90,94],[0.02168021680216803,0.8,0.2,0.5]],[[82,86,91,95],[0.01734417344173442,0.8,0.2,0.5]],[[88,83],[0.8,0.1]],[[89,84],[0.8,0.1]],[[90,85],[0.8,0.1]],[[91,86],[0.8,0.1]],[[92,87],[0.8,0.1]],[[83],[0.1]],[[93,84],[0.5,0.1]],[[94,85],[0.5,0.1]],[[95,86],[0.5,0.1]],[[87,92,96],[0.9,0.2,1.0]],[[1,113],[0.125,0.5]],[[98],[0.9]],[[98,99,104],[0.03387533875338753,0.8,0.2]],[[98,100,105,109],[0.02710027100271003,0.8,0.2,0.5]],[[98,101,106,110],[0.02168021680216803,0.8,0.2,0.5]],[[98,102,107,111],[0.01734417344173442,0.8,0.2,0.5]],[[104,99],[0.8,0.1]],[[105,100],[0.8,0.1]],[[106,101],[0.8,0.1]],[[107,102],[0.8,0.1]],[[108,103],[0.8,0.1]],[[99],[0.1]],[[109,100],[0.5,0.1]],[[110,101],[0.5,0.1]],[[111,102],[0.5,0.1]],[[103,108,112],[0.9,0.2,1.0]],[[1,132],[0.125,0.5]],[[114],[0.9]],[[114,115,121],[0.029747739171822943,0.8,0.2]],[[114,116,122,127],[0.023798191337458356,0.8,0.2,0.5]],[[114,117,123,128],[0.01903855306996669,0.8,0.2,0.5]],[[114,118,124,129],[0.015230842455973349,0.8,0.2,0.5]],[[114,119,125,130],[0.01218467396477868,0.8,0.2,0.5]],[[121,115],[0.8,0.1]],[[122,116],[0.8,0.1]],[[123,117],[0.8,0.1]],[[124,118],[0.8,0.1]],[[125,119],[0.8,0.1]],[[126,120],[0.8,0.1]],[[115],[0.1]],[[127,116],[0.5,0.1]],[[128,117],[0.5,0.1]],[[129,118],[0.5,0.1]],[[130,119],[0.5,0.1]],[[120,126,131],[0.9,0.2,1.0]],[[17,33,49,65,81,97,113,132],[0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]],[[133],[1.0]]]|(2,17,0),(18,33,1),(34,49,2),(50,65,3),(66,81,4),(82,97,5),(98,113,6),(114,132,7)"); - - hmm_defs - }; -} diff --git a/trgt/src/label/label_with_hmm.rs b/trgt/src/label/label_with_hmm.rs index 7bea064..c3a5283 100644 --- a/trgt/src/label/label_with_hmm.rs +++ b/trgt/src/label/label_with_hmm.rs @@ -1,20 +1,46 @@ -use super::hmm::{self, HmmMotif}; -use super::hmm_defs::HMM_DEFS; -use super::{Annotation, Span}; +use super::hmm::{Hmm, HmmMotif}; +use super::spans::Span; +use super::Annotation; +use crate::cli::handle_error_and_exit; use crate::locus::Locus; use itertools::Itertools; pub fn label_with_hmm(locus: &Locus, seqs: &Vec) -> Vec { - let encoding = HMM_DEFS.get(&locus.struc[..]).unwrap(); - let hmm = decode_hmm(encoding); - let mut annotations = Vec::new(); + let motifs = locus + .motifs + .iter() + .map(|m| m.as_bytes().to_vec()) + .collect_vec(); + let hmm = build_hmm(&motifs); + let mut annotations = Vec::new(); for seq in seqs { - let labels = hmm.label(seq); + let seq: String = seq + .as_bytes() + .iter() + .enumerate() + .map(|(i, b)| match b { + b'A' | b'T' | b'C' | b'G' => *b as char, + _ => ['A', 'T', 'C', 'G'][i % 4], + }) + .collect(); + let labels = hmm.label(&seq); + let purity = calc_purity(&seq.as_bytes(), &hmm, &motifs, &labels); let labels = hmm.label_motifs(&labels); + // Remove labels corresponding to the skip state + let labels = labels + .into_iter() + .filter(|rec| rec.motif_index < motifs.len()) + .collect_vec(); let motif_counts = count_motifs(&locus.motifs, &labels); - let labels = Some(collapse_labels(labels)); - let purity = 1.0; + let labels = collapse_labels(labels); + // TODO: Consider using empty labels instead of None + let labels = if !labels.is_empty() { + Some(labels) + } else { + None + }; + annotations.push(Annotation { labels, motif_counts, @@ -25,118 +51,431 @@ pub fn label_with_hmm(locus: &Locus, seqs: &Vec) -> Vec { annotations } -fn decode_hmm(encoding: &str) -> hmm::Hmm { - let mats = encoding.split('|').collect_vec(); - assert!(mats.len() == 3); - let ems = decode_emissions(mats[0]); - let transitions = decode_transitions(mats[1]); +fn collapse_labels(spans: Vec) -> Vec { + let mut collapsed = Vec::new(); + for span in spans { + if collapsed.is_empty() { + collapsed.push(span); + continue; + } + + let last_span = collapsed.last_mut().unwrap(); + if last_span.motif_index == span.motif_index && last_span.end == span.start { + last_span.end = span.end; + } else { + collapsed.push(span); + } + } + collapsed +} + +fn count_motifs(motifs: &Vec, labels: &Vec) -> Vec { + let mut motif_counts = vec![0; motifs.len()]; + for span in labels { + motif_counts[span.motif_index] += 1; + } + motif_counts +} + +pub fn build_hmm(motifs: &[Vec]) -> Hmm { + // 2 terminal states + 2 run start states + 3 states of the skip block + (4n - 1) states for each motif of length n + let num_states = 7 + motifs.iter().map(|m| 3 * m.len() + 1).sum::(); + let mut hmm = Hmm::new(num_states); + + let start = 0; + let end = num_states - 1; + let rs = start + 1; + let re = end - 1; + + // # A T C G + hmm.set_ems(start, vec![1.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_ems(end, vec![1.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(end, vec![re], vec![0.10]); - assert!(ems.len() == transitions.len()); - let num_states = ems.len(); - let mut hmm = hmm::Hmm::new(num_states); + hmm.set_ems(rs, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(rs, vec![start, re], vec![1.00, 0.90]); - for (state, state_ems) in ems.into_iter().enumerate() { - hmm.set_ems(state, state_ems); + let rs_to_ms = 0.50; // / (motifs.len() as f64 + 1.0); <- No longer an HMM because of this change + let me_to_re = 0.05; + let mut mes = Vec::new(); + mes.reserve(motifs.len() + 1); + let mut ms = rs + 1; + for motif in motifs { + let num_motif_states = 3 * motif.len() + 1; + let me = ms + num_motif_states - 1; + + hmm.set_ems(ms, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(ms, vec![rs, me], vec![rs_to_ms, 1.0 - me_to_re]); + + define_motif_block(&mut hmm, ms, &motif); + + mes.push(me); + ms += num_motif_states; } - for (state, (in_states, probs)) in transitions.into_iter().enumerate() { - hmm.set_trans(state, in_states, probs); + assert_eq!(ms + 3, re); + + // Defined the skip block + let (skip_state, me) = (ms + 1, ms + 2); + hmm.set_ems(ms, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(ms, vec![rs, me], vec![rs_to_ms, 1.0 - me_to_re]); + + let skip_to_skip = 0.9; + hmm.set_ems(skip_state, vec![0.00, 0.25, 0.25, 0.25, 0.25]); + hmm.set_trans(skip_state, vec![ms, skip_state], vec![1.0, skip_to_skip]); + + hmm.set_ems(me, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(me, vec![skip_state], vec![1.0 - skip_to_skip]); + + mes.push(me); + + // Define the re state + hmm.set_ems(re, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(re, mes.clone(), vec![me_to_re; motifs.len() + 1]); + + // Define motif spans + for (motif_index, motif) in motifs.iter().enumerate() { + let me = mes[motif_index]; + let ms = me - 3 * motif.len(); + hmm.motifs.push(HmmMotif { + start_state: ms, + end_state: me, + motif_index, + }); } - hmm.motifs = decode_motifs(mats[2]); + // Add skip state span + hmm.motifs.push(HmmMotif { + start_state: skip_state - 1, + end_state: skip_state + 1, + motif_index: motifs.len(), + }); hmm } -fn decode_emissions(encoding: &str) -> Vec> { - let mut mat = Vec::new(); - for row in encoding.split("],[") { - let row = row.trim_matches(|c| "\"[]".contains(c)); - let row = row - .split(',') - .map(|e| e.parse::().unwrap()) - .collect_vec(); - mat.push(row); - } +fn define_motif_block(hmm: &mut Hmm, ms: usize, motif: &Vec) { + let match_states = (ms + 1..ms + 1 + motif.len()).collect_vec(); + let first_ins_state = *match_states.last().unwrap() + 1; + let ins_states = (first_ins_state..first_ins_state + motif.len()).collect_vec(); + let first_del_state = *ins_states.last().unwrap() + 1; // If any + let del_states = (first_del_state..first_del_state + motif.len() - 1).collect_vec(); - mat -} + let match_prob = 0.90; + let ins_to_ins = 0.25; + let match_to_indel = (1.00 - match_prob) / 2.00; + let del_to_match = 0.50; -fn decode_transitions(encoding: &str) -> Vec<(Vec, Vec)> { - let mut transitions = Vec::new(); - let encoding = encoding - .strip_prefix("[[[") - .unwrap() - .strip_suffix("]]]") - .unwrap(); - - for row in encoding.split("]],[[") { - if row.chars().all(|c| "[],".contains(c)) { - transitions.push((Vec::new(), Vec::new())); - continue; + // Define match states + let mismatch_seed_prob = 2.00 * (1.00 - match_prob) / (motif.len() * (motif.len() - 1)) as f64; + for (match_index, match_state) in match_states.iter().enumerate() { + hmm.set_ems(*match_state, get_match_emissions(motif[match_index])); + if match_index == 0 { + hmm.set_trans(*match_state, vec![ms], vec![match_prob]); + } else if match_index == 1 { + let multiplier = motif.len() - match_index; + let mismatch_prob = mismatch_seed_prob * multiplier as f64; + let prev_ins = ins_states[match_index - 1]; + + hmm.set_trans( + *match_state, + vec![match_state - 1, ms, prev_ins], + vec![match_prob, mismatch_prob, 1.0 - ins_to_ins], + ); + } else { + let multiplier = motif.len() - match_index; + let mismatch_prob = mismatch_seed_prob * multiplier as f64; + let prev_ins = ins_states[match_index - 1]; + let prev_del = del_states[match_index - 2]; + + hmm.set_trans( + *match_state, + vec![match_state - 1, ms, prev_ins, prev_del], + vec![match_prob, mismatch_prob, 1.0 - ins_to_ins, del_to_match], + ); } + } - let states_and_probs = row.split("],[").collect_vec(); - assert!(states_and_probs.len() == 2); - let (states, probs) = (states_and_probs[0], states_and_probs[1]); - let states = states - .split(',') - .map(|e| e.parse::().unwrap()) - .collect_vec(); - let probs = probs - .split(',') - .map(|e| e.parse::().unwrap()) - .collect_vec(); + // Define insersion states + for (ins_index, ins_state) in ins_states.iter().enumerate() { + hmm.set_ems(*ins_state, vec![0.00, 0.25, 0.25, 0.25, 0.25]); + let match_state = match_states[ins_index]; + hmm.set_trans( + *ins_state, + vec![*ins_state, match_state], + vec![ins_to_ins, match_to_indel], + ); + } - assert!(states.len() == probs.len()); - transitions.push((states, probs)); + // Define deletion states + for (del_index, del_state) in del_states.iter().enumerate() { + hmm.set_ems(*del_state, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + let prev_match = match_states[del_index]; + if del_index == 0 { + hmm.set_trans(*del_state, vec![prev_match], vec![match_to_indel]); + } else { + let prev_del = del_states[del_index - 1]; + hmm.set_trans( + *del_state, + vec![prev_match, prev_del], + vec![match_to_indel, 1.0 - del_to_match], + ); + } } - transitions + let num_motif_states = 3 * motif.len() + 1; + let me = ms + num_motif_states - 1; + hmm.set_ems(me, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + if !del_states.is_empty() { + let last_match = *match_states.last().unwrap(); + let last_ins = *ins_states.last().unwrap(); + let last_del = *del_states.last().unwrap(); + hmm.set_trans( + me, + vec![last_match, last_ins, last_del], + vec![match_prob, 1.0 - ins_to_ins, 1.0], + ); + } else if !ins_states.is_empty() { + let last_match = *match_states.last().unwrap(); + let last_ins = *ins_states.last().unwrap(); + hmm.set_trans( + me, + vec![last_match, last_ins], + vec![match_prob, 1.0 - ins_to_ins], + ); + } else { + let last_match = *match_states.last().unwrap(); + hmm.set_trans(me, vec![last_match], vec![match_prob]); + } } -fn decode_motifs(encoding: &str) -> Vec { - let mut motifs = Vec::new(); - for motif_encoding in encoding.split("),(") { - let (start_state, end_state, motif_index) = motif_encoding - .trim_matches('(') - .trim_end_matches(')') - .split(',') - .map(|m| m.parse::().unwrap()) - .collect_tuple() - .unwrap(); - motifs.push(HmmMotif { - start_state, - end_state, - motif_index, - }); +fn get_match_emissions(char: u8) -> Vec { + match char { + b'A' => vec![0.00, 0.90, 0.03, 0.03, 0.03], + b'T' => vec![0.00, 0.03, 0.90, 0.03, 0.03], + b'C' => vec![0.00, 0.03, 0.03, 0.90, 0.03], + b'G' => vec![0.00, 0.03, 0.03, 0.03, 0.90], + _ => panic!("Enountered unknown base {char}"), } +} - motifs +#[derive(Debug, PartialEq)] +pub enum HmmEvent { + Match, + Mismatch, + Ins, + Del, + Trans, // Silent states that don't encode alignment operation + Skip, // Skip state that matches bases outside of any motif run } -fn collapse_labels(spans: Vec) -> Vec { - let mut collapsed = Vec::new(); - for span in spans { - if collapsed.is_empty() { - collapsed.push(span); - continue; +pub fn get_events(hmm: &Hmm, motifs: &[Vec], states: &[usize], query: &[u8]) -> Vec { + let mut state_to_hmm_motif = vec![-1; hmm.num_states]; + for hmm_motif in &hmm.motifs { + // Skip two terminal states MS and ME + for state in hmm_motif.start_state + 1..hmm_motif.end_state { + state_to_hmm_motif[state] = hmm_motif.motif_index as i32; } + } - let last_span = collapsed.last_mut().unwrap(); - if last_span.motif_index == span.motif_index && last_span.end == span.start { - last_span.end = span.end; + let mut base_index = 0; + let mut events = Vec::new(); + let base_consumers = [ + HmmEvent::Match, + HmmEvent::Mismatch, + HmmEvent::Ins, + HmmEvent::Skip, + ]; + for state in states { + let motif_index = state_to_hmm_motif[*state]; + let event = if motif_index == -1 { + HmmEvent::Trans + } else if motif_index as usize + 1 == hmm.motifs.len() { + HmmEvent::Skip } else { - collapsed.push(span); + let motif = &hmm.motifs[motif_index as usize]; + let offset = state - motif.start_state - 1; + let motif_len = motifs[motif.motif_index].len(); + match offset.div_euclid(motif_len) { + 0 => { + let base = query[base_index]; + if base == get_base_match(&hmm, *state) { + HmmEvent::Match + } else { + HmmEvent::Mismatch + } + } + 1 => HmmEvent::Ins, + 2 => HmmEvent::Del, + _ => handle_error_and_exit(format!("Event decoding error")), + } + }; + + if base_consumers.contains(&event) { + base_index += 1; } + events.push(event); } - collapsed + + events } -fn count_motifs(motifs: &Vec, labels: &Vec) -> Vec { - let mut motif_counts = vec![0; motifs.len()]; - for span in labels { - motif_counts[span.motif_index] += 1; +fn get_base_match(hmm: &Hmm, state: usize) -> u8 { + let ems = &hmm.ems[state]; + assert_eq!(ems.len(), 5); + + if !hmm.emits_base(state) { + return b' '; + } + + let max_lp = *ems.iter().max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap(); + let top_indexes = ems + .iter() + .enumerate() + .filter(|(_i, p)| **p == max_lp) + .map(|(i, _p)| i) + .collect_vec(); + if top_indexes.len() == 1 { + match top_indexes[0] { + 0 => b'#', + 1 => b'A', + 2 => b'T', + 3 => b'C', + 4 => b'G', + _ => handle_error_and_exit(format!("Unexpected base match event")), + } + } else { + b' ' + } +} + +fn calc_purity(query: &[u8], hmm: &Hmm, motifs: &[Vec], states: &[usize]) -> f64 { + let events = get_events(hmm, motifs, states, query); + let num_matches = events + .iter() + .map(|e| match *e { + HmmEvent::Del | HmmEvent::Ins | HmmEvent::Mismatch | HmmEvent::Skip => -1, + HmmEvent::Match => 1, + HmmEvent::Trans => 0, + }) + .sum::(); + num_matches as f64 / query.len() as f64 +} + +#[cfg(test)] +mod tests { + use super::*; + + fn summarize(spans: &Vec) -> Vec<(usize, usize, usize)> { + let mut summary = Vec::new(); + for (motif_index, group) in &spans + .iter() + .map(|s| (s.start, s.end, s.motif_index)) + .group_by(|(_s, _e, m)| *m) + { + let group = group.collect_vec(); + summary.push(( + group.first().unwrap().0, + group.last().unwrap().1, + motif_index, + )); + } + summary + } + + #[test] + fn annotate_two_perfect_motif_runs() { + let motifs = vec!["CAG".as_bytes().to_vec(), "A".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let labels = hmm.label_motifs(&hmm.label("CAGCAGCAGCAGAAAAA")); + let expected = vec![(0, 12, 0), (12, 17, 1)]; + + assert_eq!(summarize(&labels), expected); + } + + #[test] + fn annotate_motif_runs_separated_by_insertion() { + let motifs = vec!["CAG".as_bytes().to_vec(), "A".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let labels = hmm.label_motifs(&hmm.label("CAGCAGATCGATCGATCGATCGAAAAA")); + let expected = vec![(0, 6, 0), (6, 22, 2), (22, 27, 1)]; + + assert_eq!(summarize(&labels), expected); + } + + #[test] + fn annotate_imperfect_repeat_run() { + let motifs = vec!["CAG".as_bytes().to_vec(), "A".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let labels = hmm.label_motifs(&hmm.label("CAGCAGCTGCAGCAGAAACAG")); + let expected = vec![(0, 21, 0)]; + + assert_eq!(summarize(&labels), expected); + } + + #[test] + fn parse_aga_repeat() { + // TODO: Consider improving this segmentation + let motifs = vec!["AAG".as_bytes().to_vec(), "CAAC".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let query = "TCTATGCAACCAACTTTCTGTTAGTCATAGTACCCCAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAATAGAAATGTGTTTAAGAATTCCTCAATAAG"; + let labels = hmm.label_motifs(&hmm.label(query)); + let expected = vec![ + (0, 6, 2), + (6, 14, 1), + (14, 36, 2), + (36, 101, 0), + (101, 119, 2), + (119, 125, 0), + ]; + + assert_eq!(summarize(&labels), expected); + } + + #[test] + fn states_match_the_most_likely_base() { + let motifs = vec!["A".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + assert_eq!(get_base_match(&hmm, 3), b'A'); + } + + #[test] + fn silent_states_match_a_blank_character() { + let mut hmm = Hmm::new(3); + hmm.set_ems(0, vec![1.00, 0.00, 0.00, 0.00, 0.00]); // Start state + hmm.set_ems(1, vec![0.00, 0.50, 0.50, 0.00, 0.00]); // State with identical scores for 'A' and 'T' + hmm.set_ems(2, vec![1.00, 0.00, 0.00, 0.00, 0.00]); // End state + assert_eq!(get_base_match(&hmm, 1), b' ',); + } + + #[test] + fn calculate_purity_of_perfect_repeats() { + let motifs = vec!["CAG".as_bytes().to_vec(), "CCG".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let query = "CAGCAGCAGCCGCCGCCGCCG"; + let states = hmm.label(&query); + assert_eq!(calc_purity(&query.as_bytes(), &hmm, &motifs, &states), 1.0); + } + + #[test] + fn calculate_purity_of_imperfect_repeats() { + let motifs = vec!["CAG".as_bytes().to_vec(), "CCG".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let query = "CAGCGCAGCCGCCGCCGGG"; + let states = hmm.label(&query); + + let purity = calc_purity(&query.as_bytes(), &hmm, &motifs, &states); + assert_eq!(purity, 16.0 / 19.0); + } + + #[test] + fn calculate_purity_of_repeats_with_skip_states() { + let motifs = vec!["CAG".as_bytes().to_vec(), "CCG".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let query = "CAGCAGCAGTTTTTTTTCCGCCGCCG"; + let states = hmm.label(&query); + + let purity = calc_purity(&query.as_bytes(), &hmm, &motifs, &states); + assert_eq!(purity, 10.0 / 26.0); } - motif_counts } diff --git a/trgt/src/label/mod.rs b/trgt/src/label/mod.rs index bcd9dcd..dee1056 100644 --- a/trgt/src/label/mod.rs +++ b/trgt/src/label/mod.rs @@ -1,8 +1,5 @@ -mod struc; - mod guess_motif_counts; mod hmm; -mod hmm_defs; mod kmer_filter; mod label_alleles; mod label_motif; @@ -10,6 +7,7 @@ mod label_with_hmm; mod label_with_regexp; mod refine_motif_counts; mod spans; +mod struc; pub use label_alleles::label_alleles; pub use label_with_hmm::label_with_hmm; diff --git a/trgt/src/label/refine_motif_counts.rs b/trgt/src/label/refine_motif_counts.rs index ef64f64..bb641a1 100644 --- a/trgt/src/label/refine_motif_counts.rs +++ b/trgt/src/label/refine_motif_counts.rs @@ -2,11 +2,11 @@ use super::struc::*; use bio::alignment::distance::simd::*; pub fn refine_motif_counts( - motifs: &Vec, + motifs: &[Motif], query: &str, motif_counts: &[usize], ) -> (Vec, f64) { - let mut top_counts = fix(motifs, motif_counts); + let mut top_counts = apply_multiplicity_counts(motifs, motif_counts); let mut min_dist = get_dist(motifs, query, &top_counts); loop { @@ -26,7 +26,7 @@ pub fn refine_motif_counts( (top_counts, purity) } -fn get_candidates(motifs: &Vec, seed_counts: &[usize]) -> Vec> { +fn get_candidates(motifs: &[Motif], seed_counts: &[usize]) -> Vec> { let mut candidates = Vec::new(); for motif_index in 0..motifs.len() { if motifs[motif_index].mult == Mult::Once { @@ -40,15 +40,10 @@ fn get_candidates(motifs: &Vec, seed_counts: &[usize]) -> Vec> candidates.last_mut().unwrap()[motif_index] -= 1; } } - candidates } -fn get_top_candidate( - motifs: &Vec, - query: &str, - mut candidates: Vec>, -) -> Vec { +fn get_top_candidate(motifs: &[Motif], query: &str, mut candidates: Vec>) -> Vec { let mut top_candidate = candidates.pop().unwrap(); let mut top_dist = get_dist(motifs, query, &top_candidate); for candidate in candidates { @@ -61,67 +56,203 @@ fn get_top_candidate( top_candidate } -fn fix(motifs: &Vec, motif_counts: &[usize]) -> Vec { - let mut fixed_counts = Vec::new(); - - for index in 0..motifs.len() { - if motifs[index].mult == Mult::Many { - fixed_counts.push(motif_counts[index]); - } else { - fixed_counts.push(1); - } - } - - fixed_counts +fn apply_multiplicity_counts(motifs: &[Motif], motif_counts: &[usize]) -> Vec { + motifs + .iter() + .zip(motif_counts.iter()) + .map( + |(motif, &count)| { + if motif.mult == Mult::Many { + count + } else { + 1 + } + }, + ) + .collect() } -fn get_dist(template: &Vec, query: &str, motif_counts: &[usize]) -> u32 { +fn get_dist(template: &[Motif], query: &str, motif_counts: &[usize]) -> u32 { let mut reference = String::new(); - for index in 0..template.len() { - let seq = &template[index]; - reference += &seq.seq.repeat(motif_counts[index]); + for (motif, &count) in template.iter().zip(motif_counts) { + reference.push_str(&motif.seq.repeat(count)); } - - let score = levenshtein(reference.as_bytes(), query.as_bytes()); - score + levenshtein(reference.as_bytes(), query.as_bytes()) } -/* #[cfg(test)] mod tests { use super::*; - use crate::label::struc::decode_regexp; + + fn setup_motifs() -> Vec { + vec![ + Motif { + seq: "A".to_string(), + mult: Mult::Once, + }, + Motif { + seq: "C".to_string(), + mult: Mult::Many, + }, + Motif { + seq: "G".to_string(), + mult: Mult::Once, + }, + ] + } + + #[test] + fn test_get_dist_exact_match() { + let template = setup_motifs(); + let query = "ACG"; + let motif_counts = vec![1, 1, 1]; + assert_eq!(get_dist(&template, query, &motif_counts), 0); + } + + #[test] + fn test_get_dist_1_ins() { + let template = setup_motifs(); + let query = "AACG"; + let motif_counts = vec![1, 1, 1]; + assert_eq!(get_dist(&template, query, &motif_counts), 1); + } + + #[test] + fn test_get_dist_1_del() { + let template = setup_motifs(); + let query = "CG"; + let motif_counts = vec![1, 1, 1]; + assert_eq!(get_dist(&template, query, &motif_counts), 1); + } + + #[test] + fn test_get_dist_1_sub() { + let template = setup_motifs(); + let query = "ACC"; + let motif_counts = vec![1, 1, 1]; + assert_eq!(get_dist(&template, query, &motif_counts), 1); + } + + #[test] + fn test_get_dist_diff() { + let template = setup_motifs(); + let query = "TTT"; + let motif_counts = vec![1, 1, 1]; + assert_eq!(get_dist(&template, query, &motif_counts), 3); + } + + #[test] + fn test_get_dist_empty_template() { + let template: Vec = Vec::new(); + let query = "ACG"; + let motif_counts: Vec = Vec::new(); + assert_eq!( + get_dist(&template, query, &motif_counts), + query.len() as u32 + ); + } + + #[test] + fn test_get_dist_empty_query() { + let template = setup_motifs(); + let query = ""; + let motif_counts = vec![1, 2, 1]; + // Distance should be the sum of motif_counts since query is empty. + assert_eq!( + get_dist(&template, query, &motif_counts), + motif_counts.iter().sum::() as u32 + ); + } + + #[test] + fn test_get_dist_empty_both() { + let template: Vec = Vec::new(); + let query = ""; + let motif_counts: Vec = Vec::new(); + assert_eq!(get_dist(&template, query, &motif_counts), 0); + } + + #[test] + fn test_apply_multiplicity_counts_basic() { + let motifs = vec![ + Motif { + seq: "A".to_string(), + mult: Mult::Once, + }, + Motif { + seq: "B".to_string(), + mult: Mult::Many, + }, + ]; + let motif_counts = vec![1, 3]; + let expected = vec![1, 3]; + assert_eq!(apply_multiplicity_counts(&motifs, &motif_counts), expected); + } + + #[test] + fn test_apply_multiplicity_counts_empty_vectors() { + let motifs: Vec = Vec::new(); + let motif_counts: Vec = Vec::new(); + let expected: Vec = Vec::new(); + assert_eq!(apply_multiplicity_counts(&motifs, &motif_counts), expected); + } #[test] - fn test_imperfect_htt_repeat() { - let motifs = decode_regexp("(CAG)nCAACAG(CCG)n"); - // 111111111111111333333 - let query = "CAGCAGCAGCATCAGCCGCCG"; - let counts = vec![5, 0, 2]; - let expected = vec![ - Span { - motif_index: 0, - start: 0, - end: 9, + fn test_all_once() { + let motifs = vec![ + Motif { + seq: "A".to_string(), + mult: Mult::Once, }, - Span { - motif_index: 1, - start: 9, - end: 15, + Motif { + seq: "B".to_string(), + mult: Mult::Once, + }, + ]; + let motif_counts = vec![5, 10]; // These counts should be ignored. + let expected = vec![1, 1]; + assert_eq!(apply_multiplicity_counts(&motifs, &motif_counts), expected); + } + + #[test] + fn test_all_many() { + let motifs = vec![ + Motif { + seq: "A".to_string(), + mult: Mult::Many, }, - Span { - motif_index: 2, - start: 15, - end: 21, + Motif { + seq: "B".to_string(), + mult: Mult::Many, }, ]; - assert_eq!(refine_motif_counts(&motifs, &query, &counts), expected); + let motif_counts = vec![2, 4]; + let expected = vec![2, 4]; + assert_eq!(apply_multiplicity_counts(&motifs, &motif_counts), expected); + } - // 111111111111111333333 - //let query = "CACCAGCAGCATCAGCGGCCG"; - //let counts = vec![1, 0, 5]; - //assert_eq!(match_with_align(&motifs, &query, &counts), vec![3, 2, 2]); + #[test] + fn test_mixed_multiplicities() { + let motifs = vec![ + Motif { + seq: "A".to_string(), + mult: Mult::Once, + }, + Motif { + seq: "B".to_string(), + mult: Mult::Many, + }, + Motif { + seq: "C".to_string(), + mult: Mult::Once, + }, + Motif { + seq: "D".to_string(), + mult: Mult::Many, + }, + ]; + let motif_counts = vec![5, 10, 15, 20]; + let expected = vec![1, 10, 1, 20]; + assert_eq!(apply_multiplicity_counts(&motifs, &motif_counts), expected); } } - -*/ diff --git a/trgt/src/locate/locate.rs b/trgt/src/locate/locate.rs index e9907ee..c0353e2 100644 --- a/trgt/src/locate/locate.rs +++ b/trgt/src/locate/locate.rs @@ -3,25 +3,6 @@ use bio::alignment::{pairwise::*, AlignmentOperation}; use itertools::Itertools; use std::str; -/* -#[derive(Debug)] -pub struct SearchParams { - pub search_flank_len: usize, - pub output_flank_len: usize, - pub kmer_len: usize, - pub step_len: usize, - pub max_delta: i32, - pub min_kmer_count: usize, -} -*/ - -pub struct Locator { - //params: SearchParams, - lf: String, - rf: String, - flank_len: usize, -} - #[derive(Debug, Clone, Copy)] pub struct TrgtScoring { pub match_scr: i32, @@ -34,102 +15,91 @@ pub struct TrgtScoring { type Span = (usize, usize); -impl Locator { - pub fn new(lf: &str, rf: &str, flank_len: usize) -> Locator { - assert!(lf.len() == rf.len()); - Locator { - lf: lf.to_string(), - rf: rf.to_string(), - flank_len, - } - } - - fn find_spans( - aligner: &mut banded::Aligner, - piece: &str, - seqs: &[&str], - params: &Params, - ) -> Vec> - where - F: Fn(u8, u8) -> i32, - { - seqs.iter() - .map(|s| { - if let Some(start) = s.find(piece) { - Some((start, start + piece.len())) +fn find_spans( + aligner: &mut banded::Aligner, + piece: &str, + seqs: &[&str], + params: &Params, +) -> Vec> +where + F: Fn(u8, u8) -> i32, +{ + seqs.iter() + .map(|s| { + if let Some(start) = s.find(piece) { + Some((start, start + piece.len())) + } else { + let align = aligner.semiglobal(piece.as_bytes(), s.as_bytes()); + let flank_aln_len = align + .operations + .iter() + .filter(|x| **x == AlignmentOperation::Match) + .count(); + if flank_aln_len as f32 + >= (params.search_flank_len as f32) * params.min_flank_id_frac + { + Some((align.ystart, align.yend)) } else { - let align = aligner.semiglobal(piece.as_bytes(), s.as_bytes()); - let flank_aln_len = align - .operations - .iter() - .filter(|x| **x == AlignmentOperation::Match) - .count(); - if flank_aln_len as f32 - >= (params.search_flank_len as f32) * params.min_flank_id_frac - { - Some((align.ystart, align.yend)) - } else { - None - } + None } - }) - .collect() - } + } + }) + .collect() +} - pub fn locate(&mut self, reads: &[HiFiRead], params: &Params) -> Vec> { - let lf_piece = &self.lf[self.lf.len() - self.flank_len..]; - let rf_piece = &self.rf[..self.flank_len]; +pub fn find_tr_spans(lf: &str, rf: &str, reads: &[HiFiRead], params: &Params) -> Vec> { + let lf_piece = &lf[lf.len() - params.search_flank_len..]; + let rf_piece = &rf[..params.search_flank_len]; - let scoring = Scoring { - match_fn: |a: u8, b: u8| { - if a == b { - params.aln_scoring.match_scr + let scoring = Scoring { + match_fn: |a: u8, b: u8| { + if a == b { + params.aln_scoring.match_scr + } else { + -params.aln_scoring.mism_scr + } + }, + match_scores: Some((params.aln_scoring.match_scr, -params.aln_scoring.mism_scr)), + gap_open: -params.aln_scoring.gapo_scr, + gap_extend: -params.aln_scoring.gape_scr, + xclip_prefix: MIN_SCORE, + xclip_suffix: MIN_SCORE, + yclip_prefix: 0, + yclip_suffix: 0, + }; + + let mut aligner = banded::Aligner::with_capacity_and_scoring( + params.search_flank_len + 10, // global length + 20000, // local length: maximum HiFi read length + scoring, + params.aln_scoring.kmer_len, + params.aln_scoring.bandwidth, + ); + + let seqs = reads + .iter() + .map(|r| std::str::from_utf8(&r.bases).unwrap()) + .collect_vec(); + + let lf_spans = find_spans(&mut aligner, lf_piece, &seqs, params); + let rf_spans = find_spans(&mut aligner, rf_piece, &seqs, params); + + lf_spans + .iter() + .zip(rf_spans.iter()) + .map(|(lf_span, rf_span)| match (lf_span, rf_span) { + (None, None) => None, // No left or right span + (Some(_lf), None) => None, // Left flanking + (None, Some(_rf)) => None, // Right flanking + (Some(lf), Some(rf)) => { + if lf.1 <= rf.0 { + Some((lf.1, rf.0)) } else { - -params.aln_scoring.mism_scr - } - }, - match_scores: Some((params.aln_scoring.match_scr, -params.aln_scoring.mism_scr)), - gap_open: -params.aln_scoring.gapo_scr, - gap_extend: -params.aln_scoring.gape_scr, - xclip_prefix: MIN_SCORE, - xclip_suffix: MIN_SCORE, - yclip_prefix: 0, - yclip_suffix: 0, - }; - - let mut aligner = banded::Aligner::with_capacity_and_scoring( - self.flank_len + 10, // global length - 20000, // local length: maximum HiFi read length - scoring, - params.aln_scoring.kmer_len, - params.aln_scoring.bandwidth, - ); - - let seqs = reads - .iter() - .map(|r| std::str::from_utf8(&r.bases).unwrap()) - .collect_vec(); - - let lf_spans = Self::find_spans(&mut aligner, lf_piece, &seqs, params); - let rf_spans = Self::find_spans(&mut aligner, rf_piece, &seqs, params); - - lf_spans - .iter() - .zip(rf_spans.iter()) - .map(|(lf_span, rf_span)| match (lf_span, rf_span) { - (None, None) => None, // No left or right span - (Some(_lf), None) => None, // Left flanking - (None, Some(_rf)) => None, // Right flanking - (Some(lf), Some(rf)) => { - if lf.1 <= rf.0 { - Some((lf.1, rf.0)) - } else { - None // Discordant flanks - } + None // Discordant flanks } - }) - .collect() - } + } + }) + .collect() } /* diff --git a/trgt/src/locate/mod.rs b/trgt/src/locate/mod.rs index e78ecef..d0cc5a6 100644 --- a/trgt/src/locate/mod.rs +++ b/trgt/src/locate/mod.rs @@ -1,5 +1,5 @@ mod locate; -pub use locate::Locator; +pub use locate::find_tr_spans; pub use locate::TrgtScoring; mod consensus; diff --git a/trgt/src/locus.rs b/trgt/src/locus.rs index b16bb8e..99ab38c 100644 --- a/trgt/src/locus.rs +++ b/trgt/src/locus.rs @@ -1,83 +1,11 @@ +use crate::faidx; use crate::genotype::Ploidy; -use crate::utils::{self, GenomicRegion}; -use rust_htslib::faidx; -use std::collections::{HashMap, HashSet}; -use std::fs; +use crate::karyotype::Karyotype; +use crate::utils::GenomicRegion; +use std::collections::HashMap; use std::io::{BufRead, BufReader, Read as ioRead}; use std::str::FromStr; -#[derive(Debug, PartialEq, Clone)] -pub struct Karyotype { - ploidy: PloidyInfo, -} - -#[derive(Debug, PartialEq, Clone)] -enum PloidyInfo { - PresetXX, - PresetXY, - Custom(HashMap), -} - -impl Karyotype { - pub fn new(encoding: &str) -> Result { - match encoding { - "XX" => Ok(Self { - ploidy: PloidyInfo::PresetXX, - }), - "XY" => Ok(Self { - ploidy: PloidyInfo::PresetXY, - }), - _ => Self::from_file(encoding), - } - } - - fn from_file(path: &str) -> Result { - let contents = fs::read_to_string(path).map_err(|e| format!("File {}: {}", path, e))?; - - let ploidies = contents - .lines() - .map(|line| { - let mut parts = line.split_whitespace(); - let chrom = parts.next().ok_or("Missing chromosome".to_string())?; - let ploidy_str = parts.next().ok_or("Missing ploidy".to_string())?; - let ploidy = ploidy_str - .parse() - .map_err(|e: String| format!("Invalid ploidy: {}", e))?; - Ok((chrom.to_string(), ploidy)) - }) - .collect::, String>>()?; - - Ok(Self { - ploidy: PloidyInfo::Custom(ploidies), - }) - } - - pub fn get_ploidy(&self, chrom: &str) -> Result { - let is_on_chrx = chrom == "X" || chrom == "chrX"; - let is_on_chry = chrom == "Y" || chrom == "chrY"; - match &self.ploidy { - PloidyInfo::PresetXX => { - if is_on_chry { - Ok(Ploidy::Zero) - } else { - Ok(Ploidy::Two) - } - } - PloidyInfo::PresetXY => { - if is_on_chrx || is_on_chry { - Ok(Ploidy::One) - } else { - Ok(Ploidy::Two) - } - } - PloidyInfo::Custom(ploidies) => ploidies - .get(chrom) - .copied() - .ok_or_else(|| format!("Ploidy was not specified for chromosome: {}", chrom)), - } - } -} - #[derive(Debug, Clone, Copy)] pub enum Genotyper { Size, @@ -101,7 +29,7 @@ pub struct Locus { pub left_flank: String, pub tr: String, pub right_flank: String, - pub region: utils::GenomicRegion, + pub region: GenomicRegion, pub motifs: Vec, pub struc: String, pub ploidy: Ploidy, @@ -111,23 +39,34 @@ pub struct Locus { impl Locus { pub fn new( genome_reader: &faidx::Reader, - chrom_lookup: &HashSet, + chrom_lookup: &HashMap, line: &str, flank_len: usize, karyotype: &Karyotype, genotyper: Genotyper, ) -> Result { + const EXPECTED_FIELD_COUNT: usize = 4; let split_line: Vec<&str> = line.split_whitespace().collect(); - if split_line.len() != 4 { - return Err(format!("Expected 4 fields, found {}", split_line.len())); + if split_line.len() != EXPECTED_FIELD_COUNT { + return Err(format!( + "Expected {} fields in the format 'chrom start end info', found {}: {}", + EXPECTED_FIELD_COUNT, + split_line.len(), + line + )); } - let (chrom, start, end) = (split_line[0], split_line[1], split_line[2]); + let (chrom, start, end, info_fields) = match &split_line[..] { + [chrom, start, end, info_fields] => (*chrom, *start, *end, *info_fields), + _ => unreachable!(), + }; + let region = GenomicRegion::new(&format!("{}:{}-{}", chrom, start, end))?; + check_region_bounds(®ion, flank_len, chrom_lookup)?; + let ploidy = karyotype.get_ploidy(chrom)?; - let info_fields = split_line[3]; let fields = decode_fields(info_fields)?; let get_field = |key: &str| { @@ -144,8 +83,7 @@ impl Locus { .collect(); let struc = get_field("STRUC")?; - let (left_flank, tr, right_flank) = - get_tr_and_flanks(genome_reader, chrom_lookup, ®ion, flank_len)?; + let (left_flank, tr, right_flank) = get_tr_and_flanks(genome_reader, ®ion, flank_len)?; Ok(Locus { id, @@ -161,16 +99,14 @@ impl Locus { } } -pub fn get_loci<'a>( +pub fn get_loci( catalog_reader: BufReader>, - genome_reader: &'a faidx::Reader, + genome_reader: &faidx::Reader, + karyotype: Karyotype, flank_len: usize, - karyotype: &'a Karyotype, genotyper: Genotyper, -) -> impl Iterator> + 'a { - let chrom_lookup: HashSet = (0..genome_reader.n_seqs()) - .filter_map(|i| genome_reader.seq_name(i as i32).ok()) - .collect(); +) -> impl Iterator> + '_ { + let chrom_lookup = genome_reader.create_chrom_lookup().unwrap(); catalog_reader .lines() @@ -182,7 +118,7 @@ pub fn get_loci<'a>( &chrom_lookup, &line, flank_len, - karyotype, + &karyotype, genotyper, ) { Ok(locus) => Some(Ok(locus)), @@ -195,71 +131,83 @@ pub fn get_loci<'a>( fn get_tr_and_flanks( genome: &faidx::Reader, - chrom_lookup: &HashSet, - region: &utils::GenomicRegion, + region: &GenomicRegion, flank_len: usize, ) -> Result<(String, String, String), String> { - let (lf_start, lf_end) = (region.start as usize - flank_len, region.start as usize); - let (rf_start, rf_end) = (region.end as usize, region.end as usize + flank_len); + let fetch_flank = |start: usize, end: usize| { + genome + .fetch_seq_string(®ion.contig, start, end) + .map_err(|e| { + format!( + "Error fetching sequence for region {}:{}-{}: {}", + ®ion.contig, start, end, e + ) + }) + .map(|seq| seq.to_uppercase()) + }; - // TODO: This is necessary because faidx is unsafe and segfaults when - // the region is invalid, so we need to fail gracefully in the event of - // a bad input. Should be removed if rust_htslib addresses this. - if !chrom_lookup.contains(®ion.contig) { - return Err(format!( - "FASTA reference does not contain chromosome '{}' in BED file", - region.contig - )); - } + let left_flank = fetch_flank(region.start as usize - flank_len, region.start as usize - 1)?; + let tr = fetch_flank(region.start as usize, region.end as usize - 1)?; + let right_flank = fetch_flank(region.end as usize, region.end as usize + flank_len - 1)?; - let left_flank = genome - .fetch_seq_string(®ion.contig, lf_start, lf_end - 1) - .map_err(|e| e.to_string())?; - let tr = genome - .fetch_seq_string( - ®ion.contig, - region.start as usize, - region.end as usize - 1, - ) - .map_err(|e| e.to_string())?; - let right_flank = genome - .fetch_seq_string(®ion.contig, rf_start, rf_end - 1) - .map_err(|e| e.to_string())?; - Ok(( - left_flank.to_uppercase(), - tr.to_uppercase(), - right_flank.to_uppercase(), - )) + Ok((left_flank, tr, right_flank)) } fn decode_fields(info_fields: &str) -> Result, String> { let mut fields = HashMap::new(); for field_encoding in info_fields.split(';') { - let (name, value) = decode_info_field(field_encoding)?; + let (name, value) = decode_info_field(field_encoding).map_err(|e| e.to_string())?; if fields.insert(name, value.to_string()).is_some() { - return Err(format!("Duplicate field: {}", name)); + return Err(format!("Duplicate field name: '{}'", name)); } } Ok(fields) } fn decode_info_field(encoding: &str) -> Result<(&str, &str), String> { - if encoding.is_empty() { - return Err("Field is empty".to_string()); + let error_message = || format!("Field must be in 'name=value' format: '{}'", encoding); + let parts: Vec<&str> = encoding.splitn(2, '=').collect(); + if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() { + Err(error_message()) + } else { + Ok((parts[0], parts[1])) } +} - let mut name_and_value = encoding.splitn(2, '='); - let error_message = || format!("Invalid field entry: {}", encoding); +fn check_region_bounds( + region: &GenomicRegion, + flank_len: usize, + chrom_lookup: &HashMap, +) -> Result<(), String> { + let chrom_length = *chrom_lookup.get(®ion.contig).ok_or_else(|| { + format!( + "FASTA reference does not contain chromosome '{}' in BED file", + ®ion.contig + ) + })?; + + let flank_len_u32 = flank_len as u32; + + if region.start < flank_len_u32 + 1 { + return Err(format!( + "Region start '{}' with flank length '{}' underflows for chromosome '{}'.", + region.start, flank_len, ®ion.contig + )); + } - let name = name_and_value - .next() - .filter(|s| !s.is_empty()) - .ok_or_else(error_message)?; + let adjusted_end = region.end.checked_add(flank_len_u32).ok_or_else(|| { + format!( + "Region end '{}' with flank length '{}' overflows for chromosome '{}'.", + region.end, flank_len, ®ion.contig + ) + })?; - let value = name_and_value - .next() - .filter(|s| !s.is_empty()) - .ok_or_else(error_message)?; + if adjusted_end > chrom_length { + return Err(format!( + "Region end '{}' with flank length '{}' exceeds chromosome '{}' bounds (0..{}).", + adjusted_end, flank_len, ®ion.contig, chrom_length + )); + } - Ok((name, value)) + Ok(()) } diff --git a/trgt/src/main.rs b/trgt/src/main.rs index fe19695..32d6d8c 100644 --- a/trgt/src/main.rs +++ b/trgt/src/main.rs @@ -16,14 +16,13 @@ //! --output-prefix sample //! ``` -use crate::read_output::BamWriter; -use crate::vcf::VcfWriter; use cli::{get_cli_params, handle_error_and_exit}; use flate2::read::GzDecoder; -use locus::Karyotype; +use karyotype::Karyotype; +use rust_htslib::bam; use rust_htslib::bam::Read; -use rust_htslib::{bam, faidx}; use std::cell::RefCell; +use std::collections::HashSet; use std::fs::File; use std::io::{BufReader, Read as ioRead}; use std::path::{Path, PathBuf}; @@ -32,18 +31,19 @@ use std::sync::Arc; use std::{thread, time}; use threadpool::ThreadPool; use workflows::analyze_tr; +use writers::{BamWriter, VcfWriter}; mod cli; mod cluster; +mod faidx; mod genotype; +mod karyotype; mod label; mod locate; mod locus; -mod read_output; mod reads; -mod snp; mod utils; -mod vcf; mod workflows; +mod writers; pub type Result = std::result::Result; @@ -58,21 +58,49 @@ thread_local! { } pub fn get_bam_header(bam_path: &PathBuf) -> Result { - let bam = match bam::IndexedReader::from_path(bam_path) { - Ok(reader) => reader, - Err(e) => return Err(format!("Failed to create bam reader: {}", e)), - }; - let bam_header = bam::Header::from_template(bam.header()); - Ok(bam_header) + let bam = bam::IndexedReader::from_path(bam_path) + .map_err(|e| format!("Failed to create bam reader: {}", e))?; + Ok(bam::Header::from_template(bam.header())) +} + +fn is_bam_mapped(bam_header: &bam::Header) -> bool { + // input is already sorted because it fails an index. + // If it is mapped, the index needs the SQ tags to fetch data. + for line in String::from_utf8(bam_header.to_bytes()).unwrap().lines() { + if line.starts_with("@SQ") { + return true; + } + } + false } -fn get_sample_name(reads_path: &Path) -> String { - reads_path +fn get_sample_name(reads_path: &PathBuf) -> Result { + let bam_header = get_bam_header(reads_path)?; + + let header_hashmap = bam_header.to_hashmap(); + let mut sample_names = HashSet::new(); + + if let Some(rg_fields) = header_hashmap.get("RG") { + for rg_field in rg_fields { + if let Some(sample_name) = rg_field.get("SM") { + sample_names.insert(sample_name.to_owned()); + } + } + } + + match sample_names.len() { + 1 => return Ok(sample_names.into_iter().next().unwrap()), + 0 => log::warn!("No sample names found"), + _ => log::warn!("Multiple sample names found"), + }; + + let sample = reads_path .file_stem() - .unwrap() - .to_str() - .unwrap() - .to_string() + .and_then(|stem| stem.to_str()) + .ok_or("Invalid reads file name")? + .to_string(); + + Ok(sample) } fn create_writer(output_prefix: &str, output_suffix: &str, f: F) -> Result @@ -80,56 +108,48 @@ where F: FnOnce(&str) -> Result, { let output_path = format!("{}.{}", output_prefix, output_suffix); - f(&output_path).map_err(|e| { - eprintln!("Error creating writer: {}", e); - e - }) + f(&output_path) } fn open_catalog_reader(path: &PathBuf) -> Result>> { - fn get_format(path: &Path) -> Option<&'static str> { + fn is_gzipped(path: &Path) -> bool { let path_str = path.to_string_lossy(); - let formats = ["bed", "bed.gz", "bed.gzip"]; - formats - .iter() - .find(|&&format| path_str.ends_with(format)) - .copied() + let formats = [".gz", ".gzip", ".GZ", ".GZIP"]; + formats.iter().any(|format| path_str.ends_with(*format)) } let file = File::open(path).map_err(|e| e.to_string())?; - match get_format(path) { - Some("bed.gz") | Some("bed.gzip") => { - let gz_decoder = GzDecoder::new(file); - if gz_decoder.header().is_some() { - Ok(BufReader::new(Box::new(gz_decoder))) - } else { - Err(format!("Invalid gzip header: {}", path.to_string_lossy())) - } + if is_gzipped(path) { + let gz_decoder = GzDecoder::new(file); + if gz_decoder.header().is_some() { + Ok(BufReader::new(Box::new(gz_decoder))) + } else { + Err(format!("Invalid gzip header: {}", path.to_string_lossy())) } - Some("bed") => Ok(BufReader::new(Box::new(file))), - _ => Err(format!( - "Unknown bed format: {}. Supported formats are: .bed or .bed.gz(ip)", - path.to_string_lossy() - )), + } else { + Ok(BufReader::new(Box::new(file))) } } -fn open_genome_reader(path: &PathBuf) -> Result { - let reader = faidx::Reader::from_path(path).map_err(|e| e.to_string())?; - Ok(reader) +fn open_genome_reader(path: &Path) -> Result { + let extension = path.extension().unwrap().to_str().unwrap(); + let fai_path = path.with_extension(extension.to_owned() + ".fai"); + if !fai_path.exists() { + return Err(format!( + "Reference index file not found: {}. Create it using 'samtools faidx {}'", + fai_path.display(), + path.display() + )); + } + faidx::Reader::from_path(path).map_err(|e| e.to_string()) } -fn is_bam_mapped(bam_header: &bam::Header) -> bool { - // input is already sorted because it fails an index. - // If it is mapped, the index needs the SQ tags to fetch data. - for line in String::from_utf8(bam_header.to_bytes()).unwrap().lines() { - if line.starts_with("@SQ") { - return true; - } +fn main() { + if let Err(e) = run_trgt() { + handle_error_and_exit(e); } - false } -fn main() -> Result<()> { +fn run_trgt() -> Result<()> { let params = get_cli_params(); log::info!( @@ -139,26 +159,23 @@ fn main() -> Result<()> { ); let start_timer = time::Instant::now(); - let karyotype = - Karyotype::new(¶ms.karyotype).unwrap_or_else(|err| handle_error_and_exit(err)); + let karyotype = Karyotype::new(¶ms.karyotype)?; - let search_flank_len = params.flank_len; - let output_flank_len = std::cmp::min(search_flank_len, 50); - let sample_name = get_sample_name(¶ms.reads_path); + let sample_name = params + .sample_name + .unwrap_or(get_sample_name(¶ms.reads_path)?); - let catalog_reader = - open_catalog_reader(¶ms.repeats_path).unwrap_or_else(|err| handle_error_and_exit(err)); + let catalog_reader = open_catalog_reader(¶ms.repeats_path)?; let genome_reader = open_genome_reader(¶ms.genome_path)?; let all_loci = locus::get_loci( catalog_reader, &genome_reader, - search_flank_len, - &karyotype, + karyotype, + params.flank_len, params.genotyper, ) - .collect::>>() - .unwrap_or_else(|err| handle_error_and_exit(err)); + .collect::>>()?; let bam_header = get_bam_header(¶ms.reads_path)?; if !is_bam_mapped(&bam_header) { @@ -168,8 +185,10 @@ fn main() -> Result<()> { let mut vcf_writer = create_writer(¶ms.output_prefix, "vcf.gz", |path| { VcfWriter::new(path, &sample_name, &bam_header) })?; + + let output_flank_len = std::cmp::min(params.flank_len, 50); let mut bam_writer = create_writer(¶ms.output_prefix, "spanning.bam", |path| { - BamWriter::new(path, bam_header) + BamWriter::new(path, bam_header, output_flank_len) })?; log::info!("Starting job pool with {} threads...", params.num_threads); @@ -179,13 +198,13 @@ fn main() -> Result<()> { let writer_thread = thread::spawn(move || { for (locus, results) in &receiver { vcf_writer.write(&locus, &results); - bam_writer.write(&locus, output_flank_len, &results); + bam_writer.write(&locus, &results); } }); let reads_path = Arc::new(params.reads_path.clone()); let workflow_params = Arc::new(workflows::Params { - search_flank_len, + search_flank_len: params.flank_len, min_read_qual: params.min_hifi_read_qual, max_depth: params.max_depth, aln_scoring: params.aln_scoring, @@ -211,7 +230,7 @@ fn main() -> Result<()> { sender.send((locus, results)).unwrap(); } Err(err) => { - eprintln!("Error occurred while analyzing: {}", err); + log::error!("Error occurred while analyzing: {}", err); } } }); diff --git a/trgt/src/reads/clip_bases.rs b/trgt/src/reads/clip_bases.rs index dec262b..4f21d28 100644 --- a/trgt/src/reads/clip_bases.rs +++ b/trgt/src/reads/clip_bases.rs @@ -41,14 +41,12 @@ pub fn clip_bases(read: &HiFiRead, left_len: usize, right_len: usize) -> Option< }; Some(HiFiRead { - id: read.id.clone(), bases: clipped_bases, meth: clipped_meth, - read_qual: read.read_qual, - mismatch_offsets: read.mismatch_offsets.clone(), - start_offset: read.start_offset, - end_offset: read.end_offset, cigar: clipped_cigar, + id: read.id.clone(), + mismatch_offsets: read.mismatch_offsets.clone(), + ..*read }) } @@ -131,6 +129,8 @@ mod tests { start_offset: 0, end_offset: 0, cigar: Some(cigar), + hp_tag: None, + mapq: 60, } } diff --git a/trgt/src/reads/clip_region.rs b/trgt/src/reads/clip_region.rs index 7ef0f44..779070a 100644 --- a/trgt/src/reads/clip_region.rs +++ b/trgt/src/reads/clip_region.rs @@ -44,15 +44,12 @@ pub fn clip_to_region(read: HiFiRead, region: (i64, i64)) -> Option { ref_pos: clipped_ref_start, ops: clipped_cigar, }; + Some(HiFiRead { - id: read.id, bases: clipped_bases, meth: clipped_meth, - read_qual: read.read_qual, - mismatch_offsets: read.mismatch_offsets, - start_offset: read.start_offset, - end_offset: read.end_offset, cigar: Some(cigar), + ..read }) } @@ -132,16 +129,19 @@ fn clip_cigar(cigar: &Cigar, region: (i64, i64)) -> Option<(i64, i64, Vec CigarOp::Match(ref_inside_len as u32), - CigarOp::RefSkip(_) => CigarOp::RefSkip(ref_inside_len as u32), - CigarOp::Del(_) => CigarOp::Del(ref_inside_len as u32), - CigarOp::Equal(_) => CigarOp::Equal(ref_inside_len as u32), - CigarOp::Diff(_) => CigarOp::Diff(ref_inside_len as u32), - op => panic!("Unexpected operation {:?}", op), - }); + if let Some(op) = current_op { + if ref_pos < region_end { + let ref_inside_len = (region_end - ref_pos) as u32; + let clipped_op = match op { + CigarOp::Match(_) => CigarOp::Match(ref_inside_len), + CigarOp::RefSkip(_) => CigarOp::RefSkip(ref_inside_len), + CigarOp::Del(_) => CigarOp::Del(ref_inside_len), + CigarOp::Equal(_) => CigarOp::Equal(ref_inside_len), + CigarOp::Diff(_) => CigarOp::Diff(ref_inside_len), + _ => panic!("Unexpected operation {:?}", op), + }; + clipped_ops.push(clipped_op); + } } Some((clipped_ref_start, clipped_query_start, clipped_ops)) @@ -184,6 +184,8 @@ mod tests { start_offset: 0, end_offset: 0, cigar: Some(cigar), + hp_tag: None, + mapq: 60, } } diff --git a/trgt/src/reads/mod.rs b/trgt/src/reads/mod.rs index 46670bb..a2754fa 100644 --- a/trgt/src/reads/mod.rs +++ b/trgt/src/reads/mod.rs @@ -3,6 +3,7 @@ mod clip_bases; mod clip_region; mod meth; mod read; +mod snp; pub use clip_bases::clip_bases; pub use clip_region::clip_to_region; diff --git a/trgt/src/reads/read.rs b/trgt/src/reads/read.rs index e08b3c3..3253942 100644 --- a/trgt/src/reads/read.rs +++ b/trgt/src/reads/read.rs @@ -1,11 +1,9 @@ -use super::cigar::Cigar; +use super::{cigar::Cigar, meth, snp::extract_snps_offset}; use crate::utils::GenomicRegion; use itertools::Itertools; use rust_htslib::bam::{self, ext::BamRecordExtensions, record::Aux}; use std::str; -use super::meth; - #[derive(Debug)] pub struct MethInfo { pub poses: Vec, @@ -22,6 +20,8 @@ pub struct HiFiRead { pub start_offset: i32, pub end_offset: i32, pub cigar: Option, + pub hp_tag: Option, + pub mapq: u8, } impl std::fmt::Debug for HiFiRead { @@ -45,48 +45,47 @@ impl HiFiRead { let id = str::from_utf8(rec.qname()).unwrap().to_string(); let bases = rec.seq().as_bytes(); - let mm_tag = get_mm_tag(&rec); - let ml_tag = get_ml_tag(&rec); - - let meth = match mm_tag { - Some(mm_tag) => parse_meth_tags(mm_tag, ml_tag.unwrap()), - None => None, - }; - - let meth = match meth { - Some(tags) => { - if rec.is_reverse() { - meth::decode_on_minus(&bases, &tags) - } else { - meth::decode_on_plus(&bases, &tags) - } - } - None => None, - }; - + let meth = get_mm_tag(&rec).and_then(|mm_tag| { + get_ml_tag(&rec) + .and_then(|ml_tag| parse_meth_tags(mm_tag, ml_tag)) + .and_then(|tags| { + if rec.is_reverse() { + meth::decode_on_minus(&bases, &tags) + } else { + meth::decode_on_plus(&bases, &tags) + } + }) + }); + + let mapq = rec.mapq(); + let hp_tag = get_hp_tag(&rec); let read_qual = get_rq_tag(&rec); - let cigar = if rec.is_unmapped() { - None + let cigar = if !rec.is_unmapped() { + Some(Cigar { + ref_pos: rec.reference_start(), + ops: rec.cigar().take().to_vec(), + }) } else { - let ref_pos = rec.reference_start(); - let ops = rec.cigar().take().to_vec(); - let cigar = Cigar { ref_pos, ops }; - Some(cigar) + None }; let start_offset = (rec.reference_start() - region.start as i64) as i32; let end_offset = (rec.reference_end() - region.end as i64) as i32; + let mismatch_offsets = cigar.as_ref().map(|c| extract_snps_offset(c, region)); + HiFiRead { id, bases, meth, read_qual, - mismatch_offsets: None, + mismatch_offsets, start_offset, end_offset, cigar, + hp_tag, + mapq, } } } @@ -125,35 +124,23 @@ fn parse_meth_tags(mm_tag: Aux, ml_tag: Aux) -> Option { } fn get_mm_tag(rec: &bam::Record) -> Option { - if let Ok(value) = rec.aux(b"MM") { - Some(value) - } else if let Ok(value) = rec.aux(b"Mm") { - Some(value) - } else { - None - } + rec.aux(b"MM").or_else(|_| rec.aux(b"Mm")).ok() } fn get_ml_tag(rec: &bam::Record) -> Option { - if let Ok(value) = rec.aux(b"ML") { - Some(value) - } else if let Ok(value) = rec.aux(b"Ml") { - Some(value) - } else { - None - } + rec.aux(b"ML").or_else(|_| rec.aux(b"Ml")).ok() } fn get_rq_tag(rec: &bam::Record) -> Option { - let rq_tag = rec.aux(b"rq"); - if rq_tag.is_err() { - return None; + match rec.aux(b"rq") { + Ok(Aux::Float(value)) => Some(f64::from(value)), + _ => None, } +} - let rq_tag = rq_tag.unwrap(); - if let Aux::Float(value) = rq_tag { - return Some(value as f64); +fn get_hp_tag(rec: &bam::Record) -> Option { + match rec.aux(b"HP") { + Ok(Aux::U8(value)) => Some(u8::from(value)), + _ => None, } - - panic!("Unexpected rq tag format: {:?}", rq_tag); } diff --git a/trgt/src/snp.rs b/trgt/src/reads/snp.rs similarity index 99% rename from trgt/src/snp.rs rename to trgt/src/reads/snp.rs index 889ff82..925dcfb 100644 --- a/trgt/src/snp.rs +++ b/trgt/src/reads/snp.rs @@ -36,8 +36,7 @@ pub fn extract_snps(cigar: &Cigar, region: &GenomicRegion) -> Vec { mismatches } -// Collect all positions in a read that are mismatches relative to the starting point of the region -#[allow(dead_code)] +// Collect all positions in a read that are mismatches relative to the starting and ending point of the region pub fn extract_snps_offset(cigar: &Cigar, region: &GenomicRegion) -> Vec { let mut mismatches: Vec = Vec::new(); let mut start_ref = cigar.ref_pos as u32; diff --git a/trgt/src/workflows/tr.rs b/trgt/src/workflows/tr.rs index 8b928fc..ff71005 100644 --- a/trgt/src/workflows/tr.rs +++ b/trgt/src/workflows/tr.rs @@ -1,10 +1,9 @@ use crate::cluster; use crate::genotype::{self, flank_genotype, Gt}; use crate::label::label_alleles; -use crate::locate::{Locator, TrgtScoring}; +use crate::locate::{find_tr_spans, TrgtScoring}; use crate::locus::{Genotyper, Locus}; use crate::reads::{clip_to_region, HiFiRead}; -use crate::snp; use crate::workflows::{Allele, Genotype, LocusResult}; use itertools::Itertools; use rust_htslib::bam; @@ -28,16 +27,14 @@ pub fn analyze( if locus.ploidy == genotype::Ploidy::Zero { return Ok(LocusResult::empty()); } - let mut reads = extract_reads( + let reads = extract_reads( locus, - params.search_flank_len as u32, bam, + params.search_flank_len as u32, params.min_read_qual, )?; log::debug!("{}: Collected {} reads", locus.id, reads.len()); - snp::analyze_snps(&mut reads, &locus.region); - let clip_radius = 500; let reads = clip_reads(locus, clip_radius, reads); log::debug!("{}: {} reads left after clipping", locus.id, reads.len()); @@ -108,12 +105,7 @@ fn get_spanning_reads( params: &Params, reads: Vec, ) -> (Vec, Vec<(usize, usize)>) { - let mut locator = Locator::new( - &locus.left_flank, - &locus.right_flank, - params.search_flank_len, - ); - let tr_spans = locator.locate(&reads, params); + let tr_spans = find_tr_spans(&locus.left_flank, &locus.right_flank, &reads, params); let reads_and_spans = reads .into_iter() @@ -269,8 +261,8 @@ fn assign_read(gt: &Gt, tr_len: usize) -> Assignment { fn extract_reads( locus: &Locus, - flank_len: u32, bam: &mut bam::IndexedReader, + flank_len: u32, min_read_qual: f64, ) -> Result> { let mut reads = Vec::new(); diff --git a/trgt/src/writers/mod.rs b/trgt/src/writers/mod.rs new file mode 100644 index 0000000..1255188 --- /dev/null +++ b/trgt/src/writers/mod.rs @@ -0,0 +1,5 @@ +mod write_bam; +mod write_vcf; + +pub use write_bam::BamWriter; +pub use write_vcf::VcfWriter; diff --git a/trgt/src/read_output.rs b/trgt/src/writers/write_bam.rs similarity index 79% rename from trgt/src/read_output.rs rename to trgt/src/writers/write_bam.rs index f398fc4..13f8709 100644 --- a/trgt/src/read_output.rs +++ b/trgt/src/writers/write_bam.rs @@ -9,6 +9,7 @@ use std::env; pub struct BamWriter { writer: bam::Writer, + output_flank_len: usize, } impl BamWriter { @@ -27,27 +28,34 @@ impl BamWriter { header } - pub fn new(output_bam_path: &str, bam_header: bam::Header) -> Result { + pub fn new( + output_bam_path: &str, + bam_header: bam::Header, + output_flank_len: usize, + ) -> Result { let bam_header = Self::update_header(bam_header); let writer = bam::Writer::from_path(output_bam_path, &bam_header, bam::Format::Bam) .map_err(|e| e.to_string())?; - Ok(BamWriter { writer }) + Ok(BamWriter { + writer, + output_flank_len, + }) } - pub fn write(&mut self, locus: &Locus, output_flank_len: usize, results: &LocusResult) { + pub fn write(&mut self, locus: &Locus, results: &LocusResult) { let num_reads = results.reads.len(); for index in 0..num_reads { let read = &results.reads[index]; let classification = results.classification[index]; let span = &results.tr_spans[index]; - if span.0 < output_flank_len || read.bases.len() < span.1 + output_flank_len { + if span.0 < self.output_flank_len || read.bases.len() < span.1 + self.output_flank_len { log::error!("Read {} has unexpectedly short flanks", read.id); continue; } - let left_clip_len = span.0 - output_flank_len; - let right_clip_len = read.bases.len() - span.1 - output_flank_len; + let left_clip_len = span.0 - self.output_flank_len; + let right_clip_len = read.bases.len() - span.1 - self.output_flank_len; let clipped_read = clip_bases(read, left_clip_len, right_clip_len); if clipped_read.is_none() { log::error!("Read {} has unexpectedly short flanks", read.id); @@ -71,7 +79,7 @@ impl BamWriter { &read.bases, quals.as_bytes(), ); - rec.set_mapq(60); + rec.set_mapq(read.mapq); } else { rec.set(read.id.as_bytes(), None, &read.bases, quals.as_bytes()); rec.set_pos(locus.region.start as i64); @@ -91,11 +99,16 @@ impl BamWriter { rec.push_aux(b"MO", mm_tag).unwrap(); } + if let Some(hp) = read.hp_tag { + let hp_tag = Aux::U8(hp); + rec.push_aux(b"HP", hp_tag).unwrap(); + } + rec.push_aux(b"SO", Aux::I32(read.start_offset)).unwrap(); rec.push_aux(b"EO", Aux::I32(read.end_offset)).unwrap(); rec.push_aux(b"AL", Aux::I32(classification)).unwrap(); - let dat: &Vec = &vec![output_flank_len as u32, output_flank_len as u32]; + let dat: &Vec = &vec![self.output_flank_len as u32, self.output_flank_len as u32]; let fl_tag: AuxArray = dat.into(); rec.push_aux(b"FL", Aux::ArrayU32(fl_tag)).unwrap(); diff --git a/trgt/src/vcf.rs b/trgt/src/writers/write_vcf.rs similarity index 96% rename from trgt/src/vcf.rs rename to trgt/src/writers/write_vcf.rs index 890defa..078a521 100644 --- a/trgt/src/vcf.rs +++ b/trgt/src/writers/write_vcf.rs @@ -1,7 +1,6 @@ use crate::locus::Locus; use crate::workflows::{Genotype, LocusResult}; use itertools::Itertools; -use log::error; use rust_htslib::bam::{self}; use rust_htslib::bcf::record::GenotypeAllele; use rust_htslib::bcf::{self, Format, Record}; @@ -60,13 +59,9 @@ impl VcfWriter { vcf_header.push_sample(sample_name.as_bytes()); - let writer = match bcf::Writer::from_path(output_path, &vcf_header, false, Format::Vcf) { - Ok(file) => file, - Err(_) => { - error!("Invalid VCF output path: {}", &output_path); - std::process::exit(1); - } - }; + let writer = bcf::Writer::from_path(output_path, &vcf_header, false, Format::Vcf) + .map_err(|_| format!("Invalid VCF output path: {}", output_path))?; + Ok(VcfWriter { writer }) } diff --git a/trvz/Cargo.toml b/trvz/Cargo.toml index 1887548..b3ec49e 100644 --- a/trvz/Cargo.toml +++ b/trvz/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "trvz" -version = "0.5.0" +version = "0.7.0" edition = "2021" build = "build.rs" diff --git a/trvz/src/cli.rs b/trvz/src/cli.rs index aff848c..dac6920 100644 --- a/trvz/src/cli.rs +++ b/trvz/src/cli.rs @@ -17,7 +17,7 @@ lazy_static! { #[derive(Parser)] #[command(name="trvz", author="Egor Dolzhenko \nGuilherme De Sena Brandine \nTom Mokveld ", - version=&**FULL_VERSION, + version=&**FULL_VERSION, about="Tandem Repeat Visualizer", long_about = None, after_help = format!("Copyright (C) 2004-{} Pacific Biosciences of California, Inc. diff --git a/trvz/src/hmm.rs b/trvz/src/hmm.rs index 35a7be7..5b47519 100644 --- a/trvz/src/hmm.rs +++ b/trvz/src/hmm.rs @@ -5,9 +5,17 @@ use std::collections::HashMap; // lp = log probability // ems = emissions +#[derive(Debug, Clone, PartialEq)] +pub struct Span { + pub motif_index: usize, + pub start: usize, + pub end: usize, +} + type MatF64 = Vec>; type MatInt = Vec>; +#[derive(Debug, PartialEq)] pub struct Hmm { num_states: usize, ems: MatF64, @@ -16,6 +24,7 @@ pub struct Hmm { pub motifs: Vec, } +#[derive(Debug, PartialEq)] pub struct HmmMotif { pub start_state: usize, pub end_state: usize, @@ -88,6 +97,7 @@ impl Hmm { } } + //TODO: Investigate "em_term != 0.0" if index == 0 && in_states.is_empty() && em_term.is_finite() { max_score = em_term; best_state = Some(&state); @@ -152,7 +162,7 @@ impl Hmm { self.traceback(&query, &states) } - pub fn label_motifs(&self, states: &Vec) -> Vec<(usize, usize, usize)> { + pub fn label_motifs(&self, states: &Vec) -> Vec { let state_to_motif: HashMap = self .motifs .iter() @@ -160,7 +170,7 @@ impl Hmm { .map(|(index, m)| (m.start_state, index)) .collect(); - let mut motif_spans: Vec<(usize, usize, usize)> = Vec::new(); + let mut motif_spans: Vec = Vec::new(); let mut state_index = 0; while state_index < states.len() { let state = states[state_index]; @@ -183,11 +193,15 @@ impl Hmm { let motif_start = if motif_spans.is_empty() { 0 } else { - motif_spans.last().unwrap().1 + motif_spans.last().unwrap().end }; let motif_end = motif_start + motif_span; - motif_spans.push((motif_start, motif_end, motif_index)); + motif_spans.push(Span { + motif_index, + start: motif_start, + end: motif_end, + }); } else { assert!(!self.emits_base(state)); state_index += 1; @@ -249,8 +263,7 @@ fn encode_base(base: u8) -> u8 { } } -/* -#[cfg(test)] +/*#[cfg(test)] mod tests { use super::*; use approx::assert_relative_eq; diff --git a/trvz/src/hmm_defs.rs b/trvz/src/hmm_defs.rs deleted file mode 100644 index dbeff87..0000000 --- a/trvz/src/hmm_defs.rs +++ /dev/null @@ -1,12 +0,0 @@ -use lazy_static::lazy_static; -use std::collections::HashMap; - -lazy_static! { - pub static ref HMM_DEFS: HashMap<&'static str, &'static str> = { - let mut hmm_defs = HashMap::new(); - hmm_defs.insert("", "[[1.0,0.0,0.0,0.0,0.0],[0.0,0.05,0.45,0.45,0.05],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.45,0.45,0.05,0.05],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[1.0,0.0,0.0,0.0,0.0]]|[[[],[]],[[0,1,8],[1.0,0.7,0.2]],[[1],[0.2]],[[2,9,15],[0.8,0.5,0.5]],[[3,10,16],[0.8,0.5,0.5]],[[4,11,17],[0.8,0.5,0.5]],[[5,12,18],[0.8,0.5,0.5]],[[6,13,19],[0.8,0.5,0.5]],[[7,8,14,20],[0.8,0.7,0.5,0.1]],[[1,2,9],[0.05,0.1,0.5]],[[3,10],[0.1,0.5]],[[4,11],[0.1,0.5]],[[5,12],[0.1,0.5]],[[6,13],[0.1,0.5]],[[7,14],[0.2,0.5]],[[1],[0.05]],[[2,15],[0.1,0.5]],[[3,16],[0.1,0.5]],[[4,17],[0.1,0.5]],[[5,18],[0.1,0.5]],[[6,19],[0.1,0.5]],[[8],[0.1]]]|(1,8,0)"); - hmm_defs.insert("", "[[1.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.7,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.7,0.1],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.7,0.1,0.1,0.1],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.1,0.1,0.1,0.7],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.25,0.25,0.25,0.25],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[0.0,0.0,0.0,0.0,0.0],[1.0,0.0,0.0,0.0,0.0]]|[[[],[]],[[0,133],[1.0,0.9]],[[1,17],[0.125,0.5]],[[2],[0.9]],[[2,3,8],[0.03387533875338753,0.8,0.2]],[[2,4,9,13],[0.02710027100271003,0.8,0.2,0.5]],[[2,5,10,14],[0.02168021680216803,0.8,0.2,0.5]],[[2,6,11,15],[0.01734417344173442,0.8,0.2,0.5]],[[8,3],[0.8,0.1]],[[9,4],[0.8,0.1]],[[10,5],[0.8,0.1]],[[11,6],[0.8,0.1]],[[12,7],[0.8,0.1]],[[3],[0.1]],[[13,4],[0.5,0.1]],[[14,5],[0.5,0.1]],[[15,6],[0.5,0.1]],[[7,12,16],[0.9,0.2,1.0]],[[1,33],[0.125,0.5]],[[18],[0.9]],[[18,19,24],[0.03387533875338753,0.8,0.2]],[[18,20,25,29],[0.02710027100271003,0.8,0.2,0.5]],[[18,21,26,30],[0.02168021680216803,0.8,0.2,0.5]],[[18,22,27,31],[0.01734417344173442,0.8,0.2,0.5]],[[24,19],[0.8,0.1]],[[25,20],[0.8,0.1]],[[26,21],[0.8,0.1]],[[27,22],[0.8,0.1]],[[28,23],[0.8,0.1]],[[19],[0.1]],[[29,20],[0.5,0.1]],[[30,21],[0.5,0.1]],[[31,22],[0.5,0.1]],[[23,28,32],[0.9,0.2,1.0]],[[1,49],[0.125,0.5]],[[34],[0.9]],[[34,35,40],[0.03387533875338753,0.8,0.2]],[[34,36,41,45],[0.02710027100271003,0.8,0.2,0.5]],[[34,37,42,46],[0.02168021680216803,0.8,0.2,0.5]],[[34,38,43,47],[0.01734417344173442,0.8,0.2,0.5]],[[40,35],[0.8,0.1]],[[41,36],[0.8,0.1]],[[42,37],[0.8,0.1]],[[43,38],[0.8,0.1]],[[44,39],[0.8,0.1]],[[35],[0.1]],[[45,36],[0.5,0.1]],[[46,37],[0.5,0.1]],[[47,38],[0.5,0.1]],[[39,44,48],[0.9,0.2,1.0]],[[1,65],[0.125,0.5]],[[50],[0.9]],[[50,51,56],[0.03387533875338753,0.8,0.2]],[[50,52,57,61],[0.02710027100271003,0.8,0.2,0.5]],[[50,53,58,62],[0.02168021680216803,0.8,0.2,0.5]],[[50,54,59,63],[0.01734417344173442,0.8,0.2,0.5]],[[56,51],[0.8,0.1]],[[57,52],[0.8,0.1]],[[58,53],[0.8,0.1]],[[59,54],[0.8,0.1]],[[60,55],[0.8,0.1]],[[51],[0.1]],[[61,52],[0.5,0.1]],[[62,53],[0.5,0.1]],[[63,54],[0.5,0.1]],[[55,60,64],[0.9,0.2,1.0]],[[1,81],[0.125,0.5]],[[66],[0.9]],[[66,67,72],[0.03387533875338753,0.8,0.2]],[[66,68,73,77],[0.02710027100271003,0.8,0.2,0.5]],[[66,69,74,78],[0.02168021680216803,0.8,0.2,0.5]],[[66,70,75,79],[0.01734417344173442,0.8,0.2,0.5]],[[72,67],[0.8,0.1]],[[73,68],[0.8,0.1]],[[74,69],[0.8,0.1]],[[75,70],[0.8,0.1]],[[76,71],[0.8,0.1]],[[67],[0.1]],[[77,68],[0.5,0.1]],[[78,69],[0.5,0.1]],[[79,70],[0.5,0.1]],[[71,76,80],[0.9,0.2,1.0]],[[1,97],[0.125,0.5]],[[82],[0.9]],[[82,83,88],[0.03387533875338753,0.8,0.2]],[[82,84,89,93],[0.02710027100271003,0.8,0.2,0.5]],[[82,85,90,94],[0.02168021680216803,0.8,0.2,0.5]],[[82,86,91,95],[0.01734417344173442,0.8,0.2,0.5]],[[88,83],[0.8,0.1]],[[89,84],[0.8,0.1]],[[90,85],[0.8,0.1]],[[91,86],[0.8,0.1]],[[92,87],[0.8,0.1]],[[83],[0.1]],[[93,84],[0.5,0.1]],[[94,85],[0.5,0.1]],[[95,86],[0.5,0.1]],[[87,92,96],[0.9,0.2,1.0]],[[1,113],[0.125,0.5]],[[98],[0.9]],[[98,99,104],[0.03387533875338753,0.8,0.2]],[[98,100,105,109],[0.02710027100271003,0.8,0.2,0.5]],[[98,101,106,110],[0.02168021680216803,0.8,0.2,0.5]],[[98,102,107,111],[0.01734417344173442,0.8,0.2,0.5]],[[104,99],[0.8,0.1]],[[105,100],[0.8,0.1]],[[106,101],[0.8,0.1]],[[107,102],[0.8,0.1]],[[108,103],[0.8,0.1]],[[99],[0.1]],[[109,100],[0.5,0.1]],[[110,101],[0.5,0.1]],[[111,102],[0.5,0.1]],[[103,108,112],[0.9,0.2,1.0]],[[1,132],[0.125,0.5]],[[114],[0.9]],[[114,115,121],[0.029747739171822943,0.8,0.2]],[[114,116,122,127],[0.023798191337458356,0.8,0.2,0.5]],[[114,117,123,128],[0.01903855306996669,0.8,0.2,0.5]],[[114,118,124,129],[0.015230842455973349,0.8,0.2,0.5]],[[114,119,125,130],[0.01218467396477868,0.8,0.2,0.5]],[[121,115],[0.8,0.1]],[[122,116],[0.8,0.1]],[[123,117],[0.8,0.1]],[[124,118],[0.8,0.1]],[[125,119],[0.8,0.1]],[[126,120],[0.8,0.1]],[[115],[0.1]],[[127,116],[0.5,0.1]],[[128,117],[0.5,0.1]],[[129,118],[0.5,0.1]],[[130,119],[0.5,0.1]],[[120,126,131],[0.9,0.2,1.0]],[[17,33,49,65,81,97,113,132],[0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5]],[[133],[1.0]]]|(2,17,0),(18,33,1),(34,49,2),(50,65,3),(66,81,4),(82,97,5),(98,113,6),(114,132,7)"); - - hmm_defs - }; -} diff --git a/trvz/src/label_hmm.rs b/trvz/src/label_hmm.rs index ba55fef..8e9b5c8 100644 --- a/trvz/src/label_hmm.rs +++ b/trvz/src/label_hmm.rs @@ -1,12 +1,17 @@ -use super::hmm::{self, HmmMotif}; -use super::hmm_defs::HMM_DEFS; +use super::hmm::{Hmm, HmmMotif}; use crate::locus::{BaseLabel, Locus}; use itertools::Itertools; pub fn label_with_hmm(locus: &Locus, alleles: &Vec) -> Vec> { + let motifs = locus + .motifs + .iter() + .map(|m| m.as_bytes().to_vec()) + .collect_vec(); + let hmm = build_hmm(&motifs); + let mut labels_by_allele = Vec::new(); - let encoding = HMM_DEFS.get(&locus.struc[..]).unwrap(); - let hmm = decode_hmm(encoding); + for allele in alleles { let query = &allele[locus.left_flank.len()..allele.len() - locus.right_flank.len()]; let mut labels = vec![BaseLabel::Match; locus.left_flank.len()]; @@ -19,114 +24,265 @@ pub fn label_with_hmm(locus: &Locus, alleles: &Vec) -> Vec) -> Vec { +fn get_base_labels(hmm: &Hmm, states: &Vec) -> Vec { let mut base_labels = vec![BaseLabel::MotifBound]; - for (start, end, _motif_index) in hmm.label_motifs(states) { - base_labels.extend(vec![BaseLabel::Match; end - start]); + for span in hmm.label_motifs(states) { + base_labels.extend(vec![BaseLabel::Match; span.end - span.start]); base_labels.push(BaseLabel::MotifBound); } base_labels } -/*pub fn label_hmm(locus: &Locus, consensuses: &Vec) -> Vec> { - let hmm = decode_hmm(&locus.struc, &locus.motifs); - let mut labels_by_hap = Vec::new(); +pub fn build_hmm(motifs: &[Vec]) -> Hmm { + // 2 terminal states + 2 run start states + 3 states of the skip block + (4n - 1) states for each motif of length n + let num_states = 7 + motifs.iter().map(|m| 3 * m.len() + 1).sum::(); + let mut hmm = Hmm::new(num_states); - for seq in consensuses { - let labels = hmm.label(seq); - let spans = hmm.label_motifs(&labels); - labels_by_hap.push(Some(spans)); - } + let start = 0; + let end = num_states - 1; + let rs = start + 1; + let re = end - 1; + + // # A T C G + hmm.set_ems(start, vec![1.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_ems(end, vec![1.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(end, vec![re], vec![0.10]); + + hmm.set_ems(rs, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(rs, vec![start, re], vec![1.00, 0.90]); - labels_by_hap -} */ + let rs_to_ms = 0.50; // / (motifs.len() as f64 + 1.0); <- No longer an HMM because of this change + let me_to_re = 0.05; + let mut mes = Vec::new(); + mes.reserve(motifs.len() + 1); + let mut ms = rs + 1; + for motif in motifs { + let num_motif_states = 3 * motif.len() + 1; + let me = ms + num_motif_states - 1; -fn decode_hmm(encoding: &str) -> hmm::Hmm { - let mats = encoding.split('|').collect_vec(); - assert!(mats.len() == 3); - let ems = decode_emissions(mats[0]); - let transitions = decode_transitions(mats[1]); + hmm.set_ems(ms, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(ms, vec![rs, me], vec![rs_to_ms, 1.0 - me_to_re]); - assert!(ems.len() == transitions.len()); - let num_states = ems.len(); - let mut hmm = hmm::Hmm::new(num_states); + define_motif_block(&mut hmm, ms, &motif); - for (state, state_ems) in ems.into_iter().enumerate() { - hmm.set_ems(state, state_ems); + mes.push(me); + ms += num_motif_states; } - for (state, (in_states, probs)) in transitions.into_iter().enumerate() { - hmm.set_trans(state, in_states, probs); + assert_eq!(ms + 3, re); + + // Defined the skip block + let (skip_state, me) = (ms + 1, ms + 2); + hmm.set_ems(ms, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(ms, vec![rs, me], vec![rs_to_ms, 1.0 - me_to_re]); + + let skip_to_skip = 0.9; + hmm.set_ems(skip_state, vec![0.00, 0.25, 0.25, 0.25, 0.25]); + hmm.set_trans(skip_state, vec![ms, skip_state], vec![1.0, skip_to_skip]); + + hmm.set_ems(me, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(me, vec![skip_state], vec![1.0 - skip_to_skip]); + + mes.push(me); + + // Define the re state + hmm.set_ems(re, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + hmm.set_trans(re, mes.clone(), vec![me_to_re; motifs.len() + 1]); + + // Define motif spans + for (motif_index, motif) in motifs.iter().enumerate() { + let me = mes[motif_index]; + let ms = me - 3 * motif.len(); + hmm.motifs.push(HmmMotif { + start_state: ms, + end_state: me, + motif_index, + }); } - hmm.motifs = decode_motifs(mats[2]); + // Add skip state span + hmm.motifs.push(HmmMotif { + start_state: skip_state - 1, + end_state: skip_state + 1, + motif_index: motifs.len(), + }); hmm } -fn decode_emissions(encoding: &str) -> Vec> { - let mut mat = Vec::new(); - for row in encoding.split("],[") { - let row = row.trim_matches(|c| "\"[]".contains(c)); - let row = row - .split(',') - .map(|e| e.parse::().unwrap()) - .collect_vec(); - mat.push(row); +fn define_motif_block(hmm: &mut Hmm, ms: usize, motif: &Vec) { + let match_states = (ms + 1..ms + 1 + motif.len()).collect_vec(); + let first_ins_state = *match_states.last().unwrap() + 1; + let ins_states = (first_ins_state..first_ins_state + motif.len()).collect_vec(); + let first_del_state = *ins_states.last().unwrap() + 1; // If any + let del_states = (first_del_state..first_del_state + motif.len() - 1).collect_vec(); + + let match_prob = 0.90; + let ins_to_ins = 0.25; + let match_to_indel = (1.00 - match_prob) / 2.00; + let del_to_match = 0.50; + + // Define match states + let mismatch_seed_prob = 2.00 * (1.00 - match_prob) / (motif.len() * (motif.len() - 1)) as f64; + for (match_index, match_state) in match_states.iter().enumerate() { + hmm.set_ems(*match_state, get_match_emissions(motif[match_index])); + if match_index == 0 { + hmm.set_trans(*match_state, vec![ms], vec![match_prob]); + } else if match_index == 1 { + let multiplier = motif.len() - match_index; + let mismatch_prob = mismatch_seed_prob * multiplier as f64; + let prev_ins = ins_states[match_index - 1]; + + hmm.set_trans( + *match_state, + vec![match_state - 1, ms, prev_ins], + vec![match_prob, mismatch_prob, 1.0 - ins_to_ins], + ); + } else { + let multiplier = motif.len() - match_index; + let mismatch_prob = mismatch_seed_prob * multiplier as f64; + let prev_ins = ins_states[match_index - 1]; + let prev_del = del_states[match_index - 2]; + + hmm.set_trans( + *match_state, + vec![match_state - 1, ms, prev_ins, prev_del], + vec![match_prob, mismatch_prob, 1.0 - ins_to_ins, del_to_match], + ); + } } - mat -} + // Define insersion states + for (ins_index, ins_state) in ins_states.iter().enumerate() { + hmm.set_ems(*ins_state, vec![0.00, 0.25, 0.25, 0.25, 0.25]); + let match_state = match_states[ins_index]; + hmm.set_trans( + *ins_state, + vec![*ins_state, match_state], + vec![ins_to_ins, match_to_indel], + ); + } -fn decode_transitions(encoding: &str) -> Vec<(Vec, Vec)> { - let mut transitions = Vec::new(); - let encoding = encoding - .strip_prefix("[[[") - .unwrap() - .strip_suffix("]]]") - .unwrap(); - - for row in encoding.split("]],[[") { - if row.chars().all(|c| "[],".contains(c)) { - transitions.push((Vec::new(), Vec::new())); - continue; + // Define deletion states + for (del_index, del_state) in del_states.iter().enumerate() { + hmm.set_ems(*del_state, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + let prev_match = match_states[del_index]; + if del_index == 0 { + hmm.set_trans(*del_state, vec![prev_match], vec![match_to_indel]); + } else { + let prev_del = del_states[del_index - 1]; + hmm.set_trans( + *del_state, + vec![prev_match, prev_del], + vec![match_to_indel, 1.0 - del_to_match], + ); } + } - let states_and_probs = row.split("],[").collect_vec(); - assert!(states_and_probs.len() == 2); - let (states, probs) = (states_and_probs[0], states_and_probs[1]); - let states = states - .split(',') - .map(|e| e.parse::().unwrap()) - .collect_vec(); - let probs = probs - .split(',') - .map(|e| e.parse::().unwrap()) - .collect_vec(); - - assert!(states.len() == probs.len()); - transitions.push((states, probs)); + let num_motif_states = 3 * motif.len() + 1; + let me = ms + num_motif_states - 1; + hmm.set_ems(me, vec![0.00, 0.00, 0.00, 0.00, 0.00]); + if !del_states.is_empty() { + let last_match = *match_states.last().unwrap(); + let last_ins = *ins_states.last().unwrap(); + let last_del = *del_states.last().unwrap(); + hmm.set_trans( + me, + vec![last_match, last_ins, last_del], + vec![match_prob, 1.0 - ins_to_ins, 1.0], + ); + } else if !ins_states.is_empty() { + let last_match = *match_states.last().unwrap(); + let last_ins = *ins_states.last().unwrap(); + hmm.set_trans( + me, + vec![last_match, last_ins], + vec![match_prob, 1.0 - ins_to_ins], + ); + } else { + let last_match = *match_states.last().unwrap(); + hmm.set_trans(me, vec![last_match], vec![match_prob]); } +} - transitions +fn get_match_emissions(char: u8) -> Vec { + match char { + b'A' => vec![0.00, 0.90, 0.03, 0.03, 0.03], + b'T' => vec![0.00, 0.03, 0.90, 0.03, 0.03], + b'C' => vec![0.00, 0.03, 0.03, 0.90, 0.03], + b'G' => vec![0.00, 0.03, 0.03, 0.03, 0.90], + _ => panic!("Enountered unknown base {char}"), + } } -fn decode_motifs(encoding: &str) -> Vec { - let mut motifs = Vec::new(); - for motif_encoding in encoding.split("),(") { - let (start_state, end_state, motif_index) = motif_encoding - .trim_matches('(') - .trim_end_matches(')') - .split(',') - .map(|m| m.parse::().unwrap()) - .collect_tuple() - .unwrap(); - motifs.push(HmmMotif { - start_state, - end_state, - motif_index, - }); +#[cfg(test)] +mod tests { + use super::*; + use crate::hmm::Span; + + fn summarize(spans: &Vec) -> Vec<(usize, usize, usize)> { + let mut summary = Vec::new(); + for (motif_index, group) in &spans + .iter() + .map(|s| (s.start, s.end, s.motif_index)) + .group_by(|(_s, _e, m)| *m) + { + let group = group.collect_vec(); + summary.push(( + group.first().unwrap().0, + group.last().unwrap().1, + motif_index, + )); + } + summary } - motifs + #[test] + fn annotate_two_perfect_motif_runs() { + let motifs = vec!["CAG".as_bytes().to_vec(), "A".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let labels = hmm.label_motifs(&hmm.label("CAGCAGCAGCAGAAAAA")); + let expected = vec![(0, 12, 0), (12, 17, 1)]; + + assert_eq!(summarize(&labels), expected); + } + + #[test] + fn annotate_motif_runs_separated_by_insertion() { + let motifs = vec!["CAG".as_bytes().to_vec(), "A".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let labels = hmm.label_motifs(&hmm.label("CAGCAGATCGATCGATCGATCGAAAAA")); + let expected = vec![(0, 6, 0), (6, 22, 2), (22, 27, 1)]; + + assert_eq!(summarize(&labels), expected); + } + + #[test] + fn annotate_imperfect_repeat_run() { + let motifs = vec!["CAG".as_bytes().to_vec(), "A".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let labels = hmm.label_motifs(&hmm.label("CAGCAGCTGCAGCAGAAACAG")); + let expected = vec![(0, 21, 0)]; + + assert_eq!(summarize(&labels), expected); + } + + #[test] + fn parse_aga_repeat() { + // TODO: Consider improving this segmentation + let motifs = vec!["AAG".as_bytes().to_vec(), "CAAC".as_bytes().to_vec()]; + let hmm = build_hmm(&motifs); + let query = "TCTATGCAACCAACTTTCTGTTAGTCATAGTACCCCAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAAGAATAGAAATGTGTTTAAGAATTCCTCAATAAG"; + let labels = hmm.label_motifs(&hmm.label(query)); + let expected = vec![ + (0, 6, 2), + (6, 14, 1), + (14, 36, 2), + (36, 101, 0), + (101, 119, 2), + (119, 125, 0), + ]; + + assert_eq!(summarize(&labels), expected); + } } diff --git a/trvz/src/main.rs b/trvz/src/main.rs index 1147e85..32d4cb7 100644 --- a/trvz/src/main.rs +++ b/trvz/src/main.rs @@ -14,7 +14,6 @@ mod align; mod cli; mod genotype_plot; mod hmm; -mod hmm_defs; mod input; mod label_hmm; mod label_motifs; @@ -91,30 +90,21 @@ fn create_image(plot: PipePlot, path: String) -> io::Result<()> { } fn open_catalog_reader(path: &PathBuf) -> Result>> { - fn get_format(path: &Path) -> Option<&'static str> { + fn is_gzipped(path: &Path) -> bool { let path_str = path.to_string_lossy(); - let formats = ["bed", "bed.gz", "bed.gzip"]; - formats - .iter() - .find(|&&format| path_str.ends_with(format)) - .copied() + let formats = [".gz", ".gzip", ".GZ", ".GZIP"]; + formats.iter().any(|format| path_str.ends_with(*format)) } let file = File::open(path).map_err(|e| e.to_string())?; - match get_format(path) { - Some("bed.gz") | Some("bed.gzip") => { - let gz_decoder = GzDecoder::new(file); - if gz_decoder.header().is_some() { - Ok(BufReader::new(Box::new(gz_decoder))) - } else { - Err(format!("Invalid gzip header: {}", path.to_string_lossy()).into()) - } + if is_gzipped(path) { + let gz_decoder = GzDecoder::new(file); + if gz_decoder.header().is_some() { + Ok(BufReader::new(Box::new(gz_decoder))) + } else { + Err(format!("Invalid gzip header: {}", path.to_string_lossy())) } - Some("bed") => Ok(BufReader::new(Box::new(file))), - _ => Err(format!( - "Unknown bed format: {}. Supported formats are: .bed or .bed.gz(ip)", - path.to_string_lossy() - ) - .into()), + } else { + Ok(BufReader::new(Box::new(file))) } } diff --git a/trvz/src/pipe_plot.rs b/trvz/src/pipe_plot.rs index ade2146..8e97558 100644 --- a/trvz/src/pipe_plot.rs +++ b/trvz/src/pipe_plot.rs @@ -20,6 +20,8 @@ pub enum Color { Yellow, Red, Khaki, + PaleRed, + PaleBlue, Grad(f64), } @@ -85,6 +87,8 @@ pub fn encode_color(color: &Color) -> String { Color::Green => "#009D4E".to_string(), Color::Red => "#E3371E".to_string(), Color::Khaki => "#F0E68C".to_string(), + Color::PaleRed => "#FF4858".to_string(), + Color::PaleBlue => "#46B2E8".to_string(), Color::Grad(value) => get_gradient(*value), } } @@ -454,6 +458,8 @@ pub fn get_color(locus: &Locus, op: AlignmentOperation, label: &RegionLabel) -> Color::Green, Color::Red, Color::Khaki, + Color::PaleRed, + Color::PaleBlue, ]; if op == AlignOp::Subst { @@ -466,7 +472,7 @@ pub fn get_color(locus: &Locus, op: AlignmentOperation, label: &RegionLabel) -> match label { RegionLabel::Flank(_, _) => Color::Teal, - RegionLabel::Seq(_, _) => Color::Teal, + RegionLabel::Seq(_, _) => Color::LightGray, RegionLabel::Tr(_, _, motif) => { let index = locus.motifs.iter().position(|m| m == motif).unwrap(); tr_colors[index % tr_colors.len()].clone()