From 5bc659f956bed872555c13f90d2d6b88672f8b94 Mon Sep 17 00:00:00 2001 From: LTLA Date: Sun, 12 Jan 2025 12:36:30 -0800 Subject: [PATCH] Manually allocate thread-local vectors to protect against false sharing. This is more memory-efficient than tatami_stats::LocalOutputBuffer as we don't create any allocations in the t = 0 case, and we avoid storing unnecessary NULL pointers in the t > 0 case. More efficiency is important as there might be an arbitrarily large number of groups and we don't want memory usage to be dominated by the parallization overhead. Incidentally, this allows us to remove the tatami_stats dependency. --- CMakeLists.txt | 3 +- cmake/Config.cmake.in | 1 - extern/CMakeLists.txt | 7 -- .../aggregate_across_cells.hpp | 73 ++++++++++++------- 4 files changed, 46 insertions(+), 38 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f67d58d..2ce6ec7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,10 +23,9 @@ if(SCRAN_AGGREGATE_FETCH_EXTERN) add_subdirectory(extern) else() find_package(tatami_tatami 3.0.0 CONFIG REQUIRED) - find_package(tatami_tatami_stats 1.1.0 CONFIG REQUIRED) endif() -target_link_libraries(scran_aggregate INTERFACE tatami::tatami tatami::tatami_stats) +target_link_libraries(scran_aggregate INTERFACE tatami::tatami) # Tests if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) diff --git a/cmake/Config.cmake.in b/cmake/Config.cmake.in index a236203..8397be9 100644 --- a/cmake/Config.cmake.in +++ b/cmake/Config.cmake.in @@ -2,6 +2,5 @@ include(CMakeFindDependencyMacro) find_dependency(tatami_tatami 3.0.0 CONFIG REQUIRED) -find_dependency(tatami_tatami_stats 1.1.0 CONFIG REQUIRED) include("${CMAKE_CURRENT_LIST_DIR}/libscran_scran_aggregateTargets.cmake") diff --git a/extern/CMakeLists.txt b/extern/CMakeLists.txt index a34208d..c7828fe 100644 --- a/extern/CMakeLists.txt +++ b/extern/CMakeLists.txt @@ -6,11 +6,4 @@ FetchContent_Declare( GIT_TAG master # ^3.0.0 ) -FetchContent_Declare( - tatami_stats - GIT_REPOSITORY https://github.com/tatami-inc/tatami_stats - GIT_TAG master # ^1.1.0 -) - FetchContent_MakeAvailable(tatami) -FetchContent_MakeAvailable(tatami_stats) diff --git a/include/scran_aggregate/aggregate_across_cells.hpp b/include/scran_aggregate/aggregate_across_cells.hpp index 391d597..fefa4e1 100644 --- a/include/scran_aggregate/aggregate_across_cells.hpp +++ b/include/scran_aggregate/aggregate_across_cells.hpp @@ -5,7 +5,6 @@ #include #include "tatami/tatami.hpp" -#include "tatami_stats/tatami_stats.hpp" /** * @file aggregate_across_cells.hpp @@ -175,23 +174,37 @@ void compute_aggregate_by_column( tatami::Options opt; opt.sparse_ordered_index = false; - tatami::parallelize([&](size_t t, Index_ s, Index_ l) { + tatami::parallelize([&](size_t t, Index_ start, Index_ length) { auto NC = p.ncol(); - auto ext = tatami::consecutive_extractor(&p, false, static_cast(0), NC, s, l, opt); - std::vector vbuffer(l); - typename std::conditional, Index_>::type ibuffer(l); - + auto ext = tatami::consecutive_extractor(&p, false, static_cast(0), NC, start, length, opt); + std::vector vbuffer(length); + typename std::conditional, Index_>::type ibuffer(length); + + // Creating local buffers to protect against false sharing in all but + // the first thread. The first thread has the honor of writing directly + // to the output buffers, to avoid extra allocations in the serial case + // where no false sharing can occur. + std::vector > local_sums; + std::vector > local_detected; size_t num_sums = buffers.sums.size(); - std::vector > local_sums; - local_sums.reserve(num_sums); - for (auto ptr : buffers.sums) { - local_sums.emplace_back(t, s, l, ptr); - } size_t num_detected = buffers.detected.size(); - std::vector > local_detected; - local_detected.reserve(num_detected); - for (auto ptr : buffers.detected) { - local_detected.emplace_back(t, s, l, ptr); + if (t != 0) { + local_sums.reserve(num_sums); + for (size_t s = 0; s < num_sums; ++s) { + local_sums.emplace_back(length); + } + local_detected.reserve(num_detected); + for (size_t d = 0; d < num_detected; ++d) { + local_detected.emplace_back(length); + } + } else { + // Need to zero it in the first thread for consistency with the other threads. + for (size_t s = 0; s < num_sums; ++s) { + std::fill_n(buffers.sums[s] + start, length, static_cast(0)); + } + for (size_t d = 0; d < num_sums; ++d) { + std::fill_n(buffers.detected[d] + start, length, static_cast(0)); + } } for (Index_ x = 0; x < NC; ++x) { @@ -200,42 +213,46 @@ void compute_aggregate_by_column( if constexpr(sparse_) { auto col = ext->fetch(vbuffer.data(), ibuffer.data()); if (num_sums) { - auto cursum = local_sums[current].data(); + auto cursum = (t != 0 ? local_sums[current].data() : buffers.sums[current] + start); for (Index_ i = 0; i < col.number; ++i) { - cursum[col.index[i] - s] += col.value[i]; + cursum[col.index[i] - start] += col.value[i]; } } if (num_detected) { - auto curdetected = local_detected[current].data(); + auto curdetected = (t != 0 ? local_detected[current].data() : buffers.detected[current] + start); for (Index_ i = 0; i < col.number; ++i) { - curdetected[col.index[i] - s] += (col.value[i] > 0); + curdetected[col.index[i] - start] += (col.value[i] > 0); } } } else { auto col = ext->fetch(vbuffer.data()); if (num_sums) { - auto cursum = local_sums[current].data(); - for (Index_ i = 0; i < l; ++i) { + auto cursum = (t != 0 ? local_sums[current].data() : buffers.sums[current] + start); + for (Index_ i = 0; i < length; ++i) { cursum[i] += col[i]; } } if (num_detected) { - auto curdetected = local_detected[current].data(); - for (Index_ i = 0; i < l; ++i) { + auto curdetected = (t != 0 ? local_detected[current].data() : buffers.detected[current] + start); + for (Index_ i = 0; i < length; ++i) { curdetected[i] += (col[i] > 0); } } } } - for (auto& lsums : local_sums) { - lsums.transfer(); - } - for (auto& ldetected : local_detected) { - ldetected.transfer(); + if (t != 0) { + for (size_t s = 0; s < num_sums; ++s) { + const auto& current = local_sums[s]; + std::copy(current.begin(), current.end(), buffers.sums[s] + start); + } + for (size_t d = 0; d < num_detected; ++d) { + const auto& current = local_detected[d]; + std::copy(current.begin(), current.end(), buffers.detected[d] + start); + } } }, p.nrow(), options.num_threads); }