Skip to content

Commit

Permalink
[CUFFT] Preallocate a buffer for complex-to-real FFT (#2578)
Browse files Browse the repository at this point in the history
* [CUFFT] Preallocate a buffer for complex-to-real FFT

* Update cufft.jl

* Fix new errors in fft.jl

* More fixes in fft.jl

* Allocate a buffer in both plan_rfft and plan_brfft

* Allocate a buffer in both plan_rfft and plan_brfft

* Update lib/cufft/fft.jl

Co-authored-by: Tim Besard <[email protected]>

---------

Co-authored-by: Tim Besard <[email protected]>
  • Loading branch information
amontoison and maleadt authored Dec 14, 2024
1 parent ca8f6cf commit 19a08ef
Showing 1 changed file with 71 additions and 41 deletions.
112 changes: 71 additions & 41 deletions lib/cufft/fft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,34 @@ Base.:(*)(p::ScaledPlan, x::DenseCuArray) = rmul!(p.p * x, p.scale)

# N is the number of dimensions

mutable struct CuFFTPlan{T<:cufftNumber,S<:cufftNumber,K,inplace,N} <: Plan{S}
mutable struct CuFFTPlan{T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B} <: Plan{S}
# handle to Cuda low level plan. Note that this plan sometimes has lower dimensions
# to handle more transform cases such as individual directions
handle::cufftHandle
ctx::CuContext
stream::CuStream
input_size::NTuple{N,Int} # Julia size of input array
output_size::NTuple{N,Int} # Julia size of output array
region::Any
region::NTuple{R,Int}
buffer::B # buffer for out-of-place complex-to-real FFT, or `nothing` if not needed
pinv::ScaledPlan{T} # required by AbstractFFTs API, will be defined by AbstractFFTs if needed

function CuFFTPlan{T,S,K,inplace,N}(handle::cufftHandle,
input_size::NTuple{N,Int}, output_size::NTuple{N,Int}, region
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N}
function CuFFTPlan{T,S,K,inplace,N,R,B}(handle::cufftHandle,
input_size::NTuple{N,Int}, output_size::NTuple{N,Int},
region::NTuple{R,Int}, buffer::B
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B}
abs(K) == 1 || throw(ArgumentError("FFT direction must be either -1 (forward) or +1 (inverse)"))
inplace isa Bool || throw(ArgumentError("FFT inplace argument must be a Bool"))
p = new{T,S,K,inplace,N}(handle, context(), stream(), input_size, output_size, region)
p = new{T,S,K,inplace,N,R,B}(handle, context(), stream(), input_size, output_size, region, buffer)
finalizer(unsafe_free!, p)
p
end
end

function CuFFTPlan{T,S,K,inplace,N}(handle::cufftHandle, X::DenseCuArray{S,N},
sizey::NTuple{N,Int}, region,
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N}
CuFFTPlan{T,S,K,inplace,N}(handle, size(X), sizey, region)
function CuFFTPlan{T,S,K,inplace,N,R,B}(handle::cufftHandle, X::DenseCuArray{S,N},
sizey::NTuple{N,Int}, region::NTuple{R,Int}, buffer::B
) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B}
CuFFTPlan{T,S,K,inplace,N,R,B}(handle, size(X), sizey, region, buffer)
end

function CUDA.unsafe_free!(plan::CuFFTPlan)
Expand All @@ -60,6 +62,9 @@ function CUDA.unsafe_free!(plan::CuFFTPlan)
end
plan.handle = C_NULL
end
if !isnothing(plan.buffer)
CUDA.unsafe_free!(plan.buffer)
end
end

function showfftdims(io, sz, T)
Expand Down Expand Up @@ -151,103 +156,116 @@ end
function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
K = CUFFT_FORWARD
inplace = true
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region, size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)

CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region)
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end


function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
K = CUFFT_INVERSE
inplace = true
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region, size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)

CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region)
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end

# out-of-place complex
function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
K = CUFFT_FORWARD
inplace = false
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)

CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region)
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing)
end

function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N}
K = CUFFT_INVERSE
inplace = false
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region,size(X))
sizex = size(X)[1:md]
handle = cufftGetPlan(T, T, sizex, region)

CuFFTPlan{T,T,K,inplace,N}(handle, size(X), size(X), region)
CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, size(X), size(X), region, nothing)
end

# out-of-place real-to-complex
function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N}
K = CUFFT_FORWARD
inplace = false
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

md = plan_max_dims(region,size(X))
# X = front_view(X, md)
sizex = size(X)[1:md]

handle = cufftGetPlan(complex(T), T, sizex, region)

