From 9f1d026ad9698524cbaefe283dc811bfcf0160ee Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 16 Aug 2019 00:25:58 +0200 Subject: [PATCH 1/2] Add Wfact and Wfact_t support --- src/functionwrapper.jl | 47 ++++++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/src/functionwrapper.jl b/src/functionwrapper.jl index 119dc84..6bc38b8 100644 --- a/src/functionwrapper.jl +++ b/src/functionwrapper.jl @@ -1,3 +1,29 @@ +# convenience macro +macro wrap_h(signature) + Meta.isexpr(signature, :call) || + throw(ArgumentError("signature has to be a function call expression")) + + name = signature.args[1] + args = signature.args[2:end] + args_wo_h = [arg for arg in args if arg !== :h] + + quote + if f.$name === nothing + nothing + else + if isinplace(f) + let _f = f.$name, h = h + ($(args_wo_h...),) -> _f($(args...)) + end + else + let _f = f.$name, h = h + ($(args_wo_h[2:end]...),) -> _f($(args[2:end]...)) + end + end + end + end |> esc +end + struct ODEFunctionWrapper{iip,F,H,TMM,Ta,Tt,TJ,JP,TW,TWt,TPJ,S,TCV} <: DiffEqBase.AbstractODEFunction{iip} f::F h::H @@ -14,26 +40,17 @@ struct ODEFunctionWrapper{iip,F,H,TMM,Ta,Tt,TJ,JP,TW,TWt,TPJ,S,TCV} <: DiffEqBas end function ODEFunctionWrapper(f::DDEFunction, h) - if f.jac === nothing - jac = nothing - else - if isinplace(f) - jac = let f_jac = f.jac, h = h - (J, u, p, t) -> f_jac(J, u, h, p, t) - end - else - jac = let f_jac = f.jac, h = h - (u, p, t) -> f_jac(u, h, p, t) - end - end - end + # wrap functions + jac = @wrap_h jac(J, u, h, p, t) + Wfact = @wrap_h Wfact(W, u, h, p, dtgamma, t) + Wfact_t = @wrap_h Wfact_t(W, u, h, p, dtgamma, t) ODEFunctionWrapper{isinplace(f),typeof(f.f),typeof(h),typeof(f.mass_matrix), typeof(f.analytic),typeof(f.tgrad),typeof(jac), - typeof(f.jac_prototype),typeof(f.Wfact),typeof(f.Wfact_t), + typeof(f.jac_prototype),typeof(Wfact),typeof(Wfact_t), typeof(f.paramjac),typeof(f.syms),typeof(f.colorvec)}( f.f, h, f.mass_matrix, f.analytic, f.tgrad, jac, - f.jac_prototype, f.Wfact, f.Wfact_t, f.paramjac, f.syms, + f.jac_prototype, Wfact, Wfact_t, f.paramjac, f.syms, f.colorvec) end From 2ae59d30d1393a38b4e2495692880a6a27f958ab Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 16 Aug 2019 00:28:10 +0200 Subject: [PATCH 2/2] Update Jacobian and Wfact tests --- test/interface/jacobian.jl | 169 ++++++++++++++++++++++++++++++++----- 1 file changed, 150 insertions(+), 19 deletions(-) diff --git a/test/interface/jacobian.jl b/test/interface/jacobian.jl index 0b110ac..7a7229c 100644 --- a/test/interface/jacobian.jl +++ b/test/interface/jacobian.jl @@ -8,49 +8,180 @@ using Test nothing end - function g(J, u, h, p, t) + njacs = Ref(0) + function jac(J, u, h, p, t) + njacs[] += 1 J[1, 1] = 1 - h(p, t - 1)[1] nothing end + nWfacts = Ref(0) + function Wfact(W, u, h, p, dtgamma, t) + nWfacts[] += 1 + W[1,1] = dtgamma * (1 - h(p, t - 1)[1]) - 1 + nothing + end + + nWfact_ts = Ref(0) + function Wfact_t(W, u, h, p, dtgamma, t) + nWfact_ts[] += 1 + W[1,1] = 1 - h(p, t - 1)[1] - inv(dtgamma) + nothing + end + h(p, t) = [0.0] # define problems - prob_wo_jac = DDEProblem(DDEFunction{true}(f), [1.0], h, (0.0, 40.0); - constant_lags = [1]) - prob_w_jac = DDEProblem(DDEFunction{true}(f; jac = g), [1.0], h, (0.0, 40.0); - constant_lags = [1]) + prob = DDEProblem(DDEFunction{true}(f), [1.0], h, (0.0, 40.0); constant_lags = [1]) + prob_jac = remake(prob; f = DDEFunction{true}(f; jac = jac)) + prob_Wfact = remake(prob; f = DDEFunction{true}(f; Wfact = Wfact)) + prob_Wfact_t = remake(prob; f = DDEFunction{true}(f; Wfact_t = Wfact_t)) # compute solutions for alg in (Rosenbrock23(), TRBDF2()) - sol_wo_jac = solve(prob_wo_jac, MethodOfSteps(alg)) - sol_w_jac = solve(prob_w_jac, MethodOfSteps(alg)) + sol = solve(prob, MethodOfSteps(alg)) + + ## Jacobian + njacs[] = 0 + sol_jac = solve(prob_jac, MethodOfSteps(alg)) + + # check number of function evaluations + @test !iszero(njacs[]) + @test njacs[] == sol_jac.destats.njacs + if alg isa Rosenbrock23 + @test njacs[] == sol_jac.destats.nw + else + @test_broken njacs[] == sol_jac.destats.nw + end + + # check resulting solution + @test sol.t ≈ sol_jac.t + @test sol.u ≈ sol_jac.u + + ## Wfact + nWfacts[] = 0 + sol_Wfact = solve(prob_Wfact, MethodOfSteps(alg)) + + # check number of function evaluations + if alg isa Rosenbrock23 + @test !iszero(nWfacts[]) + @test nWfacts[] == njacs[] + @test iszero(sol_Wfact.destats.njacs) + else + @test_broken !iszero(nWfacts[]) + @test_broken nWfacts[] == njacs[] + @test_broken iszero(sol_Wfact.destats.njacs) + end + @test_broken nWfacts[] == sol_Wfact.destats.nw + + # check resulting solution + @test sol.t ≈ sol_Wfact.t + @test sol.u ≈ sol_Wfact.u + + ## Wfact_t + nWfact_ts[] = 0 + sol_Wfact_t = solve(prob_Wfact_t, MethodOfSteps(alg)) + + # check number of function evaluations + if alg isa Rosenbrock23 + @test_broken !iszero(nWfact_ts[]) + @test_broken nWfact_ts[] == njacs[] + @test_broken iszero(sol_Wfact_t.destats.njacs) + else + @test !iszero(nWfact_ts[]) + @test_broken nWfact_ts[] == njacs[] + @test iszero(sol_Wfact_t.destats.njacs) + end + @test_broken nWfact_ts[] == sol_Wfact_t.destats.nw - @test sol_wo_jac.t ≈ sol_w_jac.t - @test sol_wo_jac.u ≈ sol_w_jac.u + # check resulting solution + if alg isa Rosenbrock23 + @test sol.t ≈ sol_Wfact_t.t + @test sol.u ≈ sol_Wfact_t.u + else + @test_broken sol.t ≈ sol_Wfact_t.t + @test_broken sol.u ≈ sol_Wfact_t.u + end end end @testset "out-of-place" begin # define functions (Hutchinson's equation) - f(u, h, p, t) = [u[1] * (1 - h(p, t - 1)[1])] + f(u, h, p, t) = u[1] .* (1 .- h(p, t - 1)) - g(u, h, p, t) = fill(1 - h(p, t - 1)[1], 1, 1) + njacs = Ref(0) + function jac(u, h, p, t) + njacs[] += 1 + reshape(1 .- h(p, t - 1), 1, 1) + end + + nWfacts = Ref(0) + function Wfact(u, h, p, dtgamma, t) + nWfacts[] += 1 + reshape(dtgamma .* (1 .- h(p, t - 1)) .- 1, 1, 1) + end + + nWfact_ts = Ref(0) + function Wfact_t(u, h, p, dtgamma, t) + nWfact_ts[] += 1 + reshape((1 - inv(dtgamma)) .- h(p, t - 1), 1, 1) + end h(p, t) = [0.0] # define problems - prob_wo_jac = DDEProblem(DDEFunction{false}(f), [1.0], h, (0.0, 40.0); - constant_lags = [1]) - prob_w_jac = DDEProblem(DDEFunction{false}(f; jac = g), [1.0], h, (0.0, 40.0); - constant_lags = [1]) + prob = DDEProblem(DDEFunction{false}(f), [1.0], h, (0.0, 40.0); constant_lags = [1]) + prob_jac = remake(prob; f = DDEFunction{false}(f; jac = jac)) + prob_Wfact = remake(prob; f = DDEFunction{false}(f; Wfact = Wfact)) + prob_Wfact_t = remake(prob; f = DDEFunction{false}(f; Wfact_t = Wfact_t)) # compute solutions for alg in (Rosenbrock23(), TRBDF2()) - sol_wo_jac = solve(prob_wo_jac, MethodOfSteps(alg)) - sol_w_jac = solve(prob_w_jac, MethodOfSteps(alg)) + sol = solve(prob, MethodOfSteps(alg)) + + ## Jacobian + njacs[] = 0 + sol_jac = solve(prob_jac, MethodOfSteps(alg)) + + # check number of function evaluations + @test !iszero(njacs[]) + @test_broken njacs[] == sol_jac.destats.njacs + if alg isa Rosenbrock23 + @test njacs[] == sol_jac.destats.nw + else + @test_broken njacs[] == sol_jac.destats.nw + end + + # check resulting solution + @test sol.t ≈ sol_jac.t + @test sol.u ≈ sol_jac.u + + ## Wfact + nWfacts[] = 0 + sol_Wfact = solve(prob_Wfact, MethodOfSteps(alg)) + + # check number of function evaluations + @test_broken !iszero(nWfacts[]) + @test_broken nWfacts[] == njacs[] + @test_broken iszero(sol_Wfact.destats.njacs) + @test_broken nWfacts[] == sol_Wfact.destats.nw + + # check resulting solution + @test sol.t ≈ sol_Wfact.t + @test sol.u ≈ sol_Wfact.u + + ## Wfact_t + nWfact_ts[] = 0 + sol_Wfact_t = solve(prob_Wfact_t, MethodOfSteps(alg)) + + # check number of function evaluations + @test_broken !iszero(nWfact_ts[]) + @test_broken nWfact_ts[] == njacs[] + @test_broken iszero(sol_Wfact_ts.destats.njacs) + @test_broken nWfact_ts[] == sol_Wfact_t.destats.nw - @test sol_wo_jac.t ≈ sol_w_jac.t - @test sol_wo_jac.u ≈ sol_w_jac.u + # check resulting solution + @test sol.t ≈ sol_Wfact_t.t + @test sol.u ≈ sol_Wfact_t.u end end