From e1d18769ffbe59675a856a0291b7019f5600715d Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 28 May 2023 08:44:18 -0400 Subject: [PATCH 01/13] initial commit --- Project.toml | 1 + ext/IntegralsForwardDiffExt.jl | 40 ++++++++--------- ext/IntegralsZygoteExt.jl | 31 ++++++++------ src/Integrals.jl | 1 + src/common.jl | 78 ++++++++++++++++++++++++++++------ 5 files changed, 105 insertions(+), 46 deletions(-) diff --git a/Project.toml b/Project.toml index 21435c1f..013a3422 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] ChainRulesCore = "0.10.7, 1" diff --git a/ext/IntegralsForwardDiffExt.jl b/ext/IntegralsForwardDiffExt.jl index 2eff6541..8ab97743 100644 --- a/ext/IntegralsForwardDiffExt.jl +++ b/ext/IntegralsForwardDiffExt.jl @@ -1,39 +1,40 @@ module IntegralsForwardDiffExt using Integrals +using Integrals: set_f, set_p, build_problem isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) ### Forward-Mode AD Intercepts # Direct AD on solvers with QuadGK and HCubature -function Integrals.__solvebp(prob, alg::QuadGKJL, sensealg, lb, ub, +function Integrals.__solvebp(cache, alg::QuadGKJL, sensealg, lb, ub, p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}; kwargs...) where {T, V, P, N} - Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...) + Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...) end -function Integrals.__solvebp(prob, alg::HCubatureJL, sensealg, lb, ub, +function Integrals.__solvebp(cache, alg::HCubatureJL, sensealg, lb, ub, p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}; kwargs...) where {T, V, P, N} - Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...) + Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...) end # Manually split for the pushforward -function Integrals.__solvebp(prob, alg, sensealg, lb, ub, +function Integrals.__solvebp(cache, alg, sensealg, lb, ub, p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}; kwargs...) where {T, V, P, N} - primal = Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, ForwardDiff.value.(p); + primal = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, ForwardDiff.value.(p); kwargs...) - nout = prob.nout * P + nout = cache.nout * P - if isinplace(prob) + if isinplace(cache) dfdp = function (out, x, p) dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p) - if prob.batch > 0 - dx = similar(dualp, prob.nout, size(x, 2)) + if cache.batch > 0 + dx = similar(dualp, cache.nout, size(x, 2)) else - dx = similar(dualp, prob.nout) + dx = similar(dualp, cache.nout) end - prob.f(dx, x, dualp) + cache.f(dx, x, dualp) ys = reinterpret(ForwardDiff.Dual{T, V, P}, dx) idx = 0 @@ -47,8 +48,8 @@ function Integrals.__solvebp(prob, alg, sensealg, lb, ub, else dfdp = function (x, p) dualp = reinterpret(ForwardDiff.Dual{T, V, P}, p) - ys = prob.f(x, dualp) - if prob.batch > 0 + ys = cache.f(x, dualp) + if cache.batch > 0 out = similar(p, V, nout, size(x, 2)) else out = similar(p, V, nout) @@ -64,12 +65,13 @@ function Integrals.__solvebp(prob, alg, sensealg, lb, ub, return out end end + rawp = copy(reinterpret(V, p)) - dp_prob = IntegralProblem(dfdp, lb, ub, rawp; nout = nout, batch = prob.batch, - kwargs...) - dual = Integrals.__solvebp_call(dp_prob, alg, sensealg, lb, ub, rawp; kwargs...) - res = similar(p, prob.nout) + dp_cache = set_p(set_f(cache, dfdp, nout), rawp) + dual = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, rawp; kwargs...) + + res = similar(p, cache.nout) partials = reinterpret(typeof(first(res).partials), dual.u) for idx in eachindex(res) res[idx] = ForwardDiff.Dual{T, V, P}(primal.u[idx], partials[idx]) @@ -77,6 +79,6 @@ function Integrals.__solvebp(prob, alg, sensealg, lb, ub, if primal.u isa Number res = first(res) end - SciMLBase.build_solution(prob, alg, res, primal.resid) + SciMLBase.build_solution(build_problem(cache), alg, res, primal.resid) end end diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index bfa94947..56d2e4c5 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -1,5 +1,6 @@ module IntegralsZygoteExt using Integrals +using Integrals: set_f if isdefined(Base, :get_extension) using Zygote import ChainRulesCore @@ -11,19 +12,21 @@ else end ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...) -function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg, lb, ub, p; +function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub, p; kwargs...) - out = Integrals.__solvebp_call(prob, alg, sensealg, lb, ub, p; kwargs...) + + out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...) + function quadrature_adjoint(Δ) y = typeof(Δ) <: Array{<:Number, 0} ? Δ[1] : Δ - if isinplace(prob) - dx = zeros(prob.nout) - _f = x -> prob.f(dx, x, p) + if isinplace(cache) + dx = zeros(cache.nout) + _f = x -> cache.f(dx, x, p) if sensealg.vjp isa Integrals.ZygoteVJP dfdp = function (dx, x, p) _, back = Zygote.pullback(p) do p - _dx = Zygote.Buffer(x, prob.nout, size(x, 2)) - prob.f(_dx, x, p) + _dx = Zygote.Buffer(x, cache.nout, size(x, 2)) + cache.f(_dx, x, p) copy(_dx) end @@ -38,11 +41,11 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg error("TODO") end else - _f = x -> prob.f(x, p) + _f = x -> cache.f(x, p) if sensealg.vjp isa Integrals.ZygoteVJP - if prob.batch > 0 + if cache.batch > 0 dfdp = function (x, p) - _, back = Zygote.pullback(p -> prob.f(x, p), p) + _, back = Zygote.pullback(p -> cache.f(x, p), p) out = zeros(length(p), size(x, 2)) z = zeros(size(x, 2)) @@ -55,7 +58,7 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg end else dfdp = function (x, p) - _, back = Zygote.pullback(p -> prob.f(x, p), p) + _, back = Zygote.pullback(p -> cache.f(x, p), p) back(y)[1] end end @@ -65,12 +68,12 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), prob, alg, sensealg end end - dp_prob = remake(prob, f = dfdp, lb = lb, ub = ub, p = p, nout = length(p)) + dp_cache = set_f(cache, dfdp, length(p)) if p isa Number - dp = Integrals.__solvebp_call(dp_prob, alg, sensealg, lb, ub, p; kwargs...)[1] + dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1] else - dp = Integrals.__solvebp_call(dp_prob, alg, sensealg, lb, ub, p; kwargs...).u + dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...).u end if lb isa Number diff --git a/src/Integrals.jl b/src/Integrals.jl index 04552869..dd357595 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -7,6 +7,7 @@ end using Reexport, MonteCarloIntegration, QuadGK, HCubature @reexport using SciMLBase using LinearAlgebra +using Setfield include("common.jl") include("init.jl") diff --git a/src/common.jl b/src/common.jl index 8d3123fe..b4fa83fb 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,27 +1,70 @@ -struct IntegralCache{P, A, S, K, Tc} - prob::P +struct IntegralCache{iip,F,B,P,PK,A,S,K,Tc} + iip::Val{iip} + f::F + lb::B + ub::B + nout::Int + p::P + batch::Int + prob_kwargs::PK alg::A sensealg::S kwargs::K - # cache for algorithm goes here (currently unused) - cacheval::Tc - isfresh::Bool + cacheval::Tc # store alg cache here + isfresh::Bool # false => cacheval is set wrt f, true => update cacheval wrt f end -function SciMLBase.init(prob::IntegralProblem, +SciMLBase.isinplace(::IntegralCache{iip}) where iip = iip + +function set_f(cache::IntegralCache, f, nout=cache.nout) + @set! cache.f = f + @set! cache.iip = Val(isinplace(f, 3)) + @set! cache.nout = nout + @set! cache.isfresh = true + return cache +end + +function set_lb(cache::IntegralCache, lb) + @set! cache.lb = lb + return cache +end + +function set_ub(cache::IntegralCache, ub) + @set! cache.ub = ub + return cache +end + +function set_p(cache::IntegralCache, p) + @set! cache.p = p + return cache +end + +init_cacheval(::SciMLBase.AbstractIntegralAlgorithm, args...) = (nothing, true) + +function SciMLBase.init(prob::IntegralProblem{iip}, alg::SciMLBase.AbstractIntegralAlgorithm; sensealg = ReCallVJP(ZygoteVJP()), - do_inf_transformation = nothing, kwargs...) + do_inf_transformation = nothing, kwargs...) where iip checkkwargs(kwargs...) prob = transformation_if_inf(prob, do_inf_transformation) - cacheval = nothing - isfresh = true + cacheval, isfresh = init_cacheval(alg, prob) - IntegralCache{typeof(prob), + IntegralCache{iip, + typeof(prob.f), + typeof(prob.lb), + typeof(prob.p), + typeof(prob.kwargs), typeof(alg), typeof(sensealg), typeof(kwargs), - typeof(cacheval)}(prob, + typeof(cacheval)}(Val(iip), + prob.f, + prob.lb, + prob.ub, + prob.nout, + prob.p, + prob.batch, + prob.kwargs, alg, sensealg, kwargs, @@ -63,6 +106,15 @@ function SciMLBase.solve(prob::IntegralProblem, end function SciMLBase.solve!(cache::IntegralCache) - prob = cache.prob - __solvebp(prob, cache.alg, cache.sensealg, prob.lb, prob.ub, prob.p; cache.kwargs...) + __solvebp(cache, cache.alg, cache.sensealg, cache.lb, cache.ub, cache.p; cache.kwargs...) +end + +function build_problem(cache::IntegralCache{iip}) where iip + IntegralProblem{iip}(cache.f, cache.lb, cache.ub, cache.p; + nout = cache.nout, batch = cache.batch, cache.prob_kwargs...) +end + +# fallback method for existing algorithms which use no cache +function __solvebp_call(cache::IntegralCache, args...; kwargs...) + __solvebp_call(build_problem(cache), args...; kwargs...) end From 0f4ca5f5adf7e019cc6be544f5855892a92f7c7c Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 28 May 2023 08:58:44 -0400 Subject: [PATCH 02/13] add quadgk cache --- src/Integrals.jl | 17 +++++++++++++++-- src/common.jl | 10 +++++++++- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/Integrals.jl b/src/Integrals.jl index dd357595..cf427254 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -62,17 +62,30 @@ end # Give a layer to intercept with AD __solvebp(args...; kwargs...) = __solvebp_call(args...; kwargs...) -function __solvebp_call(prob::IntegralProblem, alg::QuadGKJL, sensealg, lb, ub, p; +function init_cacheval(alg::QuadGKJL, prob::IntegralProblem) + mid = (prob.lb + prob.ub) / 2 + DT = typeof(mid) + val = prob.f(mid, prob.p) # TODO: infer this type or let user pass it + RT = typeof(val) + NT = typeof(alg.norm(val)) + return (QuadGK.alloc_segbuf(DT, RT, NT), false) +end + +function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, lb, ub, p; reltol = 1e-8, abstol = 1e-8, maxiters = typemax(Int)) + prob = build_problem(cache) if isinplace(prob) || lb isa AbstractArray || ub isa AbstractArray error("QuadGKJL only accepts one-dimensional quadrature problems.") end @assert prob.batch == 0 @assert prob.nout == 1 + + cache.isfresh && throw(ArgumentError("cannot reset QuadGK cache")) + p = p f = x -> prob.f(x, p) - val, err = quadgk(f, lb, ub, + val, err = quadgk(f, lb, ub, segbuf=cache.cacheval, maxevals=maxiters, rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm) SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success) end diff --git a/src/common.jl b/src/common.jl index b4fa83fb..64c0de25 100644 --- a/src/common.jl +++ b/src/common.jl @@ -20,7 +20,10 @@ function set_f(cache::IntegralCache, f, nout=cache.nout) @set! cache.f = f @set! cache.iip = Val(isinplace(f, 3)) @set! cache.nout = nout - @set! cache.isfresh = true + prob = build_problem(cache) + cacheval, isfresh = init_cacheval(cache.alg, prob) + @set! cache.cacheval = cacheval + @set! cache.isfresh = isfresh return cache end @@ -29,6 +32,7 @@ function set_lb(cache::IntegralCache, lb) return cache end +# since types of lb and ub are constrained, we do not need to refresh cache function set_ub(cache::IntegralCache, ub) @set! cache.ub = ub return cache @@ -36,6 +40,10 @@ end function set_p(cache::IntegralCache, p) @set! cache.p = p + prob = build_problem(cache) + cacheval, isfresh = init_cacheval(cache.alg, prob) + @set! cache.cacheval = cacheval + @set! cache.isfresh = isfresh return cache end From 9d0d7754092bbcaf94e6326ed604ad782a89951a Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 28 May 2023 09:01:17 -0400 Subject: [PATCH 03/13] apply format --- docs/src/solvers/IntegralSolvers.md | 2 +- ext/IntegralsZygoteExt.jl | 4 ++-- src/Integrals.jl | 2 +- src/common.jl | 15 ++++++++------- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/docs/src/solvers/IntegralSolvers.md b/docs/src/solvers/IntegralSolvers.md index e523f17b..c650d3b2 100644 --- a/docs/src/solvers/IntegralSolvers.md +++ b/docs/src/solvers/IntegralSolvers.md @@ -7,7 +7,7 @@ The following algorithms are available: - `VEGAS`: Uses MonteCarloIntegration.jl. Requires `nout=1`. Works only for `>1`-dimensional integrations. - `CubatureJLh`: h-Cubature from Cubature.jl. Requires `using IntegralsCubature`. - `CubatureJLp`: p-Cubature from Cubature.jl. Requires `using IntegralsCubature`. - - `CubaVegas`: Vegas from Cuba.jl. Requires `using IntegralsCuba`, `nout=1`. + - `CubaVegas`: Vegas from Cuba.jl. Requires `using IntegralsCuba`, `nout=1`. - `CubaSUAVE`: SUAVE from Cuba.jl. Requires `using IntegralsCuba`. - `CubaDivonne`: Divonne from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations. - `CubaCuhre`: Cuhre from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations. diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index 56d2e4c5..45827f8c 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -12,9 +12,9 @@ else end ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...) -function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub, p; +function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub, + p; kwargs...) - out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...) function quadrature_adjoint(Δ) diff --git a/src/Integrals.jl b/src/Integrals.jl index cf427254..755cfd68 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -85,7 +85,7 @@ function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, lb, ub, p p = p f = x -> prob.f(x, p) - val, err = quadgk(f, lb, ub, segbuf=cache.cacheval, maxevals=maxiters, + val, err = quadgk(f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters, rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm) SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success) end diff --git a/src/common.jl b/src/common.jl index 64c0de25..21b989be 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,4 +1,4 @@ -struct IntegralCache{iip,F,B,P,PK,A,S,K,Tc} +struct IntegralCache{iip, F, B, P, PK, A, S, K, Tc} iip::Val{iip} f::F lb::B @@ -14,9 +14,9 @@ struct IntegralCache{iip,F,B,P,PK,A,S,K,Tc} isfresh::Bool # false => cacheval is set wrt f, true => update cacheval wrt f end -SciMLBase.isinplace(::IntegralCache{iip}) where iip = iip +SciMLBase.isinplace(::IntegralCache{iip}) where {iip} = iip -function set_f(cache::IntegralCache, f, nout=cache.nout) +function set_f(cache::IntegralCache, f, nout = cache.nout) @set! cache.f = f @set! cache.iip = Val(isinplace(f, 3)) @set! cache.nout = nout @@ -52,7 +52,7 @@ init_cacheval(::SciMLBase.AbstractIntegralAlgorithm, args...) = (nothing, true) function SciMLBase.init(prob::IntegralProblem{iip}, alg::SciMLBase.AbstractIntegralAlgorithm; sensealg = ReCallVJP(ZygoteVJP()), - do_inf_transformation = nothing, kwargs...) where iip + do_inf_transformation = nothing, kwargs...) where {iip} checkkwargs(kwargs...) prob = transformation_if_inf(prob, do_inf_transformation) cacheval, isfresh = init_cacheval(alg, prob) @@ -114,12 +114,13 @@ function SciMLBase.solve(prob::IntegralProblem, end function SciMLBase.solve!(cache::IntegralCache) - __solvebp(cache, cache.alg, cache.sensealg, cache.lb, cache.ub, cache.p; cache.kwargs...) + __solvebp(cache, cache.alg, cache.sensealg, cache.lb, cache.ub, cache.p; + cache.kwargs...) end -function build_problem(cache::IntegralCache{iip}) where iip +function build_problem(cache::IntegralCache{iip}) where {iip} IntegralProblem{iip}(cache.f, cache.lb, cache.ub, cache.p; - nout = cache.nout, batch = cache.batch, cache.prob_kwargs...) + nout = cache.nout, batch = cache.batch, cache.prob_kwargs...) end # fallback method for existing algorithms which use no cache From 5298836e133f41b3de4154cba0e1e9fbceb620da Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 4 Jun 2023 11:02:13 -0400 Subject: [PATCH 04/13] bump quadgk compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 013a3422..9d9aa4ca 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,7 @@ Distributions = "0.23, 0.24, 0.25" ForwardDiff = "0.10" HCubature = "1.4" MonteCarloIntegration = "0.0.1, 0.0.2, 0.0.3" -QuadGK = "2.1" +QuadGK = "2.5" Reexport = "0.2, 1.0" Requires = "1" SciMLBase = "1.70" From 5a9aa5f5cf35ea989aeb9f9ac0507eefacf4e940 Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 4 Jun 2023 12:52:01 -0400 Subject: [PATCH 05/13] use inference to build cache --- src/Integrals.jl | 23 +++++++----- src/common.jl | 93 +++++++++++++++++++++++++++++------------------- 2 files changed, 72 insertions(+), 44 deletions(-) diff --git a/src/Integrals.jl b/src/Integrals.jl index 755cfd68..dcd93c88 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -62,13 +62,22 @@ end # Give a layer to intercept with AD __solvebp(args...; kwargs...) = __solvebp_call(args...; kwargs...) -function init_cacheval(alg::QuadGKJL, prob::IntegralProblem) - mid = (prob.lb + prob.ub) / 2 +function quadgk_prob_types(f, lb, ub, p, nrm) + mid = (lb + ub) / 2 DT = typeof(mid) - val = prob.f(mid, prob.p) # TODO: infer this type or let user pass it - RT = typeof(val) - NT = typeof(alg.norm(val)) - return (QuadGK.alloc_segbuf(DT, RT, NT), false) + RT = Base.promote_op(f, DT, typeof(p)) + NT = Base.promote_op(nrm, RT) + return DT, RT, NT +end +function init_cacheval(alg::QuadGKJL, prob::IntegralProblem) + DT, RT, NT = quadgk_prob_types(prob.f, prob.lb, prob.ub, prob.p, alg.norm) + return (isconcretetype(RT) ? QuadGK.alloc_segbuf(DT, RT, NT) : nothing) +end +function refresh_cacheval(cacheval, alg::QuadGKJL, prob) + DT, RT, NT = quadgk_prob_types(prob.f, prob.lb, prob.ub, prob.p, alg.norm) + isconcretetype(RT) || return nothing + T = QuadGK.Segment{DT,RT,NT} + return (cacheval isa Vector{T} ? cacheval : QuadGK.alloc_segbuf(DT, RT, NT)) end function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, lb, ub, p; @@ -81,8 +90,6 @@ function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, lb, ub, p @assert prob.batch == 0 @assert prob.nout == 1 - cache.isfresh && throw(ArgumentError("cannot reset QuadGK cache")) - p = p f = x -> prob.f(x, p) val, err = quadgk(f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters, diff --git a/src/common.jl b/src/common.jl index 21b989be..a8f1d589 100644 --- a/src/common.jl +++ b/src/common.jl @@ -11,43 +11,11 @@ struct IntegralCache{iip, F, B, P, PK, A, S, K, Tc} sensealg::S kwargs::K cacheval::Tc # store alg cache here - isfresh::Bool # false => cacheval is set wrt f, true => update cacheval wrt f end SciMLBase.isinplace(::IntegralCache{iip}) where {iip} = iip -function set_f(cache::IntegralCache, f, nout = cache.nout) - @set! cache.f = f - @set! cache.iip = Val(isinplace(f, 3)) - @set! cache.nout = nout - prob = build_problem(cache) - cacheval, isfresh = init_cacheval(cache.alg, prob) - @set! cache.cacheval = cacheval - @set! cache.isfresh = isfresh - return cache -end - -function set_lb(cache::IntegralCache, lb) - @set! cache.lb = lb - return cache -end - -# since types of lb and ub are constrained, we do not need to refresh cache -function set_ub(cache::IntegralCache, ub) - @set! cache.ub = ub - return cache -end - -function set_p(cache::IntegralCache, p) - @set! cache.p = p - prob = build_problem(cache) - cacheval, isfresh = init_cacheval(cache.alg, prob) - @set! cache.cacheval = cacheval - @set! cache.isfresh = isfresh - return cache -end - -init_cacheval(::SciMLBase.AbstractIntegralAlgorithm, args...) = (nothing, true) +init_cacheval(::SciMLBase.AbstractIntegralAlgorithm, args...) = nothing function SciMLBase.init(prob::IntegralProblem{iip}, alg::SciMLBase.AbstractIntegralAlgorithm; @@ -55,7 +23,7 @@ function SciMLBase.init(prob::IntegralProblem{iip}, do_inf_transformation = nothing, kwargs...) where {iip} checkkwargs(kwargs...) prob = transformation_if_inf(prob, do_inf_transformation) - cacheval, isfresh = init_cacheval(alg, prob) + cacheval = init_cacheval(alg, prob) IntegralCache{iip, typeof(prob.f), @@ -76,10 +44,63 @@ function SciMLBase.init(prob::IntegralProblem{iip}, alg, sensealg, kwargs, - cacheval, - isfresh) + cacheval) end +refresh_cacheval(cacheval, alg, prob) = nothing + +""" + set_f(cache, f, [nout=cache.nout]) + +Return a new cache with the new integrand `f`, optionally resetting `nout` at the same time. +""" +function set_f(cache::IntegralCache, f, nout = cache.nout) + prob = remake(build_problem(cache), f=f, nout=nout) + alg = cache.alg; cacheval = cache.cacheval + # lots of type-instability hereafter + @set! cache.f = f + @set! cache.iip = Val(isinplace(f, 3)) + @set! cache.nout = nout + @set! cache.cacheval = refresh_cacheval(cacheval, alg, prob) + return cache +end + +""" + set_lb(cache, lb) + +Return a new cache with new lower limits `lb`. +""" +function set_lb(cache::IntegralCache, lb) + @set! cache.lb = lb + return cache +end + +# since types of lb and ub are constrained, we do not need to refresh cache + +""" + set_ub(cache, ub) + +Return a new cache with new lower limits `ub`. +""" +function set_ub(cache::IntegralCache, ub) + @set! cache.ub = ub + return cache +end + +""" + set_p(cache, p, [refresh=true]) + +Return a new cache with parameters `p`. +""" +function set_p(cache::IntegralCache, p) + prob = remake(build_problem(cache), p=p) + alg = cache.alg; cacheval = cache.cacheval + @set! cache.p = p + @set! cache.cacheval = refresh_cacheval(cacheval, alg, prob) + return cache +end + + # Throw error if alg is not provided, as defaults are not implemented. function SciMLBase.solve(::IntegralProblem; kwargs...) checkkwargs(kwargs...) From aa987970f46727df23cfc717da9a16b8f49a045b Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 4 Jun 2023 13:03:58 -0400 Subject: [PATCH 06/13] add tests --- test/interface_tests.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 5d10fd80..01f5bb9e 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -353,3 +353,31 @@ end relztol = 1e-3, abstol = 1e-3) end + +@testset "Caching interface" begin + lb, ub = (1.0, 3.0) + nout = 1 + dim = 1 + for alg in algs + if alg_req[alg].min_dim > 1 + continue + end + for i in 1:length(integrands) + prob = IntegralProblem(integrands[i], lb, ub) + cache = init(prob, alg, reltol = reltol, abstol = abstol) + @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 + lb = 0.5 + cache = Integrals.set_lb(cache, lb) + @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 + ub = 3.5 + cache = Integrals.set_ub(cache, ub) + @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 + p = missing # the integrands don't actually use this + cache = Integrals.set_p(cache, p) + @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 + f = (x,p) -> integrands[i](x,p) # for lack of creativity, wrap the old integrand + cache = Integrals.set_f(cache, f) + @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 + end + end +end From 36aeb75e9c700a73703e8435c453a575ec3323bc Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 4 Jun 2023 13:04:06 -0400 Subject: [PATCH 07/13] add docs --- docs/src/tutorials/caching_interface.md | 44 +++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 docs/src/tutorials/caching_interface.md diff --git a/docs/src/tutorials/caching_interface.md b/docs/src/tutorials/caching_interface.md new file mode 100644 index 00000000..3aee849b --- /dev/null +++ b/docs/src/tutorials/caching_interface.md @@ -0,0 +1,44 @@ +# Integrals with Caching Interface + +Often, integral solvers allocate memory or reuse quadrature rules for solving different +problems. For example, if one is going to perform +```julia +using Integrals + +prob = IntegralProblem((x,p) -> sin(x*p), 0, 1, 14.0) +alg = QuadGKJL() + +solve(prob, alg) + +prob = remake(prob, f=(x,p) -> cos(x*p)) +solve(prob, alg) +``` +then it would be more efficient to allocate the heap used by `quadgk` across several calls, +shown below by directly calling the library +```julia +using QuadGK +segbuf = QuadGK.alloc_segbuf() +quadgk(x -> sin(15x), 0, 1, segbuf=segbuf) +quadgk(x -> cos(15x), 0, 1, segbuf=segbuf) +``` +Integrals.jl's caching interface automates this process to reuse resources if an algorithm +supports it and if the necessary types to build the cache can be inferred from `prob`. To do +this with Integrals.jl, you simply `init` a cache, `solve`, replace `f`, and solve again. +This looks like +```@example cache1 +using Integrals + +prob = IntegralProblem((x,p) -> sin(x*p), 0, 1, 14.0) +alg = QuadGKJL() + +cache = init(prob, alg) +sol1 = solve!(cache) +``` + +```@example cache1 +cache = Integrals.set_f(cache, (x,p) -> cos(x*p)) +sol2 = solve!(cache) +``` +Similar cache-rebuilding functions are provided, including: `set_p`, `set_lb`, and `set_ub`, +each of which provides a new value of `lb`, `ub`, or `p`, respectively. When resetting the +cache, new allocations may be needed if those inferred types change. \ No newline at end of file From 4c2aebf324d3c254b4ba0437d9950ed3278b4d89 Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 4 Jun 2023 13:05:28 -0400 Subject: [PATCH 08/13] apply format --- docs/src/tutorials/caching_interface.md | 20 +++++++++++++------- src/Integrals.jl | 2 +- src/common.jl | 13 +++++++------ test/interface_tests.jl | 2 +- 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/docs/src/tutorials/caching_interface.md b/docs/src/tutorials/caching_interface.md index 3aee849b..9a6a440d 100644 --- a/docs/src/tutorials/caching_interface.md +++ b/docs/src/tutorials/caching_interface.md @@ -2,33 +2,38 @@ Often, integral solvers allocate memory or reuse quadrature rules for solving different problems. For example, if one is going to perform + ```julia using Integrals -prob = IntegralProblem((x,p) -> sin(x*p), 0, 1, 14.0) +prob = IntegralProblem((x, p) -> sin(x * p), 0, 1, 14.0) alg = QuadGKJL() solve(prob, alg) -prob = remake(prob, f=(x,p) -> cos(x*p)) +prob = remake(prob, f = (x, p) -> cos(x * p)) solve(prob, alg) ``` + then it would be more efficient to allocate the heap used by `quadgk` across several calls, shown below by directly calling the library + ```julia using QuadGK segbuf = QuadGK.alloc_segbuf() -quadgk(x -> sin(15x), 0, 1, segbuf=segbuf) -quadgk(x -> cos(15x), 0, 1, segbuf=segbuf) +quadgk(x -> sin(15x), 0, 1, segbuf = segbuf) +quadgk(x -> cos(15x), 0, 1, segbuf = segbuf) ``` + Integrals.jl's caching interface automates this process to reuse resources if an algorithm supports it and if the necessary types to build the cache can be inferred from `prob`. To do this with Integrals.jl, you simply `init` a cache, `solve`, replace `f`, and solve again. This looks like + ```@example cache1 using Integrals -prob = IntegralProblem((x,p) -> sin(x*p), 0, 1, 14.0) +prob = IntegralProblem((x, p) -> sin(x * p), 0, 1, 14.0) alg = QuadGKJL() cache = init(prob, alg) @@ -36,9 +41,10 @@ sol1 = solve!(cache) ``` ```@example cache1 -cache = Integrals.set_f(cache, (x,p) -> cos(x*p)) +cache = Integrals.set_f(cache, (x, p) -> cos(x * p)) sol2 = solve!(cache) ``` + Similar cache-rebuilding functions are provided, including: `set_p`, `set_lb`, and `set_ub`, each of which provides a new value of `lb`, `ub`, or `p`, respectively. When resetting the -cache, new allocations may be needed if those inferred types change. \ No newline at end of file +cache, new allocations may be needed if those inferred types change. diff --git a/src/Integrals.jl b/src/Integrals.jl index dcd93c88..77d9015a 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -76,7 +76,7 @@ end function refresh_cacheval(cacheval, alg::QuadGKJL, prob) DT, RT, NT = quadgk_prob_types(prob.f, prob.lb, prob.ub, prob.p, alg.norm) isconcretetype(RT) || return nothing - T = QuadGK.Segment{DT,RT,NT} + T = QuadGK.Segment{DT, RT, NT} return (cacheval isa Vector{T} ? cacheval : QuadGK.alloc_segbuf(DT, RT, NT)) end diff --git a/src/common.jl b/src/common.jl index a8f1d589..1df059d0 100644 --- a/src/common.jl +++ b/src/common.jl @@ -55,8 +55,9 @@ refresh_cacheval(cacheval, alg, prob) = nothing Return a new cache with the new integrand `f`, optionally resetting `nout` at the same time. """ function set_f(cache::IntegralCache, f, nout = cache.nout) - prob = remake(build_problem(cache), f=f, nout=nout) - alg = cache.alg; cacheval = cache.cacheval + prob = remake(build_problem(cache), f = f, nout = nout) + alg = cache.alg + cacheval = cache.cacheval # lots of type-instability hereafter @set! cache.f = f @set! cache.iip = Val(isinplace(f, 3)) @@ -93,14 +94,14 @@ end Return a new cache with parameters `p`. """ function set_p(cache::IntegralCache, p) - prob = remake(build_problem(cache), p=p) - alg = cache.alg; cacheval = cache.cacheval + prob = remake(build_problem(cache), p = p) + alg = cache.alg + cacheval = cache.cacheval @set! cache.p = p - @set! cache.cacheval = refresh_cacheval(cacheval, alg, prob) + @set! cache.cacheval = refresh_cacheval(cacheval, alg, prob) return cache end - # Throw error if alg is not provided, as defaults are not implemented. function SciMLBase.solve(::IntegralProblem; kwargs...) checkkwargs(kwargs...) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 01f5bb9e..ce75aed7 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -375,7 +375,7 @@ end p = missing # the integrands don't actually use this cache = Integrals.set_p(cache, p) @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 - f = (x,p) -> integrands[i](x,p) # for lack of creativity, wrap the old integrand + f = (x, p) -> integrands[i](x, p) # for lack of creativity, wrap the old integrand cache = Integrals.set_f(cache, f) @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 end From 31ab73ba11ed91bfaf72575184a8fbdbc7696c8c Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 4 Jun 2023 13:58:51 -0400 Subject: [PATCH 09/13] adjust quadgk inference kernel --- src/Integrals.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/Integrals.jl b/src/Integrals.jl index 77d9015a..a74e559c 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -62,10 +62,9 @@ end # Give a layer to intercept with AD __solvebp(args...; kwargs...) = __solvebp_call(args...; kwargs...) -function quadgk_prob_types(f, lb, ub, p, nrm) - mid = (lb + ub) / 2 - DT = typeof(mid) - RT = Base.promote_op(f, DT, typeof(p)) +function quadgk_prob_types(f, lb::T, ub::T, p, nrm) where {T} + DT = float(T) # we need to be careful to infer the same result as `evalrule` + RT = Base.promote_op(*, real(DT), Base.promote_op(f, DT, typeof(p))) # kernel NT = Base.promote_op(nrm, RT) return DT, RT, NT end From 03bfa85cfef4318a40050559229cccf062fab303 Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 9 Aug 2023 11:35:01 -0400 Subject: [PATCH 10/13] tweak type inference --- src/Integrals.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Integrals.jl b/src/Integrals.jl index a74e559c..120631d8 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -64,7 +64,7 @@ __solvebp(args...; kwargs...) = __solvebp_call(args...; kwargs...) function quadgk_prob_types(f, lb::T, ub::T, p, nrm) where {T} DT = float(T) # we need to be careful to infer the same result as `evalrule` - RT = Base.promote_op(*, real(DT), Base.promote_op(f, DT, typeof(p))) # kernel + RT = Base.promote_op(*, DT, Base.promote_op(f, DT, typeof(p))) # kernel NT = Base.promote_op(nrm, RT) return DT, RT, NT end From 05ab510480e0501a2746ce7041f6efeb7514937f Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 9 Aug 2023 12:01:59 -0400 Subject: [PATCH 11/13] apply format --- docs/make.jl | 32 ++-- docs/pages.jl | 6 +- ext/IntegralsFastGaussQuadratureExt.jl | 10 +- ext/IntegralsForwardDiffExt.jl | 14 +- ext/IntegralsZygoteExt.jl | 8 +- lib/IntegralsCuba/src/IntegralsCuba.jl | 56 +++---- .../src/IntegralsCubature.jl | 88 +++++------ src/Integrals.jl | 26 ++-- src/common.jl | 52 +++---- src/init.jl | 12 +- test/inf_integral_tests.jl | 4 +- test/interface_tests.jl | 142 +++++++++--------- test/runtests.jl | 16 +- 13 files changed, 242 insertions(+), 224 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index a2c90ffb..a63fbdcc 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -6,21 +6,21 @@ cp("./docs/Project.toml", "./docs/src/assets/Project.toml", force = true) include("pages.jl") makedocs(sitename = "Integrals.jl", - authors = "Chris Rackauckas", - modules = [Integrals, Integrals.SciMLBase], - clean = true, doctest = false, - strict = [ - :doctest, - :linkcheck, - :parse_error, - :example_block, - # Other available options are - # :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block - ], - format = Documenter.HTML(analytics = "UA-90474609-3", - assets = ["assets/favicon.ico"], - canonical = "https://docs.sciml.ai/Integrals/stable/"), - pages = pages) + authors = "Chris Rackauckas", + modules = [Integrals, Integrals.SciMLBase], + clean = true, doctest = false, + strict = [ + :doctest, + :linkcheck, + :parse_error, + :example_block, + # Other available options are + # :autodocs_block, :cross_references, :docs_block, :eval_block, :example_block, :footnote, :meta_block, :missing_docs, :setup_block + ], + format = Documenter.HTML(analytics = "UA-90474609-3", + assets = ["assets/favicon.ico"], + canonical = "https://docs.sciml.ai/Integrals/stable/"), + pages = pages) deploydocs(repo = "github.com/SciML/Integrals.jl.git"; - push_preview = true) + push_preview = true) diff --git a/docs/pages.jl b/docs/pages.jl index 73a377b7..76c058fc 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,8 +1,8 @@ pages = ["index.md", "Tutorials" => Any["tutorials/numerical_integrals.md", - "tutorials/differentiating_integrals.md"], + "tutorials/differentiating_integrals.md"], "Basics" => Any["basics/IntegralProblem.md", - "basics/solve.md", - "basics/FAQ.md"], + "basics/solve.md", + "basics/FAQ.md"], "Solvers" => Any["solvers/IntegralSolvers.md"], ] diff --git a/ext/IntegralsFastGaussQuadratureExt.jl b/ext/IntegralsFastGaussQuadratureExt.jl index 9c84745a..7dd1a55a 100644 --- a/ext/IntegralsFastGaussQuadratureExt.jl +++ b/ext/IntegralsFastGaussQuadratureExt.jl @@ -30,9 +30,9 @@ function composite_gauss_legendre(f, p, lb, ub, nodes, weights, subintervals) end function Integrals.__solvebp_call(prob::IntegralProblem, alg::Integrals.GaussLegendre{C}, - sensealg, lb, ub, p; - reltol = nothing, abstol = nothing, - maxiters = nothing) where {C} + sensealg, lb, ub, p; + reltol = nothing, abstol = nothing, + maxiters = nothing) where {C} if isinplace(prob) || lb isa AbstractArray || ub isa AbstractArray error("GaussLegendre only accepts one-dimensional quadrature problems.") end @@ -40,10 +40,10 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::Integrals.GaussLeg @assert prob.nout == 1 if C val = composite_gauss_legendre(prob.f, prob.p, lb, ub, - alg.nodes, alg.weights, alg.subintervals) + alg.nodes, alg.weights, alg.subintervals) else val = gauss_legendre(prob.f, prob.p, lb, ub, - alg.nodes, alg.weights) + alg.nodes, alg.weights) end err = nothing SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success) diff --git a/ext/IntegralsForwardDiffExt.jl b/ext/IntegralsForwardDiffExt.jl index 8ab97743..f87b4103 100644 --- a/ext/IntegralsForwardDiffExt.jl +++ b/ext/IntegralsForwardDiffExt.jl @@ -6,23 +6,23 @@ isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) # Direct AD on solvers with QuadGK and HCubature function Integrals.__solvebp(cache, alg::QuadGKJL, sensealg, lb, ub, - p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}; - kwargs...) where {T, V, P, N} + p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}; + kwargs...) where {T, V, P, N} Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...) end function Integrals.__solvebp(cache, alg::HCubatureJL, sensealg, lb, ub, - p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}; - kwargs...) where {T, V, P, N} + p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}; + kwargs...) where {T, V, P, N} Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...) end # Manually split for the pushforward function Integrals.__solvebp(cache, alg, sensealg, lb, ub, - p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}; - kwargs...) where {T, V, P, N} + p::AbstractArray{<:ForwardDiff.Dual{T, V, P}, N}; + kwargs...) where {T, V, P, N} primal = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, ForwardDiff.value.(p); - kwargs...) + kwargs...) nout = cache.nout * P diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index 45827f8c..3d3185b2 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -13,8 +13,8 @@ end ChainRulesCore.@non_differentiable Integrals.checkkwargs(kwargs...) function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, sensealg, lb, ub, - p; - kwargs...) + p; + kwargs...) out = Integrals.__solvebp_call(cache, alg, sensealg, lb, ub, p; kwargs...) function quadrature_adjoint(Δ) @@ -82,14 +82,14 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), dlb, dub, dp) else return (NoTangent(), NoTangent(), NoTangent(), NoTangent(), NoTangent(), - NoTangent(), dp) + NoTangent(), dp) end end out, quadrature_adjoint end Zygote.@adjoint function Zygote.literal_getproperty(sol::SciMLBase.IntegralSolution, - ::Val{:u}) + ::Val{:u}) sol.u, Δ -> (SciMLBase.build_solution(sol.prob, sol.alg, Δ, sol.resid),) end end diff --git a/lib/IntegralsCuba/src/IntegralsCuba.jl b/lib/IntegralsCuba/src/IntegralsCuba.jl index b7eff607..26bfd7a5 100644 --- a/lib/IntegralsCuba/src/IntegralsCuba.jl +++ b/lib/IntegralsCuba/src/IntegralsCuba.jl @@ -116,26 +116,26 @@ struct CubaCuhre <: AbstractCubaAlgorithm end function CubaVegas(; flags = 0, seed = 0, minevals = 0, nstart = 1000, nincrease = 500, - gridno = 0) + gridno = 0) CubaVegas(flags, seed, minevals, nstart, nincrease, gridno) end function CubaSUAVE(; flags = 0, seed = 0, minevals = 0, nnew = 1000, nmin = 2, - flatness = 25.0) + flatness = 25.0) CubaSUAVE(flags, seed, minevals, nnew, nmin, flatness) end function CubaDivonne(; flags = 0, seed = 0, minevals = 0, - key1 = 47, key2 = 1, key3 = 1, maxpass = 5, border = 0.0, - maxchisq = 10.0, mindeviation = 0.25) + key1 = 47, key2 = 1, key3 = 1, maxpass = 5, border = 0.0, + maxchisq = 10.0, mindeviation = 0.25) CubaDivonne(flags, seed, minevals, key1, key2, key3, maxpass, border, maxchisq, - mindeviation) + mindeviation) end CubaCuhre(; flags = 0, minevals = 0, key = 0) = CubaCuhre(flags, minevals, key) function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgorithm, - sensealg, - lb, ub, p; - reltol = 1e-8, abstol = 1e-8, - maxiters = alg isa CubaSUAVE ? 1000000 : typemax(Int)) + sensealg, + lb, ub, p; + reltol = 1e-8, abstol = 1e-8, + maxiters = alg isa CubaSUAVE ? 1000000 : typemax(Int)) @assert maxiters>=1000 "maxiters for $alg should be larger than 1000" if lb isa Number && prob.batch == 0 _x = Float64[lb] @@ -208,30 +208,30 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgori if alg isa CubaVegas out = Cuba.vegas(f, ndim, prob.nout; rtol = reltol, - atol = abstol, nvec = nvec, - maxevals = maxiters, - flags = alg.flags, seed = alg.seed, minevals = alg.minevals, - nstart = alg.nstart, nincrease = alg.nincrease, - gridno = alg.gridno) + atol = abstol, nvec = nvec, + maxevals = maxiters, + flags = alg.flags, seed = alg.seed, minevals = alg.minevals, + nstart = alg.nstart, nincrease = alg.nincrease, + gridno = alg.gridno) elseif alg isa CubaSUAVE out = Cuba.suave(f, ndim, prob.nout; rtol = reltol, - atol = abstol, nvec = nvec, - maxevals = maxiters, - flags = alg.flags, seed = alg.seed, minevals = alg.minevals, - nnew = alg.nnew, nmin = alg.nmin, flatness = alg.flatness) + atol = abstol, nvec = nvec, + maxevals = maxiters, + flags = alg.flags, seed = alg.seed, minevals = alg.minevals, + nnew = alg.nnew, nmin = alg.nmin, flatness = alg.flatness) elseif alg isa CubaDivonne out = Cuba.divonne(f, ndim, prob.nout; rtol = reltol, - atol = abstol, nvec = nvec, - maxevals = maxiters, - flags = alg.flags, seed = alg.seed, minevals = alg.minevals, - key1 = alg.key1, key2 = alg.key2, key3 = alg.key3, - maxpass = alg.maxpass, border = alg.border, - maxchisq = alg.maxchisq, mindeviation = alg.mindeviation) + atol = abstol, nvec = nvec, + maxevals = maxiters, + flags = alg.flags, seed = alg.seed, minevals = alg.minevals, + key1 = alg.key1, key2 = alg.key2, key3 = alg.key3, + maxpass = alg.maxpass, border = alg.border, + maxchisq = alg.maxchisq, mindeviation = alg.mindeviation) elseif alg isa CubaCuhre out = Cuba.cuhre(f, ndim, prob.nout; rtol = reltol, - atol = abstol, nvec = nvec, - maxevals = maxiters, - flags = alg.flags, minevals = alg.minevals, key = alg.key) + atol = abstol, nvec = nvec, + maxevals = maxiters, + flags = alg.flags, minevals = alg.minevals, key = alg.key) end if isinplace(prob) || prob.batch != 0 @@ -245,7 +245,7 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::AbstractCubaAlgori end SciMLBase.build_solution(prob, alg, val, out.error, - chi = out.probability, retcode = ReturnCode.Success) + chi = out.probability, retcode = ReturnCode.Success) end export CubaVegas, CubaSUAVE, CubaDivonne, CubaCuhre diff --git a/lib/IntegralsCubature/src/IntegralsCubature.jl b/lib/IntegralsCubature/src/IntegralsCubature.jl index a845de31..120fdb62 100644 --- a/lib/IntegralsCubature/src/IntegralsCubature.jl +++ b/lib/IntegralsCubature/src/IntegralsCubature.jl @@ -48,10 +48,10 @@ end CubatureJLp() = CubatureJLp(Cubature.INDIVIDUAL) function Integrals.__solvebp_call(prob::IntegralProblem, - alg::AbstractCubatureJLAlgorithm, - sensealg, lb, ub, p; - reltol = 1e-8, abstol = 1e-8, - maxiters = typemax(Int)) + alg::AbstractCubatureJLAlgorithm, + sensealg, lb, ub, p; + reltol = 1e-8, abstol = 1e-8, + maxiters = typemax(Int)) nout = prob.nout if nout == 1 if prob.batch == 0 @@ -64,23 +64,23 @@ function Integrals.__solvebp_call(prob::IntegralProblem, if lb isa Number if alg isa CubatureJLh _val, err = Cubature.hquadrature(f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters) + reltol = reltol, abstol = abstol, + maxevals = maxiters) else _val, err = Cubature.pquadrature(f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters) + reltol = reltol, abstol = abstol, + maxevals = maxiters) end val = prob.f(lb, p) isa Number ? _val : [_val] else if alg isa CubatureJLh _val, err = Cubature.hcubature(f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters) + reltol = reltol, abstol = abstol, + maxevals = maxiters) else _val, err = Cubature.pcubature(f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters) + reltol = reltol, abstol = abstol, + maxevals = maxiters) end if isinplace(prob) || !isa(prob.f(lb, p), Number) @@ -112,22 +112,22 @@ function Integrals.__solvebp_call(prob::IntegralProblem, if lb isa Number if alg isa CubatureJLh _val, err = Cubature.hquadrature_v(f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters) + reltol = reltol, abstol = abstol, + maxevals = maxiters) else _val, err = Cubature.pquadrature_v(f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters) + reltol = reltol, abstol = abstol, + maxevals = maxiters) end else if alg isa CubatureJLh _val, err = Cubature.hcubature_v(f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters) + reltol = reltol, abstol = abstol, + maxevals = maxiters) else _val, err = Cubature.pcubature_v(f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters) + reltol = reltol, abstol = abstol, + maxevals = maxiters) end end val = _val isa Number ? [_val] : _val @@ -142,26 +142,26 @@ function Integrals.__solvebp_call(prob::IntegralProblem, if lb isa Number if alg isa CubatureJLh val, err = Cubature.hquadrature(nout, f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters, - error_norm = alg.error_norm) + reltol = reltol, abstol = abstol, + maxevals = maxiters, + error_norm = alg.error_norm) else val, err = Cubature.pquadrature(nout, f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters, - error_norm = alg.error_norm) + reltol = reltol, abstol = abstol, + maxevals = maxiters, + error_norm = alg.error_norm) end else if alg isa CubatureJLh val, err = Cubature.hcubature(nout, f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters, - error_norm = alg.error_norm) + reltol = reltol, abstol = abstol, + maxevals = maxiters, + error_norm = alg.error_norm) else val, err = Cubature.pcubature(nout, f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters, - error_norm = alg.error_norm) + reltol = reltol, abstol = abstol, + maxevals = maxiters, + error_norm = alg.error_norm) end end else @@ -178,26 +178,26 @@ function Integrals.__solvebp_call(prob::IntegralProblem, if lb isa Number if alg isa CubatureJLh val, err = Cubature.hquadrature_v(nout, f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters, - error_norm = alg.error_norm) + reltol = reltol, abstol = abstol, + maxevals = maxiters, + error_norm = alg.error_norm) else val, err = Cubature.pquadrature_v(nout, f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters, - error_norm = alg.error_norm) + reltol = reltol, abstol = abstol, + maxevals = maxiters, + error_norm = alg.error_norm) end else if alg isa CubatureJLh val, err = Cubature.hcubature_v(nout, f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters, - error_norm = alg.error_norm) + reltol = reltol, abstol = abstol, + maxevals = maxiters, + error_norm = alg.error_norm) else val, err = Cubature.pcubature_v(nout, f, lb, ub; - reltol = reltol, abstol = abstol, - maxevals = maxiters, - error_norm = alg.error_norm) + reltol = reltol, abstol = abstol, + maxevals = maxiters, + error_norm = alg.error_norm) end end end diff --git a/src/Integrals.jl b/src/Integrals.jl index 120631d8..d1d80fc1 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -80,8 +80,8 @@ function refresh_cacheval(cacheval, alg::QuadGKJL, prob) end function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, lb, ub, p; - reltol = 1e-8, abstol = 1e-8, - maxiters = typemax(Int)) + reltol = 1e-8, abstol = 1e-8, + maxiters = typemax(Int)) prob = build_problem(cache) if isinplace(prob) || lb isa AbstractArray || ub isa AbstractArray error("QuadGKJL only accepts one-dimensional quadrature problems.") @@ -92,13 +92,13 @@ function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, lb, ub, p p = p f = x -> prob.f(x, p) val, err = quadgk(f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters, - rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm) + rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm) SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success) end function __solvebp_call(prob::IntegralProblem, alg::HCubatureJL, sensealg, lb, ub, p; - reltol = 1e-8, abstol = 1e-8, - maxiters = typemax(Int)) + reltol = 1e-8, abstol = 1e-8, + maxiters = typemax(Int)) p = p if isinplace(prob) @@ -111,19 +111,19 @@ function __solvebp_call(prob::IntegralProblem, alg::HCubatureJL, sensealg, lb, u if lb isa Number val, err = hquadrature(f, lb, ub; - rtol = reltol, atol = abstol, - maxevals = maxiters, norm = alg.norm, initdiv = alg.initdiv) + rtol = reltol, atol = abstol, + maxevals = maxiters, norm = alg.norm, initdiv = alg.initdiv) else val, err = hcubature(f, lb, ub; - rtol = reltol, atol = abstol, - maxevals = maxiters, norm = alg.norm, initdiv = alg.initdiv) + rtol = reltol, atol = abstol, + maxevals = maxiters, norm = alg.norm, initdiv = alg.initdiv) end SciMLBase.build_solution(prob, HCubatureJL(), val, err, retcode = ReturnCode.Success) end function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p; - reltol = 1e-8, abstol = 1e-8, - maxiters = typemax(Int)) + reltol = 1e-8, abstol = 1e-8, + maxiters = typemax(Int)) p = p @assert prob.nout == 1 if prob.batch == 0 @@ -143,8 +143,8 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p; end ncalls = prob.batch == 0 ? alg.ncalls : prob.batch val, err, chi = vegas(f, lb, ub, rtol = reltol, atol = abstol, - maxiter = maxiters, nbins = alg.nbins, debug = alg.debug, - ncalls = ncalls, batch = prob.batch != 0) + maxiter = maxiters, nbins = alg.nbins, debug = alg.debug, + ncalls = ncalls, batch = prob.batch != 0) SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success) end diff --git a/src/common.jl b/src/common.jl index 1df059d0..b58defc8 100644 --- a/src/common.jl +++ b/src/common.jl @@ -18,33 +18,33 @@ SciMLBase.isinplace(::IntegralCache{iip}) where {iip} = iip init_cacheval(::SciMLBase.AbstractIntegralAlgorithm, args...) = nothing function SciMLBase.init(prob::IntegralProblem{iip}, - alg::SciMLBase.AbstractIntegralAlgorithm; - sensealg = ReCallVJP(ZygoteVJP()), - do_inf_transformation = nothing, kwargs...) where {iip} + alg::SciMLBase.AbstractIntegralAlgorithm; + sensealg = ReCallVJP(ZygoteVJP()), + do_inf_transformation = nothing, kwargs...) where {iip} checkkwargs(kwargs...) prob = transformation_if_inf(prob, do_inf_transformation) cacheval = init_cacheval(alg, prob) IntegralCache{iip, - typeof(prob.f), - typeof(prob.lb), - typeof(prob.p), - typeof(prob.kwargs), - typeof(alg), - typeof(sensealg), - typeof(kwargs), - typeof(cacheval)}(Val(iip), - prob.f, - prob.lb, - prob.ub, - prob.nout, - prob.p, - prob.batch, - prob.kwargs, - alg, - sensealg, - kwargs, - cacheval) + typeof(prob.f), + typeof(prob.lb), + typeof(prob.p), + typeof(prob.kwargs), + typeof(alg), + typeof(sensealg), + typeof(kwargs), + typeof(cacheval)}(Val(iip), + prob.f, + prob.lb, + prob.ub, + prob.nout, + prob.p, + prob.batch, + prob.kwargs, + alg, + sensealg, + kwargs, + cacheval) end refresh_cacheval(cacheval, alg, prob) = nothing @@ -130,19 +130,19 @@ These common arguments are: - `reltol` (relative tolerance in changes of the objective value) """ function SciMLBase.solve(prob::IntegralProblem, - alg::SciMLBase.AbstractIntegralAlgorithm; - kwargs...) + alg::SciMLBase.AbstractIntegralAlgorithm; + kwargs...) solve!(init(prob, alg; kwargs...)) end function SciMLBase.solve!(cache::IntegralCache) __solvebp(cache, cache.alg, cache.sensealg, cache.lb, cache.ub, cache.p; - cache.kwargs...) + cache.kwargs...) end function build_problem(cache::IntegralCache{iip}) where {iip} IntegralProblem{iip}(cache.f, cache.lb, cache.ub, cache.p; - nout = cache.nout, batch = cache.batch, cache.prob_kwargs...) + nout = cache.nout, batch = cache.batch, cache.prob_kwargs...) end # fallback method for existing algorithms which use no cache diff --git a/src/init.jl b/src/init.jl index 3ba7949e..fb12ef32 100644 --- a/src/init.jl +++ b/src/init.jl @@ -1,7 +1,13 @@ @static if !isdefined(Base, :get_extension) function __init__() - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin include("../ext/IntegralsForwardDiffExt.jl") end - @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin include("../ext/IntegralsZygoteExt.jl") end - @require FastGaussQuadrature="442a2c76-b920-505d-bb47-c5924d526838" begin include("../ext/IntegralsFastGaussQuadratureExt.jl") end + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin + include("../ext/IntegralsForwardDiffExt.jl") + end + @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/IntegralsZygoteExt.jl") + end + @require FastGaussQuadrature="442a2c76-b920-505d-bb47-c5924d526838" begin + include("../ext/IntegralsFastGaussQuadratureExt.jl") + end end end diff --git a/test/inf_integral_tests.jl b/test/inf_integral_tests.jl index 8742cacc..b51ea690 100644 --- a/test/inf_integral_tests.jl +++ b/test/inf_integral_tests.jl @@ -57,5 +57,5 @@ prob = IntegralProblem(m2, SVector(-Inf, -Inf), SVector(Inf, Inf)) prob = @test_nowarn @inferred Integrals.transformation_if_inf(prob, Val(true)) @test_nowarn @inferred Integrals.__solvebp_call(prob, HCubatureJL(), - Integrals.ReCallVJP(Integrals.ZygoteVJP()), - prob.lb, prob.ub, prob.p) + Integrals.ReCallVJP(Integrals.ZygoteVJP()), + prob.lb, prob.ub, prob.p) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index ce75aed7..4946dc57 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -11,23 +11,23 @@ algs = [QuadGKJL(), HCubatureJL(), CubatureJLh(), CubatureJLp(), #VEGAS(), #Cuba CubaSUAVE(), CubaDivonne(), CubaCuhre()] alg_req = Dict(QuadGKJL() => (nout = 1, allows_batch = false, min_dim = 1, max_dim = 1, - allows_iip = false), - HCubatureJL() => (nout = Inf, allows_batch = false, min_dim = 1, - max_dim = Inf, allows_iip = true), - VEGAS() => (nout = 1, allows_batch = true, min_dim = 2, max_dim = Inf, - allows_iip = true), - CubatureJLh() => (nout = Inf, allows_batch = true, min_dim = 1, - max_dim = Inf, allows_iip = true), - CubatureJLp() => (nout = Inf, allows_batch = true, min_dim = 1, - max_dim = Inf, allows_iip = true), - CubaVegas() => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf, - allows_iip = true), - CubaSUAVE() => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf, - allows_iip = true), - CubaDivonne() => (nout = Inf, allows_batch = true, min_dim = 2, - max_dim = Inf, allows_iip = true), - CubaCuhre() => (nout = Inf, allows_batch = true, min_dim = 2, max_dim = Inf, - allows_iip = true)) + allows_iip = false), + HCubatureJL() => (nout = Inf, allows_batch = false, min_dim = 1, + max_dim = Inf, allows_iip = true), + VEGAS() => (nout = 1, allows_batch = true, min_dim = 2, max_dim = Inf, + allows_iip = true), + CubatureJLh() => (nout = Inf, allows_batch = true, min_dim = 1, + max_dim = Inf, allows_iip = true), + CubatureJLp() => (nout = Inf, allows_batch = true, min_dim = 1, + max_dim = Inf, allows_iip = true), + CubaVegas() => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf, + allows_iip = true), + CubaSUAVE() => (nout = Inf, allows_batch = true, min_dim = 1, max_dim = Inf, + allows_iip = true), + CubaDivonne() => (nout = Inf, allows_batch = true, min_dim = 2, + max_dim = Inf, allows_iip = true), + CubaCuhre() => (nout = Inf, allows_batch = true, min_dim = 2, max_dim = Inf, + allows_iip = true)) integrands = [ (x, p) -> 1.0, @@ -36,7 +36,7 @@ integrands = [ iip_integrands = [(dx, x, p) -> (dx .= f(x, p)) for f in integrands] integrands_v = [(x, p, nout) -> collect(1.0:nout) - (x, p, nout) -> integrands[2](x, p) * collect(1.0:nout)] + (x, p, nout) -> integrands[2](x, p) * collect(1.0:nout)] iip_integrands_v = [(dx, x, p, nout) -> (dx .= f(x, p, nout)) for f in integrands_v] exact_sol = [ @@ -219,7 +219,7 @@ end for i in 1:length(integrands_v) for nout in 1:max_nout_test prob = IntegralProblem((x, p) -> integrands_v[i](x, p, nout), lb, ub, - nout = nout) + nout = nout) if req.min_dim > 1 || req.nout < nout continue end @@ -231,58 +231,62 @@ end end end -@testset "Standard Vector Integrands" begin for alg in algs - req = alg_req[alg] - for i in 1:length(integrands_v) - for dim in 1:max_dim_test - lb, ub = (ones(dim), 3ones(dim)) - for nout in 1:max_nout_test - if dim > req.max_dim || dim < req.min_dim || req.nout < nout || - alg isa QuadGKJL || alg isa VEGAS - #QuadGKJL and VEGAS require numbers, not single element arrays - continue - end - prob = IntegralProblem((x, p) -> integrands_v[i](x, p, nout), lb, ub, - nout = nout) - @info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout" - sol = solve(prob, alg, reltol = reltol, abstol = abstol) - if sol.u isa Number - @test sol.u≈exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2 - else - @test sol.u≈exact_sol_v[i](dim, nout, lb, ub) rtol=1e-2 +@testset "Standard Vector Integrands" begin + for alg in algs + req = alg_req[alg] + for i in 1:length(integrands_v) + for dim in 1:max_dim_test + lb, ub = (ones(dim), 3ones(dim)) + for nout in 1:max_nout_test + if dim > req.max_dim || dim < req.min_dim || req.nout < nout || + alg isa QuadGKJL || alg isa VEGAS + #QuadGKJL and VEGAS require numbers, not single element arrays + continue + end + prob = IntegralProblem((x, p) -> integrands_v[i](x, p, nout), lb, ub, + nout = nout) + @info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout" + sol = solve(prob, alg, reltol = reltol, abstol = abstol) + if sol.u isa Number + @test sol.u≈exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2 + else + @test sol.u≈exact_sol_v[i](dim, nout, lb, ub) rtol=1e-2 + end end end end end -end end +end -@testset "In-place Standard Vector Integrands" begin for alg in algs - req = alg_req[alg] - for i in 1:length(iip_integrands_v) - for dim in 1:max_dim_test - lb, ub = (ones(dim), 3ones(dim)) - for nout in 1:max_nout_test - prob = IntegralProblem((dx, x, p) -> iip_integrands_v[i](dx, x, p, nout), - lb, ub, nout = nout) - if dim > req.max_dim || dim < req.min_dim || req.nout < nout || - alg isa QuadGKJL #QuadGKJL requires numbers, not single element arrays - continue - end - @info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout" - if alg isa HCubatureJL && dim == 1 # HCubature library requires finer tol to pass test. When requiring array outputs for iip integrands - sol = solve(prob, alg, reltol = 1e-5, abstol = 1e-5) - else - sol = solve(prob, alg, reltol = reltol, abstol = abstol) - end - if sol.u isa Number - @test sol.u≈exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2 - else - @test sol.u≈exact_sol_v[i](dim, nout, lb, ub) rtol=1e-2 +@testset "In-place Standard Vector Integrands" begin + for alg in algs + req = alg_req[alg] + for i in 1:length(iip_integrands_v) + for dim in 1:max_dim_test + lb, ub = (ones(dim), 3ones(dim)) + for nout in 1:max_nout_test + prob = IntegralProblem((dx, x, p) -> iip_integrands_v[i](dx, x, p, nout), + lb, ub, nout = nout) + if dim > req.max_dim || dim < req.min_dim || req.nout < nout || + alg isa QuadGKJL #QuadGKJL requires numbers, not single element arrays + continue + end + @info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout" + if alg isa HCubatureJL && dim == 1 # HCubature library requires finer tol to pass test. When requiring array outputs for iip integrands + sol = solve(prob, alg, reltol = 1e-5, abstol = 1e-5) + else + sol = solve(prob, alg, reltol = reltol, abstol = abstol) + end + if sol.u isa Number + @test sol.u≈exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2 + else + @test sol.u≈exact_sol_v[i](dim, nout, lb, ub) rtol=1e-2 + end end end end end -end end +end @testset "Batched Single Dimension Vector Integrands" begin (lb, ub) = (1.0, 3.0) @@ -291,7 +295,7 @@ end end req = alg_req[alg] for i in 1:length(integrands_v) prob = IntegralProblem(batch_f_v(integrands_v[i], nout), lb, ub, batch = 1000, - nout = nout) + nout = nout) if req.min_dim > 1 || !req.allows_batch || req.nout < nout continue end @@ -310,8 +314,8 @@ end for dim in 1:max_dim_test (lb, ub) = (ones(dim), 3ones(dim)) prob = IntegralProblem(batch_f_v(integrands_v[i], nout), lb, ub, - batch = 1000, - nout = nout) + batch = 1000, + nout = nout) if dim > req.max_dim || dim < req.min_dim || !req.allows_batch || req.nout < nout continue @@ -332,7 +336,7 @@ end for dim in 1:max_dim_test (lb, ub) = (ones(dim), 3ones(dim)) prob = IntegralProblem(batch_iip_f_v(integrands_v[i], nout), lb, ub, - batch = 10, nout = nout) + batch = 10, nout = nout) if dim > req.max_dim || dim < req.min_dim || !req.allows_batch || !req.allows_iip || req.nout < nout continue @@ -349,9 +353,9 @@ end f(u, p) = sum(sin.(u)) prob = IntegralProblem(f, ones(3), 3ones(3)) @test_throws Integrals.CommonKwargError((:relztol => 1e-3, :abstol => 1e-3)) solve(prob, - HCubatureJL(); - relztol = 1e-3, - abstol = 1e-3) + HCubatureJL(); + relztol = 1e-3, + abstol = 1e-3) end @testset "Caching interface" begin diff --git a/test/runtests.jl b/test/runtests.jl index d3e30f04..ef990ac6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,15 @@ end dev_subpkg("IntegralsCuba") dev_subpkg("IntegralsCubature") -@time @safetestset "Interface Tests" begin include("interface_tests.jl") end -@time @safetestset "Derivative Tests" begin include("derivative_tests.jl") end -@time @safetestset "Infinite Integral Tests" begin include("inf_integral_tests.jl") end -@time @safetestset "Gaussian Quadrature Tests" begin include("gaussian_quadrature_tests.jl") end +@time @safetestset "Interface Tests" begin + include("interface_tests.jl") +end +@time @safetestset "Derivative Tests" begin + include("derivative_tests.jl") +end +@time @safetestset "Infinite Integral Tests" begin + include("inf_integral_tests.jl") +end +@time @safetestset "Gaussian Quadrature Tests" begin + include("gaussian_quadrature_tests.jl") +end From 5026be13fc5ff49e4f8eca11fb77a47d717bbafd Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 9 Aug 2023 13:54:47 -0400 Subject: [PATCH 12/13] make cache mutable --- Project.toml | 1 - docs/src/tutorials/caching_interface.md | 20 +++++---- src/Integrals.jl | 1 - src/common.jl | 57 +------------------------ test/interface_tests.jl | 15 +++---- 5 files changed, 17 insertions(+), 77 deletions(-) diff --git a/Project.toml b/Project.toml index 9d9aa4ca..1b739b4f 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] ChainRulesCore = "0.10.7, 1" diff --git a/docs/src/tutorials/caching_interface.md b/docs/src/tutorials/caching_interface.md index 9a6a440d..13fe41de 100644 --- a/docs/src/tutorials/caching_interface.md +++ b/docs/src/tutorials/caching_interface.md @@ -1,7 +1,7 @@ # Integrals with Caching Interface Often, integral solvers allocate memory or reuse quadrature rules for solving different -problems. For example, if one is going to perform +problems. For example, if one is going to solve the same integral for several parameters ```julia using Integrals @@ -11,7 +11,7 @@ alg = QuadGKJL() solve(prob, alg) -prob = remake(prob, f = (x, p) -> cos(x * p)) +prob = remake(prob, p = 15.0) solve(prob, alg) ``` @@ -21,14 +21,14 @@ shown below by directly calling the library ```julia using QuadGK segbuf = QuadGK.alloc_segbuf() +quadgk(x -> sin(14x), 0, 1, segbuf = segbuf) quadgk(x -> sin(15x), 0, 1, segbuf = segbuf) -quadgk(x -> cos(15x), 0, 1, segbuf = segbuf) ``` Integrals.jl's caching interface automates this process to reuse resources if an algorithm supports it and if the necessary types to build the cache can be inferred from `prob`. To do -this with Integrals.jl, you simply `init` a cache, `solve`, replace `f`, and solve again. -This looks like +this with Integrals.jl, you simply `init` a cache, `solve!`, replace `p`, and solve again. +This uses the [SciML `init` interface](https://docs.sciml.ai/SciMLBase/stable/interfaces/Init_Solve/#init-and-the-Iterator-Interface) ```@example cache1 using Integrals @@ -41,10 +41,12 @@ sol1 = solve!(cache) ``` ```@example cache1 -cache = Integrals.set_f(cache, (x, p) -> cos(x * p)) +cache.p = 15.0 sol2 = solve!(cache) ``` -Similar cache-rebuilding functions are provided, including: `set_p`, `set_lb`, and `set_ub`, -each of which provides a new value of `lb`, `ub`, or `p`, respectively. When resetting the -cache, new allocations may be needed if those inferred types change. +The caching interface is intended for updating `p`, `lb`, `ub`, `nout`, and `batch`. +Note that the types of these variables is not allowed to change. +If it is necessary to change the integrand `f` instead of defining a new +`IntegralProblem`, consider using +[FunctionWrappers.jl](https://github.com/yuyichao/FunctionWrappers.jl). diff --git a/src/Integrals.jl b/src/Integrals.jl index d1d80fc1..07bb525e 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -7,7 +7,6 @@ end using Reexport, MonteCarloIntegration, QuadGK, HCubature @reexport using SciMLBase using LinearAlgebra -using Setfield include("common.jl") include("init.jl") diff --git a/src/common.jl b/src/common.jl index b58defc8..d13a0165 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,4 +1,4 @@ -struct IntegralCache{iip, F, B, P, PK, A, S, K, Tc} +mutable struct IntegralCache{iip, F, B, P, PK, A, S, K, Tc} iip::Val{iip} f::F lb::B @@ -47,61 +47,6 @@ function SciMLBase.init(prob::IntegralProblem{iip}, cacheval) end -refresh_cacheval(cacheval, alg, prob) = nothing - -""" - set_f(cache, f, [nout=cache.nout]) - -Return a new cache with the new integrand `f`, optionally resetting `nout` at the same time. -""" -function set_f(cache::IntegralCache, f, nout = cache.nout) - prob = remake(build_problem(cache), f = f, nout = nout) - alg = cache.alg - cacheval = cache.cacheval - # lots of type-instability hereafter - @set! cache.f = f - @set! cache.iip = Val(isinplace(f, 3)) - @set! cache.nout = nout - @set! cache.cacheval = refresh_cacheval(cacheval, alg, prob) - return cache -end - -""" - set_lb(cache, lb) - -Return a new cache with new lower limits `lb`. -""" -function set_lb(cache::IntegralCache, lb) - @set! cache.lb = lb - return cache -end - -# since types of lb and ub are constrained, we do not need to refresh cache - -""" - set_ub(cache, ub) - -Return a new cache with new lower limits `ub`. -""" -function set_ub(cache::IntegralCache, ub) - @set! cache.ub = ub - return cache -end - -""" - set_p(cache, p, [refresh=true]) - -Return a new cache with parameters `p`. -""" -function set_p(cache::IntegralCache, p) - prob = remake(build_problem(cache), p = p) - alg = cache.alg - cacheval = cache.cacheval - @set! cache.p = p - @set! cache.cacheval = refresh_cacheval(cacheval, alg, prob) - return cache -end - # Throw error if alg is not provided, as defaults are not implemented. function SciMLBase.solve(::IntegralProblem; kwargs...) checkkwargs(kwargs...) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 4946dc57..6789f653 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -360,6 +360,7 @@ end @testset "Caching interface" begin lb, ub = (1.0, 3.0) + p = NaN # the integrands don't actually use this nout = 1 dim = 1 for alg in algs @@ -367,20 +368,14 @@ end continue end for i in 1:length(integrands) - prob = IntegralProblem(integrands[i], lb, ub) + prob = IntegralProblem(integrands[i], lb, ub, p) cache = init(prob, alg, reltol = reltol, abstol = abstol) @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 - lb = 0.5 - cache = Integrals.set_lb(cache, lb) - @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 - ub = 3.5 - cache = Integrals.set_ub(cache, ub) + cache.lb = lb = 0.5 @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 - p = missing # the integrands don't actually use this - cache = Integrals.set_p(cache, p) + cache.ub = ub = 3.5 @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 - f = (x, p) -> integrands[i](x, p) # for lack of creativity, wrap the old integrand - cache = Integrals.set_f(cache, f) + cache.p = Inf @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 end end From 28a77890ea5df8e741b507cb2dc75f6db0060968 Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 9 Aug 2023 14:52:54 -0400 Subject: [PATCH 13/13] fix AD --- ext/IntegralsForwardDiffExt.jl | 12 +++++++++--- ext/IntegralsZygoteExt.jl | 10 ++++++++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/ext/IntegralsForwardDiffExt.jl b/ext/IntegralsForwardDiffExt.jl index f87b4103..6cfb3254 100644 --- a/ext/IntegralsForwardDiffExt.jl +++ b/ext/IntegralsForwardDiffExt.jl @@ -1,6 +1,5 @@ module IntegralsForwardDiffExt using Integrals -using Integrals: set_f, set_p, build_problem isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) ### Forward-Mode AD Intercepts @@ -68,7 +67,14 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub, rawp = copy(reinterpret(V, p)) - dp_cache = set_p(set_f(cache, dfdp, nout), rawp) + prob = Integrals.build_problem(cache) + dp_prob = remake(prob, f = dfdp, nout = nout, p = rawp) + # the infinity transformation was already applied to f so we don't apply it to dfdp + dp_cache = init(dp_prob, + alg; + sensealg = sensealg, + do_inf_transformation = Val(false), + cache.kwargs...) dual = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, rawp; kwargs...) res = similar(p, cache.nout) @@ -79,6 +85,6 @@ function Integrals.__solvebp(cache, alg, sensealg, lb, ub, if primal.u isa Number res = first(res) end - SciMLBase.build_solution(build_problem(cache), alg, res, primal.resid) + SciMLBase.build_solution(prob, alg, res, primal.resid) end end diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index 3d3185b2..a5c41870 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -1,6 +1,5 @@ module IntegralsZygoteExt using Integrals -using Integrals: set_f if isdefined(Base, :get_extension) using Zygote import ChainRulesCore @@ -68,7 +67,14 @@ function ChainRulesCore.rrule(::typeof(Integrals.__solvebp), cache, alg, senseal end end - dp_cache = set_f(cache, dfdp, length(p)) + prob = Integrals.build_problem(cache) + dp_prob = remake(prob, f = dfdp, nout = length(p)) + # the infinity transformation was already applied to f so we don't apply it to dfdp + dp_cache = init(dp_prob, + alg; + sensealg = sensealg, + do_inf_transformation = Val(false), + cache.kwargs...) if p isa Number dp = Integrals.__solvebp_call(dp_cache, alg, sensealg, lb, ub, p; kwargs...)[1]