From 5109968142d83ac07989407e12fffa21a542007f Mon Sep 17 00:00:00 2001 From: niklasmueboe Date: Wed, 24 Jul 2024 15:20:47 +0200 Subject: [PATCH] add lowmemory mode --- src/KDE.jl | 129 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 92 insertions(+), 37 deletions(-) diff --git a/src/KDE.jl b/src/KDE.jl index c0f5f68..106f469 100644 --- a/src/KDE.jl +++ b/src/KDE.jl @@ -2,7 +2,7 @@ module KDE export gaussiankernel, kde, assigncelltype -using ..GridCount: GridCounts +using ..GridCount: GridCounts, gridsize using Base.Broadcast: @__dot__ using Base.Threads: @threads @@ -76,34 +76,65 @@ end # Celltype assignment function chunk_slices(i, step, n, pad) - bound1 = (i - 1) * step + 1 - bound2 = i * step + start = (i - 1) * step + 1 + stop = i * step - slice = max(1, bound1 + pad[1]):min(n, bound2 + pad[2]) - reslice = range(1 + max(0, bound1 - slice[1]); length=min(step, n - bound1 + 1)) - return slice, reslice + slice = max(1, start + pad[1]):min(n, stop + pad[2]) + unpad = range(1 + max(0, start - slice[1]); length=min(step, n - start + 1)) + return slice, unpad end -function chunk(counts, kernel, sx=500, sy=500) - m, n = size(first(counts)) - pad_x = extrema(axes(kernel, 1)) - pad_y = extrema(axes(kernel, 2)) +function padinfo(counts, kernel) + padrow = extrema(axes(kernel, 1)) + padcol = extrema(axes(kernel, 2)) + return size(first(counts)), (padrow, padcol) +end + +function chunkinfo(counts, kernel, chunksize) + function chunklengths(n, s, pad) + lengths = Int[] + for i in 1:cld(n, s) + _, unpad = chunk_slices(i, s, n, pad) + push!(lengths, length(unpad)) + end + return lengths + end + srow, scol = chunksize + (m, n), (padrow, padcol) = padinfo(counts, kernel) + + return chunklengths(m, srow, padrow), chunklengths(n, scol, padcol) +end + +function getchunk(counts, kernel, i, j; chunksize=(500, 500)) + (m, n), (padrow, padcol) = padinfo(counts, kernel) + srow, scol = chunksize + + slicerow, unpadrow = chunk_slices(i, srow, m, padrow) + slicecol, unpadcol = chunk_slices(j, scol, n, padcol) + chunk = map(x -> x[slicerow, slicecol], counts) + + return chunk, (unpadrow, unpadcol) +end + +function getchunks(counts, kernel; chunksize=(500, 500)) + (m, n), (padrow, padcol) = padinfo(counts, kernel) + srow, scol = chunksize chunks = Dict{Block{2,Int},Any}() colslices = UnitRange{Int}[] rowslices = UnitRange{Int}[] - for j in 1:cld(n, sx) - slice_x, reslice_x = chunk_slices(j, sx, n, pad_x) - col_chunk = map(x -> x[:, slice_x], counts) - push!(colslices, reslice_x) + for j in 1:cld(n, scol) + slicecol, unpadcol = chunk_slices(j, scol, n, padcol) + col_chunk = map(x -> x[:, slicecol], counts) + push!(colslices, unpadcol) - for i in 1:cld(m, sy) - slice_y, reslice_y = chunk_slices(i, sy, m, pad_y) - chunks[Block(i, j)] = map(x -> x[slice_y, :], col_chunk) + for i in 1:cld(m, srow) + slicerow, unpadrow = chunk_slices(i, srow, m, padrow) + chunks[Block(i, j)] = map(x -> x[slicerow, :], col_chunk) if j == 1 - push!(rowslices, reslice_y) + push!(rowslices, unpadrow) end end end @@ -213,8 +244,17 @@ The `eltype(kernel)` will be used for calculations and `signatures` will be cast - `celltypes::Vector{AbstractString}=nothing`: optional celltype names. - `log::Bool`: whether to log-transform the KDE. Useful if `signatures` are calculated from log-transformed gene expression. +- `lowmemory::Bool`: if false will use more memory but may have shorter runtime. """ -function assigncelltype(counts, signatures, kernel; celltypes=nothing, log=false) +function assigncelltype( + counts, + signatures, + kernel; + celltypes=nothing, + log=false, + lowmemory=true, + chunksize=(500, 500), +) if !isnothing(celltypes) && length(celltypes) != nrow(signatures) error("Length of 'celltypes' must match number of rows in 'signatures'") end @@ -243,24 +283,39 @@ function assigncelltype(counts, signatures, kernel; celltypes=nothing, log=false end end - chunked_counts, rowslices, colslices = chunk([counts[g] for g in genes], kernel) - rows, cols = length.(rowslices), length.(colslices) - - cosine = BlockArray(undef_blocks, Matrix{T}, rows, cols) - score = BlockArray(undef_blocks, Matrix{T}, rows, cols) - celltypemap = BlockArray(undef_blocks, Matrix{U}, rows, cols) - - @threads for (i, r) in collect(enumerate(rowslices)) - @threads for (j, c) in collect(enumerate(colslices)) - idx = Block(i, j) - celltypemap[idx], cosine[idx], score[idx] = calculatecosinesim( - pop!(chunked_counts, idx), - signatures, - sigcorrection, - kernel, - (r, c); - log=log, - ) + m, n = gridsize(counts) + + counts = [counts[g] for g in genes] + + chunklengths = chunkinfo(counts, kernel, chunksize) + cosine = BlockArray(undef_blocks, Matrix{T}, chunklengths...) + score = BlockArray(undef_blocks, Matrix{T}, chunklengths...) + celltypemap = BlockArray(undef_blocks, Matrix{U}, chunklengths...) + + if lowmemory + srow, scol = chunksize + @threads for i in 1:cld(m, srow) + @threads for j in 1:cld(n, scol) + idx = Block(i, j) + chunk, unpad = getchunk(counts, kernel, i, j; chunksize=chunksize) + celltypemap[idx], cosine[idx], score[idx] = calculatecosinesim( + chunk, signatures, sigcorrection, kernel, unpad; log=log + ) + end + end + else + chunked_counts, rowslices, colslices = getchunks( + counts, kernel; chunksize=chunksize + ) + + @threads for (i, r) in collect(enumerate(rowslices)) + @threads for (j, c) in collect(enumerate(colslices)) + idx = Block(i, j) + chunk = pop!(chunked_counts, idx) + celltypemap[idx], cosine[idx], score[idx] = calculatecosinesim( + chunk, signatures, sigcorrection, kernel, (r, c); log=log + ) + end end end