Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into smalton/DOR-276-pai…
Browse files Browse the repository at this point in the history
…rs-crash
  • Loading branch information
tijyojwad committed Aug 8, 2023
2 parents 8d243e5 + d131786 commit 5119185
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 17 deletions.
16 changes: 8 additions & 8 deletions cmake/Torch.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ else()
endif()
else()
if (TRY_USING_STATIC_TORCH_LIB)
set(TORCH_URL https://cdn.oxfordnanoportal.com/software/analysis/torch-2.0.0-linux-aarch64-ont.zip)
set(TORCH_PATCH_SUFFIX -ont)
set(TORCH_URL https://cdn.oxfordnanoportal.com/software/analysis/torch-2.0.0.1-linux-aarch64-ont.zip)
set(TORCH_PATCH_SUFFIX -ont.1)
set(TORCH_LIB_SUFFIX "/libtorch")
set(USING_STATIC_TORCH_LIB TRUE)
else()
Expand All @@ -89,11 +89,11 @@ else()
else()
if (TRY_USING_STATIC_TORCH_LIB)
if(DORADO_USING_OLD_CPP_ABI)
set(TORCH_URL https://cdn.oxfordnanoportal.com/software/analysis/torch-2.0.0-linux-x64-ont-pre-cxx11.zip)
set(TORCH_PATCH_SUFFIX -ont-pre-cxx11)
set(TORCH_URL https://cdn.oxfordnanoportal.com/software/analysis/torch-2.0.0.1-linux-x64-ont-pre-cxx11.zip)
set(TORCH_PATCH_SUFFIX -ont.1-pre-cxx11)
else()
set(TORCH_URL https://cdn.oxfordnanoportal.com/software/analysis/torch-2.0.0-linux-x64-ont-cxx11-abi.zip)
set(TORCH_PATCH_SUFFIX -ont-cxx11-abi)
set(TORCH_URL https://cdn.oxfordnanoportal.com/software/analysis/torch-2.0.0.1-linux-x64-ont-cxx11-abi.zip)
set(TORCH_PATCH_SUFFIX -ont.1-cxx11-abi)
endif()
set(USING_STATIC_TORCH_LIB TRUE)
else()
Expand Down Expand Up @@ -135,8 +135,8 @@ else()
endif()
elseif(WIN32)
if (TRY_USING_STATIC_TORCH_LIB)
set(TORCH_URL https://cdn.oxfordnanoportal.com/software/analysis/torch-2.0.0.2-Windows-ont.zip)
set(TORCH_PATCH_SUFFIX -ont.2)
set(TORCH_URL https://cdn.oxfordnanoportal.com/software/analysis/torch-2.0.0.3-Windows-ont.zip)
set(TORCH_PATCH_SUFFIX -ont.3)
set(TORCH_LIB_SUFFIX "/libtorch")
set(USING_STATIC_TORCH_LIB TRUE)
add_compile_options(
Expand Down
7 changes: 5 additions & 2 deletions dorado/nn/CudaCRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ class CudaCaller {
return granularity;
}

const int max_batch_size = available / (bytes_per_chunk_timestep * chunk_size_out);
const int64_t max_batch_size_limit = 10240;
const int max_batch_size = std::min(available / (bytes_per_chunk_timestep * chunk_size_out),
max_batch_size_limit);
if (max_batch_size < utils::pad_to(128, granularity) + granularity) {
spdlog::warn("Auto batchsize detection failed. Estimated max batch size only {}.",
max_batch_size);
Expand Down Expand Up @@ -166,7 +168,8 @@ class CudaCaller {
handle_cuda_result(cudaEventDestroy(stop));
}

spdlog::debug("Auto batchsize: {}, time per chunk {} ms", batch_size, time);
spdlog::debug("Auto batchsize {}: {}, time per chunk {} ms", m_device, batch_size,
time);
if (time < best_time) {
best_time = time;
best_batch_size = batch_size;
Expand Down
23 changes: 18 additions & 5 deletions dorado/nn/Runners.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "Runners.h"

#include "cxxpool.h"
#include "decode/CPUDecoder.h"

#if DORADO_GPU_BUILD
Expand All @@ -26,7 +27,7 @@ std::pair<std::vector<dorado::Runner>, size_t> create_basecall_runners(
std::vector<dorado::Runner> runners;

// Default is 1 device. CUDA path may alter this.
int num_devices = 1;
size_t num_devices = 1;

if (device == "cpu") {
num_runners = std::thread::hardware_concurrency();
Expand Down Expand Up @@ -61,14 +62,26 @@ std::pair<std::vector<dorado::Runner>, size_t> create_basecall_runners(
if (num_devices == 0) {
throw std::runtime_error("CUDA device requested but no devices found.");
}

cxxpool::thread_pool pool{num_devices};
std::vector<std::shared_ptr<CudaCaller>> callers;
std::vector<std::future<std::shared_ptr<dorado::CudaCaller>>> futures;

for (auto device_string : devices) {
auto caller = dorado::create_cuda_caller(model_config, chunk_size, batch_size,
device_string, memory_fraction, guard_gpus);
futures.push_back(pool.push(dorado::create_cuda_caller, model_config, chunk_size,
batch_size, device_string, memory_fraction, guard_gpus));
}

for (auto& caller : futures) {
callers.push_back(caller.get());
}

for (size_t j = 0; j < num_devices; j++) {
for (size_t i = 0; i < num_runners; i++) {
runners.push_back(std::make_shared<dorado::CudaModelRunner>(caller));
runners.push_back(std::make_shared<dorado::CudaModelRunner>(callers[j]));
}
if (runners.back()->batch_size() != batch_size) {
spdlog::debug("- set batch size for {} to {}", device_string,
spdlog::debug("- set batch size for {} to {}", devices[j],
runners.back()->batch_size());
}
}
Expand Down
12 changes: 10 additions & 2 deletions dorado/read_pipeline/ProgressTracker.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class ProgressTracker {
if (m_num_duplex_reads_filtered > 0) {
spdlog::info("> Duplex reads filtered: {}", m_num_duplex_reads_filtered);
}
spdlog::info("> Duplex rate: {}%",
((static_cast<float>(m_num_duplex_bases_processed) * 2) /
m_num_simplex_bases_processed) *
100);
}
if (m_num_bases_processed > 0) {
std::ostringstream samples_sec;
Expand Down Expand Up @@ -76,10 +80,12 @@ class ProgressTracker {
m_num_simplex_reads_written = fetch_stat("HtsWriter.unique_simplex_reads_written");

m_num_simplex_reads_filtered = fetch_stat("ReadFilterNode.simplex_reads_filtered");
m_num_bases_processed = fetch_stat("BasecallerNode.bases_processed");
m_num_simplex_bases_processed = fetch_stat("BasecallerNode.bases_processed");
m_num_bases_processed = m_num_simplex_bases_processed;
m_num_samples_processed = fetch_stat("BasecallerNode.samples_processed");
if (m_duplex) {
m_num_bases_processed += fetch_stat("StereoBasecallerNode.bases_processed");
m_num_duplex_bases_processed = fetch_stat("StereoBasecallerNode.bases_processed");
m_num_bases_processed += m_num_duplex_bases_processed;
m_num_samples_processed += fetch_stat("StereoBasecallerNode.samples_processed");
}
m_num_duplex_reads_written = fetch_stat("HtsWriter.duplex_reads_written");
Expand Down Expand Up @@ -115,6 +121,8 @@ class ProgressTracker {
private:
int64_t m_num_bases_processed{0};
int64_t m_num_samples_processed{0};
int64_t m_num_simplex_bases_processed{0};
int64_t m_num_duplex_bases_processed{0};
int m_num_reads_processed{0};
int m_num_simplex_reads_written{0};
int m_num_simplex_reads_filtered{0};
Expand Down

0 comments on commit 5119185

Please sign in to comment.