Skip to content

Commit

Permalink
rename to QuadratureRule
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Sep 17, 2023
1 parent a7d495d commit 3c26875
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/src/solvers/IntegralSolvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The following algorithms are available:
- `CubaDivonne`: Divonne from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations.
- `CubaCuhre`: Cuhre from Cuba.jl. Requires `using IntegralsCuba`. Works only for `>1`-dimensional integrations.
- `GaussLegendre`: Uses Gauss-Legendre quadrature with nodes and weights from FastGaussQuadrature.jl.
- `QuadratureFunction`: Accepts a user-defined function that returns nodes and weights.
- `QuadratureRule`: Accepts a user-defined function that returns nodes and weights.

```@docs
QuadGKJL
Expand Down
2 changes: 1 addition & 1 deletion src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,5 +148,5 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, lb, ub, p;
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
end

export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre, QuadratureFunction
export QuadGKJL, HCubatureJL, VEGAS, GaussLegendre, QuadratureRule
end # module
8 changes: 4 additions & 4 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ function GaussLegendre(; n = 250, subintervals = 1, nodes = nothing, weights = n
end

"""
QuadratureFunction(q; n=250)
QuadratureRule(q; n=250)
Algorithm to construct and evaluate a quadrature rule `q` of `n` points computed from the
inputs as `x, w = q(n)`. It assumes the nodes and weights are for the standard interval
Expand All @@ -133,13 +133,13 @@ solved. The nodes `x` may be scalars in 1d or vectors in arbitrary dimensions, a
weights `w` must be scalar. The algorithm computes the quadrature rule `sum(w .* f.(x))` and
the caller must check that the result is converged with respect to `n`.
"""
struct QuadratureFunction{Q} <: SciMLBase.AbstractIntegralAlgorithm
struct QuadratureRule{Q} <: SciMLBase.AbstractIntegralAlgorithm
q::Q
n::Int
function QuadratureFunction(q::Q, n::Integer) where {Q}
function QuadratureRule(q::Q, n::Integer) where {Q}
n > 0 ||
throw(ArgumentError("Cannot use a nonpositive number of quadrature nodes."))
return new{Q}(q, n)
end
end
QuadratureFunction(q; n = 250) = QuadratureFunction(q, n)
QuadratureRule(q; n = 250) = QuadratureRule(q, n)
7 changes: 3 additions & 4 deletions src/quadrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@ function evalrule(f, p, lb, ub, nodes, weights)
return prod(scale) * I
end

function init_cacheval(alg::QuadratureFunction, ::IntegralProblem)
function init_cacheval(alg::QuadratureRule, ::IntegralProblem)
return alg.q(alg.n)
end

function Integrals.__solvebp_call(cache::IntegralCache, alg::QuadratureFunction,
function Integrals.__solvebp_call(cache::IntegralCache, alg::QuadratureRule,
sensealg, lb, ub, p;
reltol = nothing, abstol = nothing,
maxiters = nothing)
prob = build_problem(cache)
if isinplace(prob)
error("QuadratureFunction does not support inplace integrandss.")
error("QuadratureRule does not support inplace integrandss.")
end
@assert prob.batch == 0
@assert prob.nout == 1

val = evalrule(cache.f, cache.p, lb, ub, cache.cacheval...)

Expand Down
23 changes: 19 additions & 4 deletions test/quadrule_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function trapz(n::Integer)
return (x, w)
end

alg = QuadratureFunction(trapz, n = 1000)
alg = QuadratureRule(trapz, n = 1000)

lb = -1.2
ub = 3.5
Expand All @@ -43,7 +43,7 @@ function trapz2(n)
return [SVector(y, z) for (y, z) in Iterators.product(x, x)], w .* w'
end

alg = QuadratureFunction(trapz2, n = 100)
alg = QuadratureRule(trapz2, n = 100)

lb = SVector(-1.2, -1.0)
ub = SVector(3.5, 3.7)
Expand All @@ -58,7 +58,7 @@ u = solve(prob, alg).u

g = (x, p) -> p / (x^2 + p^2)

alg = QuadratureFunction(gausslegendre, n = 1000)
alg = QuadratureRule(gausslegendre, n = 1000)

lb = -Inf
ub = Inf
Expand All @@ -68,13 +68,28 @@ prob = IntegralProblem(g, lb, ub, p)

@test solve(prob, alg).upi rtol=1e-4

# 1d with nout

g2 = (x, p) -> [p[1] / (x^2 + p[1]^2), p[2] / (x^2 + p[2]^2)]

alg = QuadratureRule(gausslegendre, n = 1000)

lb = -Inf
ub = Inf
p = (1.0, 1.3)

prob = IntegralProblem(g2, lb, ub, p)

@test solve(prob, alg).u[pi,pi] rtol=1e-4


#= derivative tests
using Zygote
function testf(lb, ub, p, f = f)
prob = IntegralProblem(f, lb, ub, p)
solve(prob, QuadratureFunction(trapz, n=200))[1]
solve(prob, QuadratureRule(trapz, n=200))[1]
end
lb = -1.2
Expand Down

0 comments on commit 3c26875

Please sign in to comment.