From 5026be13fc5ff49e4f8eca11fb77a47d717bbafd Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 9 Aug 2023 13:54:47 -0400 Subject: [PATCH] make cache mutable --- Project.toml | 1 - docs/src/tutorials/caching_interface.md | 20 +++++---- src/Integrals.jl | 1 - src/common.jl | 57 +------------------------ test/interface_tests.jl | 15 +++---- 5 files changed, 17 insertions(+), 77 deletions(-) diff --git a/Project.toml b/Project.toml index 9d9aa4ca..1b739b4f 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" [compat] ChainRulesCore = "0.10.7, 1" diff --git a/docs/src/tutorials/caching_interface.md b/docs/src/tutorials/caching_interface.md index 9a6a440d..13fe41de 100644 --- a/docs/src/tutorials/caching_interface.md +++ b/docs/src/tutorials/caching_interface.md @@ -1,7 +1,7 @@ # Integrals with Caching Interface Often, integral solvers allocate memory or reuse quadrature rules for solving different -problems. For example, if one is going to perform +problems. For example, if one is going to solve the same integral for several parameters ```julia using Integrals @@ -11,7 +11,7 @@ alg = QuadGKJL() solve(prob, alg) -prob = remake(prob, f = (x, p) -> cos(x * p)) +prob = remake(prob, p = 15.0) solve(prob, alg) ``` @@ -21,14 +21,14 @@ shown below by directly calling the library ```julia using QuadGK segbuf = QuadGK.alloc_segbuf() +quadgk(x -> sin(14x), 0, 1, segbuf = segbuf) quadgk(x -> sin(15x), 0, 1, segbuf = segbuf) -quadgk(x -> cos(15x), 0, 1, segbuf = segbuf) ``` Integrals.jl's caching interface automates this process to reuse resources if an algorithm supports it and if the necessary types to build the cache can be inferred from `prob`. To do -this with Integrals.jl, you simply `init` a cache, `solve`, replace `f`, and solve again. -This looks like +this with Integrals.jl, you simply `init` a cache, `solve!`, replace `p`, and solve again. +This uses the [SciML `init` interface](https://docs.sciml.ai/SciMLBase/stable/interfaces/Init_Solve/#init-and-the-Iterator-Interface) ```@example cache1 using Integrals @@ -41,10 +41,12 @@ sol1 = solve!(cache) ``` ```@example cache1 -cache = Integrals.set_f(cache, (x, p) -> cos(x * p)) +cache.p = 15.0 sol2 = solve!(cache) ``` -Similar cache-rebuilding functions are provided, including: `set_p`, `set_lb`, and `set_ub`, -each of which provides a new value of `lb`, `ub`, or `p`, respectively. When resetting the -cache, new allocations may be needed if those inferred types change. +The caching interface is intended for updating `p`, `lb`, `ub`, `nout`, and `batch`. +Note that the types of these variables is not allowed to change. +If it is necessary to change the integrand `f` instead of defining a new +`IntegralProblem`, consider using +[FunctionWrappers.jl](https://github.com/yuyichao/FunctionWrappers.jl). diff --git a/src/Integrals.jl b/src/Integrals.jl index d1d80fc1..07bb525e 100644 --- a/src/Integrals.jl +++ b/src/Integrals.jl @@ -7,7 +7,6 @@ end using Reexport, MonteCarloIntegration, QuadGK, HCubature @reexport using SciMLBase using LinearAlgebra -using Setfield include("common.jl") include("init.jl") diff --git a/src/common.jl b/src/common.jl index b58defc8..d13a0165 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,4 +1,4 @@ -struct IntegralCache{iip, F, B, P, PK, A, S, K, Tc} +mutable struct IntegralCache{iip, F, B, P, PK, A, S, K, Tc} iip::Val{iip} f::F lb::B @@ -47,61 +47,6 @@ function SciMLBase.init(prob::IntegralProblem{iip}, cacheval) end -refresh_cacheval(cacheval, alg, prob) = nothing - -""" - set_f(cache, f, [nout=cache.nout]) - -Return a new cache with the new integrand `f`, optionally resetting `nout` at the same time. -""" -function set_f(cache::IntegralCache, f, nout = cache.nout) - prob = remake(build_problem(cache), f = f, nout = nout) - alg = cache.alg - cacheval = cache.cacheval - # lots of type-instability hereafter - @set! cache.f = f - @set! cache.iip = Val(isinplace(f, 3)) - @set! cache.nout = nout - @set! cache.cacheval = refresh_cacheval(cacheval, alg, prob) - return cache -end - -""" - set_lb(cache, lb) - -Return a new cache with new lower limits `lb`. -""" -function set_lb(cache::IntegralCache, lb) - @set! cache.lb = lb - return cache -end - -# since types of lb and ub are constrained, we do not need to refresh cache - -""" - set_ub(cache, ub) - -Return a new cache with new lower limits `ub`. -""" -function set_ub(cache::IntegralCache, ub) - @set! cache.ub = ub - return cache -end - -""" - set_p(cache, p, [refresh=true]) - -Return a new cache with parameters `p`. -""" -function set_p(cache::IntegralCache, p) - prob = remake(build_problem(cache), p = p) - alg = cache.alg - cacheval = cache.cacheval - @set! cache.p = p - @set! cache.cacheval = refresh_cacheval(cacheval, alg, prob) - return cache -end - # Throw error if alg is not provided, as defaults are not implemented. function SciMLBase.solve(::IntegralProblem; kwargs...) checkkwargs(kwargs...) diff --git a/test/interface_tests.jl b/test/interface_tests.jl index 4946dc57..6789f653 100644 --- a/test/interface_tests.jl +++ b/test/interface_tests.jl @@ -360,6 +360,7 @@ end @testset "Caching interface" begin lb, ub = (1.0, 3.0) + p = NaN # the integrands don't actually use this nout = 1 dim = 1 for alg in algs @@ -367,20 +368,14 @@ end continue end for i in 1:length(integrands) - prob = IntegralProblem(integrands[i], lb, ub) + prob = IntegralProblem(integrands[i], lb, ub, p) cache = init(prob, alg, reltol = reltol, abstol = abstol) @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 - lb = 0.5 - cache = Integrals.set_lb(cache, lb) - @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 - ub = 3.5 - cache = Integrals.set_ub(cache, ub) + cache.lb = lb = 0.5 @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 - p = missing # the integrands don't actually use this - cache = Integrals.set_p(cache, p) + cache.ub = ub = 3.5 @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 - f = (x, p) -> integrands[i](x, p) # for lack of creativity, wrap the old integrand - cache = Integrals.set_f(cache, f) + cache.p = Inf @test solve!(cache).u≈exact_sol[i](dim, nout, lb, ub) rtol=1e-2 end end