From f96c64cd9d6b234b405cf4c9e7fd32baac91fb79 Mon Sep 17 00:00:00 2001 From: lxvm Date: Sat, 30 Dec 2023 10:50:03 -0800 Subject: [PATCH] move vegasmc to ext --- Project.toml | 6 ++-- ext/IntegralsMCIntegrationExt.jl | 48 ++++++++++++++++++++++++++++++++ src/Integrals.jl | 45 +----------------------------- src/algorithms.jl | 35 +++++++++++------------ test/interface_tests.jl | 2 +- 5 files changed, 71 insertions(+), 65 deletions(-) create mode 100644 ext/IntegralsMCIntegrationExt.jl diff --git a/Project.toml b/Project.toml index 0d4f8e50..0d5a58e9 100644 --- a/Project.toml +++ b/Project.toml @@ -7,7 +7,6 @@ version = "4.1.0" CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167" MonteCarloIntegration = "4886b29c-78c9-11e9-0a6e-41e1f4161f7b" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -20,6 +19,7 @@ Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1" Cubature = "667455a9-e2ce-5579-9412-b964f529a492" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -28,6 +28,7 @@ IntegralsCubaExt = "Cuba" IntegralsCubatureExt = "Cubature" IntegralsFastGaussQuadratureExt = "FastGaussQuadrature" IntegralsForwardDiffExt = "ForwardDiff" +IntegralsMCIntegrationExt = "MCIntegration" IntegralsZygoteExt = ["Zygote", "ChainRulesCore"] [compat] @@ -65,6 +66,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" @@ -72,4 +74,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature"] +test = ["Aqua", "Arblib", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration"] diff --git a/ext/IntegralsMCIntegrationExt.jl b/ext/IntegralsMCIntegrationExt.jl new file mode 100644 index 00000000..e14233d7 --- /dev/null +++ b/ext/IntegralsMCIntegrationExt.jl @@ -0,0 +1,48 @@ +module IntegralsMCIntegrationExt + +using MCIntegration, Integrals + +function Integrals.__solvebp_call(prob::IntegralProblem, alg::VEGASMC, sensealg, domain, p; + reltol = nothing, abstol = nothing, maxiters = 1000) + lb, ub = domain + mid = vec(collect((lb + ub) / 2)) + vars = Continuous(vec([tuple(a,b) for (a,b) in zip(lb, ub)])) + + if prob.f isa BatchIntegralFunction + error("VEGASMC doesn't support batching. See https://github.com/numericalEFT/MCIntegration.jl/issues/29") + else + if isinplace(prob) + f0 = similar(prob.f.integrand_prototype) + f_ = (x, f, c) -> begin + n = 0 + for v in x + mid[n+=1] = first(v) + end + prob.f(f0, mid, p) + f .= vec(f0) + end + else + f0 = prob.f(mid, p) + f_ = (x, c) -> begin + n = 0 + for v in x + mid[n+=1] = first(v) + end + fx = prob.f(mid, p) + fx isa AbstractArray ? vec(fx) : fx + end + end + dof = ones(Int, length(f0)) # each composite Continuous var gets 1 dof + res = integrate(f_, inplace=isinplace(prob), var=vars, dof=dof, solver=:vegasmc, + neval=alg.neval, niter=min(maxiters,alg.niter), block=alg.block, adapt=alg.adapt, + gamma=alg.gamma, verbose=alg.verbose, debug=alg.debug, type=eltype(f0), print=-2) + out, err, chi = if f0 isa Number + map(only, (res.mean, res.stdev, res.chi2)) + else + map(a -> reshape(a, size(f0)), (res.mean, res.stdev, res.chi2)) + end + SciMLBase.build_solution(prob, VEGASMC(), out, err, chi=chi, retcode = ReturnCode.Success) + end +end + +end diff --git a/src/Integrals.jl b/src/Integrals.jl index 25cf08fe..4f22aa36 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -4,7 +4,7 @@ if !isdefined(Base, :get_extension) using Requires end -using Reexport, MonteCarloIntegration, QuadGK, HCubature, MCIntegration +using Reexport, MonteCarloIntegration, QuadGK, HCubature @reexport using SciMLBase using LinearAlgebra @@ -218,49 +218,6 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, domain, p; end -function __solvebp_call(prob::IntegralProblem, alg::VEGASMC, sensealg, domain, p; - reltol = nothing, abstol = nothing, maxiters = 1000) - lb, ub = domain - mid = collect((lb + ub) / 2) - vars = Continuous(vec([tuple(a,b) for (a,b) in zip(lb, ub)])) - - if prob.f isa BatchIntegralFunction - error("VEGASMC doesn't support batching. See https://github.com/numericalEFT/MCIntegration.jl/issues/29") - else - if isinplace(prob) - f0 = similar(prob.f.integrand_prototype) - f_ = (x, f, c) -> begin - n = 0 - for v in x - mid[n+=1] = first(v) - end - prob.f(f0, mid, p) - f .= vec(f0) - end - else - f0 = prob.f(mid, p) - f_ = (x, c) -> begin - n = 0 - for v in x - mid[n+=1] = first(v) - end - prob.f(mid, p) - end - end - dof = f0 isa Number ? 1 : ones(Int, length(f0)) - res = integrate(f_, inplace=isinplace(prob), var=vars, dof=dof, solver=:vegasmc, - neval=alg.neval, niter=min(maxiters,alg.niter), block=alg.block, adapt=alg.adapt, - gamma=alg.gamma, verbose=alg.verbose, debug=alg.debug) - out, err, chi = if f0 isa Number - map(only, (res.mean, res.stdev, res.chi2)) - else - map(a -> reshape(a, size(f0)), (res.mean, res.stdev, res.chi2)) - end - SciMLBase.build_solution(prob, VEGASMC(), out, err, chi=chi, retcode = ReturnCode.Success) - end -end - - export QuadGKJL, HCubatureJL, VEGAS, VEGASMC, GaussLegendre, QuadratureRule, TrapezoidalRule export CubaVegas, CubaSUAVE, CubaDivonne, CubaCuhre export CubatureJLh, CubatureJLp diff --git a/src/algorithms.jl b/src/algorithms.jl index c294fff6..972c4641 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -87,24 +87,6 @@ struct VEGAS <: SciMLBase.AbstractIntegralAlgorithm end VEGAS(; nbins = 100, ncalls = 1000, debug = false) = VEGAS(nbins, ncalls, debug) - -""" - VEGASMC() - -Markov-chain based Vegas algorithm from MCIntegration.jl -""" -struct VEGASMC <: SciMLBase.AbstractIntegralAlgorithm - neval::Int - niter::Int - block::Int - adapt::Bool - gamma::Float64 - verbose::Int - debug::Bool -end -VEGASMC(; neval=10^4, niter=20, block=16, adapt=true, gamma=1.0, verbose=-2, debug=false) = - VEGASMC(neval, niter, block, adapt, gamma, verbose, debug) - """ GaussLegendre{C, N, W} @@ -388,3 +370,20 @@ end function ArblibJL(; check_analytic=false, take_prec=false, warn_on_no_convergence=false, opts=C_NULL) return ArblibJL(check_analytic, take_prec, warn_on_no_convergence, opts) end + +""" + VEGASMC(; neval=10^4, niter=20, block=16, adapt=true, gamma=1.0, verbose=-2, debug=false) + +Markov-chain based Vegas algorithm from MCIntegration.jl +""" +struct VEGASMC <: SciMLBase.AbstractIntegralAlgorithm + neval::Int + niter::Int + block::Int + adapt::Bool + gamma::Float64 + verbose::Int + debug::Bool +end +VEGASMC(; neval=10^4, niter=20, block=16, adapt=true, gamma=1.0, verbose=-2, debug=false) = + VEGASMC(neval, niter, block, adapt, gamma, verbose, debug) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 8daef4b3..c3b79319 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -1,5 +1,5 @@ using Integrals -using Cuba, Cubature, Arblib +using Cuba, Cubature, Arblib, MCIntegration using Test max_dim_test = 2