Skip to content

Commit

Permalink
Merge pull request #195 from lxvm/mcinterface
Browse files Browse the repository at this point in the history
Fix VEGAS interface
  • Loading branch information
ChrisRackauckas authored Nov 4, 2023
2 parents 893ab8b + aa15339 commit 6cea3a8
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
14 changes: 10 additions & 4 deletions src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,25 @@ end

function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, domain, p;
reltol = 1e-8, abstol = 1e-8,
maxiters = typemax(Int))
maxiters = 1000)
lb, ub = domain
mid = (lb + ub) / 2
if prob.f isa BatchIntegralFunction
if isinplace(prob)
y = similar(prob.f.integrand_prototype,
size(prob.f.integrand_prototype)[begin:(end - 1)]...,
prob.f.max_batch)
f = x -> (prob.f(y, x', p); vec(y))
# MonteCarloIntegration v0.0.x passes points as rows of a matrix
# MonteCarloIntegration v0.1 passes batches as a vector of views of
# a matrix with points as columns of a matrix
# see https://github.com/ranjanan/MonteCarloIntegration.jl/issues/16
# This is an ugly hack that is compatible with both
f = x -> (prob.f(y, eltype(x) <: SubArray ? parent(first(x)) : x', p); vec(y))
else
y = prob.f(mid isa Number ? typeof(mid)[] :
Matrix{eltype(mid)}(undef, length(mid), 0),
p)
f = x -> prob.f(x', p)
f = x -> prob.f(eltype(x) <: SubArray ? parent(first(x)) : x', p)
end
else
if isinplace(prob)
Expand All @@ -161,9 +166,10 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, domain, p;
end

ncalls = prob.f isa BatchIntegralFunction ? prob.f.max_batch : alg.ncalls
val, err, chi = vegas(f, lb, ub, rtol = reltol, atol = abstol,
out = vegas(f, lb, ub, rtol = reltol, atol = abstol,
maxiter = maxiters, nbins = alg.nbins, debug = alg.debug,
ncalls = ncalls, batch = prob.f isa BatchIntegralFunction)
val, err, chi = out isa Tuple ? out : (out.integral_estimate, out.standard_deviation, out.chi_squared_average)
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
end

Expand Down
2 changes: 1 addition & 1 deletion test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ max_nout_test = 2
reltol = 1e-3
abstol = 1e-3

algs = [QuadGKJL(), HCubatureJL(), CubatureJLh(), CubatureJLp(), #VEGAS(), #CubaVegas(),
algs = [QuadGKJL(), HCubatureJL(), CubatureJLh(), CubatureJLp(), VEGAS(), #CubaVegas(),
CubaSUAVE(), CubaDivonne(), CubaCuhre()]

alg_req = Dict(QuadGKJL() => (nout = 1, allows_batch = false, min_dim = 1, max_dim = 1,
Expand Down

0 comments on commit 6cea3a8

Please sign in to comment.