Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap GSL.jl in extension #213

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GSL = "92c85e6c-cbff-5e0c-80f7-495c94daaecd"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
Expand All @@ -27,6 +28,7 @@ IntegralsCubaExt = "Cuba"
IntegralsCubatureExt = "Cubature"
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"
IntegralsForwardDiffExt = "ForwardDiff"
IntegralsGSLExt = "GSL"
IntegralsZygoteExt = ["Zygote", "ChainRulesCore"]

[compat]
Expand All @@ -40,6 +42,7 @@ Distributions = "0.25.87"
FastGaussQuadrature = "0.5"
FiniteDiff = "2.12"
ForwardDiff = "0.10.19"
GSL = "1"
HCubature = "1.5"
LinearAlgebra = "1.9"
MonteCarloIntegration = "0.0.3, 0.1"
Expand All @@ -64,6 +67,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GSL = "92c85e6c-cbff-5e0c-80f7-495c94daaecd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Expand All @@ -72,4 +76,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Arblib", "SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature"]
test = ["Aqua", "Arblib", "SciMLSensitivity", "StaticArrays", "FiniteDiff", "GSL", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature"]
48 changes: 48 additions & 0 deletions ext/IntegralsGSLExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module IntegralsGSLExt

using GSL
using Integrals
using Integrals: IntegralCache

mutable struct GSLCache{T}
value::T
end
getvalue(cache::GSLCache) = cache.value

function Integrals.init_cacheval(alg::GSLIntegration{typeof(integration_cquad)}, prob::IntegralProblem)
ws = integration_cquad_workspace_alloc(alg.kws.wssize)
gslcache = GSLCache(ws)
finalizer(integration_cquad_workspace_free∘getvalue, gslcache)
lxvm marked this conversation as resolved.
Show resolved Hide resolved
result = Cdouble[0]
abserr = Cdouble[0]
nevals = C_NULL # Csize_t[0]
return (; gslcache, result, abserr, nevals)
end

function Integrals.__solvebp_call(cache::IntegralCache, alg::GSLIntegration{typeof(integration_cquad)}, sensealg, domain, p;
reltol = 1e-8, abstol = 1e-8, maxiters = nothing)

prob = Integrals.build_problem(cache)

if !all(isone∘length, domain)
error("GSLIntegration only accepts one-dimensional quadrature problems.")
end
@assert prob.f isa IntegralFunction

f = if isinplace(prob)
@assert isone(length(prob.f.integrand_prototype)) "GSL only supports scalar, real-valued integrands"
y = similar(prob.f.integrand_prototype, Cdouble)
x -> (prob.f(y, x, p); only(y))
else
x -> Cdouble(only(prob.f(x, p)))
end
# gslf = @gsl_function(f) # broken, see: https://github.com/JuliaMath/GSL.jl/pull/128
ptr = @cfunction($((x,p) -> f(x)), Cdouble, (Cdouble, Ptr{Cvoid}))
gslf = gsl_function(Base.unsafe_convert(Ptr{Cvoid},ptr), 0)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
a, b = map(Cdouble∘only, domain)
(; gslcache, result, abserr, nevals) = cache.cacheval
integration_cquad(gslf, a, b, abstol, reltol, getvalue(gslcache), result, abserr, nevals)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
return SciMLBase.build_solution(prob, alg, only(result), only(abserr), retcode = ReturnCode.Success)
end

end
1 change: 1 addition & 0 deletions src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,5 +221,6 @@ export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre, QuadratureRule, TrapezoidalR
export CubaVegas, CubaSUAVE, CubaDivonne, CubaCuhre
export CubatureJLh, CubatureJLp
export ArblibJL
export GSLIntegration

end # module
13 changes: 13 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,16 @@
function ArblibJL(; check_analytic=false, take_prec=false, warn_on_no_convergence=false, opts=C_NULL)
return ArblibJL(check_analytic, take_prec, warn_on_no_convergence, opts)
end


"""
GSLIntegration(routine; kws...)

One-dimensional quadrature of Float64-valued function using `routine` from GSL with
additional arguments. For example `using Integrals, GSL; GSLIntegration(integration_cquad; wssize=100)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably document the available functions. How many are there?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are quite a few, since the GSL includes a rewrite of quadpack. I'll put a full list below and I hope that most of them can be wrapped metaprogramatically

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

integration_cquad
integration_fixed
integration_glfixed
integration_qag
integration_qagi
integration_qagil
integration_qagiu
integration_qagp
integration_qags
integration_qawc
integration_qawf
integration_qawo
integration_qaws
integration_qcheb
integration_qk
integration_qk15
integration_qk21
integration_qk31
integration_qk41
integration_qk51
integration_qk61
integration_qng
integration_romberg

"""
struct GSLIntegration{F,A<:NamedTuple} <: SciMLBase.AbstractIntegralAlgorithm
f::F
kws::A
end
GSLIntegration(f; kws...) = GSLIntegration(f, NamedTuple(kws))

Check warning on line 385 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L385

Added line #L385 was not covered by tests
7 changes: 4 additions & 3 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Integrals
using Cuba, Cubature, Arblib
using Cuba, Cubature, Arblib, GSL
using Test

max_dim_test = 2
Expand All @@ -9,7 +9,7 @@ abstol = 1e-3


algs = [QuadGKJL, HCubatureJL, CubatureJLh, CubatureJLp, #VEGAS, #CubaVegas,
CubaSUAVE, CubaDivonne, CubaCuhre]
CubaSUAVE, CubaDivonne, CubaCuhre, ArblibJL, () -> GSL(integration_cquad; wssize=100)]

alg_req = Dict(QuadGKJL => (nout = 1, allows_batch = true, min_dim = 1, max_dim = 1,
allows_iip = true),
Expand All @@ -29,7 +29,8 @@ alg_req = Dict(QuadGKJL => (nout = 1, allows_batch = true, min_dim = 1, max_dim
max_dim = Inf, allows_iip = true),
CubaCuhre => (nout = Inf, allows_batch = true, min_dim = 2, max_dim = Inf,
allows_iip = true),
ArblibJL => (nout=1, allows_batch=false, min_dim=1, max_dim=1, allows_iip=true))
ArblibJL => (nout=1, allows_batch=false, min_dim=1, max_dim=1, allows_iip=true,
GSLIntegration => (nout=1, allows_batch=false, min_dim=1, max_dim=1, allows_iip=true)))

integrands = [
(x, p) -> 1.0,
Expand Down
Loading