Skip to content

Commit

Permalink
add lowmemory mode
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasmueboe committed Jul 24, 2024
1 parent 77aea86 commit 5109968
Showing 1 changed file with 92 additions and 37 deletions.
129 changes: 92 additions & 37 deletions src/KDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module KDE

export gaussiankernel, kde, assigncelltype

using ..GridCount: GridCounts
using ..GridCount: GridCounts, gridsize

using Base.Broadcast: @__dot__
using Base.Threads: @threads
Expand Down Expand Up @@ -76,34 +76,65 @@ end

# Celltype assignment
function chunk_slices(i, step, n, pad)
bound1 = (i - 1) * step + 1
bound2 = i * step
start = (i - 1) * step + 1
stop = i * step

slice = max(1, bound1 + pad[1]):min(n, bound2 + pad[2])
reslice = range(1 + max(0, bound1 - slice[1]); length=min(step, n - bound1 + 1))
return slice, reslice
slice = max(1, start + pad[1]):min(n, stop + pad[2])
unpad = range(1 + max(0, start - slice[1]); length=min(step, n - start + 1))
return slice, unpad
end

function chunk(counts, kernel, sx=500, sy=500)
m, n = size(first(counts))
pad_x = extrema(axes(kernel, 1))
pad_y = extrema(axes(kernel, 2))
function padinfo(counts, kernel)
padrow = extrema(axes(kernel, 1))
padcol = extrema(axes(kernel, 2))
return size(first(counts)), (padrow, padcol)
end

function chunkinfo(counts, kernel, chunksize)
function chunklengths(n, s, pad)
lengths = Int[]
for i in 1:cld(n, s)
_, unpad = chunk_slices(i, s, n, pad)
push!(lengths, length(unpad))
end
return lengths
end
srow, scol = chunksize
(m, n), (padrow, padcol) = padinfo(counts, kernel)

return chunklengths(m, srow, padrow), chunklengths(n, scol, padcol)
end

function getchunk(counts, kernel, i, j; chunksize=(500, 500))
(m, n), (padrow, padcol) = padinfo(counts, kernel)
srow, scol = chunksize

slicerow, unpadrow = chunk_slices(i, srow, m, padrow)
slicecol, unpadcol = chunk_slices(j, scol, n, padcol)
chunk = map(x -> x[slicerow, slicecol], counts)

return chunk, (unpadrow, unpadcol)
end

function getchunks(counts, kernel; chunksize=(500, 500))
(m, n), (padrow, padcol) = padinfo(counts, kernel)
srow, scol = chunksize

chunks = Dict{Block{2,Int},Any}()

colslices = UnitRange{Int}[]
rowslices = UnitRange{Int}[]

for j in 1:cld(n, sx)
slice_x, reslice_x = chunk_slices(j, sx, n, pad_x)
col_chunk = map(x -> x[:, slice_x], counts)
push!(colslices, reslice_x)
for j in 1:cld(n, scol)
slicecol, unpadcol = chunk_slices(j, scol, n, padcol)
col_chunk = map(x -> x[:, slicecol], counts)
push!(colslices, unpadcol)

for i in 1:cld(m, sy)
slice_y, reslice_y = chunk_slices(i, sy, m, pad_y)
chunks[Block(i, j)] = map(x -> x[slice_y, :], col_chunk)
for i in 1:cld(m, srow)
slicerow, unpadrow = chunk_slices(i, srow, m, padrow)
chunks[Block(i, j)] = map(x -> x[slicerow, :], col_chunk)
if j == 1
push!(rowslices, reslice_y)
push!(rowslices, unpadrow)
end
end
end
Expand Down Expand Up @@ -213,8 +244,17 @@ The `eltype(kernel)` will be used for calculations and `signatures` will be cast
- `celltypes::Vector{AbstractString}=nothing`: optional celltype names.
- `log::Bool`: whether to log-transform the KDE. Useful if `signatures` are calculated
from log-transformed gene expression.
- `lowmemory::Bool`: if false will use more memory but may have shorter runtime.
"""
function assigncelltype(counts, signatures, kernel; celltypes=nothing, log=false)
function assigncelltype(
counts,
signatures,
kernel;
celltypes=nothing,
log=false,
lowmemory=true,
chunksize=(500, 500),
)
if !isnothing(celltypes) && length(celltypes) != nrow(signatures)
error("Length of 'celltypes' must match number of rows in 'signatures'")
end
Expand Down Expand Up @@ -243,24 +283,39 @@ function assigncelltype(counts, signatures, kernel; celltypes=nothing, log=false
end
end

chunked_counts, rowslices, colslices = chunk([counts[g] for g in genes], kernel)
rows, cols = length.(rowslices), length.(colslices)

cosine = BlockArray(undef_blocks, Matrix{T}, rows, cols)
score = BlockArray(undef_blocks, Matrix{T}, rows, cols)
celltypemap = BlockArray(undef_blocks, Matrix{U}, rows, cols)

@threads for (i, r) in collect(enumerate(rowslices))
@threads for (j, c) in collect(enumerate(colslices))
idx = Block(i, j)
celltypemap[idx], cosine[idx], score[idx] = calculatecosinesim(
pop!(chunked_counts, idx),
signatures,
sigcorrection,
kernel,
(r, c);
log=log,
)
m, n = gridsize(counts)

counts = [counts[g] for g in genes]

chunklengths = chunkinfo(counts, kernel, chunksize)
cosine = BlockArray(undef_blocks, Matrix{T}, chunklengths...)
score = BlockArray(undef_blocks, Matrix{T}, chunklengths...)
celltypemap = BlockArray(undef_blocks, Matrix{U}, chunklengths...)

if lowmemory
srow, scol = chunksize
@threads for i in 1:cld(m, srow)
@threads for j in 1:cld(n, scol)
idx = Block(i, j)
chunk, unpad = getchunk(counts, kernel, i, j; chunksize=chunksize)
celltypemap[idx], cosine[idx], score[idx] = calculatecosinesim(
chunk, signatures, sigcorrection, kernel, unpad; log=log
)
end
end
else
chunked_counts, rowslices, colslices = getchunks(
counts, kernel; chunksize=chunksize
)

@threads for (i, r) in collect(enumerate(rowslices))
@threads for (j, c) in collect(enumerate(colslices))
idx = Block(i, j)
chunk = pop!(chunked_counts, idx)
celltypemap[idx], cosine[idx], score[idx] = calculatecosinesim(
chunk, signatures, sigcorrection, kernel, (r, c); log=log
)
end
end
end

Expand Down

0 comments on commit 5109968

Please sign in to comment.