Skip to content

Commit

Permalink
Merge pull request #297 from waveygang/index-buildy
Browse files Browse the repository at this point in the history
Parallelize k-mer frequency counting and index building
  • Loading branch information
ekg authored Nov 20, 2024
2 parents 0243f62 + 753fe63 commit da81070
Showing 1 changed file with 111 additions and 54 deletions.
165 changes: 111 additions & 54 deletions src/map/include/winSketch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,72 +231,129 @@ namespace skch
total_windows,
"[wfmash::mashmap] building index");

// First pass - count k-mer frequencies
// Parallel k-mer frequency counting
std::vector<HF_Map_t> thread_kmer_freqs(param.threads);
std::vector<std::thread> freq_threads;

// Split outputs into chunks for parallel processing
size_t chunk_size = (threadOutputs.size() + param.threads - 1) / param.threads;

for (size_t t = 0; t < param.threads; ++t) {
freq_threads.emplace_back([&, t]() {
size_t start = t * chunk_size;
size_t end = std::min(start + chunk_size, threadOutputs.size());

for (size_t i = start; i < end; ++i) {
for (const MinmerInfo& mi : *threadOutputs[i]) {
thread_kmer_freqs[t][mi.hash]++;
}
}
});
}

for (auto& thread : freq_threads) {
thread.join();
}

// Merge frequency maps
HF_Map_t kmer_freqs;
for (auto* output : threadOutputs) {
for (const MinmerInfo& mi : *output) {
kmer_freqs[mi.hash]++;
for (const auto& thread_freq : thread_kmer_freqs) {
for (const auto& [hash, freq] : thread_freq) {
kmer_freqs[hash] += freq;
}
}

// Second pass - build filtered indexes
uint64_t total_kmers = 0;
uint64_t filtered_kmers = 0;
// Parallel index building
std::vector<MI_Map_t> thread_pos_indexes(param.threads);
std::vector<MI_Type> thread_minmer_indexes(param.threads);
std::vector<uint64_t> thread_total_kmers(param.threads, 0);
std::vector<uint64_t> thread_filtered_kmers(param.threads, 0);
std::vector<std::thread> index_threads;

for (size_t t = 0; t < param.threads; ++t) {
index_threads.emplace_back([&, t]() {
size_t start = t * chunk_size;
size_t end = std::min(start + chunk_size, threadOutputs.size());

for (size_t i = start; i < end; ++i) {
for (const MinmerInfo& mi : *threadOutputs[i]) {
thread_total_kmers[t]++;

auto freq_it = kmer_freqs.find(mi.hash);
if (freq_it == kmer_freqs.end()) {
continue; // Should never happen
}

// Clear existing indexes
minmerPosLookupIndex.clear();
minmerIndex.clear();
uint64_t freq = freq_it->second;
uint64_t min_occ = 10;
uint64_t max_occ = std::numeric_limits<uint64_t>::max();
uint64_t count_threshold;

if (param.max_kmer_freq <= 1.0) {
count_threshold = std::min(max_occ,
std::max(min_occ,
(uint64_t)(total_windows * param.max_kmer_freq)));
} else {
count_threshold = std::min(max_occ,
std::max(min_occ,
(uint64_t)param.max_kmer_freq));
}

for (auto* output : threadOutputs) {
for (const MinmerInfo& mi : *output) {
total_kmers++;

auto freq_it = kmer_freqs.find(mi.hash);
if (freq_it == kmer_freqs.end()) {
continue; // Should never happen
}
if (freq > count_threshold && freq > min_occ) {
thread_filtered_kmers[t]++;
continue;
}

uint64_t freq = freq_it->second;
uint64_t min_occ = 10; // minimum occurrence threshold to prevent over-filtering in small datasets
uint64_t max_occ = std::numeric_limits<uint64_t>::max(); // no upper limit on occurrences
uint64_t count_threshold;

if (param.max_kmer_freq <= 1.0) {
// Calculate threshold based on fraction, but respect min/max bounds
count_threshold = std::min(max_occ,
std::max(min_occ,
(uint64_t)(total_windows * param.max_kmer_freq)));
} else {
// Use direct count threshold, but respect min/max bounds
count_threshold = std::min(max_occ,
std::max(min_occ,
(uint64_t)param.max_kmer_freq));
}
auto& pos_list = thread_pos_indexes[t][mi.hash];
if (pos_list.size() == 0
|| pos_list.back().hash != mi.hash
|| pos_list.back().pos != mi.wpos) {
pos_list.push_back(IntervalPoint {mi.wpos, mi.hash, mi.seqId, side::OPEN});
pos_list.push_back(IntervalPoint {mi.wpos_end, mi.hash, mi.seqId, side::CLOSE});
} else {
pos_list.back().pos = mi.wpos_end;
}

// Filter only if BOTH conditions are met:
// 1. Frequency exceeds the calculated threshold
// 2. Count exceeds minimum occurrence threshold
if (freq > count_threshold && freq > min_occ) {
filtered_kmers++;
continue;
thread_minmer_indexes[t].push_back(mi);
index_progress.increment(1);
}
delete threadOutputs[i];
}
});
}

// Add to position lookup index
auto& pos_list = minmerPosLookupIndex[mi.hash];
if (pos_list.size() == 0
|| pos_list.back().hash != mi.hash
|| pos_list.back().pos != mi.wpos) {
pos_list.push_back(IntervalPoint {mi.wpos, mi.hash, mi.seqId, side::OPEN});
pos_list.push_back(IntervalPoint {mi.wpos_end, mi.hash, mi.seqId, side::CLOSE});
} else {
pos_list.back().pos = mi.wpos_end;
}
for (auto& thread : index_threads) {
thread.join();
}

// Merge results
uint64_t total_kmers = std::accumulate(thread_total_kmers.begin(), thread_total_kmers.end(), 0ULL);
uint64_t filtered_kmers = std::accumulate(thread_filtered_kmers.begin(), thread_filtered_kmers.end(), 0ULL);

// Add to minmer index
minmerIndex.push_back(mi);
index_progress.increment(1);
// Clear and resize main indexes
minmerPosLookupIndex.clear();
minmerIndex.clear();

// Reserve approximate space
size_t total_minmers = 0;
for (const auto& thread_index : thread_minmer_indexes) {
total_minmers += thread_index.size();
}
minmerIndex.reserve(total_minmers);

// Merge position lookup indexes
for (auto& thread_pos_index : thread_pos_indexes) {
for (auto& [hash, pos_list] : thread_pos_index) {
auto& main_pos_list = minmerPosLookupIndex[hash];
main_pos_list.insert(main_pos_list.end(), pos_list.begin(), pos_list.end());
}
delete output;
}

// Merge minmer indexes
for (auto& thread_index : thread_minmer_indexes) {
minmerIndex.insert(minmerIndex.end(),
std::make_move_iterator(thread_index.begin()),
std::make_move_iterator(thread_index.end()));
}

// Finish second progress meter
Expand Down

0 comments on commit da81070

Please sign in to comment.