diff --git a/ext/AbstractFFTsChainRulesCoreExt.jl b/ext/AbstractFFTsChainRulesCoreExt.jl index 5ab5d2e..0741584 100644 --- a/ext/AbstractFFTsChainRulesCoreExt.jl +++ b/ext/AbstractFFTsChainRulesCoreExt.jl @@ -30,16 +30,20 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims) halfdim = first(dims) d = size(x, halfdim) n = size(y, halfdim) - scale = reshape( - [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) project_x = ChainRulesCore.ProjectTo(x) function rfft_pullback(ȳ) ybar = ChainRulesCore.unthunk(ȳ) - _scale = convert(typeof(ybar),scale) - x̄ = project_x(brfft(ybar ./ _scale, d, dims)) + ybar_scaled = map(ybar, CartesianIndices(ybar)) do ybar_j, j + i = j[halfdim] + ybar_scaled_j = if i == 1 || (i == n && 2 * (i - 1) == d) + ybar_j + else + ybar_j / 2 + end + return ybar_scaled_j + end + x̄ = project_x(brfft(ybar_scaled, d, dims)) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent() end return y, rfft_pullback @@ -74,16 +78,20 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims) n = size(x, halfdim) invN = AbstractFFTs.normalization(y, dims) twoinvN = 2 * invN - scale = reshape( - [i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) project_x = ChainRulesCore.ProjectTo(x) function irfft_pullback(ȳ) ybar = ChainRulesCore.unthunk(ȳ) - _scale = convert(typeof(ybar),scale) - x̄ = project_x(_scale .* rfft(real.(ybar), dims)) + x̄_scaled = rfft(real.(ybar), dims) + x̄ = project_x(map(x̄_scaled, CartesianIndices(x̄_scaled)) do x̄_scaled_j, j + i = j[halfdim] + x̄_j = if i == 1 || (i == n && 2 * (i - 1) == d) + invN * x̄_scaled_j + else + twoinvN * x̄_scaled_j + end + return x̄_j + end) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() end return y, irfft_pullback @@ -115,14 +123,19 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims) # compute scaling factors halfdim = first(dims) n = size(x, halfdim) - scale = reshape( - [i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))), - ) project_x = ChainRulesCore.ProjectTo(x) function brfft_pullback(ȳ) - x̄ = project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)) + x̄_scaled = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims) + x̄ = project_x(map(x̄_scaled, CartesianIndices(x̄_scaled)) do x̄_scaled_j, j + i = j[halfdim] + x̄_j = if i == 1 || (i == n && 2 * (i - 1) == d) + x̄_scaled_j + else + 2 * x̄_scaled_j + end + return x̄_j + end) return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent() end return y, brfft_pullback diff --git a/ext/AbstractFFTsTestExt.jl b/ext/AbstractFFTsTestExt.jl index ccea93a..a3f427d 100644 --- a/ext/AbstractFFTsTestExt.jl +++ b/ext/AbstractFFTsTestExt.jl @@ -60,6 +60,7 @@ function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transform _x_out = similar(P * _copy(x)) @test mul!(_x_out, P, _copy(x)) ≈ x_transformed @test _x_out ≈ x_transformed + @test P * view(_copy(x), axes(x)...) ≈ x_transformed # test view input else _x = copy(x) @test P * _copy(_x) ≈ x_transformed @@ -85,6 +86,7 @@ function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; rea @test _component_dot(y, P * _copy(x)) ≈ _component_dot(P' * _copy(y), x) @test _component_dot(x, P \ _copy(y)) ≈ _component_dot(P' \ _copy(x), y) end + @test P' * view(_copy(y), axes(y)...) ≈ P' * _copy(y) # test view input (AbstractFFTs.jl#112) @test_throws MethodError mul!(x, P', y) end diff --git a/src/definitions.jl b/src/definitions.jl index 5dc703f..943718f 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -702,11 +702,16 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T< halfdim = first(dims) d = size(p, halfdim) n = output_size(p, halfdim) - scale = reshape( - [(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n], - ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) - ) - return p \ (x ./ convert(typeof(x), scale)) + y = map(x, CartesianIndices(x)) do xj, j + i = j[halfdim] + yj = if i == 1 || (i == n && 2 * (i - 1) == d) + xj / N + else + xj / (2 * N) + end + return yj + end + return p \ y end function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T} @@ -715,11 +720,17 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T halfdim = first(dims) n = size(p, halfdim) d = output_size(p, halfdim) - scale = reshape( - [(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n], - ntuple(i -> i == halfdim ? n : 1, Val(ndims(x))) - ) - return (convert(typeof(x), scale) ./ N) .* (p \ x) + y = p \ x + z = map(y, CartesianIndices(y)) do yj, j + i = j[halfdim] + zj = if i == 1 || (i == n && 2 * (i - 1) == d) + yj / N + else + 2 * yj / N + end + return zj + end + return z end adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x