Skip to content

Commit

Permalink
DOR-740 DOR-774 Dorado Correct Performance
Browse files Browse the repository at this point in the history
  • Loading branch information
HalfPhoton authored and John Stone committed Jul 31, 2024
1 parent 01c39db commit 6178d43
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 170 deletions.
142 changes: 60 additions & 82 deletions dorado/correct/features.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
#include "features.h"

#include "conversions.h"
#include "correct/types.h"
#include "read_pipeline/messages.h"
#include "utils/sequence_utils.h"
#include "utils/types.h"

#include <ATen/Tensor.h>
#include <spdlog/spdlog.h>
#include <torch/types.h>

#include <cstdint>

#ifdef NDEBUG
#define LOG_TRACE(...)
#else
Expand All @@ -19,61 +23,27 @@ const int TOP_K = 30;
namespace dorado::correction {

bool overlap_has_long_indel(const OverlapWindow& overlap, const CorrectionAlignments& alignments) {
bool long_indel = false;
const auto& cigar = alignments.cigars[overlap.overlap_idx];
size_t max_cigar_idx = std::min(size_t(overlap.cigar_end_idx + 1), cigar.size());
for (size_t i = overlap.cigar_start_idx; i < max_cigar_idx; i++) {
if (cigar[i].op == CigarOpType::INS || cigar[i].op == CigarOpType::DEL) {
long_indel |= cigar[i].len >= 30;
if (cigar[i].len >= 30 &&
(cigar[i].op == CigarOpType::INS || cigar[i].op == CigarOpType::DEL)) {
LOG_TRACE("filter ? tstart {} qstart {} qend {} long indel: {}", overlap.tstart,
overlap.qstart, overlap.qend, cigar[i].len);
return true;
}
}
LOG_TRACE("filter ? tstart {} qstart {} qend {} res {}", overlap.tstart, overlap.qstart,
overlap.qend, long_indel);
return long_indel;
return false;
}

// Measure the accuracy of an alignment segment within a window
// determined from the cigar string.
void calculate_accuracy(OverlapWindow& overlap,
const CorrectionAlignments& alignments,
size_t win_idx,
int win_len,
int window_size) {
int tstart = overlap.tstart;
int tend = (int)win_idx * window_size + win_len;

// get query region
const auto overlap_idx = overlap.overlap_idx;
int oqstart = alignments.overlaps[overlap_idx].qstart;
int oqend = alignments.overlaps[overlap_idx].qend;
int qstart, qend;
if (alignments.overlaps[overlap_idx].fwd) {
qstart = oqstart + overlap.qstart;
qend = oqstart + overlap.qend;
} else {
qstart = oqend - overlap.qend;
qend = oqend - overlap.qstart;
}

int qlen = qend - qstart;

// Fetch subsequences
std::string tseq = alignments.read_seq.substr(tstart, tend - tstart);
std::string qseq;
if (alignments.overlaps[overlap_idx].fwd) {
qseq = alignments.seqs[overlap_idx].substr(qstart, qlen);
} else {
qseq = utils::reverse_complement(alignments.seqs[overlap_idx].substr(qstart, qlen));
}

LOG_TRACE("tstart {} tend {} qstart {} qend {} cig st {} cig end {}", tstart, tend, qstart,
qend, overlap.cigar_start_idx, overlap.cigar_end_idx);

void calculate_accuracy(OverlapWindow& overlap, const CorrectionAlignments& alignments) {
const auto& cigar = alignments.cigars[overlap.overlap_idx];
bool has_warned_bad_cigar_op = false;

// Calculate accuracy
int tpos = 0, qpos = 0;
int m = 0, s = 0, i = 0, d = 0;
// counts of match, mismatch, insert, deletion
int n_match = 0, n_miss = 0, n_ins = 0, n_del = 0;

for (int idx = overlap.cigar_start_idx; idx <= overlap.cigar_end_idx; idx++) {
int len = -1;
Expand All @@ -91,40 +61,30 @@ void calculate_accuracy(OverlapWindow& overlap,
break;
}

LOG_TRACE("len {} tpos {} qpos {}", len, tpos, qpos);

switch (cigar[idx].op) {
case CigarOpType::MATCH:
for (int j = 0; j < len; j++) {
auto tbase = tseq[tpos + j];
auto qbase = qseq[qpos + j];
LOG_TRACE("{} tbase {}, {} qbase {}", tpos + j, tbase, qpos + j, qbase);

if (tbase == qbase) {
m += 1;
} else {
s += 1;
}
}

tpos += len;
qpos += len;
case CigarOpType::EQ_MATCH:
n_match += len;
break;
case CigarOpType::X_MISMATCH:
n_miss += len;
break;
case CigarOpType::INS:
i += len;
qpos += len;
n_ins += len;
break;
case CigarOpType::DEL:
d += len;
tpos += len;
n_del += len;
break;
default:
if (!has_warned_bad_cigar_op) {
has_warned_bad_cigar_op = true;
LOG_TRACE("unexpected CigarOpType: {}", uint8_t(cigar[idx].op));
}
break;
}
}

overlap.accuracy = (static_cast<float>(m) / (m + s + i + d));
LOG_TRACE("m {} s {} i {} d {}", m, s, i, d);
overlap.accuracy = (static_cast<float>(n_match) / (n_match + n_miss + n_ins + n_del));
LOG_TRACE("m {} s {} i {} d {}", n_match, n_miss, n_ins, n_del);
LOG_TRACE("accuracy qstart {} qend {} {}", overlap.qstart, overlap.qend, overlap.accuracy);
}

Expand All @@ -148,7 +108,8 @@ std::vector<int> get_max_ins_for_window(const std::vector<OverlapWindow>& overla
int len = cigar[i].len;

int l = -1;
if (op == CigarOpType::MATCH || op == CigarOpType::DEL) {
if (op == CigarOpType::EQ_MATCH || op == CigarOpType::X_MISMATCH ||
op == CigarOpType::DEL) {
l = len;
} else if (op == CigarOpType::INS) {
max_ins[tpos - 1] = std::max(len, max_ins[tpos - 1]);
Expand Down Expand Up @@ -283,8 +244,8 @@ std::tuple<at::Tensor, at::Tensor> get_features_for_window(
LOG_TRACE("cigar_idx {} l {}", cigar_idx, l);

switch (op) {
case CigarOpType::MATCH:
case CigarOpType::MISMATCH:
case CigarOpType::EQ_MATCH:
case CigarOpType::X_MISMATCH:
for (uint32_t i = 0; i < l; i++) {
auto base = base_encoding[uint8_t(qseq[query_iter]) + (fwd ? 0 : 32)];
auto qual = qqual[query_iter];
Expand Down Expand Up @@ -414,18 +375,10 @@ at::Tensor get_indices(const at::Tensor& bases, const std::vector<std::pair<int,
.clone();
}

// Main interface function for generating features for each window
// given the overlaps for a target read.
std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWindow>>& windows,
const CorrectionAlignments& alignments,
int window_size) {
const std::string& tseq = alignments.read_seq;
int tlen = (int)tseq.length();

std::vector<WindowFeatures> wfs;
std::unordered_set<int> filter_features(std::vector<std::vector<OverlapWindow>>& windows,
const CorrectionAlignments& alignments) {
std::unordered_set<int> overlap_idxs;
for (int w = 0; w < (int)windows.size(); w++) {
int win_len = (w == (int)windows.size() - 1) ? tlen - window_size * w : window_size;
LOG_TRACE("win idx {}: win len {}", w, win_len);
auto& overlap_windows = windows[w];

// Filter overlaps with very large indels
Expand All @@ -442,7 +395,7 @@ std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWind
// Sort overlaps by score
if (overlap_windows.size() > 1) {
for (auto& ovlp : overlap_windows) {
calculate_accuracy(ovlp, alignments, w, win_len, window_size);
calculate_accuracy(ovlp, alignments);
}
// Sort the filtered overlaps by accuracy score
std::sort(overlap_windows.begin(), overlap_windows.end(),
Expand All @@ -456,14 +409,39 @@ std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWind
return a.accuracy > b.accuracy;
});
}
// Take the TOP_K best overlaps
overlap_windows.resize(std::min(TOP_K, (int)overlap_windows.size()));

#ifndef NDEBUG
if (overlap_windows.size() == 1) {
LOG_TRACE("window {} 1st {}-{}", w, overlap_windows[0].qstart, overlap_windows[0].qend);
} else if (overlap_windows.size() > 1) {
LOG_TRACE("window {} 1st {}-{} 2nd {}-{}", w, overlap_windows[0].qstart,
overlap_windows[0].qend, overlap_windows[1].qstart, overlap_windows[1].qend);
}
#endif

for (const auto& ov : overlap_windows) {
assert(ov.overlap_idx >= 0);
overlap_idxs.insert(ov.overlap_idx);
}
}
return overlap_idxs;
}

// Main interface function for generating features for the top_k overlaps for each window
// given the overlaps for a target read.
std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWindow>>& windows,
const CorrectionAlignments& alignments,
int window_size) {
const std::string& tseq = alignments.read_seq;
int tlen = (int)tseq.length();

std::vector<WindowFeatures> wfs;
for (int w = 0; w < (int)windows.size(); w++) {
int win_len = (w == (int)windows.size() - 1) ? tlen - window_size * w : window_size;
LOG_TRACE("win idx {}: win len {}", w, win_len);
auto& overlap_windows = windows[w];

WindowFeatures wf;
wf.window_idx = w;
Expand Down
6 changes: 6 additions & 0 deletions dorado/correct/features.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@

#include "types.h"

#include <unordered_set>

namespace dorado {
struct CorrectionAlignments;
}

namespace dorado::correction {

// Filter window features to TOP_K best. Returns collection of useful overlap indices
std::unordered_set<int> filter_features(std::vector<std::vector<OverlapWindow>>& windows,
const CorrectionAlignments& alignments);

std::vector<WindowFeatures> extract_features(std::vector<std::vector<OverlapWindow>>& windows,
const CorrectionAlignments& alignments,
int window_size);
Expand Down
1 change: 1 addition & 0 deletions dorado/correct/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
namespace dorado::correction {

struct OverlapWindow {
// CorrectionAlignments overlap vector index
int overlap_idx = -1;
int tstart = -1;
int qstart = -1;
Expand Down
29 changes: 14 additions & 15 deletions dorado/correct/windows.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include <spdlog/spdlog.h>

#include <stdexcept>

#ifdef NDEBUG
#define LOG_TRACE(...)
#else
Expand All @@ -30,10 +32,6 @@ bool extract_windows(std::vector<std::vector<OverlapWindow>>& windows,
for (int aln_idx = 0; aln_idx < num_alignments; aln_idx++) {
const auto& overlap = alignments.overlaps[aln_idx];
const auto& cigar = alignments.cigars[aln_idx];
//if (alignments.qnames[a] != "e3066d3e-2bdf-4803-89b9-0f077ac7ff7f") {
// continue;
//}
LOG_TRACE("window for {}", alignments.qnames[aln_idx]);

// Following the is_target == False logic form the rust code.
if ((overlap.tend - overlap.tstart < window_size) ||
Expand Down Expand Up @@ -62,10 +60,10 @@ bool extract_windows(std::vector<std::vector<OverlapWindow>>& windows,
spdlog::error(
"{} zeroth thres {} nth thres {} first win {} last win {} windows size {} "
"overlap "
"tlen {} overlsp tstart {} overlap tend {} qname {} qlen {} qstart {} qend {}",
"tlen {} overlsp tstart {} overlap tend {} qlen {} qstart {} qend {}",
alignments.read_name, zeroth_window_thresh, nth_window_thresh, first_window,
last_window, windows.size(), overlap.tlen, overlap.tstart, overlap.tend,
alignments.qnames[aln_idx], overlap.qlen, overlap.qstart, overlap.qend);
overlap.qlen, overlap.qstart, overlap.qend);
return false;
}

Expand Down Expand Up @@ -102,8 +100,8 @@ bool extract_windows(std::vector<std::vector<OverlapWindow>>& windows,
int tnew = tpos;
int qnew = qpos;
switch (op.op) {
case CigarOpType::MATCH:
case CigarOpType::MISMATCH:
case CigarOpType::EQ_MATCH:
case CigarOpType::X_MISMATCH:
tnew = tpos + op.len;
qnew = qpos + op.len;
LOG_TRACE("{} {}", op.len, "M");
Expand All @@ -117,7 +115,7 @@ bool extract_windows(std::vector<std::vector<OverlapWindow>>& windows,
LOG_TRACE("{} {}", op.len, "I");
continue;
default:
continue;
throw std::runtime_error("unexpected CigarOpType");
}

LOG_TRACE("tpos {} qpos {} tnew {} qnew {}", tpos, qpos, tnew, qnew);
Expand All @@ -137,9 +135,10 @@ bool extract_windows(std::vector<std::vector<OverlapWindow>>& windows,
for (int i = 1; i < diff_w; i++) {
int offset = (current_w + i) * window_size - tpos;

int q_start_new = (op.op == CigarOpType::MATCH || op.op == CigarOpType::MISMATCH)
? qpos + offset
: qpos;
int q_start_new =
(op.op == CigarOpType::EQ_MATCH || op.op == CigarOpType::X_MISMATCH)
? qpos + offset
: qpos;

if (cigar_start_idx >= 0) {
windows[(current_w + i) - 1].push_back(
Expand All @@ -154,7 +153,7 @@ bool extract_windows(std::vector<std::vector<OverlapWindow>>& windows,

t_window_start = tpos + offset;

if (op.op == CigarOpType::MATCH || op.op == CigarOpType::MISMATCH) {
if (op.op == CigarOpType::EQ_MATCH || op.op == CigarOpType::X_MISMATCH) {
q_window_start = qpos + offset;
} else {
q_window_start = qpos;
Expand All @@ -165,7 +164,7 @@ bool extract_windows(std::vector<std::vector<OverlapWindow>>& windows,
} else {
t_window_start = tpos + offset;

if (op.op == CigarOpType::MATCH || op.op == CigarOpType::MISMATCH) {
if (op.op == CigarOpType::EQ_MATCH || op.op == CigarOpType::X_MISMATCH) {
q_window_start = qpos + offset;
} else {
q_window_start = qpos;
Expand All @@ -179,7 +178,7 @@ bool extract_windows(std::vector<std::vector<OverlapWindow>>& windows,
LOG_TRACE("new_w {} window size {} tpos {}", new_w, window_size, tpos);
int offset = new_w * window_size - tpos;

int qend = (op.op == CigarOpType::MATCH || op.op == CigarOpType::MISMATCH)
int qend = (op.op == CigarOpType::EQ_MATCH || op.op == CigarOpType::X_MISMATCH)
? qpos + offset
: qpos;

Expand Down
Loading

0 comments on commit 6178d43

Please sign in to comment.