Skip to content

Commit

Permalink
Merge branch 'bucket-sort-quantile' into 'master'
Browse files Browse the repository at this point in the history
counting sort quantile impl

See merge request machine-learning/dorado!99
  • Loading branch information
iiSeymour committed Oct 4, 2022
2 parents 8c76bec + 734e8e1 commit 3e893dc
Show file tree
Hide file tree
Showing 10 changed files with 150 additions and 35 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ endif()
set(LIB_SOURCE_FILES
dorado/cli/basecaller.cpp
dorado/cli/download.cpp
dorado/cli/benchmark.cpp
dorado/cli/cli.h
dorado/models.h
dorado/nn/CRFModel.h
Expand Down
2 changes: 1 addition & 1 deletion dorado/cli/basecaller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void setup(std::vector<std::string> args,
basecaller_node = std::make_unique<BasecallerNode>(
writer_node, std::move(runners), batch_size, chunk_size, overlap, model_stride);
}
ScalerNode scaler_node(*basecaller_node, num_devices * 5);
ScalerNode scaler_node(*basecaller_node, num_devices * 2);
DataLoader loader(scaler_node, "cpu", num_devices);
loader.load_reads(data_path);
}
Expand Down
67 changes: 67 additions & 0 deletions dorado/cli/benchmark.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include "../utils/tensor_utils.h"
#include "Version.h"
#include "torch/torch.h"

#include <argparse.hpp>

#include <chrono>
#include <iostream>

int benchmark(int argc, char* argv[]) {
argparse::ArgumentParser parser("dorado", DORADO_VERSION);

try {
parser.parse_args(argc, argv);
} catch (const std::exception& e) {
std::cerr << e.what() << std::endl;
std::cerr << parser;
std::exit(1);
}

std::vector<size_t> sizes{1000, 1000, 2000, 3000, 4000, 10000, 100000, 1000000, 10000000};

for (auto n : sizes) {
std::cerr << "samples : " << n << std::endl;

// generate some input
auto x = torch::randint(0, 2047, n);
auto q = torch::tensor({0.2, 0.9}, {torch::kFloat32});

// torch::quantile
auto start = std::chrono::system_clock::now();
auto res = torch::quantile(x, q);
auto end = std::chrono::system_clock::now();

auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();

std::cerr << "torch:quant "
<< " q20=" << res[0].item<int>() << " q90=" << res[1].item<int>() << " "
<< duration << "us" << std::endl;

// nth_element
start = std::chrono::system_clock::now();
res = ::utils::quantile(x, q);
end = std::chrono::system_clock::now();

duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();

std::cerr << "nth_element "
<< " q20=" << res[0].item<int>() << " q90=" << res[1].item<int>() << " "
<< duration << "us" << std::endl;

x = x.to(torch::kInt16);

// counting
start = std::chrono::system_clock::now();
res = ::utils::quantile_counting(x, q);
end = std::chrono::system_clock::now();
duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start).count();

std::cerr << "counting "
<< " q20=" << res[0].item<int>() << " q90=" << res[1].item<int>() << " "
<< duration << "us" << std::endl
<< std::endl;
}

return 0;
}
27 changes: 11 additions & 16 deletions dorado/data_loader/DataLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,13 @@ std::shared_ptr<Read> process_pod5_read(size_t row,
total_sample_count += signal_rows[i]->stored_sample_count;
}

std::vector<std::int16_t> samples(total_sample_count);
auto options = torch::TensorOptions().dtype(torch::kInt16);
auto samples = torch::empty(total_sample_count, options);

