diff --git a/src/FFTW.jl b/src/FFTW.jl index 4366ee7..bc1c7c6 100644 --- a/src/FFTW.jl +++ b/src/FFTW.jl @@ -68,6 +68,7 @@ end include("fft.jl") include("dct.jl") +include("rfft!.jl") include("precompile.jl") _precompile_() diff --git a/src/rfft!.jl b/src/rfft!.jl new file mode 100644 index 0000000..7f5cc36 --- /dev/null +++ b/src/rfft!.jl @@ -0,0 +1,266 @@ +import Base: IndexStyle, getindex, setindex!, eltype, \, similar, copy, real, read! + +export PaddedRFFTArray, plan_rfft!, rfft!, plan_irfft!, plan_brfft!, brfft!, irfft! + +_check_valid_strides(a::AbstractArray) = strides(a) == colmajorstrides(size(a)) + +# This struct reinterprets the `data` array to a Complex or Float array, depending on `eltype(data)` +# It is used internally with the PaddedRFFTArray in place of `Base.ReinterpretArray` +# ReinterpretArray has some performance issues when reinterprreting a Complex array to Real +struct ComplexOrRealReinterpretArray{T<:fftwNumber,N,A<:AbstractArray{<:fftwNumber,N},B<:Ptr} <: DenseArray{T,N} + data::A # Either a real or complex array + _unsafe_pointer::B # Pointer to the `data` array, but converted to a different type representation. + + function ComplexOrRealReinterpretArray(rarray::AbstractArray{T,N}) where {T<:fftwReal,N} + _check_valid_strides(rarray) || throw( + ArgumentError("Only contiguous strided arrays are supported")) + ptr = unsafe_convert(Ptr{Complex{T}}, pointer(rarray)) + return new{Complex{T},N,typeof(rarray),typeof(ptr)}(rarray,ptr) + end + + function ComplexOrRealReinterpretArray(carray::AbstractArray{T,N}) where {T<:fftwComplex,N} + _check_valid_strides(carray) || throw( + ArgumentError("Only contiguous strided arrays are supported")) + FT = T === ComplexF64 ? Float64 : Float32 + ptr = unsafe_convert(Ptr{FT}, pointer(carray)) + return new{FT,N,typeof(carray),typeof(ptr)}(carray,ptr) + end +end + +const RealReinterpretArray{N} = ComplexOrRealReinterpretArray{<:fftwReal,N,<:AbstractArray{<:fftwComplex,N}} +const ComplexReinterpretArray{N} = ComplexOrRealReinterpretArray{<:fftwComplex,N,<:AbstractArray{<:fftwReal,N}} + +@inline size_convertion(::RealReinterpretArray,i::Integer) = 2i +@inline size_convertion(::ComplexReinterpretArray,i::Integer) = i÷2 + +IndexStyle(::Type{T}) where {T<:ComplexOrRealReinterpretArray} = IndexLinear() + +Base.size(a::ComplexOrRealReinterpretArray) = + ntuple(i->(i == 1 ? size_convertion(a,size(a.data)[i]) : size(a.data)[i]),Val(ndims(a.data))) + +Base.@propagate_inbounds function getindex(a::ComplexOrRealReinterpretArray,i::Integer) + data = a.data + @boundscheck checkbounds(a,i) + GC.@preserve data r = unsafe_load(a._unsafe_pointer, i) + return r +end + +Base.@propagate_inbounds function setindex!(a::ComplexOrRealReinterpretArray,v,i::Integer) + data = a.data + @boundscheck checkbounds(a,i) + GC.@preserve data unsafe_store!(a._unsafe_pointer,v, i) + return a +end + +Base.unsafe_convert(p::Type{Ptr{T}}, a::ComplexOrRealReinterpretArray{T,N}) where {T,N} = Base.unsafe_convert(p,a.data) + +Base.elsize(::Type{<:ComplexOrRealReinterpretArray{T,N}}) where {T,N} = sizeof(T) + +complex_or_real_reinterpret(a::AbstractArray) = ComplexOrRealReinterpretArray(a) +complex_or_real_reinterpret(a::ComplexOrRealReinterpretArray) = a.data # Avoid nesting of ComplexOrRealReinterpretArrays + +# At the time this code was written the new `ReinterpretArray` in Base had some performace issues. +# Those issues were bypassed with the usage of our simplified version of ReinterpretArray above. +# Hopefully, once the performance issues with ReinterpretArray +# are solved we can just use Base.ReinterpretArray directly. + +struct PaddedRFFTArray{T<:fftwReal,N,R,C,L,Nm1} <: DenseArray{Complex{T},N} + data::R + r::SubArray{T,N,R,Tuple{Base.OneTo{Int},Vararg{Base.Slice{Base.OneTo{Int}},Nm1}},L} # Real view skipping padding + c::C + + function PaddedRFFTArray{T}(rr::AbstractArray{T,N},d::Int) where {T<:fftwReal,N} + fsize = size(rr)[1] + iseven(fsize) || throw( + ArgumentError("First dimension of allocated array must have even number of elements")) + (d == fsize-2 || d == fsize-1) || throw( + ArgumentError("Number of elements on the first dimension of array must be either 1 or 2 less than the number of elements on the first dimension of the allocated array")) + c = complex_or_real_reinterpret(rr) + r = view(rr, Base.OneTo(d), ntuple(i->Colon(),Val(N-1))...) + return new{T, N, typeof(rr), typeof(c), N===1, N-1}(rr,r,c) + end # function + + function PaddedRFFTArray{T}(c::AbstractArray{Complex{T},N},d::Int) where {T<:fftwReal,N} + rr = complex_or_real_reinterpret(c) + fsize = size(rr)[1] + (d == fsize-2 || d == fsize-1) || throw( + ArgumentError("Given first dimension of real array d=$d incompatible with size of complex array $(size(c)). Valid values are d=$(fsize-2) or d=$(fsize-1)")) + r = view(rr, Base.OneTo(d), ntuple(i->Colon(),Val(N-1))...) + return new{T, N, typeof(rr), typeof(c), N===1, N-1}(rr,r,c) + end # function + +end # struct + +PaddedRFFTArray(a::AbstractArray{<:Union{T,Complex{T}},N},d::Int) where {T<:fftwReal,N} = + PaddedRFFTArray{T}(a,d) + +function PaddedRFFTArray{T}(ndims::Vararg{Integer}) where {T} + fsize = (ndims[1]÷2 + 1)*2 + a = zeros(T,(fsize, ndims[2:end]...)) + PaddedRFFTArray{T}(a, ndims[1]) +end + +PaddedRFFTArray{T}(ndims::NTuple{N,Integer}) where {T,N} = + PaddedRFFTArray{T}(ndims...) + +PaddedRFFTArray(ndims::Vararg{Integer,N}) where N = + PaddedRFFTArray{Float64}(ndims...) + +PaddedRFFTArray(ndims::NTuple{N,Integer}) where N = + PaddedRFFTArray{Float64}(ndims...) + +function PaddedRFFTArray(a::AbstractArray{T,N}) where {T<:fftwReal,N} + t = PaddedRFFTArray{T}(size(a)) + @inbounds copyto!(t.r, a) + return t +end + +copy(S::PaddedRFFTArray) = PaddedRFFTArray(copy(S.data),size(S.r,1)) + +similar(f::PaddedRFFTArray,::Type{T},dims::Tuple{Vararg{Int,N}}) where {T, N} = + PaddedRFFTArray{T}(dims) + +similar(f::PaddedRFFTArray{T,N,L},dims::NTuple{N2,Int}) where {T,N,L,N2} = + PaddedRFFTArray{T}(dims) + +similar(f::PaddedRFFTArray,::Type{T}) where {T} = + PaddedRFFTArray{T}(size(f.r)) + +similar(f::PaddedRFFTArray{T}) where {T} = + PaddedRFFTArray{T}(similar(f.data), size(f.r,1)) + +size(S::PaddedRFFTArray) = + size(S.c) + +IndexStyle(::Type{T}) where {T<:PaddedRFFTArray} = + IndexLinear() + +Base.@propagate_inbounds getindex(A::PaddedRFFTArray,i::Integer) = + getindex(A.c,i) + +Base.@propagate_inbounds setindex!(A::PaddedRFFTArray,x, i::Integer) = + setindex!(A.c,x,i) + +Base.unsafe_convert(p::Type{Ptr{Complex{T}}}, a::PaddedRFFTArray{T,N}) where {T,N} = Base.unsafe_convert(p,a.c) + +Base.elsize(::Type{<:PaddedRFFTArray{T,N}}) where {T,N} = sizeof(Complex{T}) + + +function PaddedRFFTArray(stream::IO, dims) + field = PaddedRFFTArray(dims) + return read!(stream,field) +end + +function PaddedRFFTArray{T}(stream::IO, dims) where T + field = PaddedRFFTArray{T}(dims) + return read!(stream,field) +end + +function read!(file::AbstractString, field::PaddedRFFTArray) + open(file) do io + return read!(io,field) + end +end + +# Read a binary file of an unpaded array directly to a PaddedRFFT array, without the need +# of the creation of a intermediary Array. If the data is already padded then the user +# should just use PaddedRFFTArray{T}(read("file",unpaddeddim),d) +function read!(stream::IO, field::PaddedRFFTArray{T,N,L}) where {T,N,L} + rr = field.data + dims = size(field.r) + d = dims[1] + nb = sizeof(T)*d + npencils = prod(dims)÷d + npad = iseven(d) ? 2 : 1 + for i=0:(npencils-1) + unsafe_read(stream,Ref(rr,Int((d+npad)*i+1)),nb) + end + return field +end + + +########################################################################################### +# Foward plans + +function plan_rfft!(X::PaddedRFFTArray{T,N}, region; + flags::Integer=ESTIMATE, + timelimit::Real=NO_TIMELIMIT) where {T<:fftwReal,N} + + (1 in region) || throw(ArgumentError("The first dimension must always be transformed")) + return rFFTWPlan{T,FORWARD,true,N}(X.r, X.c, region, flags, timelimit) +end + +plan_rfft!(f::PaddedRFFTArray;kws...) = plan_rfft!(f, 1:ndims(f); kws...) + +*(p::rFFTWPlan{T,FORWARD,true,N},f::PaddedRFFTArray{T,N}) where {T<:fftwReal,N} = + (mul!(f.c, p, f.r); f) + +rfft!(f::PaddedRFFTArray, region=1:ndims(f)) = plan_rfft!(f, region) * f + +function rfft!(r::SubArray{<:fftwReal}, region=1:ndims(r)) + f = PaddedRFFTArray(parent(r),size(r,1)) + plan_rfft!(f, region) * f +end + +function \(p::rFFTWPlan{T,FORWARD,true,N},f::PaddedRFFTArray{T,N}) where {T<:fftwReal,N} + isdefined(p,:pinv) || (p.pinv = plan_irfft!(f,p.region)) + return p.pinv * f +end + + +########################################################################################## +# Inverse plans + +function plan_brfft!(X::PaddedRFFTArray{T,N}, region; + flags::Integer=ESTIMATE, + timelimit::Real=NO_TIMELIMIT) where {T<:fftwReal,N} + (1 in region) || throw(ArgumentError("The first dimension must always be transformed")) + return rFFTWPlan{Complex{T},BACKWARD,true,N}(X.c, X.r, region, flags,timelimit) +end + +plan_brfft!(f::PaddedRFFTArray;kws...) = plan_brfft!(f,1:ndims(f);kws...) + +*(p::rFFTWPlan{Complex{T},BACKWARD,true,N},f::PaddedRFFTArray{T,N}) where {T<:fftwReal,N} = + (mul!(f.r, p, f.c); f.r) + +brfft!(f::PaddedRFFTArray, region=1:ndims(f)) = plan_brfft!(f, region) * f + +function brfft!(f::PaddedRFFTArray, i::Integer) + if i == size(f.r,1) # Assume `i` is the same as `d` in the brfft!(c::DenseArray{<:fftComplex}, d::Integer, region) defined below + return brfft!(f,1:ndims(f)) + else # Assume `i` is specifying the region. `plan_brfft!` will throw an error if i != 1 + return brfft!(f,(i,)) + end +end + +function brfft!(c::AbstractArray{<:fftwComplex}, d::Integer, region=1:ndims(c)) + f = PaddedRFFTArray(c,d) + plan_brfft!(f, region) * f +end + +function plan_irfft!(x::PaddedRFFTArray{T,N}, region; kws...) where {T,N} + ScaledPlan(plan_brfft!(x, region; kws...),normalization(T, size(x.r), region)) +end + +plan_irfft!(f::PaddedRFFTArray;kws...) = plan_irfft!(f,1:ndims(f);kws...) + +*(p::ScaledPlan{Complex{T},rFFTWPlan{Complex{T},BACKWARD,true,N}},f::PaddedRFFTArray{T,N}) where {T,N} = begin + p.p * f + rmul!(f.data, p.scale) + f.r +end + +irfft!(f::PaddedRFFTArray, region=1:ndims(f)) = plan_irfft!(f,region) * f + +function irfft!(f::PaddedRFFTArray, i::Integer) + if i == size(f.r,1) # Assume `i` is the same as `d` in the irfft!(c::AbstractArray{<:fftComplex}, d::Integer, region) defined below + return irfft!(f,1:ndims(f)) + else # Assume `i` is specifying the region. `plan_irfft!` will throw an error if i != 1 + return irfft!(f,(i,)) + end +end + +function irfft!(c::AbstractArray{<:fftwComplex}, d::Integer, region=1:ndims(c)) + f = PaddedRFFTArray(c,d) + plan_irfft!(f, region) * f +end diff --git a/test/runtests.jl b/test/runtests.jl index 2291003..1ca26ca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -528,3 +528,77 @@ end @test occursin("dft-thr", string(p2)) end end + +let a = rand(Float64,(8,4,4)), b = PaddedRFFTArray(a), c = copy(b) + + @testset "PaddedRFFTArray creation" begin + @test a == b.r + @test c == b + @test c.r == b.r + @test typeof(similar(b)) === typeof(b) + @test size(similar(b,Float32)) === size(b) + @test size(similar(b,Float32).r) === size(b.r) + @test size(similar(b,(4,4,4)).r) === (4,4,4) + @test size(similar(b,Float32,(4,4,4)).r) === (4,4,4) + end + + @testset "rfft! and irfft!" begin + @test rfft(a) ≈ rfft!(b) + @test a ≈ irfft!(b) + @test rfft(a,1:2) ≈ rfft!(b,1:2) + @test a ≈ irfft!(b,1:2) + @test rfft(a,(1,3)) ≈ rfft!(b,(1,3)) + @test a ≈ irfft!(b,(1,3)) + + p = plan_rfft!(c) + @test p*c ≈ rfft!(b) + @test p\c ≈ irfft!(b) + + aa = rand(Float64,(9,4,4)) + bb = PaddedRFFTArray(aa) + @test aa == bb.r + @test rfft(aa) ≈ rfft!(bb) + @test aa ≈ irfft!(bb) + @test rfft(aa,1:2) ≈ rfft!(bb,1:2) + @test aa ≈ irfft!(bb,1:2) + @test rfft(aa,(1,3)) ≈ rfft!(bb,(1,3)) + @test aa ≈ irfft!(bb,(1,3)) + end + + @testset "Read binary file to PaddedRFFTArray" begin + for s in ((8,4,4),(9,4,4),(8,),(9,)) + aa = rand(Float64,s) + f = IOBuffer() + write(f,aa) + @test aa == (PaddedRFFTArray(seekstart(f),s)).r + aa = rand(Float32,s) + f = IOBuffer() + write(f,aa) + @test aa == PaddedRFFTArray{Float32}(seekstart(f),s).r + end + end + + @testset "brfft!" begin + a = rand(Float64,(4,4)) + b = PaddedRFFTArray(a) + rfft!(b) + @test (brfft!(b) ./ 16) ≈ a + end + + @testset "FFTW MEASURE flag" begin + c = similar(b) + p = plan_rfft!(c,flags=FFTW.MEASURE) + p.pinv = plan_irfft!(c,flags=FFTW.MEASURE) + c .= b + @test c == b + @test p*c ≈ rfft!(b) + @test p\c ≈ irfft!(b) + end + + @testset "irfft! and brfft! of complex Array and rfft! of SubArray" begin + r = rand(8,6) + @test brfft!(rfft!(irfft!(rfft(r),8)),8)./48 ≈ r + r2 = rand(Float32,9,3,2) + @test brfft!(rfft!(irfft!(rfft(r2,(1,3)),9,(1,3))),9)./54 ≈ r2 + end +end #let block