diff --git a/dorado/correct/features.cpp b/dorado/correct/features.cpp index dddbe9d82..ae183f9b1 100644 --- a/dorado/correct/features.cpp +++ b/dorado/correct/features.cpp @@ -1,13 +1,17 @@ #include "features.h" #include "conversions.h" +#include "correct/types.h" #include "read_pipeline/messages.h" #include "utils/sequence_utils.h" +#include "utils/types.h" #include #include #include +#include + #ifdef NDEBUG #define LOG_TRACE(...) #else @@ -19,61 +23,27 @@ const int TOP_K = 30; namespace dorado::correction { bool overlap_has_long_indel(const OverlapWindow& overlap, const CorrectionAlignments& alignments) { - bool long_indel = false; const auto& cigar = alignments.cigars[overlap.overlap_idx]; size_t max_cigar_idx = std::min(size_t(overlap.cigar_end_idx + 1), cigar.size()); for (size_t i = overlap.cigar_start_idx; i < max_cigar_idx; i++) { - if (cigar[i].op == CigarOpType::INS || cigar[i].op == CigarOpType::DEL) { - long_indel |= cigar[i].len >= 30; + if (cigar[i].len >= 30 && + (cigar[i].op == CigarOpType::INS || cigar[i].op == CigarOpType::DEL)) { + LOG_TRACE("filter ? tstart {} qstart {} qend {} long indel: {}", overlap.tstart, + overlap.qstart, overlap.qend, cigar[i].len); + return true; } } - LOG_TRACE("filter ? tstart {} qstart {} qend {} res {}", overlap.tstart, overlap.qstart, - overlap.qend, long_indel); - return long_indel; + return false; } // Measure the accuracy of an alignment segment within a window // determined from the cigar string. -void calculate_accuracy(OverlapWindow& overlap, - const CorrectionAlignments& alignments, - size_t win_idx, - int win_len, - int window_size) { - int tstart = overlap.tstart; - int tend = (int)win_idx * window_size + win_len; - - // get query region - const auto overlap_idx = overlap.overlap_idx; - int oqstart = alignments.overlaps[overlap_idx].qstart; - int oqend = alignments.overlaps[overlap_idx].qend; - int qstart, qend; - if (alignments.overlaps[overlap_idx].fwd) { - qstart = oqstart + overlap.qstart; - qend = oqstart + overlap.qend; - } else { - qstart = oqend - overlap.qend; - qend = oqend - overlap.qstart; - } - - int qlen = qend - qstart; - - // Fetch subsequences - std::string tseq = alignments.read_seq.substr(tstart, tend - tstart); - std::string qseq; - if (alignments.overlaps[overlap_idx].fwd) { - qseq = alignments.seqs[overlap_idx].substr(qstart, qlen); - } else { - qseq = utils::reverse_complement(alignments.seqs[overlap_idx].substr(qstart, qlen)); - } - - LOG_TRACE("tstart {} tend {} qstart {} qend {} cig st {} cig end {}", tstart, tend, qstart, - qend, overlap.cigar_start_idx, overlap.cigar_end_idx); - +void calculate_accuracy(OverlapWindow& overlap, const CorrectionAlignments& alignments) { const auto& cigar = alignments.cigars[overlap.overlap_idx]; + bool has_warned_bad_cigar_op = false; - // Calculate accuracy - int tpos = 0, qpos = 0; - int m = 0, s = 0, i = 0, d = 0; + // counts of match, mismatch, insert, deletion + int n_match = 0, n_miss = 0, n_ins = 0, n_del = 0; for (int idx = overlap.cigar_start_idx; idx <= overlap.cigar_end_idx; idx++) { int len = -1; @@ -91,40 +61,30 @@ void calculate_accuracy(OverlapWindow& overlap, break; } - LOG_TRACE("len {} tpos {} qpos {}", len, tpos, qpos); - switch (cigar[idx].op) { - case CigarOpType::MATCH: - for (int j = 0; j < len; j++) { - auto tbase = tseq[tpos + j]; - auto qbase = qseq[qpos + j]; - LOG_TRACE("{} tbase {}, {} qbase {}", tpos + j, tbase, qpos + j, qbase); - - if (tbase == qbase) { - m += 1; - } else { - s += 1; - } - } - - tpos += len; - qpos += len; + case CigarOpType::EQ_MATCH: + n_match += len; + break; + case CigarOpType::X_MISMATCH: + n_miss += len; break; case CigarOpType::INS: - i += len; - qpos += len; + n_ins += len; break; case CigarOpType::DEL: - d += len; - tpos += len; + n_del += len; break; default: + if (!has_warned_bad_cigar_op) { + has_warned_bad_cigar_op = true; + LOG_TRACE("unexpected CigarOpType: {}", uint8_t(cigar[idx].op)); + } break; } } - overlap.accuracy = (static_cast(m) / (m + s + i + d)); - LOG_TRACE("m {} s {} i {} d {}", m, s, i, d); + overlap.accuracy = (static_cast(n_match) / (n_match + n_miss + n_ins + n_del)); + LOG_TRACE("m {} s {} i {} d {}", n_match, n_miss, n_ins, n_del); LOG_TRACE("accuracy qstart {} qend {} {}", overlap.qstart, overlap.qend, overlap.accuracy); } @@ -148,7 +108,8 @@ std::vector get_max_ins_for_window(const std::vector& overla int len = cigar[i].len; int l = -1; - if (op == CigarOpType::MATCH || op == CigarOpType::DEL) { + if (op == CigarOpType::EQ_MATCH || op == CigarOpType::X_MISMATCH || + op == CigarOpType::DEL) { l = len; } else if (op == CigarOpType::INS) { max_ins[tpos - 1] = std::max(len, max_ins[tpos - 1]); @@ -283,8 +244,8 @@ std::tuple get_features_for_window( LOG_TRACE("cigar_idx {} l {}", cigar_idx, l); switch (op) { - case CigarOpType::MATCH: - case CigarOpType::MISMATCH: + case CigarOpType::EQ_MATCH: + case CigarOpType::X_MISMATCH: for (uint32_t i = 0; i < l; i++) { auto base = base_encoding[uint8_t(qseq[query_iter]) + (fwd ? 0 : 32)]; auto qual = qqual[query_iter]; @@ -414,18 +375,10 @@ at::Tensor get_indices(const at::Tensor& bases, const std::vector extract_features(std::vector>& windows, - const CorrectionAlignments& alignments, - int window_size) { - const std::string& tseq = alignments.read_seq; - int tlen = (int)tseq.length(); - - std::vector wfs; +std::unordered_set filter_features(std::vector>& windows, + const CorrectionAlignments& alignments) { + std::unordered_set overlap_idxs; for (int w = 0; w < (int)windows.size(); w++) { - int win_len = (w == (int)windows.size() - 1) ? tlen - window_size * w : window_size; - LOG_TRACE("win idx {}: win len {}", w, win_len); auto& overlap_windows = windows[w]; // Filter overlaps with very large indels @@ -442,7 +395,7 @@ std::vector extract_features(std::vector 1) { for (auto& ovlp : overlap_windows) { - calculate_accuracy(ovlp, alignments, w, win_len, window_size); + calculate_accuracy(ovlp, alignments); } // Sort the filtered overlaps by accuracy score std::sort(overlap_windows.begin(), overlap_windows.end(), @@ -456,14 +409,39 @@ std::vector extract_features(std::vector b.accuracy; }); } + // Take the TOP_K best overlaps overlap_windows.resize(std::min(TOP_K, (int)overlap_windows.size())); +#ifndef NDEBUG if (overlap_windows.size() == 1) { LOG_TRACE("window {} 1st {}-{}", w, overlap_windows[0].qstart, overlap_windows[0].qend); } else if (overlap_windows.size() > 1) { LOG_TRACE("window {} 1st {}-{} 2nd {}-{}", w, overlap_windows[0].qstart, overlap_windows[0].qend, overlap_windows[1].qstart, overlap_windows[1].qend); } +#endif + + for (const auto& ov : overlap_windows) { + assert(ov.overlap_idx >= 0); + overlap_idxs.insert(ov.overlap_idx); + } + } + return overlap_idxs; +} + +// Main interface function for generating features for the top_k overlaps for each window +// given the overlaps for a target read. +std::vector extract_features(std::vector>& windows, + const CorrectionAlignments& alignments, + int window_size) { + const std::string& tseq = alignments.read_seq; + int tlen = (int)tseq.length(); + + std::vector wfs; + for (int w = 0; w < (int)windows.size(); w++) { + int win_len = (w == (int)windows.size() - 1) ? tlen - window_size * w : window_size; + LOG_TRACE("win idx {}: win len {}", w, win_len); + auto& overlap_windows = windows[w]; WindowFeatures wf; wf.window_idx = w; diff --git a/dorado/correct/features.h b/dorado/correct/features.h index 24dbd9c9b..25b3a61c6 100644 --- a/dorado/correct/features.h +++ b/dorado/correct/features.h @@ -2,12 +2,18 @@ #include "types.h" +#include + namespace dorado { struct CorrectionAlignments; } namespace dorado::correction { +// Filter window features to TOP_K best. Returns collection of useful overlap indices +std::unordered_set filter_features(std::vector>& windows, + const CorrectionAlignments& alignments); + std::vector extract_features(std::vector>& windows, const CorrectionAlignments& alignments, int window_size); diff --git a/dorado/correct/types.h b/dorado/correct/types.h index 4bdf8ab20..c89e8f62c 100644 --- a/dorado/correct/types.h +++ b/dorado/correct/types.h @@ -8,6 +8,7 @@ namespace dorado::correction { struct OverlapWindow { + // CorrectionAlignments overlap vector index int overlap_idx = -1; int tstart = -1; int qstart = -1; diff --git a/dorado/correct/windows.cpp b/dorado/correct/windows.cpp index 1a3266ebd..0103f4db7 100644 --- a/dorado/correct/windows.cpp +++ b/dorado/correct/windows.cpp @@ -6,6 +6,8 @@ #include +#include + #ifdef NDEBUG #define LOG_TRACE(...) #else @@ -30,10 +32,6 @@ bool extract_windows(std::vector>& windows, for (int aln_idx = 0; aln_idx < num_alignments; aln_idx++) { const auto& overlap = alignments.overlaps[aln_idx]; const auto& cigar = alignments.cigars[aln_idx]; - //if (alignments.qnames[a] != "e3066d3e-2bdf-4803-89b9-0f077ac7ff7f") { - // continue; - //} - LOG_TRACE("window for {}", alignments.qnames[aln_idx]); // Following the is_target == False logic form the rust code. if ((overlap.tend - overlap.tstart < window_size) || @@ -62,10 +60,10 @@ bool extract_windows(std::vector>& windows, spdlog::error( "{} zeroth thres {} nth thres {} first win {} last win {} windows size {} " "overlap " - "tlen {} overlsp tstart {} overlap tend {} qname {} qlen {} qstart {} qend {}", + "tlen {} overlsp tstart {} overlap tend {} qlen {} qstart {} qend {}", alignments.read_name, zeroth_window_thresh, nth_window_thresh, first_window, last_window, windows.size(), overlap.tlen, overlap.tstart, overlap.tend, - alignments.qnames[aln_idx], overlap.qlen, overlap.qstart, overlap.qend); + overlap.qlen, overlap.qstart, overlap.qend); return false; } @@ -102,8 +100,8 @@ bool extract_windows(std::vector>& windows, int tnew = tpos; int qnew = qpos; switch (op.op) { - case CigarOpType::MATCH: - case CigarOpType::MISMATCH: + case CigarOpType::EQ_MATCH: + case CigarOpType::X_MISMATCH: tnew = tpos + op.len; qnew = qpos + op.len; LOG_TRACE("{} {}", op.len, "M"); @@ -117,7 +115,7 @@ bool extract_windows(std::vector>& windows, LOG_TRACE("{} {}", op.len, "I"); continue; default: - continue; + throw std::runtime_error("unexpected CigarOpType"); } LOG_TRACE("tpos {} qpos {} tnew {} qnew {}", tpos, qpos, tnew, qnew); @@ -137,9 +135,10 @@ bool extract_windows(std::vector>& windows, for (int i = 1; i < diff_w; i++) { int offset = (current_w + i) * window_size - tpos; - int q_start_new = (op.op == CigarOpType::MATCH || op.op == CigarOpType::MISMATCH) - ? qpos + offset - : qpos; + int q_start_new = + (op.op == CigarOpType::EQ_MATCH || op.op == CigarOpType::X_MISMATCH) + ? qpos + offset + : qpos; if (cigar_start_idx >= 0) { windows[(current_w + i) - 1].push_back( @@ -154,7 +153,7 @@ bool extract_windows(std::vector>& windows, t_window_start = tpos + offset; - if (op.op == CigarOpType::MATCH || op.op == CigarOpType::MISMATCH) { + if (op.op == CigarOpType::EQ_MATCH || op.op == CigarOpType::X_MISMATCH) { q_window_start = qpos + offset; } else { q_window_start = qpos; @@ -165,7 +164,7 @@ bool extract_windows(std::vector>& windows, } else { t_window_start = tpos + offset; - if (op.op == CigarOpType::MATCH || op.op == CigarOpType::MISMATCH) { + if (op.op == CigarOpType::EQ_MATCH || op.op == CigarOpType::X_MISMATCH) { q_window_start = qpos + offset; } else { q_window_start = qpos; @@ -179,7 +178,7 @@ bool extract_windows(std::vector>& windows, LOG_TRACE("new_w {} window size {} tpos {}", new_w, window_size, tpos); int offset = new_w * window_size - tpos; - int qend = (op.op == CigarOpType::MATCH || op.op == CigarOpType::MISMATCH) + int qend = (op.op == CigarOpType::EQ_MATCH || op.op == CigarOpType::X_MISMATCH) ? qpos + offset : qpos; diff --git a/dorado/read_pipeline/CorrectionNode.cpp b/dorado/read_pipeline/CorrectionNode.cpp index 74031ccb4..e7b1d36f4 100644 --- a/dorado/read_pipeline/CorrectionNode.cpp +++ b/dorado/read_pipeline/CorrectionNode.cpp @@ -11,6 +11,9 @@ #include "utils/string_utils.h" #include "utils/thread_utils.h" #include "utils/types.h" + +#include +#include #if DORADO_CUDA_BUILD #include "utils/cuda_utils.h" #endif @@ -46,44 +49,23 @@ dorado::BamPtr create_bam_record(const std::string& read_id, const std::string& return dorado::BamPtr(rec); } -std::vector parse_cigar(const uint32_t* cigar, uint32_t n_cigar) { - std::vector cigar_ops; - cigar_ops.resize(n_cigar); - for (uint32_t i = 0; i < n_cigar; i++) { - uint32_t op = cigar[i] & 0xf; - uint32_t len = cigar[i] >> 4; - if (op == MM_CIGAR_MATCH) { - cigar_ops[i] = {dorado::CigarOpType::MATCH, len}; - } else if (op == MM_CIGAR_INS) { - cigar_ops[i] = {dorado::CigarOpType::INS, len}; - } else if (op == MM_CIGAR_DEL) { - cigar_ops[i] = {dorado::CigarOpType::DEL, len}; - } else { - throw std::runtime_error("Unknown cigar op: " + std::to_string(op)); - } - } - return cigar_ops; -} - bool populate_alignments(dorado::CorrectionAlignments& alignments, - dorado::hts_io::FastxRandomReader* reader) { + dorado::hts_io::FastxRandomReader* reader, + const std::unordered_set& useful_overlap_idxs) { const auto& tname = alignments.read_name; + alignments.read_seq = reader->fetch_seq(tname); alignments.read_qual = reader->fetch_qual(tname); int tlen = (int)alignments.read_seq.length(); + + // Might be worthwhile generating dense vectors with some index mapping to save memory + // as using filtering of useful overlaps makes these vectors sparse. auto num_qnames = alignments.qnames.size(); alignments.seqs.resize(num_qnames); alignments.quals.resize(num_qnames); alignments.cigars.resize(num_qnames); - // In some cases the target read length reported by mm2 has differed from the - // read length when loaded from the fastq. So we check that here and skip - // any alignments where information is inconsisteny. - // TODO: This was mainly observed before a bug fix for proper loading - // of split mm2 indices was added. However the check is being kept around - // for now, and can be removed later. - std::vector pos_to_remove; - for (size_t i = 0; i < num_qnames; i++) { + for (const size_t i : useful_overlap_idxs) { const std::string& qname = alignments.qnames[i]; alignments.seqs[i] = reader->fetch_seq(qname); if ((int)alignments.seqs[i].length() != alignments.overlaps[i].qlen) { @@ -97,9 +79,6 @@ bool populate_alignments(dorado::CorrectionAlignments& alignments, alignments.overlaps[i].tlen, tlen, tname); return false; } - alignments.cigars[i] = parse_cigar(alignments.mm2_cigars[i].data(), - (uint32_t)alignments.mm2_cigars[i].size()); - alignments.mm2_cigars[i] = {}; } return alignments.check_consistent_overlaps(); @@ -354,19 +333,32 @@ void CorrectionNode::input_thread_fn() { auto alignments = std::get(std::move(message)); auto tname = alignments.read_name; - if (!populate_alignments(alignments, fastx_reader.get())) { + + if (alignments.overlaps.empty()) { continue; } - size_t n_windows = (alignments.read_seq.length() + m_window_size - 1) / m_window_size; - LOG_TRACE("num windows {} for read {}", n_windows, alignments.read_name); // Get the windows + size_t n_windows = (alignments.overlaps[0].tlen + m_window_size - 1) / m_window_size; + LOG_TRACE("num windows {} for read {}", n_windows, alignments.read_name); std::vector> windows; windows.resize(n_windows); if (!extract_windows(windows, alignments, m_window_size)) { continue; } - // Get the features + + // Filter the window features and get the set of unique overlaps + const std::unordered_set overlap_idxs = filter_features(windows, alignments); + if (overlap_idxs.empty()) { + continue; + } + + // Populate the alignment data with only the records that are useful after TOP_K filter + if (!populate_alignments(alignments, fastx_reader.get(), overlap_idxs)) { + continue; + } + + // Get the filtered features auto wfs = extract_features(windows, alignments, m_window_size); std::vector corrected_seqs; diff --git a/dorado/read_pipeline/ErrorCorrectionMapperNode.cpp b/dorado/read_pipeline/ErrorCorrectionMapperNode.cpp index 97d594aa9..3017ba579 100644 --- a/dorado/read_pipeline/ErrorCorrectionMapperNode.cpp +++ b/dorado/read_pipeline/ErrorCorrectionMapperNode.cpp @@ -6,9 +6,11 @@ #include "alignment/Minimap2Index.h" #include "alignment/Minimap2IndexSupportTypes.h" #include "alignment/Minimap2Options.h" +#include "alignment/minimap2_args.h" +#include "alignment/minimap2_wrappers.h" #include "utils/PostCondition.h" #include "utils/bam_utils.h" -#include "utils/thread_utils.h" +#include "utils/thread_naming.h" #include #include @@ -20,12 +22,32 @@ #include #include -const size_t MAX_OVERLAPS_PER_READ = 500; - namespace { -std::vector copy_mm2_cigar(const uint32_t* cigar, uint32_t n_cigar) { - std::vector cigar_ops(cigar, cigar + n_cigar); +std::vector parse_cigar(const uint32_t* cigar, uint32_t n_cigar) { + std::vector cigar_ops; + cigar_ops.resize(n_cigar); + for (uint32_t i = 0; i < n_cigar; i++) { + const uint32_t op = cigar[i] & 0xf; + const uint32_t len = cigar[i] >> 4; + + // minimap2 --eqx must be set + if (op == MM_CIGAR_EQ_MATCH) { + cigar_ops[i] = {dorado::CigarOpType::EQ_MATCH, len}; + } else if (op == MM_CIGAR_X_MISMATCH) { + cigar_ops[i] = {dorado::CigarOpType::X_MISMATCH, len}; + } else if (op == MM_CIGAR_INS) { + cigar_ops[i] = {dorado::CigarOpType::INS, len}; + } else if (op == MM_CIGAR_DEL) { + cigar_ops[i] = {dorado::CigarOpType::DEL, len}; + } else if (op == MM_CIGAR_MATCH) { + throw std::runtime_error( + "cigar op MATCH is not supported must set minimap2 --eqx flag" + + std::to_string(op)); + } else { + throw std::runtime_error("Unknown cigar op: " + std::to_string(op)); + } + } return cigar_ops; } @@ -104,22 +126,13 @@ void ErrorCorrectionMapperNode::extract_alignments(const mm_reg1_t* reg, continue; } - uint32_t n_cigar = aln->p->n_cigar; - auto cigar = copy_mm2_cigar(aln->p->cigar, n_cigar); - + auto cigar = parse_cigar(aln->p->cigar, aln->p->n_cigar); { std::lock_guard aln_lock(mtx); auto& alignments = m_correction_records[tname]; - - // Cap total overlaps per read. - if (alignments.qnames.size() >= MAX_OVERLAPS_PER_READ) { - continue; - } - alignments.qnames.push_back(qname); - - alignments.mm2_cigars.push_back(std::move(cigar)); + alignments.cigars.push_back(std::move(cigar)); alignments.overlaps.push_back(std::move(ovlp)); } } @@ -245,20 +258,31 @@ ErrorCorrectionMapperNode::ErrorCorrectionMapperNode(const std::string& index_fi m_index_file(index_file), m_num_threads(threads), m_reads_queue(5000) { - alignment::Minimap2Options options = alignment::dflt_options; - options.kmer_size = 25; - options.window_size = 17; - options.index_batch_size = index_size; - options.mm2_preset = "ava-ont"; - options.bandwidth = 150; - options.bandwidth_long = 2000; - options.min_chain_score = 4000; - options.zdrop = options.zdrop_inv = 200; - options.occ_dist = 200; - options.cs = "short"; - options.dual = "yes"; - options.cap_kalloc = std::nullopt; - options.max_sw_mat = std::nullopt; + auto options = alignment::create_preset_options("ava-ont"); + auto& index_options = options.index_options->get(); + index_options.k = 25; + index_options.w = 17; + index_options.batch_size = index_size; + auto& mapping_options = options.mapping_options->get(); + mapping_options.bw = 150; + mapping_options.bw_long = 2000; + mapping_options.min_chain_score = 4000; + mapping_options.zdrop = 200; + mapping_options.zdrop_inv = 200; + mapping_options.occ_dist = 200; + mapping_options.flag |= MM_F_EQX; + + // --cs short + alignment::mm2::apply_cs_option(options, "short"); + + // --dual yes + alignment::mm2::apply_dual_option(options, "yes"); + + // reset to larger minimap2 defaults + mm_mapopt_t minimap_default_mapopt; + mm_mapopt_init(&minimap_default_mapopt); + mapping_options.cap_kalloc = minimap_default_mapopt.cap_kalloc; + mapping_options.max_sw_mat = minimap_default_mapopt.max_sw_mat; m_index = std::make_shared(); if (!m_index->initialise(options)) { diff --git a/dorado/read_pipeline/messages.h b/dorado/read_pipeline/messages.h index 7a0bd54f4..91bbd5a96 100644 --- a/dorado/read_pipeline/messages.h +++ b/dorado/read_pipeline/messages.h @@ -222,15 +222,17 @@ struct Overlap { // Overlaps for error correction struct CorrectionAlignments { + // Populated in ErrorCorrectionMapperNode::extract_alignments std::string read_name = ""; + std::vector qnames = {}; + std::vector> cigars = {}; + std::vector overlaps = {}; + + // Populated in CorrectionNode::populate_alignments if the alignment is useful std::string read_seq = ""; std::vector read_qual = {}; - std::vector overlaps = {}; - std::vector> cigars = {}; - std::vector> mm2_cigars = {}; std::vector seqs = {}; std::vector> quals = {}; - std::vector qnames = {}; // This is mostly to workaround an issue where sometimes // the tend of an overlap is much bigger than the @@ -257,9 +259,6 @@ struct CorrectionAlignments { for (auto& v : cigars) { si += v.size() * sizeof(CigarOp); } - for (auto& v : mm2_cigars) { - si += v.size() * sizeof(uint32_t); - } for (auto& s : seqs) { si += s.length(); } diff --git a/dorado/utils/types.h b/dorado/utils/types.h index 1d3b9ce3d..63532bff4 100644 --- a/dorado/utils/types.h +++ b/dorado/utils/types.h @@ -237,7 +237,7 @@ struct ModBaseInfo { }; // Enum for handling CIGAR ops -enum class CigarOpType : uint8_t { INS = 0, DEL, MATCH, MISMATCH }; +enum class CigarOpType : uint8_t { INS = 0, DEL, EQ_MATCH, X_MISMATCH }; struct CigarOp { CigarOpType op;