diff --git a/lib/cufft/fft.jl b/lib/cufft/fft.jl index 8751ae5749..c235ac2b93 100644 --- a/lib/cufft/fft.jl +++ b/lib/cufft/fft.jl @@ -25,7 +25,7 @@ Base.:(*)(p::ScaledPlan, x::DenseCuArray) = rmul!(p.p * x, p.scale) # N is the number of dimensions -mutable struct CuFFTPlan{T<:cufftNumber,S<:cufftNumber,K,inplace,N} <: Plan{S} +mutable struct CuFFTPlan{T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B} <: Plan{S} # handle to Cuda low level plan. Note that this plan sometimes has lower dimensions # to handle more transform cases such as individual directions handle::cufftHandle @@ -33,24 +33,26 @@ mutable struct CuFFTPlan{T<:cufftNumber,S<:cufftNumber,K,inplace,N} <: Plan{S} stream::CuStream input_size::NTuple{N,Int} # Julia size of input array output_size::NTuple{N,Int} # Julia size of output array - region::Any + region::NTuple{R,Int} + buffer::B # buffer for out-of-place complex-to-real FFT, or `nothing` if not needed pinv::ScaledPlan{T} # required by AbstractFFTs API, will be defined by AbstractFFTs if needed - function CuFFTPlan{T,S,K,inplace,N}(handle::cufftHandle, - input_size::NTuple{N,Int}, output_size::NTuple{N,Int}, region - ) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N} + function CuFFTPlan{T,S,K,inplace,N,R,B}(handle::cufftHandle, + input_size::NTuple{N,Int}, output_size::NTuple{N,Int}, + region::NTuple{R,Int}, buffer::B + ) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B} abs(K) == 1 || throw(ArgumentError("FFT direction must be either -1 (forward) or +1 (inverse)")) inplace isa Bool || throw(ArgumentError("FFT inplace argument must be a Bool")) - p = new{T,S,K,inplace,N}(handle, context(), stream(), input_size, output_size, region) + p = new{T,S,K,inplace,N,R,B}(handle, context(), stream(), input_size, output_size, region, buffer) finalizer(unsafe_free!, p) p end end -function CuFFTPlan{T,S,K,inplace,N}(handle::cufftHandle, X::DenseCuArray{S,N}, - sizey::NTuple{N,Int}, region, - ) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N} - CuFFTPlan{T,S,K,inplace,N}(handle, size(X), sizey, region) +function CuFFTPlan{T,S,K,inplace,N,R,B}(handle::cufftHandle, X::DenseCuArray{S,N}, + sizey::NTuple{N,Int}, region::NTuple{R,Int}, buffer::B + ) where {T<:cufftNumber,S<:cufftNumber,K,inplace,N,R,B} + CuFFTPlan{T,S,K,inplace,N,R,B}(handle, size(X), sizey, region, buffer) end function CUDA.unsafe_free!(plan::CuFFTPlan) @@ -60,6 +62,9 @@ function CUDA.unsafe_free!(plan::CuFFTPlan) end plan.handle = C_NULL end + if !isnothing(plan.buffer) + CUDA.unsafe_free!(plan.buffer) + end end function showfftdims(io, sz, T) @@ -151,103 +156,116 @@ end function plan_fft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} K = CUFFT_FORWARD inplace = true - region = Tuple(region) + R = length(region) + region = NTuple{R,Int}(region) md = plan_max_dims(region, size(X)) sizex = size(X)[1:md] handle = cufftGetPlan(T, T, sizex, region) - CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region) + CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing) end - function plan_bfft!(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} K = CUFFT_INVERSE inplace = true - region = Tuple(region) + R = length(region) + region = NTuple{R,Int}(region) md = plan_max_dims(region, size(X)) sizex = size(X)[1:md] handle = cufftGetPlan(T, T, sizex, region) - CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region) + CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing) end # out-of-place complex function plan_fft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} K = CUFFT_FORWARD inplace = false - region = Tuple(region) + R = length(region) + region = NTuple{R,Int}(region) md = plan_max_dims(region,size(X)) sizex = size(X)[1:md] handle = cufftGetPlan(T, T, sizex, region) - CuFFTPlan{T,T,K,inplace,N}(handle, X, size(X), region) + CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, X, size(X), region, nothing) end function plan_bfft(X::DenseCuArray{T,N}, region) where {T<:cufftComplexes,N} K = CUFFT_INVERSE inplace = false - region = Tuple(region) + R = length(region) + region = NTuple{R,Int}(region) md = plan_max_dims(region,size(X)) sizex = size(X)[1:md] handle = cufftGetPlan(T, T, sizex, region) - CuFFTPlan{T,T,K,inplace,N}(handle, size(X), size(X), region) + CuFFTPlan{T,T,K,inplace,N,R,Nothing}(handle, size(X), size(X), region, nothing) end # out-of-place real-to-complex function plan_rfft(X::DenseCuArray{T,N}, region) where {T<:cufftReals,N} K = CUFFT_FORWARD inplace = false - region = Tuple(region) + R = length(region) + region = NTuple{R,Int}(region) md = plan_max_dims(region,size(X)) - # X = front_view(X, md) sizex = size(X)[1:md] handle = cufftGetPlan(complex(T), T, sizex, region) ydims = collect(size(X)) - ydims[region[1]] = div(ydims[region[1]],2)+1 + ydims[region[1]] = div(ydims[region[1]], 2) + 1 - CuFFTPlan{complex(T),T,K,inplace,N}(handle, size(X), (ydims...,), region) + # The buffer is not needed for real-to-complex (`mul!`), + # but it’s required for complex-to-real (`ldiv!`). + buffer = CuArray{complex(T)}(undef, ydims...) + B = typeof(buffer) + + CuFFTPlan{complex(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer) end -function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region::Any) where {T<:cufftComplexes,N} +# out-of-place complex-to-real +function plan_brfft(X::DenseCuArray{T,N}, d::Integer, region) where {T<:cufftComplexes,N} K = CUFFT_INVERSE inplace = false - region = Tuple(region) + R = length(region) + region = NTuple{R,Int}(region) ydims = collect(size(X)) ydims[region[1]] = d handle = cufftGetPlan(real(T), T, (ydims...,), region) - CuFFTPlan{real(T),T,K,inplace,N}(handle, size(X), (ydims...,), region) + buffer = CuArray{T}(undef, size(X)) + B = typeof(buffer) + + CuFFTPlan{real(T),T,K,inplace,N,R,B}(handle, size(X), (ydims...,), region, buffer) end # FIXME: plan_inv methods allocate needlessly (to provide type parameters) # Perhaps use FakeArray types to avoid this. -function plan_inv(p::CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N} - ) where {T<:cufftNumber,S<:cufftNumber,N,inplace} +function plan_inv(p::CuFFTPlan{T,S,CUFFT_INVERSE,inplace,N,R,B} + ) where {T<:cufftNumber,S<:cufftNumber,inplace,N,R,B} md_osz = plan_max_dims(p.region, p.output_size) sz_X = p.output_size[1:md_osz] handle = cufftGetPlan(S, T, sz_X, p.region) - ScaledPlan(CuFFTPlan{S,T,CUFFT_FORWARD,inplace,N}(handle, p.output_size, p.input_size, p.region), + ScaledPlan(CuFFTPlan{S,T,CUFFT_FORWARD,inplace,N,R,B}(handle, p.output_size, p.input_size, p.region, p.buffer), normalization(real(T), p.output_size, p.region)) end -function plan_inv(p::CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N} - ) where {T<:cufftNumber,S<:cufftNumber,N,inplace} +function plan_inv(p::CuFFTPlan{T,S,CUFFT_FORWARD,inplace,N,R,B} + ) where {T<:cufftNumber,S<:cufftNumber,inplace,N,R,B} md_isz = plan_max_dims(p.region, p.input_size) sz_Y = p.input_size[1:md_isz] handle = cufftGetPlan(S, T, sz_Y, p.region) - ScaledPlan(CuFFTPlan{S,T,CUFFT_INVERSE,inplace,N}(handle, p.output_size, p.input_size, p.region), + ScaledPlan(CuFFTPlan{S,T,CUFFT_INVERSE,inplace,N,R,B}(handle, p.output_size, p.input_size, p.region, p.buffer), normalization(real(S), p.input_size, p.region)) end @@ -309,10 +327,14 @@ function LinearAlgebra.mul!(y::DenseCuArray{T}, p::CuFFTPlan{T,S,K,inplace}, x:: ) where {T,S,K,inplace} assert_applicable(p, x, y) if !inplace && T<:Real - # Out-of-place complex-to-real FFT will always overwrite input buffer. - x = copy(x) + # Out-of-place complex-to-real FFT will always overwrite input x. + # We copy the input x in an auxiliary buffer. + z = p.buffer + copyto!(z, x) + else + z = x end - unsafe_execute_trailing!(p, x, y) + unsafe_execute_trailing!(p, z, y) y end @@ -323,13 +345,21 @@ function Base.:(*)(p::CuFFTPlan{T,S,K,true}, x::DenseCuArray{S}) where {T,S,K} end function Base.:(*)(p::CuFFTPlan{T,S,K,false}, x::DenseCuArray{S1,M}) where {T,S,K,S1,M} - if S1 != S || T<:Real - # Convert to the expected input type. Also, - # Out-of-place complex-to-real FFT will always overwrite input buffer. - x = copy1(S, x) + if T<:Real + # Out-of-place complex-to-real FFT will always overwrite input x. + # We copy the input x in an auxiliary buffer. + z = p.buffer + copyto!(z, x) + else + if S1 != S + # Convert to the expected input type. + z = copy1(S, x) + else + z = x + end end - assert_applicable(p, x) + assert_applicable(p, z) y = CuArray{T,M}(undef, p.output_size) - unsafe_execute_trailing!(p, x, y) + unsafe_execute_trailing!(p, z, y) y end