Skip to content

Commit

Permalink
Merge pull request #168 from lxvm/cache_quadgk
Browse files Browse the repository at this point in the history
Cache for QuadGKJL
  • Loading branch information
ChrisRackauckas authored Aug 12, 2023
2 parents 646abf1 + 28a7789 commit 6b46c42
Show file tree
Hide file tree
Showing 16 changed files with 415 additions and 258 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Distributions = "0.23, 0.24, 0.25"
ForwardDiff = "0.10"
HCubature = "1.4"
MonteCarloIntegration = "0.0.1, 0.0.2, 0.0.3"
QuadGK = "2.1"
QuadGK = "2.5"
Reexport = "0.2, 1.0"
Requires = "1"
SciMLBase = "1.70"
Expand Down
32 changes: 16 additions & 16 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true)
include("pages.jl")

makedocs(sitename = "Integrals.jl",
authors = "Chris Rackauckas",
modules = [Integrals, Integrals.SciMLBase],
clean = true, doctest = false,
strict = [
:doctest,
:linkcheck,
:parse_error,
:example_block,
# Other available options are
# :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block
],
format = Documenter.HTML(analytics = "UA-90474609-3",
assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/Integrals/stable/"),
pages = pages)
authors = "Chris Rackauckas",
modules = [Integrals, Integrals.SciMLBase],
clean = true, doctest = false,
strict = [
:doctest,
:linkcheck,
:parse_error,
:example_block,
# Other available options are
# :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block
],
format = Documenter.HTML(analytics = "UA-90474609-3",
assets = ["assets/favicon.ico"],
canonical = "https://docs.sciml.ai/Integrals/stable/"),
pages = pages)

deploydocs(repo = "github.com/SciML/Integrals.jl.git";
push_preview = true)
push_preview = true)
6 changes: 3 additions & 3 deletions docs/pages.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pages = ["index.md",
"Tutorials" => Any["tutorials/numerical_integrals.md",
"tutorials/differentiating_integrals.md"],
"tutorials/differentiating_integrals.md"],
"Basics" => Any["basics/IntegralProblem.md",
"basics/solve.md",
"basics/FAQ.md"],
"basics/solve.md",
"basics/FAQ.md"],
"Solvers" => Any["solvers/IntegralSolvers.md"],
]
2 changes: 1 addition & 1 deletion docs/src/solvers/IntegralSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The following algorithms are available:
- `VEGAS`: Uses MonteCarloIntegration.jl. Requires `nout=1`. Works only for `>1`-dimensional integrations.
- `CubatureJLh`: h-Cubature from Cubature.jl. Requires `using IntegralsCubature`.
- `CubatureJLp`: p-Cubature from Cubature.jl. Requires `using IntegralsCubature`.
- `CubaVegas`: Vegas from Cuba.jl. Requires `using IntegralsCuba`, `nout=1`.
- `CubaVegas`: Vegas from Cuba.jl. Requires `using IntegralsCuba`, `nout=1`.
- `CubaSUAVE`: SUAVE from Cuba.jl. Requires `using IntegralsCuba`.
- `CubaDivonne`: Divonne from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations.
- `CubaCuhre`: Cuhre from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations.
Expand Down
52 changes: 52 additions & 0 deletions docs/src/tutorials/caching_interface.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Integrals with Caching Interface

Often, integral solvers allocate memory or reuse quadrature rules for solving different
problems. For example, if one is going to solve the same integral for several parameters

```julia
using Integrals

prob = IntegralProblem((x, p) -> sin(x * p), 0, 1, 14.0)
alg = QuadGKJL()

solve(prob, alg)

prob = remake(prob, p = 15.0)
solve(prob, alg)
```

then it would be more efficient to allocate the heap used by `quadgk` across several calls,
shown below by directly calling the library

```julia
using QuadGK
segbuf = QuadGK.alloc_segbuf()
quadgk(x -> sin(14x), 0, 1, segbuf = segbuf)
quadgk(x -> sin(15x), 0, 1, segbuf = segbuf)
```

