Skip to content

Commit

Permalink
Merge pull request #200 from lxvm/issue14
Browse files Browse the repository at this point in the history
add inplace and batched quadgk
  • Loading branch information
ChrisRackauckas authored Nov 20, 2023
2 parents b91f3da + 378c3f2 commit 747e507
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ ForwardDiff = "0.10"
HCubature = "1.4"
LinearAlgebra = "1.9"
MonteCarloIntegration = "0.0.1, 0.0.2, 0.0.3, 0.1"
QuadGK = "2.5"
QuadGK = "2.9"
Reexport = "0.2, 1.0"
Requires = "1"
SciMLBase = "2.6"
Expand Down
64 changes: 54 additions & 10 deletions src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ function quadgk_prob_types(f, lb::T, ub::T, p, nrm) where {T}
return DT, RT, NT
end
function init_cacheval(alg::QuadGKJL, prob::IntegralProblem)
lb, ub = prob.domain
lb, ub = map(first, prob.domain)
DT, RT, NT = quadgk_prob_types(prob.f, lb, ub, prob.p, alg.norm)
return (isconcretetype(RT) ? QuadGK.alloc_segbuf(DT, RT, NT) : nothing)
end
function refresh_cacheval(cacheval, alg::QuadGKJL, prob)
lb, ub = prob.domain
lb, ub = map(first, prob.domain)
DT, RT, NT = quadgk_prob_types(prob.f, lb, ub, prob.p, alg.norm)
isconcretetype(RT) || return nothing
T = QuadGK.Segment{DT, RT, NT}
Expand All @@ -87,16 +87,60 @@ function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, domain, p
reltol = 1e-8, abstol = 1e-8,
maxiters = typemax(Int))
prob = build_problem(cache)
lb, ub = domain
if isinplace(prob) || lb isa AbstractArray || ub isa AbstractArray
lb_, ub_ = domain
lb, ub = map(first, domain)
if !isone(length(lb_)) || !isone(length(ub_))
error("QuadGKJL only accepts one-dimensional quadrature problems.")
end
@assert prob.f isa IntegralFunction

