diff --git a/docs/src/solvers/IntegralSolvers.md b/docs/src/solvers/IntegralSolvers.md index c650d3b2..f9fd3c61 100644 --- a/docs/src/solvers/IntegralSolvers.md +++ b/docs/src/solvers/IntegralSolvers.md @@ -12,10 +12,12 @@ The following algorithms are available: - `CubaDivonne`: Divonne from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations. - `CubaCuhre`: Cuhre from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations. - `GaussLegendre`: Uses Gauss-Legendre quadrature with nodes and weights from FastGaussQuadrature.jl. + - `QuadratureRule`: Accepts a user-defined function that returns nodes and weights. ```@docs QuadGKJL HCubatureJL VEGAS GaussLegendre +QuadratureRule ``` diff --git a/src/Integrals.jl b/src/Integrals.jl index 07bb525e..9c73b40c 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -12,6 +12,7 @@ include("common.jl") include("init.jl") include("algorithms.jl") include("infinity_handling.jl") +include("quadrules.jl") abstract type QuadSensitivityAlg end struct ReCallVJP{V} @@ -147,5 +148,5 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p; SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success) end -export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre +export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre, QuadratureRule end # module diff --git a/src/algorithms.jl b/src/algorithms.jl index e49122a4..ca686beb 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -122,3 +122,24 @@ function GaussLegendre(; n = 250, subintervals = 1, nodes = nothing, weights = n end return GaussLegendre(nodes, weights, subintervals) end + +""" + QuadratureRule(q; n=250) + +Algorithm to construct and evaluate a quadrature rule `q` of `n` points computed from the +inputs as `x, w = q(n)`. It assumes the nodes and weights are for the standard interval +`[-1, 1]^d` in `d` dimensions, and rescales the nodes to the specific hypercube being +solved. The nodes `x` may be scalars in 1d or vectors in arbitrary dimensions, and the +weights `w` must be scalar. The algorithm computes the quadrature rule `sum(w .* f.(x))` and +the caller must check that the result is converged with respect to `n`. +""" +struct QuadratureRule{Q} <: SciMLBase.AbstractIntegralAlgorithm + q::Q + n::Int + function QuadratureRule(q::Q, n::Integer) where {Q} + n > 0 || + throw(ArgumentError("Cannot use a nonpositive number of quadrature nodes.")) + return new{Q}(q, n) + end +end +QuadratureRule(q; n = 250) = QuadratureRule(q, n) diff --git a/src/quadrules.jl b/src/quadrules.jl new file mode 100644 index 00000000..956d4377 --- /dev/null +++ b/src/quadrules.jl @@ -0,0 +1,39 @@ +function evalrule(f, p, lb, ub, nodes, weights) + scale = map((u, l) -> (u - l) / 2, ub, lb) + shift = (lb + ub) / 2 + f_ = x -> f(x, p) + xw = ((map(*, scale, x) + shift, w) for (x, w) in zip(nodes, weights)) + # we are basically computing sum(w .* f.(x)) + # unroll first loop iteration to get right types + next = iterate(xw) + next === nothing && throw(ArgumentError("empty quadrature rule")) + (x0, w0), state = next + I = w0 * f_(x0) + next = iterate(xw, state) + while next !== nothing + (xi, wi), state = next + I += wi * f_(xi) + next = iterate(xw, state) + end + return prod(scale) * I +end + +function init_cacheval(alg::QuadratureRule, ::IntegralProblem) + return alg.q(alg.n) +end + +function Integrals.__solvebp_call(cache::IntegralCache, alg::QuadratureRule, + sensealg, lb, ub, p; + reltol = nothing, abstol = nothing, + maxiters = nothing) + prob = build_problem(cache) + if isinplace(prob) + error("QuadratureRule does not support inplace integrands.") + end + @assert prob.batch == 0 + + val = evalrule(cache.f, cache.p, lb, ub, cache.cacheval...) + + err = nothing + SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success) +end diff --git a/test/quadrule_tests.jl b/test/quadrule_tests.jl new file mode 100644 index 00000000..8cbfe266 --- /dev/null +++ b/test/quadrule_tests.jl @@ -0,0 +1,101 @@ +using Integrals +using FastGaussQuadrature +using StaticArrays +using Test + +f = (x, p) -> prod(y -> cos(p * y), x) +exact_f = (lb, ub, p) -> prod(lu -> (sin(p * lu[2]) - sin(p * lu[1])) / p, zip(lb, ub)) + +# single dim + +""" + trapz(n::Integer) + +Return the weights and nodes on the standard interval [-1,1] of the [trapezoidal +rule](https://en.wikipedia.org/wiki/Trapezoidal_rule). +""" +function trapz(n::Integer) + @assert n > 1 + r = range(-1, 1, length = n) + x = collect(r) + halfh = step(r) / 2 + h = step(r) + w = [(i == 1) || (i == n) ? halfh : h for i in 1:n] + return (x, w) +end + +alg = QuadratureRule(trapz, n = 1000) + +lb = -1.2 +ub = 3.5 +p = 2.0 + +prob = IntegralProblem(f, lb, ub, p) +u = solve(prob, alg).u + +@test u≈exact_f(lb, ub, p) rtol=1e-3 + +# multi-dim + +# here we just form a tensor product of 1d rules to make a 2d rule +function trapz2(n) + x, w = trapz(n) + return [SVector(y, z) for (y, z) in Iterators.product(x, x)], w .* w' +end + +alg = QuadratureRule(trapz2, n = 100) + +lb = SVector(-1.2, -1.0) +ub = SVector(3.5, 3.7) +p = 1.2 + +prob = IntegralProblem(f, lb, ub, p) +u = solve(prob, alg).u + +@test u≈exact_f(lb, ub, p) rtol=1e-3 + +# 1d with inf limits + +g = (x, p) -> p / (x^2 + p^2) + +alg = QuadratureRule(gausslegendre, n = 1000) + +lb = -Inf +ub = Inf +p = 1.0 + +prob = IntegralProblem(g, lb, ub, p) + +@test solve(prob, alg).u≈pi rtol=1e-4 + +# 1d with nout + +g2 = (x, p) -> [p[1] / (x^2 + p[1]^2), p[2] / (x^2 + p[2]^2)] + +alg = QuadratureRule(gausslegendre, n = 1000) + +lb = -Inf +ub = Inf +p = (1.0, 1.3) + +prob = IntegralProblem(g2, lb, ub, p) + +@test solve(prob, alg).u≈[pi, pi] rtol=1e-4 + +#= derivative tests + +using Zygote + +function testf(lb, ub, p, f = f) + prob = IntegralProblem(f, lb, ub, p) + solve(prob, QuadratureRule(trapz, n=200))[1] +end + +lb = -1.2 +ub = 2.0 +p = 3.1 + +dp = Zygote.gradient(p -> testf(lb, ub, p), p) + +@test dp ≈ f(ub, p)-f(lb, p) rtol=1e-4 +=# diff --git a/test/runtests.jl b/test/runtests.jl index ef990ac6..a79dfef7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,3 +22,6 @@ end @time @safetestset "Gaussian Quadrature Tests" begin include("gaussian_quadrature_tests.jl") end +@time @safetestset "QuadratureFunction Tests" begin + include("quadrule_tests.jl") +end