Skip to content

Commit

Permalink
Merge branch 'DOR-826-cherry-pick-dorado-correct-fixes-for-v0.7' into…
Browse files Browse the repository at this point in the history
… 'release-v0.7'

Cherry-pick Dorado Correct fixes to release-v0.7

See merge request machine-learning/dorado!1138
  • Loading branch information
svc-jstone committed Jul 31, 2024
2 parents 37d316c + 550385c commit 5dc78ab
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 163 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
*.qdrep
dist
build
build-Release
build-Debug
build-*
archive
cmake-build*
cmake.lock
CMakeSettings.json
_CPack_Packages
tests/test_output
.temp_dorado_model-*
temp
venv
4 changes: 4 additions & 0 deletions dorado/alignment/Minimap2Index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,10 @@ const mm_mapopt_t& Minimap2Index::mapping_options() const {
return *m_mapping_options;
}

mm_mapopt_t& Minimap2Index::mapping_options() {
return const_cast<mm_mapopt_t&>(const_cast<const Minimap2Index&>(*this).mapping_options());
}

const Minimap2Options& Minimap2Index::get_options() const { return m_options; }

} // namespace dorado::alignment
1 change: 1 addition & 0 deletions dorado/alignment/Minimap2Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Minimap2Index {
const mm_idx_t* index() const { return m_index.get(); }
const mm_idxopt_t& index_options() const;
const mm_mapopt_t& mapping_options() const;
mm_mapopt_t& mapping_options();

HeaderSequenceRecords get_sequence_records_for_header() const;

Expand Down
11 changes: 11 additions & 0 deletions dorado/cli/correct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,21 @@ int correct(int argc, char* argv[]) {
std::exit(EXIT_FAILURE);
}

if (!std::filesystem::exists(reads.front())) {
spdlog::error("Input reads file {} does not exist!", reads.front());
std::exit(EXIT_FAILURE);
}

std::filesystem::path model_dir;
bool remove_tmp_dir = false;
if (parser.is_used("--model-path")) {
model_dir = std::filesystem::path(parser.get<std::string>("model-path"));

if (!std::filesystem::exists(model_dir)) {
spdlog::error("Input model path {} does not exist!", model_dir.string());
std::exit(EXIT_FAILURE);
}

} else {
// Download model
auto tmp_dir = utils::get_downloads_path(std::nullopt);
Expand Down
148 changes: 64 additions & 84 deletions dorado/correct/features.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
#include "features.h"

#include "conversions.h"
#include "correct/types.h"
#include "read_pipeline/messages.h"
#include "utils/sequence_utils.h"
#include "utils/types.h"

#include <ATen/Tensor.h>
#include <spdlog/spdlog.h>
#include <torch/types.h>

#include <cstdint>

#ifdef NDEBUG
#define LOG_TRACE(...)
#else
Expand All @@ -19,61 +23,27 @@ const int TOP_K = 30;
namespace dorado::correction {

bool overlap_has_long_indel(const OverlapWindow& overlap, const CorrectionAlignments& alignments) {
bool long_indel = false;
const auto& cigar = alignments.cigars[overlap.overlap_idx];
size_t max_cigar_idx = std::min(size_t(overlap.cigar_end_idx + 1), cigar.size());
for (size_t i = overlap.cigar_start_idx; i < max_cigar_idx; i++) {
if (cigar[i].op == CigarOpType::INS || cigar[i].op == CigarOpType::DEL) {
long_indel |= cigar[i].len >= 30;
if (cigar[i].len >= 30 &&
(cigar[i].op == CigarOpType::INS || cigar[i].op == CigarOpType::DEL)) {
LOG_TRACE("filter ? tstart {} qstart {} qend {} long indel: {}", overlap.tstart,
overlap.qstart, overlap.qend, cigar[i].len);
return true;
}
}
LOG_TRACE("filter ? tstart {} qstart {} qend {} res {}", overlap.tstart, overlap.qstart,
overlap.qend, long_indel);
return long_indel;
return false;
}

// Measure the accuracy of an alignment segment within a window
// determined from the cigar string.
void calculate_accuracy(OverlapWindow& overlap,
const CorrectionAlignments& alignments,
size_t win_idx,
int win_len,
int window_size) {
int tstart = overlap.tstart;
int tend = (int)win_idx * window_size + win_len;

// get query region
const auto overlap_idx = overlap.overlap_idx;
int oqstart = alignments.overlaps[overlap_idx].qstart;
int oqend = alignments.overlaps[overlap_idx].qend;
int qstart, qend;
if (alignments.overlaps[overlap_idx].fwd) {
qstart = oqstart + overlap.qstart;
qend = oqstart + overlap.qend;
} else {
qstart = oqend - overlap.qend;
qend = oqend - overlap.qstart;
}

int qlen = qend - qstart;

// Fetch subsequences
std::string tseq = alignments.read_seq.substr(tstart, tend - tstart);
std::string qseq;
if (alignments.overlaps[overlap_idx].fwd) {
qseq = alignments.seqs[overlap_idx].substr(qstart, qlen);
} else {
qseq = utils::reverse_complement(alignments.seqs[overlap_idx].substr(qstart, qlen));
}

LOG_TRACE("tstart {} tend {} qstart {} qend {} cig st {} cig end {}", tstart, tend, qstart,
qend, overlap.cigar_start_idx, overlap.cigar_end_idx);

void calculate_accuracy(OverlapWindow& overlap, const CorrectionAlignments& alignments) {
const auto& cigar = alignments.cigars[overlap.overlap_idx];
bool has_warned_bad_cigar_op = false;

// Calculate accuracy
int tpos = 0, qpos = 0;
int m = 0, s = 0, i = 0, d = 0;
// counts of match, mismatch, insert, deletion
int n_match = 0, n_miss = 0, n_ins = 0, n_del = 0;

for (int idx = overlap.cigar_start_idx; idx <= overlap.cigar_end_idx; idx++) {
int len = -1;
Expand All @@ -91,40 +61,30 @@ void calculate_accuracy(OverlapWindow& overlap,
break;
}

LOG_TRACE("len {} tpos {} qpos {}", len, tpos, qpos);

switch (cigar[idx].op) {
case CigarOpType::MATCH:
for (int j = 0; j < len; j++) {
auto tbase = tseq[tpos + j];
auto qbase = qseq[qpos + j];
LOG_TRACE("{} tbase {}, {} qbase {}", tpos + j, tbase, qpos + j, qbase);

if (tbase == qbase) {
m += 1;
} else {
s += 1;
}
}

tpos += len;
qpos += len;
case CigarOpType::EQ_MATCH:
n_match += len;
break;
case CigarOpType::X_MISMATCH:
n_miss += len;
break;
case CigarOpType::INS:
i += len;
qpos += len;
n_ins += len;
break;
case CigarOpType::DEL:
d += len;
tpos += len;
n_del += len;
break;
default:
if (!has_warned_bad_cigar_op) {
has_warned_bad_cigar_op = true;
LOG_TRACE("unexpected CigarOpType: {}", uint8_t(cigar[idx].op));
}
break;
}
}

overlap.accuracy = (static_cast<float>(m) / (m + s + i + d));
LOG_TRACE("m {} s {} i {} d {}", m, s, i, d);
overlap.accuracy = (static_cast<float>(n_match) / (n_match + n_miss + n_ins + n_del));
LOG_TRACE("m {} s {} i {} d {}", n_match, n_miss, n_ins, n_del);
LOG_TRACE("accuracy qstart {} qend {} {}", overlap.qstart, overlap.qend, overlap.accuracy);
}

Expand All @@ -148,7 +108,8 @@ std::vector<int> get_max_ins_for_window(const std::vector<OverlapWindow>& overla
int len = cigar[i].len;

int l = -1;
if (op == CigarOpType::MATCH || op == CigarOpType::DEL) {
if (op == CigarOpType::EQ_MATCH || op == CigarOpType::X_MISMATCH ||
op == CigarOpType::DEL) {
l = len;
} else if (op == CigarOpType::INS) {
max_ins[tpos - 1] = std::max(len, max_ins[tpos - 1]);
Expand Down Expand Up @@ -248,8 +209,10 @@ std::tuple<at::Tensor, at::Tensor> get_features_for_window(
qseq = utils::reverse_complement(qseq);
std::reverse(qqual.begin(), qqual.end());
}
int cigar_len = overlap.cigar_end_idx - overlap.cigar_start_idx + 1;
int cigar_end = std::min((int)cigar.size(), cigar_len);

const int cigar_len_total = static_cast<int>(std::size(cigar));
const int cigar_len = overlap.cigar_end_idx - overlap.cigar_start_idx + 1;
const int cigar_end = std::min(cigar_len_total - overlap.cigar_start_idx, cigar_len);

uint8_t gap = fwd ? '*' : '#';

Expand Down Expand Up @@ -281,8 +244,8 @@ std::tuple<at::Tensor, at::Tensor> get_features_for_window(
LOG_TRACE("cigar_idx {} l {}", cigar_idx, l);

switch (op) {
case CigarOpType::MATCH:
case CigarOpType::MISMATCH:
case CigarOpType::EQ_MATCH:
case CigarOpType::X_MISMATCH:
for (uint32_t i = 0; i < l; i++) {
auto base = base_encoding[uint8_t(qseq[query_iter]) + (fwd ? 0 : 32)];
auto qual = qqual[query_iter];
Expand Down Expand Up @@ -412,18 +375,10 @@ at::Tensor get_indices(const at::Tensor& bases, const std::vector<std::pair<int,
.clone();
}

// Main interface function for generating features for each window
// given the overlaps for a target read.
std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWindow>>& windows,
const CorrectionAlignments& alignments,
int window_size) {
const std::string& tseq = alignments.read_seq;
int tlen = (int)tseq.length();

std::vector<WindowFeatures> wfs;
std::unordered_set<int> filter_features(std::vector<std::vector<OverlapWindow>>& windows,
const CorrectionAlignments& alignments) {
std::unordered_set<int> overlap_idxs;
for (int w = 0; w < (int)windows.size(); w++) {
int win_len = (w == (int)windows.size() - 1) ? tlen - window_size * w : window_size;
LOG_TRACE("win idx {}: win len {}", w, win_len);
auto& overlap_windows = windows[w];

// Filter overlaps with very large indels
Expand All @@ -440,7 +395,7 @@ std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWind
// Sort overlaps by score
if (overlap_windows.size() > 1) {
for (auto& ovlp : overlap_windows) {
calculate_accuracy(ovlp, alignments, w, win_len, window_size);
calculate_accuracy(ovlp, alignments);
}
// Sort the filtered overlaps by accuracy score
std::sort(overlap_windows.begin(), overlap_windows.end(),
Expand All @@ -454,14 +409,39 @@ std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWind
return a.accuracy > b.accuracy;
});
}
// Take the TOP_K best overlaps
overlap_windows.resize(std::min(TOP_K, (int)overlap_windows.size()));

#ifndef NDEBUG
if (overlap_windows.size() == 1) {
LOG_TRACE("window {} 1st {}-{}", w, overlap_windows[0].qstart, overlap_windows[0].qend);
} else if (overlap_windows.size() > 1) {
LOG_TRACE("window {} 1st {}-{} 2nd {}-{}", w, overlap_windows[0].qstart,
overlap_windows[0].qend, overlap_windows[1].qstart, overlap_windows[1].qend);
}
#endif

for (const auto& ov : overlap_windows) {
assert(ov.overlap_idx >= 0);
overlap_idxs.insert(ov.overlap_idx);
}
}
return overlap_idxs;
}

// Main interface function for generating features for the top_k overlaps for each window
// given the overlaps for a target read.
std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWindow>>& windows,
const CorrectionAlignments& alignments,
int window_size) {
const std::string& tseq = alignments.read_seq;
int tlen = (int)tseq.length();

std::vector<WindowFeatures> wfs;
for (int w = 0; w < (int)windows.size(); w++) {
int win_len = (w == (int)windows.size() - 1) ? tlen - window_size * w : window_size;
LOG_TRACE("win idx {}: win len {}", w, win_len);
auto& overlap_windows = windows[w];

WindowFeatures wf;
wf.window_idx = w;
Expand Down
6 changes: 6 additions & 0 deletions dorado/correct/features.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

#include "types.h"

#include <unordered_set>

namespace dorado {
struct CorrectionAlignments;
}

namespace dorado::correction {

// Filter window features to TOP_K best. Returns collection of useful overlap indices
std::unordered_set<int> filter_features(std::vector<std::vector<OverlapWindow>>& windows,
const CorrectionAlignments& alignments);

std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWindow>>& windows,
const CorrectionAlignments& alignments,
int window_size);
Expand Down
1 change: 1 addition & 0 deletions dorado/correct/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
namespace dorado::correction {

struct OverlapWindow {
// CorrectionAlignments overlap vector index
int overlap_idx = -1;
int tstart = -1;
int qstart = -1;
Expand Down
Loading

0 comments on commit 5dc78ab

Please sign in to comment.