ydims = collect(size(X))
ydims[region[1]] = div(ydims[region[1]],2)+1
ydims[region[1]] = div(ydims[region[1]], 2) + 1

CuFFTPlan{complex(T),T,K,inplace,N}(handle, size(X), (ydims...,), region)
# The buffer is not needed for real-to-complex (`mul!`),
# but it’s required for complex-to-real (`ldiv!`).
buffer = CuArray{complex(T)}(undef, ydims...)
B = typeof(buffer)

CuFFTPlan{complex(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer)
end

function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::Any) where {T<:cufftComplexes,N}
# out-of-place complex-to-real
function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N}
K = CUFFT_INVERSE
inplace = false
region = Tuple(region)
R = length(region)
region = NTuple{R,Int}(region)

ydims = collect(size(X))
ydims[region[1]] = d

handle = cufftGetPlan(real(T), T, (ydims...,), region)

CuFFTPlan{real(T),T,K,inplace,N}(handle, size(X), (ydims...,), region)
buffer = CuArray{T}(undef, size(X))
B = typeof(buffer)

CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer)
end


# FIXME: plan_inv methods allocate needlessly (to provide type parameters)
# Perhaps use FakeArray types to avoid this.

function plan_inv(p::CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N}
) where {T<:cufftNumber,S<:cufftNumber,N,inplace}
function plan_inv(p::CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N,R,B}
) where {T<:cufftNumber,S<:cufftNumber,inplace,N,R,B}
md_osz = plan_max_dims(p.region, p.output_size)
sz_X = p.output_size[1:md_osz]
handle = cufftGetPlan(S, T, sz_X, p.region)
ScaledPlan(CuFFTPlan{S,T,CUFFT_FORWARD,inplace,N}(handle, p.output_size, p.input_size, p.region),
ScaledPlan(CuFFTPlan{S,T,CUFFT_FORWARD,inplace,N,R,B}(handle, p.output_size, p.input_size, p.region, p.buffer),
normalization(real(T), p.output_size, p.region))
end

function plan_inv(p::CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N}
) where {T<:cufftNumber,S<:cufftNumber,N,inplace}
function plan_inv(p::CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N,R,B}
) where {T<:cufftNumber,S<:cufftNumber,inplace,N,R,B}
md_isz = plan_max_dims(p.region, p.input_size)
sz_Y = p.input_size[1:md_isz]
handle = cufftGetPlan(S, T, sz_Y, p.region)
ScaledPlan(CuFFTPlan{S,T,CUFFT_INVERSE,inplace,N}(handle, p.output_size, p.input_size, p.region),
ScaledPlan(CuFFTPlan{S,T,CUFFT_INVERSE,inplace,N,R,B}(handle, p.output_size, p.input_size, p.region, p.buffer),
normalization(real(S), p.input_size, p.region))
end

Expand Down Expand Up @@ -309,10 +327,14 @@ function LinearAlgebra.mul!(y::DenseCuArray{T}, p::CuFFTPlan{T,S,K,inplace}, x::
) where {T,S,K,inplace}
assert_applicable(p, x, y)
if !inplace && T<:Real
# Out-of-place complex-to-real FFT will always overwrite input buffer.
x = copy(x)
# Out-of-place complex-to-real FFT will always overwrite input x.
# We copy the input x in an auxiliary buffer.
z = p.buffer
copyto!(z, x)
else
z = x
end
unsafe_execute_trailing!(p, x, y)
unsafe_execute_trailing!(p, z, y)
y
end

Expand All @@ -323,13 +345,21 @@ function Base.:(*)(p::CuFFTPlan{T,S,K,true}, x::DenseCuArray{S}) where {T,S,K}
end

function Base.:(*)(p::CuFFTPlan{T,S,K,false}, x::DenseCuArray{S1,M}) where {T,S,K,S1,M}
if S1 != S || T<:Real
# Convert to the expected input type. Also,
# Out-of-place complex-to-real FFT will always overwrite input buffer.
x = copy1(S, x)
if T<:Real
# Out-of-place complex-to-real FFT will always overwrite input x.
# We copy the input x in an auxiliary buffer.
z = p.buffer
copyto!(z, x)
else
if S1 != S
# Convert to the expected input type.
z = copy1(S, x)
else
z = x
end
end
assert_applicable(p, x)
assert_applicable(p, z)
y = CuArray{T,M}(undef, p.output_size)
unsafe_execute_trailing!(p, x, y)
unsafe_execute_trailing!(p, z, y)
y
end

0 comments on commit 19a08ef

Please sign in to comment.