diff --git a/dorado/read_pipeline/PairingNode.cpp b/dorado/read_pipeline/PairingNode.cpp index c1699cdee..0704b24c9 100644 --- a/dorado/read_pipeline/PairingNode.cpp +++ b/dorado/read_pipeline/PairingNode.cpp @@ -9,6 +9,12 @@ #include #include +namespace { +const int kMaxTimeDeltaMs = 1000; +const float kMinSeqLenRatio = 0.2f; +const int kMinOverlapLength = 50; +} // namespace + namespace dorado { // Determine whether 2 proposed reads form a duplex pair or not. @@ -24,14 +30,10 @@ namespace dorado { // one read maps to the reverse strand of the other, and the end // of the template is mapped to the beginning // of the complement read, then consider them a pair. -std::tuple -PairingNode::is_within_time_and_length_criteria(const std::shared_ptr& temp, - const std::shared_ptr& comp, - int tid) { - std::tuple pair_result = {false, 0, 0, 0, 0}; - const int kMaxTimeDeltaMs = 1000; - const float kMinSeqLenRatio = 0.2f; - const int kMinOverlapLength = 50; +PairingNode::PairingResult PairingNode::is_within_time_and_length_criteria( + const std::shared_ptr& temp, + const std::shared_ptr& comp, + int tid) { int delta = comp->start_time_ms - temp->get_end_time_ms(); int seq_len1 = temp->seq.length(); int seq_len2 = comp->seq.length(); @@ -39,96 +41,108 @@ PairingNode::is_within_time_and_length_criteria(const std::shared_ptr(min_seq_len) / static_cast(max_seq_len); - if ((delta >= 0) && (delta < kMaxTimeDeltaMs) && (len_ratio > kMinSeqLenRatio) && - (min_seq_len > kMinOverlapLength)) { - const float kEarlyAcceptSeqLenRatio = 0.98; - const int kEarlyAcceptTimeDeltaMs = 100; - if (delta <= kEarlyAcceptTimeDeltaMs && len_ratio >= kEarlyAcceptSeqLenRatio && - min_seq_len >= 5000) { - spdlog::debug( - "Early acceptance: len frac {}, delta {} temp len {}, comp len {}, {} and {}", - len_ratio, delta, temp->seq.length(), comp->seq.length(), temp->read_id, - comp->read_id); - m_early_accepted_pairs++; - return {true, 0, temp->seq.length() - 1, 0, comp->seq.length() - 1}; - } + if ((delta < 0) || (delta >= kMaxTimeDeltaMs) || (len_ratio <= kMinSeqLenRatio) || + (min_seq_len <= kMinOverlapLength)) { + return {false, 0, 0, 0, 0}; + } - const std::string nvtx_id = "pairing_map_" + std::to_string(tid); - nvtx3::scoped_range loop{nvtx_id}; - // Add mm2 based overlap check. - mm_idxopt_t m_idx_opt; - mm_mapopt_t m_map_opt; - mm_set_opt(0, &m_idx_opt, &m_map_opt); - mm_set_opt("map-hifi", &m_idx_opt, &m_map_opt); - - std::vector seqs = {temp->seq.c_str()}; - std::vector names = {temp->read_id.c_str()}; - mm_idx_t* m_index = mm_idx_str(m_idx_opt.w, m_idx_opt.k, 0, m_idx_opt.bucket_bits, 1, - seqs.data(), names.data()); - mm_mapopt_update(&m_map_opt, m_index); - - mm_tbuf_t* tbuf = m_tbufs[tid].get(); - - int hits = 0; - mm_reg1_t* reg = mm_map(m_index, comp->seq.length(), comp->seq.c_str(), &hits, tbuf, - &m_map_opt, comp->read_id.c_str()); - - mm_idx_destroy(m_index); - - // Multiple hits implies ambiguous mapping, so ignore those pairs. - if (hits == 1) { - uint8_t mapq = 0; - int32_t temp_start = 0; - int32_t temp_end = 0; - int32_t comp_start = 0; - int32_t comp_end = 0; - bool rev = false; - - auto best_map = ®[0]; - mapq = best_map->mapq; - temp_start = best_map->rs; - temp_end = best_map->re; - comp_start = best_map->qs; - comp_end = best_map->qe; - rev = best_map->rev; - - free(best_map->p); - - const int kMinMapQ = 50; - const float kMinOverlapFraction = 0.8f; - - // Require high mapping quality. - bool meets_mapq = (mapq >= kMinMapQ); - // Require overlap to cover most of at least one of the reads. - float overlap_frac = - std::max(static_cast(temp_end - temp_start) / temp->seq.length(), - static_cast(comp_end - comp_start) / comp->seq.length()); - bool meets_length = overlap_frac > kMinOverlapFraction; - // Require the start of the complement strand to map to end - // of the template strand. - bool ends_anchored = (comp_start + (temp->seq.length() - temp_end)) <= 500; - int min_overlap_length = std::min(temp_end - temp_start, comp_end - comp_start); - bool meets_min_overlap_length = min_overlap_length > kMinOverlapLength; - bool cond = (meets_mapq && meets_length && rev && ends_anchored && - meets_min_overlap_length); - - spdlog::debug( - "hits {}, mapq {}, overlap length {}, overlap frac {}, delta {}, read 1 {}, " - "read 2 " - "{}, strand {}, pass {}, temp start {} temp end {}, comp start {} comp end {}, " - "{} " - "and {}", - hits, mapq, temp_end - temp_start, overlap_frac, delta, temp->seq.length(), - comp->seq.length(), rev ? "-" : "+", cond, temp_start, temp_end, comp_start, - comp_end, temp->read_id, comp->read_id); - - if (cond) { - m_overlap_accepted_pairs++; - pair_result = {true, temp_start, temp_end, comp_start, comp_end}; - } + const float kEarlyAcceptSeqLenRatio = 0.98; + const int kEarlyAcceptTimeDeltaMs = 100; + if (delta <= kEarlyAcceptTimeDeltaMs && len_ratio >= kEarlyAcceptSeqLenRatio && + min_seq_len >= 5000) { + spdlog::debug("Early acceptance: len frac {}, delta {} temp len {}, comp len {}, {} and {}", + len_ratio, delta, temp->seq.length(), comp->seq.length(), temp->read_id, + comp->read_id); + m_early_accepted_pairs++; + return {true, 0, temp->seq.length() - 1, 0, comp->seq.length() - 1}; + } + + return is_within_alignment_criteria(temp, comp, delta, true, tid); +} + +PairingNode::PairingResult PairingNode::is_within_alignment_criteria( + const std::shared_ptr& temp, + const std::shared_ptr& comp, + int delta, + bool allow_rejection, + int tid) { + PairingResult pair_result = {false, 0, 0, 0, 0}; + const std::string nvtx_id = "pairing_map_" + std::to_string(tid); + nvtx3::scoped_range loop{nvtx_id}; + // Add mm2 based overlap check. + mm_idxopt_t m_idx_opt; + mm_mapopt_t m_map_opt; + mm_set_opt(0, &m_idx_opt, &m_map_opt); + mm_set_opt("map-hifi", &m_idx_opt, &m_map_opt); + + std::vector seqs = {temp->seq.c_str()}; + std::vector names = {temp->read_id.c_str()}; + mm_idx_t* m_index = mm_idx_str(m_idx_opt.w, m_idx_opt.k, 0, m_idx_opt.bucket_bits, 1, + seqs.data(), names.data()); + mm_mapopt_update(&m_map_opt, m_index); + + mm_tbuf_t* tbuf = m_tbufs[tid].get(); + + int hits = 0; + mm_reg1_t* reg = mm_map(m_index, comp->seq.length(), comp->seq.c_str(), &hits, tbuf, &m_map_opt, + comp->read_id.c_str()); + + mm_idx_destroy(m_index); + + // Multiple hits implies ambiguous mapping, so ignore those pairs. + if (hits == 1 || (!allow_rejection && hits > 0)) { + uint8_t mapq = 0; + int32_t temp_start = 0; + int32_t temp_end = 0; + int32_t comp_start = 0; + int32_t comp_end = 0; + bool rev = false; + + auto best_map = ®[0]; + mapq = best_map->mapq; + temp_start = best_map->rs; + temp_end = best_map->re; + comp_start = best_map->qs; + comp_end = best_map->qe; + rev = best_map->rev; + + const int kMinMapQ = 50; + const float kMinOverlapFraction = 0.8f; + + // Require high mapping quality. + bool meets_mapq = (mapq >= kMinMapQ); + // Require overlap to cover most of at least one of the reads. + float overlap_frac = + std::max(static_cast(temp_end - temp_start) / temp->seq.length(), + static_cast(comp_end - comp_start) / comp->seq.length()); + bool meets_length = overlap_frac > kMinOverlapFraction; + // Require the start of the complement strand to map to end + // of the template strand. + bool ends_anchored = (comp_start + (temp->seq.length() - temp_end)) <= 500; + int min_overlap_length = std::min(temp_end - temp_start, comp_end - comp_start); + bool meets_min_overlap_length = min_overlap_length > kMinOverlapLength; + bool cond = + (meets_mapq && meets_length && rev && ends_anchored && meets_min_overlap_length); + + spdlog::debug( + "hits {}, mapq {}, overlap length {}, overlap frac {}, delta {}, read 1 {}, " + "read 2 {}, strand {}, pass {}, accepted {}, temp start {} temp end {}, " + "comp start {} comp end {}, {} and {}", + hits, mapq, temp_end - temp_start, overlap_frac, delta, temp->seq.length(), + comp->seq.length(), rev ? "-" : "+", cond, !allow_rejection, temp_start, temp_end, + comp_start, comp_end, temp->read_id, comp->read_id); + + if (cond || !allow_rejection) { + m_overlap_accepted_pairs++; + pair_result = {true, temp_start, temp_end, comp_start, comp_end}; } - free(reg); } + + for (int i = 0; i < hits; ++i) { + free(reg[i].p); + } + free(reg); + return pair_result; } @@ -192,13 +206,21 @@ void PairingNode::pair_list_worker_thread(int tid) { template_read = partner_read; } - ReadPair read_pair; - read_pair.read_1 = template_read; - read_pair.read_2 = complement_read; + int delta = complement_read->start_time_ms - template_read->get_end_time_ms(); + auto [is_pair, qs, qe, rs, re] = is_within_alignment_criteria( + template_read, complement_read, delta, false, tid); + if (is_pair) { + ReadPair read_pair = {template_read, complement_read, qs, qe, rs, re}; + template_read->is_duplex_parent = true; + complement_read->is_duplex_parent = true; - ++template_read->num_duplex_candidate_pairs; + ++template_read->num_duplex_candidate_pairs; - send_message_to_sink(std::make_shared(read_pair)); + send_message_to_sink(std::make_shared(read_pair)); + } else { + spdlog::debug("- rejected explicitly requested read pair: {} and {}", + template_read->read_id, complement_read->read_id); + } } } } diff --git a/dorado/read_pipeline/PairingNode.h b/dorado/read_pipeline/PairingNode.h index eb29c07a1..1d6fabeb6 100644 --- a/dorado/read_pipeline/PairingNode.h +++ b/dorado/read_pipeline/PairingNode.h @@ -103,10 +103,16 @@ class PairingNode : public MessageSink { */ size_t m_max_num_reads; - std::tuple is_within_time_and_length_criteria( - const std::shared_ptr& read1, - const std::shared_ptr& read2, - int tid); + using PairingResult = std::tuple; + PairingResult is_within_time_and_length_criteria(const std::shared_ptr& read1, + const std::shared_ptr& read2, + int tid); + + PairingResult is_within_alignment_criteria(const std::shared_ptr& temp, + const std::shared_ptr& comp, + int delta, + bool allow_rejection, + int tid); // Store the minimap2 buffers used for mapping. One buffer per thread. std::vector m_tbufs; diff --git a/dorado/read_pipeline/ReadPipeline.h b/dorado/read_pipeline/ReadPipeline.h index cfcbd7cb2..46bd21452 100644 --- a/dorado/read_pipeline/ReadPipeline.h +++ b/dorado/read_pipeline/ReadPipeline.h @@ -143,10 +143,10 @@ class ReadPair { public: std::shared_ptr read_1; std::shared_ptr read_2; - uint32_t read_1_start; - uint32_t read_1_end; - uint32_t read_2_start; - uint32_t read_2_end; + uint64_t read_1_start; + uint64_t read_1_end; + uint64_t read_2_start; + uint64_t read_2_end; }; class CacheFlushMessage { diff --git a/dorado/read_pipeline/StereoDuplexEncoderNode.cpp b/dorado/read_pipeline/StereoDuplexEncoderNode.cpp index 6c8d4396d..b3bba5343 100644 --- a/dorado/read_pipeline/StereoDuplexEncoderNode.cpp +++ b/dorado/read_pipeline/StereoDuplexEncoderNode.cpp @@ -15,10 +15,10 @@ namespace dorado { std::shared_ptr StereoDuplexEncoderNode::stereo_encode( std::shared_ptr template_read, std::shared_ptr complement_read, - uint32_t temp_start, - uint32_t temp_end, - uint32_t comp_start, - uint32_t comp_end) { + uint64_t temp_start, + uint64_t temp_end, + uint64_t comp_start, + uint64_t comp_end) { // We rely on the incoming read raw data being of type float16 to allow direct memcpy // of tensor elements. assert(template_read->raw_data.dtype() == torch::kFloat16); diff --git a/dorado/read_pipeline/StereoDuplexEncoderNode.h b/dorado/read_pipeline/StereoDuplexEncoderNode.h index 1816a8469..500ed04a1 100644 --- a/dorado/read_pipeline/StereoDuplexEncoderNode.h +++ b/dorado/read_pipeline/StereoDuplexEncoderNode.h @@ -14,10 +14,10 @@ class StereoDuplexEncoderNode : public MessageSink { std::shared_ptr stereo_encode(std::shared_ptr template_read, std::shared_ptr complement_read, - uint32_t temp_start, - uint32_t temp_end, - uint32_t comp_start, - uint32_t comp_end); + uint64_t temp_start, + uint64_t temp_end, + uint64_t comp_start, + uint64_t comp_end); ~StereoDuplexEncoderNode() { terminate_impl(); }; std::string get_name() const override { return "StereoDuplexEncoderNode"; }