Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add --QuadratureFunction-- QuadratureRule #176

Merged
merged 7 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/solvers/IntegralSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ The following algorithms are available:
- `CubaDivonne`: Divonne from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations.
- `CubaCuhre`: Cuhre from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations.
- `GaussLegendre`: Uses Gauss-Legendre quadrature with nodes and weights from FastGaussQuadrature.jl.
- `QuadratureRule`: Accepts a user-defined function that returns nodes and weights.

```@docs
QuadGKJL
HCubatureJL
VEGAS
GaussLegendre
QuadratureRule
```
3 changes: 2 additions & 1 deletion src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ include("common.jl")
include("init.jl")
include("algorithms.jl")
include("infinity_handling.jl")
include("quadrules.jl")

abstract type QuadSensitivityAlg end
struct ReCallVJP{V}
Expand Down Expand Up @@ -147,5 +148,5 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p;
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
end

export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre
export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre, QuadratureRule
end # module
21 changes: 21 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,24 @@ function GaussLegendre(; n = 250, subintervals = 1, nodes = nothing, weights = n
end
return GaussLegendre(nodes, weights, subintervals)
end

"""
QuadratureRule(q; n=250)

Algorithm to construct and evaluate a quadrature rule `q` of `n` points computed from the
inputs as `x, w = q(n)`. It assumes the nodes and weights are for the standard interval
`[-1, 1]^d` in `d` dimensions, and rescales the nodes to the specific hypercube being
solved. The nodes `x` may be scalars in 1d or vectors in arbitrary dimensions, and the
weights `w` must be scalar. The algorithm computes the quadrature rule `sum(w .* f.(x))` and
the caller must check that the result is converged with respect to `n`.
"""
struct QuadratureRule{Q} <: SciMLBase.AbstractIntegralAlgorithm
q::Q
n::Int
function QuadratureRule(q::Q, n::Integer) where {Q}
n > 0 ||
throw(ArgumentError("Cannot use a nonpositive number of quadrature nodes."))
return new{Q}(q, n)
end
end
QuadratureRule(q; n = 250) = QuadratureRule(q, n)
39 changes: 39 additions & 0 deletions src/quadrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
function evalrule(f, p, lb, ub, nodes, weights)
scale = map((u, l) -> (u - l) / 2, ub, lb)
shift = (lb + ub) / 2
f_ = x -> f(x, p)
xw = ((map(*, scale, x) + shift, w) for (x, w) in zip(nodes, weights))
# we are basically computing sum(w .* f.(x))
# unroll first loop iteration to get right types
next = iterate(xw)
next === nothing && throw(ArgumentError("empty quadrature rule"))
(x0, w0), state = next
I = w0 * f_(x0)
next = iterate(xw, state)
while next !== nothing
(xi, wi), state = next
I += wi * f_(xi)
next = iterate(xw, state)
end
return prod(scale) * I
end

function init_cacheval(alg::QuadratureRule, ::IntegralProblem)
return alg.q(alg.n)
end

function Integrals.__solvebp_call(cache::IntegralCache, alg::QuadratureRule,
sensealg, lb, ub, p;
reltol = nothing, abstol = nothing,
maxiters = nothing)
prob = build_problem(cache)
if isinplace(prob)
error("QuadratureRule does not support inplace integrands.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason it doesn't? That seems like it should be possible. I guess open an issue on this as something to do later?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't know the element type of the output array to allocate. This will be fixed by the IntegralFunction PR

end
@assert prob.batch == 0

val = evalrule(cache.f, cache.p, lb, ub, cache.cacheval...)

err = nothing
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
end
101 changes: 101 additions & 0 deletions test/quadrule_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
using Integrals
using FastGaussQuadrature
using StaticArrays
using Test

f = (x, p) -> prod(y -> cos(p * y), x)
exact_f = (lb, ub, p) -> prod(lu -> (sin(p * lu[2]) - sin(p * lu[1])) / p, zip(lb, ub))

# single dim

"""
trapz(n::Integer)

Return the weights and nodes on the standard interval [-1,1] of the [trapezoidal
rule](https://en.wikipedia.org/wiki/Trapezoidal_rule).
"""
function trapz(n::Integer)
@assert n > 1
r = range(-1, 1, length = n)
x = collect(r)
halfh = step(r) / 2
h = step(r)
w = [(i == 1) || (i == n) ? halfh : h for i in 1:n]
return (x, w)
end

alg = QuadratureRule(trapz, n = 1000)

lb = -1.2
ub = 3.5
p = 2.0

prob = IntegralProblem(f, lb, ub, p)
u = solve(prob, alg).u

@test u≈exact_f(lb, ub, p) rtol=1e-3

# multi-dim

# here we just form a tensor product of 1d rules to make a 2d rule
function trapz2(n)
x, w = trapz(n)
return [SVector(y, z) for (y, z) in Iterators.product(x, x)], w .* w'
end

alg = QuadratureRule(trapz2, n = 100)

lb = SVector(-1.2, -1.0)
ub = SVector(3.5, 3.7)
p = 1.2

prob = IntegralProblem(f, lb, ub, p)
u = solve(prob, alg).u

@test u≈exact_f(lb, ub, p) rtol=1e-3

# 1d with inf limits

g = (x, p) -> p / (x^2 + p^2)

alg = QuadratureRule(gausslegendre, n = 1000)

lb = -Inf
ub = Inf
p = 1.0

prob = IntegralProblem(g, lb, ub, p)

@test solve(prob, alg).u≈pi rtol=1e-4

# 1d with nout

g2 = (x, p) -> [p[1] / (x^2 + p[1]^2), p[2] / (x^2 + p[2]^2)]

alg = QuadratureRule(gausslegendre, n = 1000)

lb = -Inf
ub = Inf
p = (1.0, 1.3)

prob = IntegralProblem(g2, lb, ub, p)

@test solve(prob, alg).u≈[pi, pi] rtol=1e-4

#= derivative tests

using Zygote

function testf(lb, ub, p, f = f)
prob = IntegralProblem(f, lb, ub, p)
solve(prob, QuadratureRule(trapz, n=200))[1]
end

lb = -1.2
ub = 2.0
p = 3.1

dp = Zygote.gradient(p -> testf(lb, ub, p), p)

@test dp ≈ f(ub, p)-f(lb, p) rtol=1e-4
=#
3 changes: 3 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ end
@time @safetestset "Gaussian Quadrature Tests" begin
include("gaussian_quadrature_tests.jl")
end
@time @safetestset "QuadratureFunction Tests" begin
include("quadrule_tests.jl")
end