Skip to content

Commit

Permalink
make cache mutable
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Aug 9, 2023
1 parent 05ab510 commit 5026be1
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 77 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[compat]
ChainRulesCore = "0.10.7, 1"
Expand Down
20 changes: 11 additions & 9 deletions docs/src/tutorials/caching_interface.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Integrals with Caching Interface

Often, integral solvers allocate memory or reuse quadrature rules for solving different
problems. For example, if one is going to perform
problems. For example, if one is going to solve the same integral for several parameters

```julia
using Integrals
Expand All @@ -11,7 +11,7 @@ alg = QuadGKJL()

solve(prob, alg)

prob = remake(prob, f = (x, p) -> cos(x * p))
prob = remake(prob, p = 15.0)
solve(prob, alg)
```

Expand All @@ -21,14 +21,14 @@ shown below by directly calling the library
```julia
using QuadGK
segbuf = QuadGK.alloc_segbuf()
quadgk(x -> sin(14x), 0, 1, segbuf = segbuf)
quadgk(x -> sin(15x), 0, 1, segbuf = segbuf)
quadgk(x -> cos(15x), 0, 1, segbuf = segbuf)
```

Integrals.jl's caching interface automates this process to reuse resources if an algorithm
supports it and if the necessary types to build the cache can be inferred from `prob`. To do
this with Integrals.jl, you simply `init` a cache, `solve`, replace `f`, and solve again.
This looks like
this with Integrals.jl, you simply `init` a cache, `solve!`, replace `p`, and solve again.
This uses the [SciML `init` interface](https://docs.sciml.ai/SciMLBase/stable/interfaces/Init_Solve/#init-and-the-Iterator-Interface)

```@example cache1
using Integrals
Expand All @@ -41,10 +41,12 @@ sol1 = solve!(cache)
```

```@example cache1
cache = Integrals.set_f(cache, (x, p) -> cos(x * p))
cache.p = 15.0
sol2 = solve!(cache)
```

Similar cache-rebuilding functions are provided, including: `set_p`, `set_lb`, and `set_ub`,
each of which provides a new value of `lb`, `ub`, or `p`, respectively. When resetting the
cache, new allocations may be needed if those inferred types change.
The caching interface is intended for updating `p`, `lb`, `ub`, `nout`, and `batch`.
Note that the types of these variables is not allowed to change.
If it is necessary to change the integrand `f` instead of defining a new
`IntegralProblem`, consider using
[FunctionWrappers.jl](https://github.com/yuyichao/FunctionWrappers.jl).
1 change: 0 additions & 1 deletion src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ end
using Reexport, MonteCarloIntegration, QuadGK, HCubature
@reexport using SciMLBase
using LinearAlgebra
using Setfield

include("common.jl")
include("init.jl")
Expand Down
57 changes: 1 addition & 56 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct IntegralCache{iip, F, B, P, PK, A, S, K, Tc}
mutable struct IntegralCache{iip, F, B, P, PK, A, S, K, Tc}
iip::Val{iip}
f::F
lb::B
Expand Down Expand Up @@ -47,61 +47,6 @@ function SciMLBase.init(prob::IntegralProblem{iip},
cacheval)
end

refresh_cacheval(cacheval, alg, prob) = nothing

"""
set_f(cache, f, [nout=cache.nout])
Return a new cache with the new integrand `f`, optionally resetting `nout` at the same time.
"""
function set_f(cache::IntegralCache, f, nout = cache.nout)
prob = remake(build_problem(cache), f = f, nout = nout)
alg = cache.alg
cacheval = cache.cacheval
# lots of type-instability hereafter
@set! cache.f = f
@set! cache.iip = Val(isinplace(f, 3))
@set! cache.nout = nout
@set! cache.cacheval = refresh_cacheval(cacheval, alg, prob)
return cache
end

"""
set_lb(cache, lb)
Return a new cache with new lower limits `lb`.
"""
function set_lb(cache::IntegralCache, lb)
@set! cache.lb = lb
return cache
end

# since types of lb and ub are constrained, we do not need to refresh cache

"""
set_ub(cache, ub)
Return a new cache with new lower limits `ub`.
"""
function set_ub(cache::IntegralCache, ub)
@set! cache.ub = ub
return cache
end

"""
set_p(cache, p, [refresh=true])
Return a new cache with parameters `p`.
"""
function set_p(cache::IntegralCache, p)
prob = remake(build_problem(cache), p = p)
alg = cache.alg
cacheval = cache.cacheval
@set! cache.p = p
@set! cache.cacheval = refresh_cacheval(cacheval, alg, prob)
return cache
end

# Throw error if alg is not provided, as defaults are not implemented.
function SciMLBase.solve(::IntegralProblem; kwargs...)
checkkwargs(kwargs...)
Expand Down
15 changes: 5 additions & 10 deletions test/interface_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -360,27 +360,22 @@ end

@testset "Caching interface" begin
lb, ub = (1.0, 3.0)
p = NaN # the integrands don't actually use this
nout = 1
dim = 1
for alg in algs
if alg_req[alg].min_dim > 1
continue
end
for i in 1:length(integrands)
prob = IntegralProblem(integrands[i], lb, ub)
prob = IntegralProblem(integrands[i], lb, ub, p)
cache = init(prob, alg, reltol = reltol, abstol = abstol)
@test solve!(cache).uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
lb = 0.5
cache = Integrals.set_lb(cache, lb)
@test solve!(cache).uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
ub = 3.5
cache = Integrals.set_ub(cache, ub)
cache.lb = lb = 0.5
@test solve!(cache).uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
p = missing # the integrands don't actually use this
cache = Integrals.set_p(cache, p)
cache.ub = ub = 3.5
@test solve!(cache).uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
f = (x, p) -> integrands[i](x, p) # for lack of creativity, wrap the old integrand
cache = Integrals.set_f(cache, f)
cache.p = Inf
@test solve!(cache).uexact_sol[i](dim, nout, lb, ub) rtol=1e-2
end
end
Expand Down

0 comments on commit 5026be1

Please sign in to comment.