Integrals.jl's caching interface automates this process to reuse resources if an algorithm
supports it and if the necessary types to build the cache can be inferred from `prob`. To do
this with Integrals.jl, you simply `init` a cache, `solve!`, replace `p`, and solve again.
This uses the [SciML `init` interface](https://docs.sciml.ai/SciMLBase/stable/interfaces/Init_Solve/#init-and-the-Iterator-Interface)

```@example cache1
using Integrals
prob = IntegralProblem((x, p) -> sin(x * p), 0, 1, 14.0)
alg = QuadGKJL()
cache = init(prob, alg)
sol1 = solve!(cache)
```

```@example cache1
cache.p = 15.0
sol2 = solve!(cache)
```

The caching interface is intended for updating `p`, `lb`, `ub`, `nout`, and `batch`.
Note that the types of these variables is not allowed to change.
If it is necessary to change the integrand `f` instead of defining a new
`IntegralProblem`, consider using
[FunctionWrappers.jl](https://github.com/yuyichao/FunctionWrappers.jl).
10 changes: 5 additions & 5 deletions ext/IntegralsFastGaussQuadratureExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ function composite_gauss_legendre(f, p, lb, ub, nodes, weights, subintervals)
end

function Integrals.__solvebp_call(prob::IntegralProblem, alg::Integrals.GaussLegendre{C},
sensealg, lb, ub, p;
reltol = nothing, abstol = nothing,
maxiters = nothing) where {C}
sensealg, lb, ub, p;
reltol = nothing, abstol = nothing,
maxiters = nothing) where {C}
if isinplace(prob) || lb isa AbstractArray || ub isa AbstractArray
error("GaussLegendre only accepts one-dimensional quadrature problems.")
end
@assert prob.batch == 0
@assert prob.nout == 1
if C
val = composite_gauss_legendre(prob.f, prob.p, lb, ub,
alg.nodes, alg.weights, alg.subintervals)
alg.nodes, alg.weights, alg.subintervals)
else
val = gauss_legendre(prob.f, prob.p, lb, ub,
alg.nodes, alg.weights)
alg.nodes, alg.weights)
end
err = nothing
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
Expand Down
58 changes: 33 additions & 25 deletions ext/IntegralsForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,36 @@ isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff)
### Forward-Mode AD Intercepts

# Direct AD on solvers with QuadGK and HCubature
function Integrals.__solvebp(prob, alg::QuadGKJL, sensealg, lb, ub,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...)
function Integrals.__solvebp(cache, alg::QuadGKJL, sensealg, lb, ub,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...)
end

function Integrals.__solvebp(prob, alg::HCubatureJL, sensealg, lb, ub,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...)
function Integrals.__solvebp(cache, alg::HCubatureJL, sensealg, lb, ub,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...)
end

# Manually split for the pushforward
function Integrals.__solvebp(prob, alg, sensealg, lb, ub,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
primal = Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, ForwardDiff.value.(p);
kwargs...)
function Integrals.__solvebp(cache, alg, sensealg, lb, ub,
p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N};
kwargs...) where {T, V, P, N}
primal = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, ForwardDiff.value.(p);
kwargs...)

nout = prob.nout * P
nout = cache.nout * P

