Skip to content

Commit

Permalink
Merge pull request #235 from lxvm/algs
Browse files Browse the repository at this point in the history
pass the algorithm to the solution
  • Loading branch information
ChrisRackauckas authored Feb 16, 2024
2 parents da1a166 + ec85909 commit a4476c0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ext/IntegralsMCIntegrationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::VEGASMC, sensealg,
else
map(a -> reshape(a, size(f0)), (res.mean, res.stdev, res.chi2))
end
SciMLBase.build_solution(prob, VEGASMC(), out, err, chi=chi, retcode = ReturnCode.Success)
SciMLBase.build_solution(prob, alg, out, err, chi=chi, retcode = ReturnCode.Success)
end
end

Expand Down
10 changes: 5 additions & 5 deletions src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, domain, p
end
val, err = quadgk(f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
else
u = prob.f(typeof(mid)[], p)
f = if u isa AbstractVector
Expand All @@ -125,20 +125,20 @@ function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, domain, p
end
val, err = quadgk(f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
end
else
if isinplace(prob)
result = prob.f.integrand_prototype * mid # result may have different units than prototype
f = (y, x) -> prob.f(y, x, p)
val, err = quadgk!(f, result, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
else
f = x -> prob.f(x, p)
val, err = quadgk(f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
SciMLBase.build_solution(prob, QuadGKJL(), val, err, retcode = ReturnCode.Success)
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
end
end
end
Expand Down Expand Up @@ -166,7 +166,7 @@ function __solvebp_call(prob::IntegralProblem, alg::HCubatureJL, sensealg, domai
rtol = reltol, atol = abstol,
maxevals = maxiters, norm = alg.norm, initdiv = alg.initdiv)
end
SciMLBase.build_solution(prob, HCubatureJL(), val, err, retcode = ReturnCode.Success)
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
end

function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, domain, p;
Expand Down
12 changes: 12 additions & 0 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ end
prob = IntegralProblem(integrands[i], lb, ub)
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
end
end
Expand All @@ -119,6 +120,7 @@ end
end
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
end
end
Expand All @@ -137,6 +139,7 @@ end
end
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
if sol.u isa Number
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
else
Expand All @@ -158,6 +161,7 @@ end
end
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
@test sol.u[1]exact_sol[i](dim, nout, lb, ub) rtol=1e-2
end
end
Expand All @@ -175,6 +179,7 @@ end
end
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
if sol.u isa Number
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
else
Expand All @@ -198,6 +203,7 @@ end
end
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
if sol.u isa Number
@test sol.uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
else
Expand All @@ -222,6 +228,7 @@ end
end
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
if nout == 1
@test sol.u[1]exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2
else
Expand All @@ -245,6 +252,7 @@ end
nout = nout)
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
if nout == 1
@test sol.u[1]exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2
else
Expand All @@ -270,6 +278,7 @@ end
end
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
if nout == 1
@test sol.u[1]exact_sol_v[i](dim, nout, lb, ub)[1] rtol=1e-2
else
Expand All @@ -293,6 +302,7 @@ end
end
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
@test sol.uexact_sol_v[i](dim, nout, lb, ub) rtol=1e-2
end
end
Expand All @@ -313,6 +323,7 @@ end
end
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
@test sol.uexact_sol_v[i](dim, nout, lb, ub) rtol=1e-2
end
end
Expand All @@ -333,6 +344,7 @@ end
end
@info "Alg = $(nameof(typeof(alg))), Integrand = $i, Dimension = $dim, Output Dimension = $nout"
sol = solve(prob, alg, reltol = reltol, abstol = abstol)
@test sol.alg == alg
@test sol.uexact_sol_v[i](dim, nout, lb, ub) rtol=1e-2
end
end
Expand Down

0 comments on commit a4476c0

Please sign in to comment.