From 2719303056807f5dbd895d49648afd92fa708078 Mon Sep 17 00:00:00 2001 From: Erik Garrison Date: Fri, 2 Aug 2024 10:04:07 +0200 Subject: [PATCH] completely rewrite alignment computation to allow for multiple seq reader threads, avoiding contention --- src/align/include/computeAlignments.hpp | 717 ++++++++++++------------ 1 file changed, 369 insertions(+), 348 deletions(-) diff --git a/src/align/include/computeAlignments.hpp b/src/align/include/computeAlignments.hpp index f5ec7675..328dc8ba 100644 --- a/src/align/include/computeAlignments.hpp +++ b/src/align/include/computeAlignments.hpp @@ -55,6 +55,7 @@ struct seq_record_t { uint64_t queryStartPos; uint64_t queryLen; uint64_t queryTotalLength; + seq_record_t(const MappingBoundaryRow& c, const std::string& r, const std::string& ref, uint64_t refStart, uint64_t refLength, uint64_t refTotalLength, const std::string& query, uint64_t queryStart, uint64_t queryLength, uint64_t queryTotalLength) @@ -183,363 +184,383 @@ typedef atomic_queue::AtomicQueue reader_done; - reader_done.store(false); - - auto& nthreads = param.threads; - //for ( - - // atomics to record if we're working or not - std::vector> working(nthreads); - for (auto& w : working) { - w.store(true); - } - - size_t total_alignments_queued = 0; - auto reader_thread = [&]() { - std::ifstream mappingListStream(param.mashmapPafFile); - if (!mappingListStream.is_open()) { - throw std::runtime_error("[wfmash::align::computeAlignments] Error! Failed to open input mapping file: " + param.mashmapPafFile); - } - - std::string mappingRecordLine; - MappingBoundaryRow currentRecord; - - while (!mappingListStream.eof()) { - std::getline(mappingListStream, mappingRecordLine); - if (!mappingRecordLine.empty()) { - parseMashmapRow(mappingRecordLine, currentRecord); - - // Get the reference sequence length - const int64_t ref_size = faidx_seq_len(ref_faidx, currentRecord.refId.c_str()); - // Get the query sequence length - const int64_t query_size = faidx_seq_len(query_faidx, currentRecord.qId.c_str()); - - // Compute padding - const uint64_t head_padding = currentRecord.rStartPos >= param.wflign_max_len_minor - ? param.wflign_max_len_minor : currentRecord.rStartPos; - const uint64_t tail_padding = ref_size - currentRecord.rEndPos >= param.wflign_max_len_minor - ? param.wflign_max_len_minor : ref_size - currentRecord.rEndPos; - - // Extract reference sequence - int64_t ref_len; - char* ref_seq = faidx_fetch_seq64(ref_faidx, currentRecord.refId.c_str(), - currentRecord.rStartPos - head_padding, currentRecord.rEndPos + tail_padding, &ref_len); - - // Extract query sequence - int64_t query_len; - char* query_seq = faidx_fetch_seq64(query_faidx, currentRecord.qId.c_str(), - currentRecord.qStartPos, currentRecord.qEndPos, &query_len); - - // Create a new seq_record_t object for the alignment using std::move - seq_record_t* rec = new seq_record_t(currentRecord, mappingRecordLine, - std::string(ref_seq, ref_len), currentRecord.rStartPos - head_padding, ref_len, ref_size, - std::string(query_seq, query_len), currentRecord.qStartPos, query_len, query_size); - - // Clean up - free(ref_seq); - free(query_seq); - - ++total_alignments_queued; - seq_queue.push(rec); - } - } - - mappingListStream.close(); - reader_done.store(true); - }; - - // helper to check if we're still aligning - auto still_working = - [&](const std::vector>& working) { - bool ongoing = false; - for (auto& w : working) { - ongoing = ongoing || w.load(); - } - return ongoing; - }; - - // writer, picks output from queue and writes it to our output stream - std::ofstream outstrm(param.pafOutputFile, ios::app); - - size_t total_alignments_written = 0; - auto writer_thread = - [&]() { - while (true) { - std::string* paf_lines = nullptr; - if (!paf_queue.try_pop(paf_lines) - && !still_working(working)) { - break; - } else if (paf_lines != nullptr) { - ++total_alignments_written; - outstrm << *paf_lines; - delete paf_lines; - } else { - std::this_thread::sleep_for(100ns); - } - } - }; - -#ifdef WFA_PNG_TSV_TIMING - auto writer_thread_tsv = - [&]() { - if (!param.tsvOutputPrefix.empty()) { - uint64_t num_alignments_completed = 0; - - while (true) { - std::string* tsv_lines = nullptr; - if (!tsv_queue.try_pop(tsv_lines) - && !still_working(working)) { - break; - } else if (tsv_lines != nullptr) { - std::ofstream ofstream_tsv(param.tsvOutputPrefix + std::to_string(num_alignments_completed++) + ".tsv"); - ofstream_tsv << *tsv_lines; - ofstream_tsv.close(); - - delete tsv_lines; - } else { - std::this_thread::sleep_for(100ns); - } - } - } - }; - - std::ofstream ofstream_patching_tsv(param.path_patching_info_in_tsv); - auto writer_thread_patching_tsv = - [&]() { - if (!param.path_patching_info_in_tsv.empty()) { - while (true) { - std::string* tsv_lines = nullptr; - if (!patching_tsv_queue.try_pop(tsv_lines) - && !still_working(working)) { - break; - } else if (tsv_lines != nullptr) { - ofstream_patching_tsv << *tsv_lines; - - delete tsv_lines; - } else { - std::this_thread::sleep_for(100ns); - } - } - } - }; -#endif - - // worker, takes candidate alignments and runs wfa alignment on them - auto worker_thread = - [&](uint64_t tid, - std::atomic& is_working) { - is_working.store(true); - while (true) { - seq_record_t* rec = nullptr; - if (!seq_queue.try_pop(rec) - && reader_done.load()) { - break; - } else if (rec != nullptr) { - std::stringstream output; -#ifdef WFA_PNG_TSV_TIMING - std::stringstream output_tsv; - std::stringstream patching_output_tsv; -#endif - doAlignment( - output, -#ifdef WFA_PNG_TSV_TIMING - output_tsv, - patching_output_tsv, -#endif - rec, - tid); - progress.increment(rec->currentRecord.qEndPos - rec->currentRecord.qStartPos); - - auto* paf_rec = new std::string(output.str()); - if (!paf_rec->empty()) { - paf_queue.push(paf_rec); - } else { - delete paf_rec; - } - -#ifdef WFA_PNG_TSV_TIMING - auto* tsv_rec = new std::string(output_tsv.str()); - if (!tsv_rec->empty()) { - tsv_queue.push(tsv_rec); - } else { - delete tsv_rec; - } - - auto* patching_tsv_rec = new std::string(patching_output_tsv.str()); - if (!patching_tsv_rec->empty()) { - patching_tsv_queue.push(patching_tsv_rec); - } else { - delete patching_tsv_rec; - } -#endif - - delete rec; - } else { - std::this_thread::sleep_for(100ns); - } - } - is_working.store(false); - }; - - // launch reader - std::thread reader(reader_thread); - // launch PAF/SAM writer - std::thread writer(writer_thread); -#ifdef WFA_PNG_TSV_TIMING - // launch TSV writer - std::thread writer_tsv(writer_thread_tsv); - std::thread writer_patching_tsv(writer_thread_patching_tsv); -#endif - // launch workers - std::vector workers; workers.reserve(nthreads); - for (uint64_t t = 0; t < nthreads; ++t) { - workers.emplace_back(worker_thread, - t, - std::ref(working[t])); - } - - // wait for reader and workers to complete - reader.join(); - for (auto& worker : workers) { - worker.join(); - } - // and finally the writer - writer.join(); -#ifdef WFA_PNG_TSV_TIMING - writer_tsv.join(); - writer_patching_tsv.join(); - ofstream_patching_tsv.close(); -#endif - - progress.finish(); - std::cerr << "[wfmash::align::computeAlignments] " - << "count of mapped reads = " << total_seqs - << ", total aligned bp = " << total_alignment_length << std::endl; - } +seq_record_t* createSeqRecord(const MappingBoundaryRow& currentRecord, + const std::string& mappingRecordLine, + faidx_t* ref_faidx, + faidx_t* query_faidx) { + // Get the reference sequence length + const int64_t ref_size = faidx_seq_len(ref_faidx, currentRecord.refId.c_str()); + // Get the query sequence length + const int64_t query_size = faidx_seq_len(query_faidx, currentRecord.qId.c_str()); + + // Compute padding + const uint64_t head_padding = currentRecord.rStartPos >= param.wflign_max_len_minor + ? param.wflign_max_len_minor : currentRecord.rStartPos; + const uint64_t tail_padding = ref_size - currentRecord.rEndPos >= param.wflign_max_len_minor + ? param.wflign_max_len_minor : ref_size - currentRecord.rEndPos; + + // Extract reference sequence + int64_t ref_len; + char* ref_seq = faidx_fetch_seq64(ref_faidx, currentRecord.refId.c_str(), + currentRecord.rStartPos - head_padding, + currentRecord.rEndPos + tail_padding, &ref_len); + + // Extract query sequence + int64_t query_len; + char* query_seq = faidx_fetch_seq64(query_faidx, currentRecord.qId.c_str(), + currentRecord.qStartPos, currentRecord.qEndPos, &query_len); + + // Create a new seq_record_t object for the alignment + seq_record_t* rec = new seq_record_t(currentRecord, mappingRecordLine, + std::string(ref_seq, ref_len), + currentRecord.rStartPos - head_padding, ref_len, ref_size, + std::string(query_seq, query_len), + currentRecord.qStartPos, query_len, query_size); + + // Clean up + free(ref_seq); + free(query_seq); + + return rec; +} - // core alignment computation function - void doAlignment( - std::stringstream& output, +std::string processAlignment(seq_record_t* rec) { + std::string& ref_seq = rec->refSequence; + std::string& query_seq = rec->querySequence; + + skch::CommonFunc::makeUpperCaseAndValidDNA(ref_seq.data(), ref_seq.length()); + skch::CommonFunc::makeUpperCaseAndValidDNA(query_seq.data(), query_seq.length()); + + // Adjust the reference sequence to start from the original start position + char* ref_seq_ptr = &ref_seq[rec->currentRecord.rStartPos - rec->refStartPos]; + + std::vector queryRegionStrand(query_seq.size() + 1); + + if(rec->currentRecord.strand == skch::strnd::FWD) { + std::copy(query_seq.begin(), query_seq.end(), queryRegionStrand.begin()); + } else { + skch::CommonFunc::reverseComplement(query_seq.data(), queryRegionStrand.data(), query_seq.size()); + } + + wflign::wavefront::WFlign wflign( + param.wflambda_segment_length, + param.min_identity, + param.force_biwfa_alignment, + param.wfa_mismatch_score, + param.wfa_gap_opening_score, + param.wfa_gap_extension_score, + param.wfa_patching_mismatch_score, + param.wfa_patching_gap_opening_score1, + param.wfa_patching_gap_extension_score1, + param.wfa_patching_gap_opening_score2, + param.wfa_patching_gap_extension_score2, + rec->currentRecord.mashmap_estimated_identity, + param.wflign_mismatch_score, + param.wflign_gap_opening_score, + param.wflign_gap_extension_score, + param.wflign_max_mash_dist, + param.wflign_min_wavefront_length, + param.wflign_max_distance_threshold, + param.wflign_max_len_major, + param.wflign_max_len_minor, + param.wflign_erode_k, + param.chain_gap, + param.wflign_min_inv_patch_len, + param.wflign_max_patching_score); + + std::stringstream output; + wflign.set_output( + &output, #ifdef WFA_PNG_TSV_TIMING - std::stringstream& output_tsv, - std::stringstream& patching_output_tsv, -#endif - seq_record_t* rec, - uint64_t tid) { -#ifdef DEBUG - std::cerr << "INFO, align::Aligner::doAlignment, aligning mashmap record: " << rec->mappingRecordLine << std::endl; + !param.tsvOutputPrefix.empty(), + nullptr, + param.prefix_wavefront_plot_in_png, + param.wfplot_max_size, + !param.path_patching_info_in_tsv.empty(), + nullptr, #endif + true, // merge alignments + param.emit_md_tag, + !param.sam_format, + param.no_seq_in_sam); + + wflign.wflign_affine_wavefront( + rec->currentRecord.qId, + queryRegionStrand.data(), + rec->queryTotalLength, + rec->queryStartPos, + rec->queryLen, + rec->currentRecord.strand != skch::strnd::FWD, + rec->currentRecord.refId, + ref_seq_ptr, + rec->refTotalLength, + rec->currentRecord.rStartPos, + rec->currentRecord.rEndPos - rec->currentRecord.rStartPos); + + return output.str(); +} - std::string& ref_seq = rec->refSequence; - std::string& query_seq = rec->querySequence; +void single_reader_thread(const std::string& input_file, + atomic_queue::AtomicQueue& line_queue, + std::atomic& reader_done) { + std::ifstream mappingListStream(input_file); + if (!mappingListStream.is_open()) { + throw std::runtime_error("[wfmash::align::computeAlignments] Error! Failed to open input mapping file: " + input_file); + } + + std::string line; + while (std::getline(mappingListStream, line)) { + if (!line.empty()) { + std::string* line_ptr = new std::string(std::move(line)); + line_queue.push(line_ptr); + } + } + + mappingListStream.close(); + reader_done.store(true); +} - skch::CommonFunc::makeUpperCaseAndValidDNA(ref_seq.data(), ref_seq.length()); - skch::CommonFunc::makeUpperCaseAndValidDNA(query_seq.data(), query_seq.length()); +void processor_thread(std::atomic& total_alignments_queued, + std::atomic& reader_done, + atomic_queue::AtomicQueue& line_queue, + seq_atomic_queue_t& seq_queue, + std::atomic& thread_should_exit) { + faidx_t* local_ref_faidx = fai_load(param.refSequences.front().c_str()); + faidx_t* local_query_faidx = fai_load(param.querySequences.front().c_str()); + + while (!thread_should_exit.load()) { + std::string* line_ptr = nullptr; + //std::cerr << "size of line queue " << line_queue.was_size() << std::endl; + if (line_queue.try_pop(line_ptr)) { + MappingBoundaryRow currentRecord; + parseMashmapRow(*line_ptr, currentRecord); + + // Process the record and create seq_record_t + seq_record_t* rec = createSeqRecord(currentRecord, *line_ptr, local_ref_faidx, local_query_faidx); + //std::cerr << "size of seq_queue " << seq_queue.was_size() << std::endl; + + while (!seq_queue.try_push(rec)) { + if (thread_should_exit.load()) { + delete rec; + delete line_ptr; + goto cleanup; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + ++total_alignments_queued; + delete line_ptr; + } else if (reader_done.load() && line_queue.was_empty()) { + break; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + +cleanup: + fai_destroy(local_ref_faidx); + fai_destroy(local_query_faidx); +} - // Adjust the reference sequence to start from the original start position - char* ref_seq_ptr = &ref_seq[rec->currentRecord.rStartPos - rec->refStartPos]; +void processor_manager(seq_atomic_queue_t& seq_queue, + atomic_queue::AtomicQueue& line_queue, + std::atomic& total_alignments_queued, + std::atomic& reader_done, + std::atomic& processor_done, + size_t max_processors) { + std::vector processor_threads; + std::vector> thread_should_exit(max_processors); + + const size_t queue_capacity = seq_queue.capacity(); + const size_t low_threshold = queue_capacity * 0.2; + const size_t high_threshold = queue_capacity * 0.8; + + auto spawn_processor = [&](size_t id) { + //std::cerr << "spawn_processor: " << id << std::endl; + thread_should_exit[id].store(false); + processor_threads.emplace_back([this, &total_alignments_queued, &reader_done, &line_queue, &seq_queue, &thread_should_exit, id]() { + this->processor_thread(total_alignments_queued, reader_done, line_queue, seq_queue, thread_should_exit[id]); + }); + }; + + // Start with one processor + spawn_processor(0); + size_t current_processors = 1; + + while (!reader_done.load() || !line_queue.was_empty() || !seq_queue.was_empty()) { + size_t queue_size = seq_queue.was_size(); + + //std::cerr << "queue_size: " << queue_size << std::endl; + + if (queue_size < low_threshold && current_processors < max_processors) { + //std::cerr << "spawn_processor: " << current_processors << std::endl; + spawn_processor(current_processors++); + } else if (queue_size > high_threshold && current_processors > 1) { + //std::cerr << "kill_processor: " << current_processors << std::endl; + thread_should_exit[--current_processors].store(true); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + // Signal all remaining threads to exit + for (size_t i = 0; i < current_processors; ++i) { + thread_should_exit[i].store(true); + } + + // Wait for all processor threads to finish + for (auto& thread : processor_threads) { + thread.join(); + } + processor_done.store(true); +} - char* queryRegionStrand = new char[query_seq.size() + 1]; +void worker_thread(uint64_t tid, + std::atomic& is_working, + seq_atomic_queue_t& seq_queue, + paf_atomic_queue_t& paf_queue, + std::atomic& reader_done, + progress_meter::ProgressMeter& progress, + std::atomic& processed_alignment_length) { + is_working.store(true); + while (true) { + seq_record_t* rec = nullptr; + if (seq_queue.try_pop(rec)) { + is_working.store(true); + std::string alignment_output = processAlignment(rec); + + // Push the alignment output to the paf_queue + paf_queue.push(new std::string(std::move(alignment_output))); + + // Update progress meter and processed alignment length + uint64_t alignment_length = rec->currentRecord.qEndPos - rec->currentRecord.qStartPos; + progress.increment(alignment_length); + processed_alignment_length.fetch_add(alignment_length, std::memory_order_relaxed); + + delete rec; + } else if (reader_done.load() && seq_queue.was_empty()) { + break; + } else { + is_working.store(false); + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + is_working.store(false); +} - if(rec->currentRecord.strand == skch::strnd::FWD) { - strncpy(queryRegionStrand, query_seq.data(), query_seq.size()); - } else { - skch::CommonFunc::reverseComplement(query_seq.data(), queryRegionStrand, query_seq.size()); - } +void writer_thread(const std::string& output_file, + paf_atomic_queue_t& paf_queue, + std::atomic& reader_done, + std::atomic& processor_done, + const std::vector>& worker_working) { + std::ofstream outstream(output_file); + if (!outstream.is_open()) { + throw std::runtime_error("[wfmash::align::computeAlignments] Error! Failed to open output file: " + output_file); + } + + auto all_workers_done = [&]() { + return std::all_of(worker_working.begin(), worker_working.end(), + [](const std::atomic& w) { return !w.load(); }); + }; + + while (true) { + std::string* paf_output = nullptr; + if (paf_queue.try_pop(paf_output)) { + outstream << *paf_output; + outstream.flush(); + delete paf_output; + } else if (reader_done.load() && processor_done.load() && paf_queue.was_empty() && all_workers_done()) { + break; + } else { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + + outstream.close(); +} - // To distinguish split alignment in SAM output format (currentRecord.rankMapping == 0 to avoid the suffix there is just one alignment for the query) - //const std::string query_name_suffix = param.split && param.sam_format ? "_" + std::to_string(rec->currentRecord.rankMapping) : ""; - - wflign::wavefront::WFlign* wflign = new wflign::wavefront::WFlign( - param.wflambda_segment_length, - param.min_identity, - param.force_biwfa_alignment, - param.wfa_mismatch_score, - param.wfa_gap_opening_score, - param.wfa_gap_extension_score, - param.wfa_patching_mismatch_score, - param.wfa_patching_gap_opening_score1, - param.wfa_patching_gap_extension_score1, - param.wfa_patching_gap_opening_score2, - param.wfa_patching_gap_extension_score2, - rec->currentRecord.mashmap_estimated_identity, - param.wflign_mismatch_score, - param.wflign_gap_opening_score, - param.wflign_gap_extension_score, - param.wflign_max_mash_dist, - param.wflign_min_wavefront_length, - param.wflign_max_distance_threshold, - param.wflign_max_len_major, - param.wflign_max_len_minor, - param.wflign_erode_k, - param.chain_gap, - param.wflign_min_inv_patch_len, - param.wflign_max_patching_score); - wflign->set_output( - &output, -#ifdef WFA_PNG_TSV_TIMING - !param.tsvOutputPrefix.empty(), - &output_tsv, - param.prefix_wavefront_plot_in_png, - param.wfplot_max_size, - !param.path_patching_info_in_tsv.empty(), - &patching_output_tsv, -#endif - true, // merge alignments - param.emit_md_tag, - !param.sam_format, - param.no_seq_in_sam); - wflign->wflign_affine_wavefront( - rec->currentRecord.qId,// + query_name_suffix, - queryRegionStrand, - rec->queryTotalLength, - rec->queryStartPos, - rec->queryLen, - rec->currentRecord.strand != skch::strnd::FWD, - rec->currentRecord.refId, - ref_seq_ptr, - rec->refTotalLength, - rec->currentRecord.rStartPos, - rec->currentRecord.rEndPos - rec->currentRecord.rStartPos); - delete wflign; - - delete[] queryRegionStrand; - // n.b. rec is deleted in calling context - } +void computeAlignments() { + std::atomic total_alignments_queued(0); + std::atomic reader_done(false); + std::atomic processor_done(false); + + // Create queues + atomic_queue::AtomicQueue line_queue; + seq_atomic_queue_t seq_queue; + paf_atomic_queue_t paf_queue; // Add this line + + // Calculate max_processors based on the number of worker threads + size_t max_processors = std::max(1UL, static_cast(param.threads)); + + // Calculate total alignment length + uint64_t total_alignment_length = 0; + { + std::ifstream mappingListStream(param.mashmapPafFile); + std::string mappingRecordLine; + MappingBoundaryRow currentRecord; + + while(std::getline(mappingListStream, mappingRecordLine)) { + if (!mappingRecordLine.empty()) { + parseMashmapRow(mappingRecordLine, currentRecord); + total_alignment_length += currentRecord.qEndPos - currentRecord.qStartPos; + } + } + } + + // Create progress meter + progress_meter::ProgressMeter progress(total_alignment_length, "[wfmash::align::computeAlignments] aligned"); + + // Create atomic counter for processed alignment length + std::atomic processed_alignment_length(0); + + // Start timing + auto start_time = std::chrono::high_resolution_clock::now(); + + // Launch single reader thread + std::thread single_reader([this, &line_queue, &reader_done]() { + this->single_reader_thread(param.mashmapPafFile, line_queue, reader_done); + }); + + // Launch processor manager + std::thread processor_manager_thread([this, &seq_queue, &line_queue, &total_alignments_queued, &reader_done, &processor_done, max_processors]() { + this->processor_manager(seq_queue, line_queue, total_alignments_queued, reader_done, processor_done, max_processors); + }); + + // Launch worker threads + std::vector workers; + std::vector> worker_working(param.threads); + for (uint64_t t = 0; t < param.threads; ++t) { + workers.emplace_back([this, t, &worker_working, &seq_queue, &paf_queue, &reader_done, &progress, &processed_alignment_length]() { + this->worker_thread(t, worker_working[t], seq_queue, paf_queue, reader_done, progress, processed_alignment_length); + }); + } + + // Launch writer thread + std::thread writer([this, &paf_queue, &reader_done, &processor_done, &worker_working]() { + this->writer_thread(param.pafOutputFile, paf_queue, reader_done, processor_done, worker_working); + }); + + // Wait for all threads to complete + single_reader.join(); + processor_manager_thread.join(); + for (auto& worker : workers) { + worker.join(); + } + writer.join(); + + // Stop timing + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end_time - start_time); + + // Finish progress meter + progress.finish(); + + std::cerr << "[wfmash::align::computeAlignments] " + << "total aligned records = " << total_alignments_queued.load() + << ", total aligned bp = " << processed_alignment_length.load() + << ", time taken = " << duration.count() << " seconds" << std::endl; +} + }; }