Skip to content

Commit

Permalink
Incorporate #185
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Nov 2, 2023
1 parent 24fce57 commit 88b4ea7
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 2 deletions.
144 changes: 144 additions & 0 deletions ext/IntegralsCubatureExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
module IntegralsCubatureExt

using Integrals, Cubature

import Integrals: transformation_if_inf, scale_x, scale_x!
import Cubature: INDIVIDUAL, PAIRED, L1, L2, LINF

function Integrals.__solvebp_call(prob::IntegralProblem,
alg::AbstractCubatureJLAlgorithm,
sensealg, lb, ub, p;
reltol = 1e-8, abstol = 1e-8,
maxiters = typemax(Int))
nout = prob.nout
if nout == 1
# the output of prob.f could be either scalar or a vector of length 1, however
# the behavior of the output of the integration routine is undefined (could differ
# across algorithms)
# Cubature will output a real number in when called without nout/fdim
if prob.batch == 0
if isinplace(prob)
dx = zeros(eltype(lb), prob.nout)
f = (x) -> (prob.f(dx, x, p); dx[1])
else
f = (x) -> prob.f(x, p)[1]
end
if lb isa Number
if alg isa CubatureJLh
val, err = Cubature.hquadrature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
val, err = Cubature.pquadrature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
else
if alg isa CubatureJLh
val, err = Cubature.hcubature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
val, err = Cubature.pcubature(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
end
else
if isinplace(prob)
f = (x, dx) -> prob.f(dx, x, p)
else
f = (x, dx) -> (dx .= prob.f(x, p))
end
if lb isa Number
if alg isa CubatureJLh
val, err = Cubature.hquadrature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
val, err = Cubature.pquadrature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
else
if alg isa CubatureJLh
val, err = Cubature.hcubature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
val, err = Cubature.pcubature_v(f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
end
end
else
if prob.batch == 0
if isinplace(prob)
f = (x, dx) -> (prob.f(dx, x, p); dx)
else
f = (x, dx) -> (dx .= prob.f(x, p))
end
if lb isa Number
if alg isa CubatureJLh
val, err = Cubature.hquadrature(nout, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters,
error_norm = alg.error_norm)
else
val, err = Cubature.pquadrature(nout, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters,
error_norm = alg.error_norm)
end
else
if alg isa CubatureJLh
val, err = Cubature.hcubature(nout, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters,
error_norm = alg.error_norm)
else
val, err = Cubature.pcubature(nout, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters,
error_norm = alg.error_norm)
end
end
else
if isinplace(prob)
f = (x, dx) -> (prob.f(dx, x, p); dx)
else
f = (x, dx) -> (dx .= prob.f(x, p))
end

if lb isa Number
if alg isa CubatureJLh
val, err = Cubature.hquadrature_v(nout, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters,
error_norm = alg.error_norm)
else
val, err = Cubature.pquadrature_v(nout, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters,
error_norm = alg.error_norm)
end
else
if alg isa CubatureJLh
val, err = Cubature.hcubature_v(nout, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters,
error_norm = alg.error_norm)
else
val, err = Cubature.pcubature_v(nout, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters,
error_norm = alg.error_norm)
end
end
end
end
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
end

end
4 changes: 2 additions & 2 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ publisher={Elsevier}
struct CubatureJLh <: AbstractCubatureJLAlgorithm
error_norm::Int32
end
CubatureJLh() = CubatureJLh(Cubature.INDIVIDUAL)
CubatureJLh(; error_norm = Cubature.INDIVIDUAL) = CubatureJLh(error_norm)

"""
CubatureJLp()
Expand All @@ -340,4 +340,4 @@ Defaults to `Cubature.INDIVIDUAL`, other options are
struct CubatureJLp <: AbstractCubatureJLAlgorithm
error_norm::Int32
end
CubatureJLp() = CubatureJLp(Cubature.INDIVIDUAL)
CubatureJLp(; error_norm = Cubature.INDIVIDUAL) = CubatureJLp(error_norm)

0 comments on commit 88b4ea7

Please sign in to comment.