std::size_t samples_read_so_far = 0;
for (std::size_t i = 0; i < signal_row_count; ++i) {
if (pod5_get_signal(file, signal_rows[i], signal_rows[i]->stored_sample_count,
samples.data() + samples_read_so_far) != POD5_OK) {
samples.data_ptr<int16_t>() + samples_read_so_far) != POD5_OK) {
std::scoped_lock lock(cerr_mtx);
std::cerr << "Failed to get read " << row << " signal: " << pod5_get_error_string()
<< "\n";
Expand All @@ -187,19 +189,12 @@ std::shared_ptr<Read> process_pod5_read(size_t row,
samples_read_so_far += signal_rows[i]->stored_sample_count;
}

std::vector<float> floatTmp(samples.begin(), samples.end());

auto new_read = std::make_shared<Read>();

auto options = torch::TensorOptions().dtype(torch::kFloat32);
new_read->raw_data =
torch::from_blob(floatTmp.data(), floatTmp.size(), options).clone().to(device);

new_read->raw_data = samples;
auto start_time_ms = run_acquisition_start_time_ms + ((start_sample * 1000) / run_sample_rate);
auto start_time = get_string_timestamp_from_unix_time(start_time_ms);
new_read->scaling = calib_data->scale;
new_read->offset = calib_data->offset;
new_read->scale_set = true;
new_read->read_id = read_id_str;
new_read->num_trimmed_samples = 0;
new_read->attributes.read_number = read_number;
Expand Down Expand Up @@ -303,9 +298,10 @@ void DataLoader::load_fast5_reads_from_file(const std::string& path) {
if (ds.getDataType().string() != "Integer16")
throw std::runtime_error("Invalid FAST5 Signal data type of " +
ds.getDataType().string());
std::vector<int16_t> tmp;
ds.read(tmp);
std::vector<float> floatTmp(tmp.begin(), tmp.end());

auto options = torch::TensorOptions().dtype(torch::kInt16);
auto samples = torch::empty(ds.getElementCount(), options);
ds.read(samples.data_ptr<int16_t>());

HighFive::Attribute mux_attr = raw.getAttribute("start_mux");
HighFive::Attribute read_number_attr = raw.getAttribute("read_number");
Expand All @@ -329,13 +325,12 @@ void DataLoader::load_fast5_reads_from_file(const std::string& path) {
auto start_time_str =
adjust_time(exp_start_time, static_cast<uint32_t>(start_time / sampling_rate));

auto options = torch::TensorOptions().dtype(torch::kFloat32);
auto new_read = std::make_shared<Read>();
new_read->raw_data =
torch::from_blob(floatTmp.data(), floatTmp.size(), options).clone().to(m_device);
new_read->raw_data = samples;
new_read->digitisation = digitisation;
new_read->range = range;
new_read->offset = offset;
new_read->scaling = range / digitisation;
new_read->read_id = read_id;
new_read->num_trimmed_samples = 0;
new_read->attributes.mux = mux;
Expand Down
2 changes: 1 addition & 1 deletion dorado/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ int main(int argc, char* argv[]) {
}

return 0;
}
}
1 change: 0 additions & 1 deletion dorado/read_pipeline/ReadPipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class Read {
float shift; // To be set by scaler
float scale; // To be set by scaler

bool scale_set = false; // Set to True if scale has been applied to raw data
float scaling; // Scale factor applied to convert raw integers from sequencer into pore current values

size_t num_chunks; // Number of chunks in the read. Reads raw data is split into chunks for efficient basecalling.
Expand Down
23 changes: 9 additions & 14 deletions dorado/read_pipeline/ScalerNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace std::chrono_literals;

std::pair<float, float> normalisation(torch::Tensor& x) {
//Calculate shift and scale factors for normalisation.
auto quantiles = utils::quantile(x, torch::tensor({0.2, 0.9}));
auto quantiles = utils::quantile_counting(x, torch::tensor({0.2, 0.9}));
float q20 = quantiles[0].item<float>();
float q90 = quantiles[1].item<float>();
float shift = std::max(10.0f, 0.51f * (q20 + q90));
Expand All @@ -35,21 +35,16 @@ void ScalerNode::worker_thread() {
m_reads.pop_front();
lock.unlock();

if (!read->scale_set) {
read->scaling = (float)read->range / (float)read->digitisation;
read->scale_set = true;
}

read->raw_data = read->scaling * (read->raw_data + read->offset);

auto [shift, scale] = normalisation(read->raw_data);
read->shift = shift;
read->scale = scale;
read->raw_data = (read->raw_data - read->shift) / read->scale;
read->raw_data = (read->raw_data - shift) / scale;

// move the shift and scale into pA.
read->scale = read->scaling * scale;
read->shift = read->scaling * (shift + read->offset);

float threshold = shift + scale * 2.4;
float threshold = read->shift + read->scale * 2.4;

//8000 value may be changed in future. Currently this is found to work well.
// 8000 value may be changed in future. Currently this is found to work well.
int trim_start =
trim(read->raw_data.index({torch::indexing::Slice(torch::indexing::None, 8000)}),
threshold);
Expand Down Expand Up @@ -93,7 +88,7 @@ int ScalerNode::trim(torch::Tensor signal,
float max_trim) {
int min_trim = 10;
bool seen_peak = false;
int num_samples = std::min(max_samples, (int)signal.size(0));
int num_samples = std::min(max_samples, static_cast<int>(signal.size(0)));
int num_windows = num_samples / window_size;

for (int pos = 0; pos < num_windows; pos++) {
Expand Down
34 changes: 32 additions & 2 deletions dorado/utils/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torch/csrc/jit/serialization/pickle.h>

#include <fstream>
#include <vector>

namespace utils {

Expand All @@ -27,8 +28,7 @@ std::vector<torch::Tensor> load_tensors(const std::filesystem::path& dir,
}

torch::Tensor quantile(const torch::Tensor t, const torch::Tensor q) {
assert(t.dtype().name() == "float");
assert(q.dtype().name() == "float");
assert(q.dtype() == torch::kF32);

auto tmp = t.clone();
auto [qval, qidx] = q.sort();
Expand All @@ -48,4 +48,34 @@ torch::Tensor quantile(const torch::Tensor t, const torch::Tensor q) {
return res;
}

torch::Tensor quantile_counting(const torch::Tensor t, const torch::Tensor q) {
assert(q.dtype() == torch::kF32);

auto p = t.data_ptr<int16_t>();
auto range_min = t.min().item<int16_t>();
auto range_max = t.max().item<int16_t>();

int size = t.size(0);

std::vector<int> counts(range_max - range_min + 1, 0);
for (int i = 0; i < size; ++i) {
counts[p[i] - range_min]++;
}
std::partial_sum(counts.begin(), counts.end(), counts.begin());

auto res = torch::empty_like(q);

for (size_t idx = 0; idx < q.numel(); idx++) {
int threshold = q[idx].item<float>() * (size - 1);
for (int i = 0; i <= counts.size(); ++i) {
if (counts[i] > threshold) {
res[idx] = i + range_min;
break;
}
}
}

return res;
}

} // namespace utils
8 changes: 8 additions & 0 deletions dorado/utils/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ void serialise_tensor(torch::Tensor t, const std::string& path);
std::vector<torch::Tensor> load_tensors(const std::filesystem::path& dir,
const std::vector<std::string>& tensors);

// Computes the q-th quantiles of each row of the input tensor `t`
// using a partial sort as opposed a full sort per torch::quantiles
// Only `interpolation='lower'` is currently implemented.
torch::Tensor quantile(const torch::Tensor t, const torch::Tensor q);

// Computes the q-th quantiles of each row of the input tensor `t`
// using a counting sort which is extremely fast for low range integers.
// Only `interpolation='lower'` is currently implemented.
torch::Tensor quantile_counting(const torch::Tensor t, const torch::Tensor q);

} // namespace utils
20 changes: 20 additions & 0 deletions tests/TensorUtilsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,23 @@ TEST_CASE(CUT_TAG ": test quantiles", CUT_TAG) {

REQUIRE(torch::equal(computed, expected));
}

TEST_CASE(CUT_TAG ": test quartiles_counting", CUT_TAG) {
auto in = torch::randint(0, 2047, 1000);
auto q = torch::tensor({0.25, 0.5, 0.75}, {torch::kFloat});

auto expected = torch::quantile(in, q, 0, false, c10::string_view("lower"));
auto computed = ::utils::quantile_counting(in.to(torch::kI16), q);

REQUIRE(torch::equal(computed, expected));
}

TEST_CASE(CUT_TAG ": test quantiles_counting", CUT_TAG) {
auto in = torch::randint(0, 2047, 1000);
auto q = torch::tensor({0.2, 0.9}, {torch::kFloat});

auto expected = torch::quantile(in, q, 0, false, c10::string_view("lower"));
auto computed = ::utils::quantile_counting(in.to(torch::kI16), q);

REQUIRE(torch::equal(computed, expected));
}

0 comments on commit 3e893dc

Please sign in to comment.