Skip to content

Commit

Permalink
Merge pull request #292 from waveygang/tweak-hits
Browse files Browse the repository at this point in the history
Tweak hits
  • Loading branch information
ekg authored Nov 11, 2024
2 parents c8985d3 + 8f18e1f commit fb42892
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 105 deletions.
59 changes: 38 additions & 21 deletions src/interface/parse_args.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,12 @@ void parse_args(int argc,
args::Flag no_filter(mapping_opts, "", "disable mapping filtering", {'f', "no-filter"});
args::Flag no_merge(mapping_opts, "", "disable merging of consecutive mappings", {'M', "no-merge"});
args::ValueFlag<double> kmer_complexity(mapping_opts, "FLOAT", "minimum k-mer complexity threshold", {'J', "kmer-cmplx"});
args::ValueFlag<std::string> hg_filter(mapping_opts, "numer,ani-Δ,conf", "hypergeometric filter params [1,0,99.9]", {"hg-filter"});
//args::Flag window_minimizers(mapping_opts, "", "Use window minimizers rather than world minimizers", {'U', "window-minimizers"});
//args::ValueFlag<std::string> path_high_frequency_kmers(mapping_opts, "FILE", " input file containing list of high frequency kmers", {'H', "high-freq-kmers"});
//args::ValueFlag<std::string> spaced_seed_params(mapping_opts, "spaced-seeds", "Params to generate spaced seeds <weight_of_seed> <number_of_seeds> <similarity> <region_length> e.g \"10 5 0.75 20\"", {'e', "spaced-seeds"});
args::ValueFlag<std::string> hg_filter(mapping_opts, "numer,ani-Δ,conf", "hypergeometric filter params [1.0,0.0,99.9]", {"hg-filter"});
args::ValueFlag<int> min_hits(mapping_opts, "INT", "minimum number of hits for L1 filtering [auto]", {'H', "l1-hits"});
args::ValueFlag<uint64_t> max_kmer_freq(mapping_opts, "INT", "maximum allowed k-mer frequency [unlimited]", {'F', "max-kmer-freq"});

args::Group alignment_opts(options_group, "Alignment:");
args::ValueFlag<std::string> input_mapping(alignment_opts, "FILE", "input PAF/SAM file for alignment", {'i', "input-mapping"});
args::ValueFlag<std::string> input_mapping(alignment_opts, "FILE", "input PAF file for alignment", {'i', "align-paf"});
args::ValueFlag<std::string> wfa_params(alignment_opts, "vals",
"scoring: mismatch, gap1(o,e), gap2(o,e) [6,6,2,26,1]", {'g', "wfa-params"});

Expand Down Expand Up @@ -151,7 +150,7 @@ void parse_args(int argc,
}

map_parameters.skip_self = false;
map_parameters.lower_triangular = args::get(lower_triangular);
map_parameters.lower_triangular = lower_triangular ? args::get(lower_triangular) : false;
map_parameters.keep_low_pct_id = true;

if (skip_prefix) {
Expand Down Expand Up @@ -521,26 +520,44 @@ void parse_args(int argc,

map_parameters.filterLengthMismatches = true;

args::Flag no_hg_filter(mapping_opts, "", "disable hypergeometric filter", {"no-hg-filter"});
map_parameters.stage1_topANI_filter = !bool(no_hg_filter);
// Parse hypergeometric filter parameters
map_parameters.stage1_topANI_filter = true;
map_parameters.stage2_full_scan = true;

args::ValueFlag<double> hg_filter_ani_diff(mapping_opts, "FLOAT", "hypergeometric filter ANI difference [0.0]", {"hg-filter-ani-diff"});
if (hg_filter_ani_diff)
{
map_parameters.ANIDiff = args::get(hg_filter_ani_diff);
map_parameters.ANIDiff /= 100;

if (hg_filter) {
std::string hg_params = args::get(hg_filter);
std::vector<std::string> params = skch::CommonFunc::split(hg_params, ',');
if (params.size() != 3) {
std::cerr << "[wfmash] ERROR: hypergeometric filter requires 3 comma-separated values: numerator,ani-diff,confidence" << std::endl;
exit(1);
}
// Parse numerator
map_parameters.hgNumerator = std::stod(params[0]);
if (map_parameters.hgNumerator < 1.0) {
std::cerr << "[wfmash] ERROR: hg-filter numerator must be >= 1.0" << std::endl;
exit(1);
}
// Parse ANI difference
map_parameters.ANIDiff = std::stod(params[1]) / 100.0;
// Parse confidence
map_parameters.ANIDiffConf = std::stod(params[2]) / 100.0;
} else {
// Use defaults
map_parameters.hgNumerator = 1.0;
map_parameters.ANIDiff = skch::fixed::ANIDiff;
map_parameters.ANIDiffConf = skch::fixed::ANIDiffConf;
}

args::ValueFlag<double> hg_filter_conf(mapping_opts, "FLOAT", "hypergeometric filter confidence [99.9]", {"hg-filter-conf"});
if (hg_filter_conf)
{
map_parameters.ANIDiffConf = args::get(hg_filter_conf);
map_parameters.ANIDiffConf /= 100;
if (min_hits) {
map_parameters.minimum_hits = args::get(min_hits);
} else {
map_parameters.ANIDiffConf = skch::fixed::ANIDiffConf;
map_parameters.minimum_hits = -1; // auto
}

if (max_kmer_freq) {
map_parameters.max_kmer_freq = args::get(max_kmer_freq);
} else {
map_parameters.max_kmer_freq = std::numeric_limits<uint64_t>::max(); // unlimited
}

//if (window_minimizers) {
Expand Down Expand Up @@ -590,7 +607,7 @@ void parse_args(int argc,
}

if (input_mapping) {
// directly use the input mapping file
// directly use the input PAF file
yeet_parameters.remapping = true;
map_parameters.outFileName = args::get(input_mapping);
align_parameters.mashmapPafFile = args::get(input_mapping);
Expand Down
19 changes: 16 additions & 3 deletions src/map/include/computeMap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ namespace skch
return a.intersectionSize < b.intersectionSize;
};

//Type for Stage L2's predicted mapping coordinate within each L1 candidate
//Cache for commonly used values
offset_t cached_segment_length;
int cached_minimum_hits;

//Type for Stage L2's predicted mapping coordinate within each L1 candidate
struct L2_mapLocus_t
{
seqno_t seqId; //sequence id where read is mapped
Expand Down Expand Up @@ -203,7 +207,9 @@ namespace skch
std::vector<std::string>{p.target_prefix},
std::string(1, p.prefix_delim),
p.query_list,
p.target_list))
p.target_list)),
cached_segment_length(p.segLength),
cached_minimum_hits(p.minimum_hits > 0 ? p.minimum_hits : Stat::estimateMinimumHitsRelaxed(p.sketchSize, p.kmerSize, p.percentageIdentity, skch::fixed::confidence_interval))
{
// Initialize sequence names right after creating idManager
this->querySequenceNames = idManager->getQuerySequenceNames();
Expand Down Expand Up @@ -481,6 +487,8 @@ namespace skch

void mapQuery()
{
std::cerr << "[wfmash::mashmap] L1 filtering parameters: cached_minimum_hits=" << cached_minimum_hits << std::endl;

//Count of reads mapped by us
//Some reads are dropped because of short length
seqno_t totalReadsPickedForMapping = 0;
Expand Down Expand Up @@ -1444,7 +1452,12 @@ namespace skch
getSeedIntervalPoints(Q, intervalPoints);

//3. Compute L1 windows
int minimumHits = Stat::estimateMinimumHitsRelaxed(Q.sketchSize, param.kmerSize, param.percentageIdentity, skch::fixed::confidence_interval);
// Always respect the minimum hits parameter if set
int minimumHits = param.minimum_hits > 0 ?
param.minimum_hits :
(Q.len == cached_segment_length ?
cached_minimum_hits :
Stat::estimateMinimumHitsRelaxed(Q.sketchSize, param.kmerSize, param.percentageIdentity, skch::fixed::confidence_interval));

// For each "group"
auto ip_begin = intervalPoints.begin();
Expand Down
2 changes: 2 additions & 0 deletions src/map/include/map_parameters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ struct Parameters
bool legacy_output;
//std::unordered_set<std::string> high_freq_kmers; //
int64_t index_by_size = std::numeric_limits<int64_t>::max(); // Target total size of sequences for each index subset
int minimum_hits = -1; // Minimum number of hits required for L1 filtering (-1 means auto)
uint64_t max_kmer_freq = std::numeric_limits<uint64_t>::max(); // Maximum allowed k-mer frequency
};


Expand Down
156 changes: 75 additions & 81 deletions src/map/include/winSketch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,6 @@ namespace skch
total_seq_length += idManager.getSequenceLength(seqId);
}

// Initialize progress meter with known total
progress_meter::ProgressMeter progress(
total_seq_length,
"[wfmash::mashmap] computing sketch");

// First progress meter for sketch computation
progress_meter::ProgressMeter sketch_progress(
total_seq_length,
Expand Down Expand Up @@ -236,14 +231,63 @@ namespace skch
total_windows,
"[wfmash::mashmap] building index");

// Build index in parallel
buildIndexInParallel(threadOutputs, index_progress, param.threads);
// First pass - count k-mer frequencies
HF_Map_t kmer_freqs;
for (auto* output : threadOutputs) {
for (const MinmerInfo& mi : *output) {
kmer_freqs[mi.hash]++;
}
}

// Second pass - build filtered indexes
uint64_t total_kmers = 0;
uint64_t filtered_kmers = 0;

// Clear existing indexes
minmerPosLookupIndex.clear();
minmerIndex.clear();

for (auto* output : threadOutputs) {
for (const MinmerInfo& mi : *output) {
total_kmers++;

auto freq_it = kmer_freqs.find(mi.hash);
if (freq_it == kmer_freqs.end()) {
continue; // Should never happen
}

if (freq_it->second > param.max_kmer_freq) {
filtered_kmers++;
continue;
}

// Add to position lookup index
auto& pos_list = minmerPosLookupIndex[mi.hash];
if (pos_list.size() == 0
|| pos_list.back().hash != mi.hash
|| pos_list.back().pos != mi.wpos) {
pos_list.push_back(IntervalPoint {mi.wpos, mi.hash, mi.seqId, side::OPEN});
pos_list.push_back(IntervalPoint {mi.wpos_end, mi.hash, mi.seqId, side::CLOSE});
} else {
pos_list.back().pos = mi.wpos_end;
}

// Add to minmer index
minmerIndex.push_back(mi);
index_progress.increment(1);
}
delete output;
}

// Finish second progress meter
index_progress.finish();

double filtered_pct = (filtered_kmers * 100.0) / total_kmers;
std::cerr << "[wfmash::mashmap] Processed " << totalSeqProcessed << " sequences (" << totalSeqSkipped << " skipped, " << total_seq_length << " total bp), "
<< minmerPosLookupIndex.size() << " unique hashes, " << minmerIndex.size() << " windows" << std::endl;
<< minmerPosLookupIndex.size() << " unique hashes, " << minmerIndex.size() << " windows" << std::endl
<< "[wfmash::mashmap] Filtered " << filtered_kmers << "/" << total_kmers
<< " k-mers (" << std::fixed << std::setprecision(2) << filtered_pct << "%) exceeding frequency threshold of "
<< param.max_kmer_freq << std::endl;
}

std::chrono::duration<double> timeRefSketch = skch::Time::now() - t0;
Expand Down Expand Up @@ -288,83 +332,24 @@ namespace skch
* @brief routine to handle thread's local minmer index
* @param[in] output thread local minmer output
*/
/**
* @brief Build the index from thread outputs in parallel
* @param[in] threadOutputs Vector of thread-local minmer indices
* @param[in] progress Progress meter for tracking
* @param[in] num_threads Number of threads to use
*/
void buildIndexInParallel(std::vector<MI_Type*>& threadOutputs,
progress_meter::ProgressMeter& progress,
size_t num_threads) {
// Split the thread outputs into chunks for parallel processing
std::vector<std::vector<MI_Type*>> chunks(num_threads);
for (size_t i = 0; i < threadOutputs.size(); ++i) {
chunks[i % num_threads].push_back(threadOutputs[i]);
}

// Create threads to process chunks
std::vector<std::thread> threads;
std::mutex index_mutex; // For thread-safe index updates

for (size_t i = 0; i < num_threads; ++i) {
threads.emplace_back([this, &chunks, i, &progress, &index_mutex]() {
MI_Map_t local_index; // Thread-local index

// Process all outputs in this chunk
for (auto* output : chunks[i]) {
for (MinmerInfo& mi : *output) {
if (local_index[mi.hash].size() == 0
|| local_index[mi.hash].back().hash != mi.hash
|| local_index[mi.hash].back().pos != mi.wpos) {
local_index[mi.hash].push_back(IntervalPoint {mi.wpos, mi.hash, mi.seqId, side::OPEN});
local_index[mi.hash].push_back(IntervalPoint {mi.wpos_end, mi.hash, mi.seqId, side::CLOSE});
} else {
local_index[mi.hash].back().pos = mi.wpos_end;
}
progress.increment(1);
}
}

// Merge thread-local index into global index
{
std::lock_guard<std::mutex> lock(index_mutex);
for (auto& [hash, points] : local_index) {
auto& global_points = minmerPosLookupIndex[hash];
global_points.insert(
global_points.end(),
std::make_move_iterator(points.begin()),
std::make_move_iterator(points.end())
);
}
}

// Insert minmers into global minmerIndex
{
std::lock_guard<std::mutex> lock(index_mutex);
for (auto* output : chunks[i]) {
minmerIndex.insert(
minmerIndex.end(),
std::make_move_iterator(output->begin()),
std::make_move_iterator(output->end())
);
delete output;
}
}
});
}

// Wait for all threads to complete
for (auto& thread : threads) {
thread.join();
}
}

void buildHandleThreadOutput(MI_Type* contigMinmerIndex)
{
// Count k-mer frequencies first
HF_Map_t kmer_freqs;
for (const auto& mi : *contigMinmerIndex) {
kmer_freqs[mi.hash]++;
}

// This function is kept for compatibility but should not be used
// when parallel index building is enabled
for (MinmerInfo& mi : *contigMinmerIndex) {
// Skip high-frequency k-mers
auto freq_it = kmer_freqs.find(mi.hash);
if (freq_it != kmer_freqs.end() && freq_it->second > param.max_kmer_freq) {
continue;
}

if (minmerPosLookupIndex[mi.hash].size() == 0
|| minmerPosLookupIndex[mi.hash].back().hash != mi.hash
|| minmerPosLookupIndex[mi.hash].back().pos != mi.wpos) {
Expand All @@ -375,10 +360,19 @@ namespace skch
}
}

// Only add k-mers that aren't too frequent
MI_Type filtered_minmers;
for (const auto& mi : *contigMinmerIndex) {
auto freq_it = kmer_freqs.find(mi.hash);
if (freq_it == kmer_freqs.end() || freq_it->second <= param.max_kmer_freq) {
filtered_minmers.push_back(mi);
}
}

this->minmerIndex.insert(
this->minmerIndex.end(),
std::make_move_iterator(contigMinmerIndex->begin()),
std::make_move_iterator(contigMinmerIndex->end()));
std::make_move_iterator(filtered_minmers.begin()),
std::make_move_iterator(filtered_minmers.end()));

delete contigMinmerIndex;
}
Expand Down

0 comments on commit fb42892

Please sign in to comment.