Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chain rules for DCT #273

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
11 changes: 11 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,22 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MKL_jll = "856f044c-d86e-5d09-b602-aeab76dc8ba7"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[extensions]
FFTWChainRulesCoreExt = "ChainRulesCore"

[compat]
ChainRulesCore = "1"
AbstractFFTs = "1.5"
FFTW_jll = "3.3.9"
MKL_jll = "2019.0.117, 2020, 2021, 2022, 2023"
Preferences = "1.2"
Reexport = "0.2, 1.0"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
42 changes: 42 additions & 0 deletions ext/FFTWChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
module FFTWChainRulesCoreExt

using FFTW
using FFTW: r2r
using ChainRulesCore

# DCT/IDCT

for (fwd, bwd) in (
(dct, idct),
(idct, dct),
)
function ChainRulesCore.frule(Δ, ::typeof(fwd), x::AbstractArray, region = 1:ndims(x))
Δx = Δ[2]
y = fwd(x, region)
Δy = fwd(Δx, region)
return y, Δy

Check warning on line 17 in ext/FFTWChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FFTWChainRulesCoreExt.jl#L13-L17

Added lines #L13 - L17 were not covered by tests
end

function ChainRulesCore.rrule(::typeof(fwd), x::AbstractArray)
project_x = ProjectTo(x)
dct_pb(Δ) = NoTangent(), project_x(bwd(unthunk(Δ)))
return fwd(x), dct_pb

Check warning on line 23 in ext/FFTWChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FFTWChainRulesCoreExt.jl#L20-L23

Added lines #L20 - L23 were not covered by tests
end

function ChainRulesCore.rrule(::typeof(fwd), x::AbstractArray, region)
project_x = ProjectTo(x)
dct_pb(Δ) = NoTangent(), project_x(bwd(unthunk(Δ), region)), NoTangent()
return fwd(x, region), dct_pb

Check warning on line 29 in ext/FFTWChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FFTWChainRulesCoreExt.jl#L26-L29

Added lines #L26 - L29 were not covered by tests
end
end

# R2R

function ChainRulesCore.frule(Δ, ::typeof(r2r), x::AbstractArray, kind, region = 1:ndims(x))
Δx = Δ[2]
y = r2r(x, kind, region)
Δy = r2r(Δx, kind, region)
return y, Δy

Check warning on line 39 in ext/FFTWChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/FFTWChainRulesCoreExt.jl#L35-L39

Added lines #L35 - L39 were not covered by tests
end

end # module
4 changes: 4 additions & 0 deletions src/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,8 @@ include("dct.jl")
include("precompile.jl")
_precompile_()

@static if !isdefined(Base, :get_extension)
include("../ext/FFTWChainRulesCoreExt.jl")
end

end # module
6 changes: 2 additions & 4 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# A bug in Julia 1.6.0's Pkg causes Preferences to be dropped during `Pkg.test()`, so we work around
# it by explicitly creating a `test/Project.toml` which will correctly communicate any preferences
# through to the child Julia process. X-ref: https://github.com/JuliaLang/Pkg.jl/issues/2500

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file was formerly a part of Julia. License is MIT: https://julialang.org/license
using FFTW
using FFTW: fftw_provider
using FFTW: fftw_provider, r2r
using AbstractFFTs: Plan, plan_inv
using Test
using LinearAlgebra
Expand Down
Loading