if isinplace(prob)
if isinplace(cache)
dfdp = function (out, x, p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
if prob.batch > 0
dx = similar(dualp, prob.nout, size(x, 2))
if cache.batch > 0
dx = similar(dualp, cache.nout, size(x, 2))
else
dx = similar(dualp, prob.nout)
dx = similar(dualp, cache.nout)
end
prob.f(dx, x, dualp)
cache.f(dx, x, dualp)

ys = reinterpret(ForwardDiff.Dual{T, V, P}, dx)
idx = 0
Expand All @@ -47,8 +47,8 @@ function Integrals.__solvebp(prob, alg, sensealg, lb, ub,
else
dfdp = function (x, p)
dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p)
ys = prob.f(x, dualp)
if prob.batch > 0
ys = cache.f(x, dualp)
if cache.batch > 0
out = similar(p, V, nout, size(x, 2))
else
out = similar(p, V, nout)
Expand All @@ -64,12 +64,20 @@ function Integrals.__solvebp(prob, alg, sensealg, lb, ub,
return out
end
end

rawp = copy(reinterpret(V, p))

dp_prob = IntegralProblem(dfdp, lb, ub, rawp; nout = nout, batch = prob.batch,
kwargs...)
dual = Integrals.__solvebp_call(dp_prob, alg, sensealg, lb, ub, rawp; kwargs...)
res = similar(p, prob.nout)
prob = Integrals.build_problem(cache)
dp_prob = remake(prob, f = dfdp, nout = nout, p = rawp)
# the infinity transformation was already applied to f so we don't apply it to dfdp
dp_cache = init(dp_prob,
alg;
sensealg = sensealg,
do_inf_transformation = Val(false),
cache.kwargs...)
dual = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, rawp; kwargs...)

res = similar(p, cache.nout)
partials = reinterpret(typeof(first(res).partials), dual.u)
for idx in eachindex(res)
res[idx] = ForwardDiff.Dual{T, V, P}(primal.u[idx], partials[idx])
Expand Down
43 changes: 26 additions & 17 deletions ext/IntegralsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@ else
end
ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...)

function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg, lb, ub, p;
kwargs...)
out = Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...)
function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub,
p;
kwargs...)
out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...)

function quadrature_adjoint(Δ)
y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ
if isinplace(prob)
dx = zeros(prob.nout)
_f = x -> prob.f(dx, x, p)
if isinplace(cache)
dx = zeros(cache.nout)
_f = x -> cache.f(dx, x, p)
if sensealg.vjp isa Integrals.ZygoteVJP
dfdp = function (dx, x, p)
_, back = Zygote.pullback(p) do p
_dx = Zygote.Buffer(x, prob.nout, size(x, 2))
prob.f(_dx, x, p)
_dx = Zygote.Buffer(x, cache.nout, size(x, 2))
cache.f(_dx, x, p)
copy(_dx)
end

Expand All @@ -38,11 +40,11 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg
error("TODO")
end
else
_f = x -> prob.f(x, p)
_f = x -> cache.f(x, p)
if sensealg.vjp isa Integrals.ZygoteVJP
if prob.batch > 0
if cache.batch > 0
dfdp = function (x, p)
_, back = Zygote.pullback(p -> prob.f(x, p), p)
_, back = Zygote.pullback(p -> cache.f(x, p), p)

out = zeros(length(p), size(x, 2))
z = zeros(size(x, 2))
Expand All @@ -55,7 +57,7 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg
end
else
dfdp = function (x, p)
_, back = Zygote.pullback(p -> prob.f(x, p), p)
_, back = Zygote.pullback(p -> cache.f(x, p), p)
back(y)[1]
end
end
Expand All @@ -65,12 +67,19 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg
end
end

dp_prob = remake(prob, f = dfdp, lb = lb, ub = ub, p = p, nout = length(p))
prob = Integrals.build_problem(cache)
dp_prob = remake(prob, f = dfdp, nout = length(p))
# the infinity transformation was already applied to f so we don't apply it to dfdp
dp_cache = init(dp_prob,
alg;
sensealg = sensealg,
do_inf_transformation = Val(false),
cache.kwargs...)

if p isa Number
dp = Integrals.__solvebp_call(dp_prob, alg, sensealg, lb, ub, p; kwargs...)[1]
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1]
else
dp = Integrals.__solvebp_call(dp_prob, alg, sensealg, lb, ub, p; kwargs...).u
dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u
end

if lb isa Number
Expand All @@ -79,14 +88,14 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp)
else
return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(),
NoTangent(), dp)
NoTangent(), dp)
end
end
out, quadrature_adjoint
end

Zygote.@adjoint function Zygote.literal_getproperty(sol::SciMLBase.IntegralSolution,
::Val{:u})
::Val{:u})
sol.u, Δ -> (SciMLBase.build_solution(sol.prob, sol.alg, Δ, sol.resid),)
end
end
Loading

0 comments on commit 6b46c42

Please sign in to comment.