Skip to content

Commit

Permalink
Merge pull request #1 from lxvm/IlianPihlajamaa/master
Browse files Browse the repository at this point in the history
add init interface for SampledIntegralProblem
  • Loading branch information
IlianPihlajamaa authored Sep 21, 2023
2 parents e7aefc9 + 9ab8ab2 commit 6ba2e3c
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ MonteCarloIntegration = "0.0.1, 0.0.2, 0.0.3"
QuadGK = "2.5"
Reexport = "0.2, 1.0"
Requires = "1"
SciMLBase = "1.70"
SciMLBase = "1.98"
Zygote = "0.4.22, 0.5, 0.6"
julia = "1.6"

Expand Down
47 changes: 47 additions & 0 deletions docs/src/tutorials/caching_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,50 @@ Note that the types of these variables is not allowed to change.
If it is necessary to change the integrand `f` instead of defining a new
`IntegralProblem`, consider using
[FunctionWrappers.jl](https://github.com/yuyichao/FunctionWrappers.jl).

## Caching for sampled integral problems

For sampled integral problems, it is possible to cache the weights and reuse
them for multiple data sets.
```@example cache2
using Integrals
x = 0.0:0.1:1.0
y = sin.(x)
prob = SampledIntegralProblem(y, x)
alg = TrapezoidalRule()
cache = init(prob, alg)
sol1 = solve!(cache)
```

```@example cache2
cache.y = cos.(x) # use .= to update in-place
sol2 = solve!(cache)
```
If the grid is modified, the weights are recomputed.
```@example cache2
cache.x = 0.0:0.2:2.0
cache.y = sin.(cache.x)
sol3 = solve!(cache)
```

For multi-dimensional datasets, the integration dimension can also be changed
```@example cache3
using Integrals
x = 0.0:0.1:1.0
y = sin.(x) .* cos.(x')
prob = SampledIntegralProblem(y, x)
alg = TrapezoidalRule()
cache = init(prob, alg)
sol1 = solve!(cache)
```

```@example cache3
cache.dim = 1
sol2 = solve!(cache)
```
67 changes: 61 additions & 6 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ function SciMLBase.solve(prob::IntegralProblem,
solve!(init(prob, alg; kwargs...))
end

function SciMLBase.solve(prob::SampledIntegralProblem,
alg::SciMLBase.AbstractIntegralAlgorithm;
kwargs...)
__solvebp(prob, alg; kwargs...)
end

function SciMLBase.solve!(cache::IntegralCache)
__solvebp(cache, cache.alg, cache.sensealg, cache.lb, cache.ub, cache.p;
cache.kwargs...)
Expand All @@ -101,3 +95,64 @@ function __solvebp_call(cache::IntegralCache, args...; kwargs...)
__solvebp_call(build_problem(cache), args...; kwargs...)
end


mutable struct SampledIntegralCache{Y, X, D, PK, A, K, Tc}
y::Y
x::X
dim::D
prob_kwargs::PK
alg::A
kwargs::K
isfresh::Bool # state of whether weights have been calculated
cacheval::Tc # store alg weights here
end

function Base.setproperty!(cache::SampledIntegralCache, name::Symbol, x)
if name === :x
setfield!(cache, :isfresh, true)
end
setfield!(cache, name, x)
end

function SciMLBase.init(prob::SampledIntegralProblem,
alg::SciMLBase.AbstractIntegralAlgorithm;
kwargs...)
NamedTuple(kwargs) == NamedTuple() || throw(ArgumentError("There are no keyword arguments allowed to `solve`"))

cacheval = init_cacheval(alg, prob)
isfresh = true

SampledIntegralCache(
prob.y,
prob.x,
prob.dim,
prob.kwargs,
alg,
kwargs,
isfresh,
cacheval)
end


"""
```julia
solve(prob::SampledIntegralProblem, alg::SciMLBase.AbstractIntegralAlgorithm; kwargs...)
```
## Keyword Arguments
There are no keyword arguments used to solve `SampledIntegralProblem`s
"""
function SciMLBase.solve(prob::SampledIntegralProblem,
alg::SciMLBase.AbstractIntegralAlgorithm;
kwargs...)
solve!(init(prob, alg; kwargs...))
end

function SciMLBase.solve!(cache::SampledIntegralCache)
__solvebp(cache, cache.alg; cache.kwargs...)
end

function build_problem(cache::SampledIntegralCache)
SampledIntegralProblem(cache.y, cache.x; dim = dimension(cache.dim), cache.prob_kwargs...)
end
58 changes: 31 additions & 27 deletions src/sampled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Base.eltype(w::UniformWeights) = typeof(w.h)
Base.size(w::UniformWeights) = (length(w), )

Check warning on line 9 in src/sampled.jl

View check run for this annotation

Codecov / codecov/patch

src/sampled.jl#L7-L9

Added lines #L7 - L9 were not covered by tests

# must contain field `x` which are the sampling points
abstract type NonuniformWeights <: AbstractWeights end
abstract type NonuniformWeights <: AbstractWeights end
@inline Base.iterate(w::NonuniformWeights) = (0 == length(w.x)) ? nothing : (w[firstindex(w.x)], firstindex(w.x))
@inline Base.iterate(w::NonuniformWeights, i) = (i == lastindex(w.x)) ? nothing : (w[i+1], i+1)
Base.length(w::NonuniformWeights) = length(w.x)
Expand All @@ -22,46 +22,50 @@ _eachslice(data::AbstractArray{T, 1}; dims=ndims(data)) where T = data

# these can be removed when the Val(dim) is removed from SciMLBase
dimension(::Val{D}) where {D} = D

Check warning on line 24 in src/sampled.jl

View check run for this annotation

Codecov / codecov/patch

src/sampled.jl#L24

Added line #L24 was not covered by tests
dimension(D::Int) = D
dimension(D::Int) = D


function evalrule(data::AbstractArray, weights, dim)
f = _eachslice(data, dims=dim)
f1, statef = iterate(f)
w1, statew = iterate(weights)
fw = zip(_eachslice(data, dims=dim), weights)
next = iterate(fw)
next === nothing && throw(ArgumentError("No points to integrate"))
(f1, w1), state = next
out = w1 * f1
nextf = iterate(f, statef)
nextw = iterate(weights, statew)
next = iterate(fw, state)
if isbits(out)
while nextf !== nothing
fi, statef = nextf
wi, statew = nextw
while next !== nothing
(fi, wi), state = next
out += wi * fi
nextf = iterate(f, statef)
nextw = iterate(weights, statew)
next = iterate(fw, state)
end
else
while nextf !== nothing
fi, statef = nextf
wi, statew = nextw
else
while next !== nothing
(fi, wi), state = next
out .+= wi .* fi
nextf = iterate(f, statef)
nextw = iterate(weights, statew)
next = iterate(fw, state)
end
end
return out
return out
end


# can be reused for other sampled rules
function __solvebp_call(prob::SampledIntegralProblem, alg::TrapezoidalRule; kwargs...)
dim = dimension(prob.dim)
# can be reused for other sampled rules, which should implement find_weights(x, alg)

function init_cacheval(alg::SciMLBase.AbstractIntegralAlgorithm, prob::SampledIntegralProblem)
find_weights(prob.x, alg)
end

function __solvebp_call(cache::SampledIntegralCache, alg::SciMLBase.AbstractIntegralAlgorithm; kwargs...)
dim = dimension(cache.dim)
err = nothing
data = prob.y
grid = prob.x
weights = find_weights(grid, alg)
data = cache.y
grid = cache.x
if cache.isfresh
cache.cacheval = find_weights(grid, alg)
cache.isfresh = false
end
weights = cache.cacheval
I = evalrule(data, weights, dim)
prob = build_problem(cache)
return SciMLBase.build_solution(prob, alg, I, err, retcode = ReturnCode.Success)
end


4 changes: 2 additions & 2 deletions src/trapezoidal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct TrapezoidalNonuniformWeights{X<:AbstractArray} <: NonuniformWeights
x::X
end

@inline function Base.getindex(w::TrapezoidalNonuniformWeights, i)
@inline function Base.getindex(w::TrapezoidalNonuniformWeights, i)
x = w.x
(i == firstindex(x)) && return (x[i + 1] - x[i])*0.5
(i == lastindex(x)) && return (x[i] - x[i - 1])*0.5
Expand All @@ -20,4 +20,4 @@ end
function find_weights(x::AbstractVector, ::TrapezoidalRule)
x isa AbstractRange && return TrapezoidalUniformWeights(length(x), step(x))
return TrapezoidalNonuniformWeights(x)
end
end
43 changes: 42 additions & 1 deletion test/sampled_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,45 @@ using Integrals, Test
end
end
end
end
end

@testset "Caching interface" begin

x = 0.0:0.1:1.0
y = sin.(x)

prob = SampledIntegralProblem(y, x)
alg = TrapezoidalRule()

cache = init(prob, alg)
sol1 = solve!(cache)

@test sol1 == solve(prob, alg)

cache.y = cos.(x) # use .= to update in-place
sol2 = solve!(cache)

@test sol2 == solve(SampledIntegralProblem(cache.y, cache.x), alg)

cache.x = 0.0:0.2:2.0
cache.y = sin.(cache.x)
sol3 = solve!(cache)

@test sol3 == solve(SampledIntegralProblem(cache.y, cache.x), alg)

x = 0.0:0.1:1.0
y = sin.(x) .* cos.(x')

prob = SampledIntegralProblem(y, x)
alg = TrapezoidalRule()

cache = init(prob, alg)
sol1 = solve!(cache)

@test sol1 == solve(prob, alg)

cache.dim = 1
sol2 = solve!(cache)

@test sol2 == solve(SampledIntegralProblem(y, x, dim=1), alg)
end

0 comments on commit 6ba2e3c

Please sign in to comment.