diff --git a/src/interface/parse_args.hpp b/src/interface/parse_args.hpp index 29c97c9b..f98d08f1 100644 --- a/src/interface/parse_args.hpp +++ b/src/interface/parse_args.hpp @@ -86,6 +86,7 @@ void parse_args(int argc, args::Flag drop_low_map_pct_identity(mapping_opts, "K", "drop mappings with estimated identity below --map-pct-id=%", {'K', "drop-low-map-id"}); args::Flag no_filter(mapping_opts, "MODE", "disable mapping filtering", {'f', "no-filter"}); args::ValueFlag map_sparsification(mapping_opts, "FACTOR", "keep this fraction of mappings", {'x', "sparsify-mappings"}); + args::Flag keep_same_score(mapping_opts, "", "keep all mappings with equal score even if it results in more than n mappings", {'D', "keep-same-score"}); args::ValueFlag sketch_size(mapping_opts, "N", "sketch size for sketching.", {'w', "sketch-size"}); args::ValueFlag kmer_complexity(mapping_opts, "F", "Drop segments w/ predicted kmer complexity below this cutoff. Kmer complexity defined as #kmers / (s - k + 1)", {'J', "kmer-complexity"}); args::Flag no_hg_filter(mapping_opts, "", "Don't use the hypergeometric filtering and instead use the MashMap2 first pass filtering.", {'1', "no-hg-filter"}); @@ -351,6 +352,7 @@ void parse_args(int argc, align_parameters.sam_format = args::get(sam_format); align_parameters.no_seq_in_sam = args::get(no_seq_in_sam); map_parameters.split = !args::get(no_split); + map_parameters.dropRand = !args::get(keep_same_score); align_parameters.split = !args::get(no_split); map_parameters.mergeMappings = !args::get(no_merge); diff --git a/src/map/include/computeMap.hpp b/src/map/include/computeMap.hpp index f9059d7a..6a3f98a2 100644 --- a/src/map/include/computeMap.hpp +++ b/src/map/include/computeMap.hpp @@ -498,7 +498,7 @@ namespace skch std::move(subrange_begin, subrange_end, tmpMappings.begin()); std::sort(tmpMappings.begin(), tmpMappings.end(), [](const auto& a, const auto& b) { return std::tie(a.queryStartPos, a.refSeqId, a.refStartPos) < std::tie(b.queryStartPos, b.refSeqId, b.refStartPos); }); - skch::Filter::query::filterMappings(tmpMappings, n_mappings); + skch::Filter::query::filterMappings(tmpMappings, n_mappings, param.dropRand); std::move(tmpMappings.begin(), tmpMappings.end(), std::back_inserter(filteredMappings)); subrange_begin = subrange_end; } diff --git a/src/map/include/filter.hpp b/src/map/include/filter.hpp index 67b34da6..e7cadd32 100644 --- a/src/map/include/filter.hpp +++ b/src/map/include/filter.hpp @@ -74,7 +74,7 @@ namespace skch * @param[in/out] L container with mappings */ template - inline void markGood(Type &L, int secondaryToKeep) + inline void markGood(Type &L, int secondaryToKeep, bool dropRand) { //first segment in the set order auto beg = L.begin(); @@ -91,6 +91,38 @@ namespace skch vec[*it].discard = 0; ++kept; } + + // check for the case where there are multiple best mappings > secondaryToKeep + // which have the same score + // we will hash the mapping struct and keep the one with the secondaryToKeep with the lowest hash value + if (kept > secondaryToKeep && dropRand) + { + // we will use hashes of the mapping structs to break ties + // first we'll make a vector of the mappings including the hashes + std::vector> score_and_hash; // The tuple is (score, hash, pointer to the mapping) + for(auto it = L.begin(); it != L.end(); it++) + { + if(vec[*it].discard == 0) + { + score_and_hash.emplace_back(get_score(*it), vec[*it].hash(), &vec[*it]); + } + } + // now we'll sort the vector by score and hash + std::sort(score_and_hash.begin(), score_and_hash.end(), std::greater{}); + // reset kept counter + kept = 0; + for (auto& x : score_and_hash) { + std::get<2>(x)->discard = 1; + } + // now we mark the best to keep + for (auto& x : score_and_hash) { + if (kept > secondaryToKeep) { + break; + } + std::get<2>(x)->discard = 0; + ++kept; + } + } } }; @@ -100,7 +132,7 @@ namespace skch * @param[in/out] readMappings Mappings computed by Mashmap */ template - void liFilterAlgorithm(VecIn &readMappings, int secondaryToKeep) + void liFilterAlgorithm(VecIn &readMappings, int secondaryToKeep, bool dropRand) { if(readMappings.size() <= 1) return; @@ -148,7 +180,7 @@ namespace skch }); //mark mappings as good - obj.markGood(bst, secondaryToKeep); + obj.markGood(bst, secondaryToKeep, dropRand); it = it2; } @@ -218,14 +250,17 @@ namespace skch } /** - * @brief filter mappings (best for query sequence) - * @param[in/out] readMappings Mappings computed by Mashmap (post merge step) + * @brief filter mappings (best for query sequence) + * @param[in/out] readMappings Mappings computed by Mashmap (post merge step) + * @param[in] secondaryToKeep How many mappings in addition to the best to keep + * @param[in] dropRand If multiple mappings have the same score, drop randomly + * until we only have secondaryToKeep secondary mappings */ template - void filterMappings(VecIn &readMappings, uint16_t secondaryToKeep) + void filterMappings(VecIn &readMappings, uint16_t secondaryToKeep, bool dropRand) { //Apply the main filtering algorithm to ensure the best mappings across complete axis - liFilterAlgorithm(readMappings, secondaryToKeep); + liFilterAlgorithm(readMappings, secondaryToKeep, dropRand); } /** diff --git a/src/map/include/map_parameters.hpp b/src/map/include/map_parameters.hpp index 2034fe36..a8da7f6a 100644 --- a/src/map/include/map_parameters.hpp +++ b/src/map/include/map_parameters.hpp @@ -47,6 +47,7 @@ struct Parameters int filterMode; //filtering mode in mashmap uint32_t numMappingsForSegment; //how many mappings to retain for each segment uint32_t numMappingsForShortSequence; //how many secondary alignments we keep for reads < segLength + bool dropRand; //drop mappings w/ same score until only numMappingsForSegment remain int threads; //execution thread count std::vector refSequences; //reference sequence(s) std::vector querySequences; //query sequence(s)