Skip to content

Commit

Permalink
Remove arrays of scaling factors (#116)
Browse files Browse the repository at this point in the history
* Remove arrays of scaling factors

* Fix typo

* Test view inputs in interface tests (#117)

---------

Co-authored-by: Gaurav Arya <[email protected]>
  • Loading branch information
devmotion and gaurav-arya authored Aug 22, 2023
1 parent ee9f1b8 commit fae1170
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 27 deletions.
47 changes: 30 additions & 17 deletions ext/AbstractFFTsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,20 @@ function ChainRulesCore.rrule(::typeof(rfft), x::AbstractArray{<:Real}, dims)
halfdim = first(dims)
d = size(x, halfdim)
n = size(y, halfdim)
scale = reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
)

project_x = ChainRulesCore.ProjectTo(x)
function rfft_pullback(ȳ)
ybar = ChainRulesCore.unthunk(ȳ)
_scale = convert(typeof(ybar),scale)
= project_x(brfft(ybar ./ _scale, d, dims))
ybar_scaled = map(ybar, CartesianIndices(ybar)) do ybar_j, j
i = j[halfdim]
ybar_scaled_j = if i == 1 || (i == n && 2 * (i - 1) == d)
ybar_j
else
ybar_j / 2
end
return ybar_scaled_j
end
= project_x(brfft(ybar_scaled, d, dims))
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent()
end
return y, rfft_pullback
Expand Down Expand Up @@ -74,16 +78,20 @@ function ChainRulesCore.rrule(::typeof(irfft), x::AbstractArray, d::Int, dims)
n = size(x, halfdim)
invN = AbstractFFTs.normalization(y, dims)
twoinvN = 2 * invN
scale = reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? invN : twoinvN for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
)

project_x = ChainRulesCore.ProjectTo(x)
function irfft_pullback(ȳ)
ybar = ChainRulesCore.unthunk(ȳ)
_scale = convert(typeof(ybar),scale)
= project_x(_scale .* rfft(real.(ybar), dims))
x̄_scaled = rfft(real.(ybar), dims)
= project_x(map(x̄_scaled, CartesianIndices(x̄_scaled)) do x̄_scaled_j, j
i = j[halfdim]
x̄_j = if i == 1 || (i == n && 2 * (i - 1) == d)
invN * x̄_scaled_j
else
twoinvN * x̄_scaled_j
end
return x̄_j
end)
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
end
return y, irfft_pullback
Expand Down Expand Up @@ -115,14 +123,19 @@ function ChainRulesCore.rrule(::typeof(brfft), x::AbstractArray, d::Int, dims)
# compute scaling factors
halfdim = first(dims)
n = size(x, halfdim)
scale = reshape(
[i == 1 || (i == n && 2 * (i - 1) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == first(dims) ? n : 1, Val(ndims(x))),
)

project_x = ChainRulesCore.ProjectTo(x)
function brfft_pullback(ȳ)
= project_x(scale .* rfft(real.(ChainRulesCore.unthunk(ȳ)), dims))
x̄_scaled = rfft(real.(ChainRulesCore.unthunk(ȳ)), dims)
= project_x(map(x̄_scaled, CartesianIndices(x̄_scaled)) do x̄_scaled_j, j
i = j[halfdim]
x̄_j = if i == 1 || (i == n && 2 * (i - 1) == d)
x̄_scaled_j
else
2 * x̄_scaled_j
end
return x̄_j
end)
return ChainRulesCore.NoTangent(), x̄, ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent()
end
return y, brfft_pullback
Expand Down
2 changes: 2 additions & 0 deletions ext/AbstractFFTsTestExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ function TestUtils.test_plan(P::AbstractFFTs.Plan, x::AbstractArray, x_transform
_x_out = similar(P * _copy(x))
@test mul!(_x_out, P, _copy(x)) x_transformed
@test _x_out x_transformed
@test P * view(_copy(x), axes(x)...) x_transformed # test view input
else
_x = copy(x)
@test P * _copy(_x) x_transformed
Expand All @@ -85,6 +86,7 @@ function TestUtils.test_plan_adjoint(P::AbstractFFTs.Plan, x::AbstractArray; rea
@test _component_dot(y, P * _copy(x)) _component_dot(P' * _copy(y), x)
@test _component_dot(x, P \ _copy(y)) _component_dot(P' \ _copy(x), y)
end
@test P' * view(_copy(y), axes(y)...) P' * _copy(y) # test view input (AbstractFFTs.jl#112)
@test_throws MethodError mul!(x, P', y)
end

Expand Down
31 changes: 21 additions & 10 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -702,11 +702,16 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::RFFTAdjointStyle) where {T<
halfdim = first(dims)
d = size(p, halfdim)
n = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? N : 2 * N for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return p \ (x ./ convert(typeof(x), scale))
y = map(x, CartesianIndices(x)) do xj, j
i = j[halfdim]
yj = if i == 1 || (i == n && 2 * (i - 1) == d)
xj / N
else
xj / (2 * N)
end
return yj
end
return p \ y
end

function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T}
Expand All @@ -715,11 +720,17 @@ function adjoint_mul(p::Plan{T}, x::AbstractArray, ::IRFFTAdjointStyle) where {T
halfdim = first(dims)
n = size(p, halfdim)
d = output_size(p, halfdim)
scale = reshape(
[(i == 1 || (i == n && 2 * (i - 1)) == d) ? 1 : 2 for i in 1:n],
ntuple(i -> i == halfdim ? n : 1, Val(ndims(x)))
)
return (convert(typeof(x), scale) ./ N) .* (p \ x)
y = p \ x
z = map(y, CartesianIndices(y)) do yj, j
i = j[halfdim]
zj = if i == 1 || (i == n && 2 * (i - 1) == d)
yj / N
else
2 * yj / N
end
return zj
end
return z
end

adjoint_mul(p::Plan, x::AbstractArray, ::UnitaryAdjointStyle) = p \ x
Expand Down

0 comments on commit fae1170

Please sign in to comment.