f = x -> prob.f(x, p)
val, err = quadgk(f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
mid = ((lb + ub) / 2)
if prob.f isa BatchIntegralFunction
if isinplace(prob)
# quadgk only works with vector buffers. If the buffer is an array, we have to
# turn it into a vector of arrays
u = prob.f.integrand_prototype
f = if u isa AbstractVector
BatchIntegrand((y, x) -> prob.f(y, x, p), similar(u))
else
fsize = size(u)[begin:(end - 1)]
BatchIntegrand{Array{eltype(u),ndims(u)-1}}() do y, x
y_ = similar(u, fsize..., length(y))
prob.f(y_, x, p)
map!(collect, y, eachslice(y_; dims=ndims(u)))
return nothing
end
end
val, err = quadgk(f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
else
u = prob.f(typeof(mid)[], p)
f = if u isa AbstractVector
BatchIntegrand((y, x) -> y .= prob.f(x, p), u)
else
BatchIntegrand{Array{eltype(u),ndims(u)-1}}() do y, x
map!(collect, y, eachslice(prob.f(x, p); dims=ndims(u)))
return nothing
end
end
val, err = quadgk(f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
end
else
if isinplace(prob)
result = prob.f.integrand_prototype * mid # result may have different units than prototype
f = (y, x) -> prob.f(y, x, p)
val, err = quadgk!(f, result, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
else
f = x -> prob.f(x, p)
val, err = quadgk(f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
end
end
end

function __solvebp_call(prob::IntegralProblem, alg::HCubatureJL, sensealg, domain, p;
Expand Down Expand Up @@ -155,7 +199,7 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, domain, p;
f = x -> (prob.f(y, x, p); only(y))
else
y = prob.f(mid, p)
f = x -> prob.f(x, prob.p)
f = x -> only(prob.f(x, prob.p))
end
end

Expand Down
37 changes: 14 additions & 23 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ max_nout_test = 2
reltol = 1e-3
abstol = 1e-3

algs = [QuadGKJL, HCubatureJL, CubatureJLh, CubatureJLp, VEGAS, #CubaVegas,
CubaSUAVE, CubaDivonne, CubaCuhre, ArblibJL]

alg_req = Dict(QuadGKJL => (nout = 1, allows_batch = false, min_dim = 1, max_dim = 1,
allows_iip = false),
algs = [QuadGKJL, HCubatureJL, CubatureJLh, CubatureJLp, #VEGAS, #CubaVegas,
CubaSUAVE, CubaDivonne, CubaCuhre]

alg_req = Dict(QuadGKJL => (nout = 1, allows_batch = true, min_dim = 1, max_dim = 1,
allows_iip = true),
HCubatureJL => (nout = Inf, allows_batch = false, min_dim = 1,
max_dim = Inf, allows_iip = true),
VEGAS => (nout = 1, allows_batch = true, min_dim = 2, max_dim = Inf,
Expand Down Expand Up @@ -59,10 +60,10 @@ batch_f(f) = (pts, p) -> begin
fevals
end

# TODO ? check if pts is a vector or matrix
batch_iip_f(f) = (fevals, pts, p) -> begin
for i in 1:size(pts, 2)
x = pts[:, i]
ax = axes(pts)
for i in ax[end]
x = pts[ax[begin:(end-1)]..., i]
fevals[i] = f(x, p)
end
nothing
Expand Down Expand Up @@ -111,7 +112,7 @@ end
for dim in 1:max_dim_test
lb, ub = (ones(dim), 3ones(dim))
prob = IntegralProblem(integrands[i], lb, ub)
if dim > req.max_dim || dim < req.min_dim || alg() isa QuadGKJL #QuadGKJL requires numbers, not single element arrays
if dim > req.max_dim || dim < req.min_dim
continue
end
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
Expand All @@ -130,15 +131,11 @@ end
for dim in 1:max_dim_test
lb, ub = (ones(dim), 3ones(dim))
prob = IntegralProblem(iip_integrands[i], lb, ub)
if dim > req.max_dim || dim < req.min_dim || alg() isa QuadGKJL #QuadGKJL requires numbers, not single element arrays
if dim > req.max_dim || dim < req.min_dim
continue
end
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
if alg() isa HCubatureJL && dim == 1 # HCubature library requires finer tol to pass test. When requiring array outputs for iip integrands
sol = solve(prob, alg(), reltol = 1e-5, abstol = 1e-5)
else
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
end
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
if sol.u isa Number
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
else
Expand Down Expand Up @@ -246,8 +243,7 @@ end
lb, ub = (ones(dim), 3ones(dim))
for nout in 1:max_nout_test
if dim > req.max_dim || dim < req.min_dim || req.nout < nout ||
alg() isa QuadGKJL || alg() isa VEGAS
#QuadGKJL and VEGAS require numbers, not single element arrays
alg() isa VEGAS # broken for integrand 2 due to sign problem?
continue
end
prob = IntegralProblem((x, p) -> integrands_v[i](x, p, nout), lb, ub,
Expand All @@ -274,16 +270,11 @@ end
for nout in 1:max_nout_test
prob = IntegralProblem((dx, x, p) -> iip_integrands_v[i](dx, x, p, nout),
lb, ub, nout = nout)
if dim > req.max_dim || dim < req.min_dim || req.nout < nout ||
alg() isa QuadGKJL #QuadGKJL requires numbers, not single element arrays
if dim > req.max_dim || dim < req.min_dim || req.nout < nout
continue
end
@info "Alg = $alg, Integrand = $i, Dimension = $dim, Output Dimension = $nout"
if alg isa HCubatureJL && dim == 1 # HCubature library requires finer tol to pass test. When requiring array outputs for iip integrands
sol = solve(prob, alg(), reltol = 1e-5, abstol = 1e-5)
else
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
end
sol = solve(prob, alg(), reltol = reltol, abstol = abstol)
if nout == 1
@test sol.u[1]exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2
else
Expand Down

0 comments on commit 747e507

Please sign in to comment.