From 5cacc30d614963def9c78f1d152f6caaff576fc9 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Wed, 17 Apr 2024 09:43:46 -0500 Subject: [PATCH] Lint code and add JuliaFormatter GH action --- .github/workflows/format.yml | 30 +++ docs/make.jl | 8 +- examples/box_model.jl | 59 ++++-- examples/heat.jl | 48 +++-- examples/optcontrol.jl | 104 ++++----- examples/optcontrolfunc.jl | 30 +-- examples/printaction.jl | 6 +- src/Checkpointing.jl | 62 +++--- src/Rules/ChainRules.jl | 9 +- src/Rules/EnzymeRules.jl | 10 +- src/Schemes/Online_r2.jl | 394 +++++++++++++++++++++-------------- src/Schemes/Periodic.jl | 51 +++-- src/Schemes/Revolve.jl | 182 +++++++++------- src/Storage/HDF5Storage.jl | 2 +- test/enzyme.jl | 51 ++--- test/multilevel.jl | 19 +- test/output_chkp.jl | 10 +- test/runtests.jl | 78 ++++--- test/speelpenning.jl | 12 +- 19 files changed, 690 insertions(+), 475 deletions(-) create mode 100644 .github/workflows/format.yml diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000..3ed83e5 --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,30 @@ +name: Format suggestions + +on: + pull_request: + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + format: + permissions: + contents: read + pull-requests: write + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v1 + with: + version: 1 + - run: | + julia -e 'using Pkg; Pkg.add("JuliaFormatter")' + julia -e 'using JuliaFormatter; format("."; verbose=true)' + - uses: reviewdog/action-suggester@v1 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + tool_name: JuliaFormatter + fail_on_error: true \ No newline at end of file diff --git a/docs/make.jl b/docs/make.jl index c76f19c..05155c8 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,6 @@ using Pkg -checkpointingspec = PackageSpec(path=joinpath(dirname(@__FILE__), "..")) +checkpointingspec = PackageSpec(path = joinpath(dirname(@__FILE__), "..")) Pkg.develop(checkpointingspec) # when first running instantiate @@ -13,7 +13,7 @@ makedocs( sitename = "Checkpointing.jl", format = Documenter.HTML( prettyurls = Base.get(ENV, "CI", nothing) == "true", - mathengine = Documenter.KaTeX() + mathengine = Documenter.KaTeX(), ), modules = [Checkpointing], repo = "https://github.com/Argonne-National-Laboratory/Checkpointing.jl/blob/{commit}{path}#{line}", @@ -25,7 +25,7 @@ makedocs( "Rules" => "rules.md", "Storage" => "storage.md", "API" => "lib/checkpointing.md", - ] + ], ) deploydocs( @@ -34,4 +34,4 @@ deploydocs( devbranch = "main", devurl = "main", push_preview = true, -) \ No newline at end of file +) diff --git a/examples/box_model.jl b/examples/box_model.jl index 6b81b4d..536f6af 100644 --- a/examples/box_model.jl +++ b/examples/box_model.jl @@ -2,40 +2,44 @@ const blength = [5000.0e5; 1000.0e5; 5000.0e5] ## north-south size of boxes, c const bdepth = [1.0e5; 5.0e5; 4.0e5] ## depth of boxes, centimeters -const delta = bdepth[1]/(bdepth[1] + bdepth[3]) ## constant ratio of two depths +const delta = bdepth[1] / (bdepth[1] + bdepth[3]) ## constant ratio of two depths -const bwidth = 4000.0*1e5 ## box width, centimeters +const bwidth = 4000.0 * 1e5 ## box width, centimeters # box areas -const barea = [blength[1]*bwidth; - blength[2]*bwidth; - blength[3]*bwidth] +const barea = [ + blength[1] * bwidth + blength[2] * bwidth + blength[3] * bwidth +] # box volumes -const bvol = [barea[1]*bdepth[1]; - barea[2]*bdepth[2]; - barea[3]*bdepth[3]] +const bvol = [ + barea[1] * bdepth[1] + barea[2] * bdepth[2] + barea[3] * bdepth[3] +] # parameters that are used to ensure units are in CGS (cent-gram-sec) const hundred = 100.0 const thousand = 1000.0 -const day = 3600.0*24.0 -const year = day*365.0 +const day = 3600.0 * 24.0 +const year = day * 365.0 const Sv = 1e12 ## one Sverdrup (a unit of ocean transport), 1e6 meters^3/second # parameters that appear in box model equations -const u0 = 16.0*Sv/0.0004 +const u0 = 16.0 * Sv / 0.0004 const alpha = 1668e-7 const beta = 0.7811e-3 -const gamma = 1/(300*day) +const gamma = 1 / (300 * day) # robert filter coefficient for the smoother part of the timestep const robert_filter_coeff = 0.25 # freshwater forcing -const FW = [(hundred/year) * 35.0 * barea[1]; -(hundred/year) * 35.0 * barea[1]] +const FW = [(hundred / year) * 35.0 * barea[1]; -(hundred / year) * 35.0 * barea[1]] # restoring atmospheric temperatures const Tstar = [22.0; 0.0] @@ -47,7 +51,7 @@ const Sstar = [36.0; 34.0]; function U_func(dens) - U = u0*(dens[2] - (delta * dens[1] + (1 - delta)*dens[3])) + U = u0 * (dens[2] - (delta * dens[1] + (1 - delta) * dens[3])) return U end @@ -123,9 +127,10 @@ end function forward_func_4_AD(in_now, in_old, out_old, out_now) rho_now = rho_func(in_now) ## compute density u_now = U_func(rho_now) ## compute transport - in_new = timestep_func(in_now, in_old, u_now, 10*day) ## compute new state values + in_new = timestep_func(in_now, in_old, u_now, 10 * day) ## compute new state values for j = 1:6 - in_now[j] = in_now[j] + robert_filter_coeff * (in_new[j] - 2.0 * in_now[j] + in_old[j]) + in_now[j] = + in_now[j] + robert_filter_coeff * (in_new[j] - 2.0 * in_now[j] + in_old[j]) end out_old[:] = in_now out_now[:] = in_new @@ -138,7 +143,7 @@ function advance(box::Box) end function timestepper_for(box::Box, scheme::Scheme, tsteps::Int) - @checkpoint_struct scheme box for i in 1:tsteps + @checkpoint_struct scheme box for i = 1:tsteps advance(box) box.in_now[:] = box.out_old box.in_old[:] = box.out_now @@ -157,7 +162,13 @@ function box_for(scheme::Scheme, tsteps::Int, ::EnzymeTool) dbox = Box(zeros(6), zeros(6), zeros(6), zeros(6), 0) # Compute gradient - autodiff(Enzyme.ReverseWithPrimal, timestepper_for, Duplicated(box, dbox), Const(scheme), Const(tsteps)) + autodiff( + Enzyme.ReverseWithPrimal, + timestepper_for, + Duplicated(box, dbox), + Const(scheme), + Const(tsteps), + ) return box.out_now[1], dbox.in_old end @@ -175,12 +186,12 @@ end function timestepper_while(box::Box, scheme::Scheme, tsteps::Int) - box.i=1 + box.i = 1 @checkpoint_struct scheme box while box.i <= tsteps advance(box) box.in_now[:] = box.out_old box.in_old[:] = box.out_now - box.i = box.i+1 + box.i = box.i + 1 nothing end return box.out_now[1] @@ -196,7 +207,13 @@ function box_while(scheme::Scheme, tsteps::Int, ::EnzymeTool) dbox = Box(zeros(6), zeros(6), zeros(6), zeros(6), 0) # Compute gradient - autodiff(Enzyme.ReverseWithPrimal, timestepper_while, Duplicated(box, dbox), Const(scheme), Const(tsteps)) + autodiff( + Enzyme.ReverseWithPrimal, + timestepper_while, + Duplicated(box, dbox), + Const(scheme), + Const(tsteps), + ) return box.out_now[1], dbox.in_old end diff --git a/examples/heat.jl b/examples/heat.jl index 96cf11c..88ff725 100644 --- a/examples/heat.jl +++ b/examples/heat.jl @@ -15,8 +15,8 @@ function advance(heat::Heat) last = heat.Tlast λ = heat.λ n = heat.n - for i in 2:(n-1) - next[i] = last[i] + λ*(last[i-1]-2*last[i]+last[i+1]) + for i = 2:(n-1) + next[i] = last[i] + λ * (last[i-1] - 2 * last[i] + last[i+1]) end return nothing end @@ -24,8 +24,8 @@ end function sumheat_for(heat::Heat, chkpscheme::Scheme, tsteps::Int64) # AD: Create shadow copy for derivatives - @checkpoint_struct chkpscheme heat for i in 1:tsteps - # checkpoint_struct_for(advance, heat) + @checkpoint_struct chkpscheme heat for i = 1:tsteps + # checkpoint_struct_for(advance, heat) heat.Tlast .= heat.Tnext advance(heat) end @@ -45,8 +45,8 @@ end function heat_for(scheme::Scheme, tsteps::Int, ::EnzymeTool) n = 100 - Δx=0.1 - Δt=0.001 + Δx = 0.1 + Δt = 0.001 # Select μ such that λ ≤ 0.5 for stability with μ = (λ*Δt)/Δx^2 λ = 0.5 @@ -56,19 +56,25 @@ function heat_for(scheme::Scheme, tsteps::Int, ::EnzymeTool) dheat = Heat(zeros(n), zeros(n), n, λ, tsteps) # Boundary conditions - heat.Tnext[1] = 20.0 + heat.Tnext[1] = 20.0 heat.Tnext[end] = 0 # Compute gradient - autodiff(Enzyme.ReverseWithPrimal, sumheat_for, Duplicated(heat, dheat), Const(scheme), Const(tsteps)) + autodiff( + Enzyme.ReverseWithPrimal, + sumheat_for, + Duplicated(heat, dheat), + Const(scheme), + Const(tsteps), + ) return heat.Tnext, dheat.Tnext[2:end-1] end function heat_for(scheme::Scheme, tsteps::Int, ::ZygoteTool) n = 100 - Δx=0.1 - Δt=0.001 + Δx = 0.1 + Δt = 0.001 # Select μ such that λ ≤ 0.5 for stability with μ = (λ*Δt)/Δx^2 λ = 0.5 @@ -76,7 +82,7 @@ function heat_for(scheme::Scheme, tsteps::Int, ::ZygoteTool) heat = Heat(zeros(n), zeros(n), n, λ, tsteps) # Boundary conditions - heat.Tnext[1] = 20.0 + heat.Tnext[1] = 20.0 heat.Tnext[end] = 0 # Compute gradient @@ -87,8 +93,8 @@ end function heat_while(scheme::Scheme, tsteps::Int, ::EnzymeTool) n = 100 - Δx=0.1 - Δt=0.001 + Δx = 0.1 + Δt = 0.001 # Select μ such that λ ≤ 0.5 for stability with μ = (λ*Δt)/Δx^2 λ = 0.5 @@ -98,19 +104,25 @@ function heat_while(scheme::Scheme, tsteps::Int, ::EnzymeTool) dheat = Heat(zeros(n), zeros(n), n, λ, tsteps) # Boundary conditions - heat.Tnext[1] = 20.0 + heat.Tnext[1] = 20.0 heat.Tnext[end] = 0 # Compute gradient - autodiff(Enzyme.ReverseWithPrimal, sumheat_while, Duplicated(heat, dheat), Const(scheme), Const(tsteps)) + autodiff( + Enzyme.ReverseWithPrimal, + sumheat_while, + Duplicated(heat, dheat), + Const(scheme), + Const(tsteps), + ) return heat.Tnext, dheat.Tnext[2:end-1] end function heat_while(scheme::Scheme, tsteps::Int, ::ZygoteTool) n = 100 - Δx=0.1 - Δt=0.001 + Δx = 0.1 + Δt = 0.001 # Select μ such that λ ≤ 0.5 for stability with μ = (λ*Δt)/Δx^2 λ = 0.5 @@ -118,7 +130,7 @@ function heat_while(scheme::Scheme, tsteps::Int, ::ZygoteTool) heat = Heat(zeros(n), zeros(n), n, λ, 1) # Boundary conditions - heat.Tnext[1] = 20.0 + heat.Tnext[1] = 20.0 heat.Tnext[end] = 0 # Compute gradient diff --git a/examples/optcontrol.jl b/examples/optcontrol.jl index 75c6648..007b801 100644 --- a/examples/optcontrol.jl +++ b/examples/optcontrol.jl @@ -10,38 +10,38 @@ using Zygote include("optcontrolfunc.jl") function header() - println("**************************************************************************") - println("* Solution of the optimal control problem *") - println("* *") - println("* J(y) = y_2(1) -> min *") - println("* s.t. dy_1/dt = 0.5*y_1(t) + u(t), y_1(0)=1 *") - println("* dy_2/dt = y_1(t)^2 + 0.5*u(t)^2 y_2(0)=0 *") - println("* *") - println("* the adjoints equations fulfill *") - println("* *") - println("* dl_1/dt = -0.5*l_1(t) - 2*y_1(t)*l_2(t) l_1(1)=0 *") - println("* dl_2/dt = 0 l_2(1)=1 *") - println("* *") - println("* with Revolve for Online and (Multi-Stage) Offline Checkpointing *") - println("* *") - println("**************************************************************************") - - println("**************************************************************************") - println("* The solution of the optimal control problem above is *") - println("* *") - println("* y_1*(t) = (2*e^(3t)+e^3)/(e^(3t/2)*(2+e^3)) *") - println("* y_2*(t) = (2*e^(3t)-e^(6-3t)-2+e^6)/((2+e^3)^2) *") - println("* u*(t) = (2*e^(3t)-e^3)/(e^(3t/2)*(2+e^3)) *") - println("* l_1*(t) = (2*e^(3-t)-2*e^(2t))/(e^(t/2)*(2+e^3)) *") - println("* l_2*(t) = 1 *") - println("* *") - println("**************************************************************************") - - return + println("**************************************************************************") + println("* Solution of the optimal control problem *") + println("* *") + println("* J(y) = y_2(1) -> min *") + println("* s.t. dy_1/dt = 0.5*y_1(t) + u(t), y_1(0)=1 *") + println("* dy_2/dt = y_1(t)^2 + 0.5*u(t)^2 y_2(0)=0 *") + println("* *") + println("* the adjoints equations fulfill *") + println("* *") + println("* dl_1/dt = -0.5*l_1(t) - 2*y_1(t)*l_2(t) l_1(1)=0 *") + println("* dl_2/dt = 0 l_2(1)=1 *") + println("* *") + println("* with Revolve for Online and (Multi-Stage) Offline Checkpointing *") + println("* *") + println("**************************************************************************") + + println("**************************************************************************") + println("* The solution of the optimal control problem above is *") + println("* *") + println("* y_1*(t) = (2*e^(3t)+e^3)/(e^(3t/2)*(2+e^3)) *") + println("* y_2*(t) = (2*e^(3t)-e^(6-3t)-2+e^6)/((2+e^3)^2) *") + println("* u*(t) = (2*e^(3t)-e^3)/(e^(3t/2)*(2+e^3)) *") + println("* l_1*(t) = (2*e^(3-t)-2*e^(2t))/(e^(t/2)*(2+e^3)) *") + println("* l_2*(t) = 1 *") + println("* *") + println("**************************************************************************") + + return end function muoptcontrol(scheme, steps, ::EnzymeTool) - println( "\n STEPS -> number of time steps to perform") + println("\n STEPS -> number of time steps to perform") println("SNAPS -> number of checkpoints") println("INFO = 1 -> calculate only approximate solution") println("INFO = 2 -> calculate approximate solution + takeshots") @@ -56,13 +56,13 @@ function muoptcontrol(scheme, steps, ::EnzymeTool) F = [1.0, 0.0] F_H = [0.0, 0.0] t = 0.0 - h = 1.0/steps + h = 1.0 / steps model = Model(F, F_H, t, h) # Just make sure it's all zero. - bmodel = Model([0.0,0.0], [0.0,0.0], 0.0, 0.0) + bmodel = Model([0.0, 0.0], [0.0, 0.0], 0.0, 0.0) function foo(model::Model) - @checkpoint_struct scheme model for i in 1:steps + @checkpoint_struct scheme model for i = 1:steps model.F_H .= model.F advance(model) model.t += h @@ -74,20 +74,20 @@ function muoptcontrol(scheme, steps, ::EnzymeTool) F = model.F L = bmodel.F - F_opt = Array{Float64, 1}(undef, 2) - L_opt = Array{Float64, 1}(undef, 2) - opt_sol(F_opt,1.0) - opt_lambda(L_opt,0.0) + F_opt = Array{Float64,1}(undef, 2) + L_opt = Array{Float64,1}(undef, 2) + opt_sol(F_opt, 1.0) + opt_lambda(L_opt, 0.0) println("\n\n") - println("y_1*(1) = " , F_opt[1] , " y_2*(1) = " , F_opt[2]) - println("y_1 (1) = " , F[1] , " y_2 (1) = " , F[2] , " \n\n") - println("l_1*(0) = " , L_opt[1] , " l_2*(0) = " , L_opt[2]) - println("l_1 (0) = " , L[1] , " sl_2 (0) = " , L[2] , " ") + println("y_1*(1) = ", F_opt[1], " y_2*(1) = ", F_opt[2]) + println("y_1 (1) = ", F[1], " y_2 (1) = ", F[2], " \n\n") + println("l_1*(0) = ", L_opt[1], " l_2*(0) = ", L_opt[2]) + println("l_1 (0) = ", L[1], " sl_2 (0) = ", L[2], " ") return F, L, F_opt, L_opt end function muoptcontrol(scheme, steps, ::ZygoteTool) - println( "\n STEPS -> number of time steps to perform") + println("\n STEPS -> number of time steps to perform") println("SNAPS -> number of checkpoints") println("INFO = 1 -> calculate only approximate solution") println("INFO = 2 -> calculate approximate solution + takeshots") @@ -102,30 +102,30 @@ function muoptcontrol(scheme, steps, ::ZygoteTool) F = [1.0, 0.0] F_H = [0.0, 0.0] t = 0.0 - h = 1.0/steps + h = 1.0 / steps model = Model(F, F_H, t, h) function foo(model::Model) - @checkpoint_struct scheme model for i in 1:steps + @checkpoint_struct scheme model for i = 1:steps model.F_H .= model.F advance(model) model.t += h end return model.F[2] end - g = Zygote.gradient(foo,model) + g = Zygote.gradient(foo, model) F = model.F L = [g[1].F[1], g[1].F[2]] - F_opt = Array{Float64, 1}(undef, 2) - L_opt = Array{Float64, 1}(undef, 2) - opt_sol(F_opt,1.0) - opt_lambda(L_opt,0.0) + F_opt = Array{Float64,1}(undef, 2) + L_opt = Array{Float64,1}(undef, 2) + opt_sol(F_opt, 1.0) + opt_lambda(L_opt, 0.0) println("\n\n") - println("y_1*(1) = " , F_opt[1] , " y_2*(1) = " , F_opt[2]) - println("y_1 (1) = " , F[1] , " y_2 (1) = " , F[2] , " \n\n") - println("l_1*(0) = " , L_opt[1] , " l_2*(0) = " , L_opt[2]) - println("l_1 (0) = " , L[1] , " sl_2 (0) = " , L[2] , " ") + println("y_1*(1) = ", F_opt[1], " y_2*(1) = ", F_opt[2]) + println("y_1 (1) = ", F[1], " y_2 (1) = ", F[2], " \n\n") + println("l_1*(0) = ", L_opt[1], " l_2*(0) = ", L_opt[2]) + println("l_1 (0) = ", L[1], " sl_2 (0) = ", L[2], " ") return F, L, F_opt, L_opt end diff --git a/examples/optcontrolfunc.jl b/examples/optcontrolfunc.jl index efc0297..2ea50ee 100644 --- a/examples/optcontrolfunc.jl +++ b/examples/optcontrolfunc.jl @@ -8,12 +8,12 @@ end function func_U(t) e = exp(1) - return 2.0*((e^(3.0*t))-(e^3))/((e^(3.0*t/2.0))*(2.0+(e^3))) + return 2.0 * ((e^(3.0 * t)) - (e^3)) / ((e^(3.0 * t / 2.0)) * (2.0 + (e^3))) end -function func(F, X,t) - F[2] = X[1]*X[1]+0.5*(func_U(t)*func_U(t)) - F[1] = 0.5*X[1]+ func_U(t) +function func(F, X, t) + F[2] = X[1] * X[1] + 0.5 * (func_U(t) * func_U(t)) + F[1] = 0.5 * X[1] + func_U(t) return nothing end @@ -22,25 +22,25 @@ function advance(model) F = model.F t = model.t h = model.h - func(F, F_H,t) - F[1] = F_H[1] + h/2.0*F[1] - F[2] = F_H[2] + h/2.0*F[2] - func(F,F,t+h/2.0) - model.F[1] = F_H[1] + h*F[1] - model.F[2] = F_H[2] + h*F[2] + func(F, F_H, t) + F[1] = F_H[1] + h / 2.0 * F[1] + F[2] = F_H[2] + h / 2.0 * F[2] + func(F, F, t + h / 2.0) + model.F[1] = F_H[1] + h * F[1] + model.F[2] = F_H[2] + h * F[2] return nothing end -function opt_sol(Y,t) +function opt_sol(Y, t) e = exp(1) - Y[1] = (2.0*e^(3.0*t)+e^3)/(e^(3.0*t/2.0)*(2.0+e^3)) - Y[2] = (2.0*e^(3.0*t)-e^(6.0-3.0*t)-2.0+e^6)/((2.0+e^3)^2) + Y[1] = (2.0 * e^(3.0 * t) + e^3) / (e^(3.0 * t / 2.0) * (2.0 + e^3)) + Y[2] = (2.0 * e^(3.0 * t) - e^(6.0 - 3.0 * t) - 2.0 + e^6) / ((2.0 + e^3)^2) return end -function opt_lambda(L,t) +function opt_lambda(L, t) e = exp(1) - L[1] = (2.0*e^(3-t)-2.0*e^(2.0*t))/(e^(t/2.0)*(2+e^3)) + L[1] = (2.0 * e^(3 - t) - 2.0 * e^(2.0 * t)) / (e^(t / 2.0) * (2 + e^3)) L[2] = 1.0 return end diff --git a/examples/printaction.jl b/examples/printaction.jl index 63a658e..c0f51c2 100644 --- a/examples/printaction.jl +++ b/examples/printaction.jl @@ -1,8 +1,8 @@ using Checkpointing -function main(steps, checkpoints; verbose=0) +function main(steps, checkpoints; verbose = 0) store = function f() end - revolve = Revolve{Nothing}(steps, checkpoints, store, store; verbose=verbose) + revolve = Revolve{Nothing}(steps, checkpoints, store, store; verbose = verbose) while true next_action = next_action!(revolve) if next_action.actionflag == Checkpointing.done @@ -10,4 +10,4 @@ function main(steps, checkpoints; verbose=0) end end return revolve -end \ No newline at end of file +end diff --git a/src/Checkpointing.jl b/src/Checkpointing.jl index 3e49903..eac728b 100644 --- a/src/Checkpointing.jl +++ b/src/Checkpointing.jl @@ -28,13 +28,13 @@ done: we are done with adjoining the loop equivalent to the `terminate` enum val """ @enum ActionFlag begin none - store - restore - forward - firstuturn - uturn + store + restore + forward + firstuturn + uturn err - done + done end """ @@ -48,10 +48,10 @@ Stores the state of the checkpointing scheme after an action is taken. """ struct Action - actionflag::ActionFlag - iteration::Int - startiteration::Int - cpnum::Int + actionflag::ActionFlag + iteration::Int + startiteration::Int + cpnum::Int end export Scheme @@ -86,21 +86,21 @@ export Revolve, guess, factor, next_action!, ActionFlag, Periodic export Online_r2, update_revolve @generated function copyto!(dest::MT, src::MT) where {MT} - assignments = [ - :( dest.$name = src.$name ) for name in fieldnames(MT) - ] - quote $(assignments...) end + assignments = [:(dest.$name = src.$name) for name in fieldnames(MT)] + quote + $(assignments...) + end end function copyto!(dest::MT, src::TT) where {MT,TT} for name in (fieldnames(MT)) - if !isa(src[name], ChainRulesCore.ZeroTangent) && !isa(getfield(dest,name), Int) + if !isa(src[name], ChainRulesCore.ZeroTangent) && !isa(getfield(dest, name), Int) setfield!(dest, name, convert(typeof(getfield(dest, name)), src[name])) end end end -to_named_tuple(p) = (; (v=>getfield(p, v) for v in fieldnames(typeof(p)))...) +to_named_tuple(p) = (; (v => getfield(p, v) for v in fieldnames(typeof(p)))...) function create_tangent(shadowmodel::MT) where {MT} shadowtuple = to_named_tuple(shadowmodel) @@ -162,18 +162,22 @@ adjoints and is created here. It is supposed to be initialized by ChainRules. """ macro checkpoint_struct(alg, model, loop) if loop.head == :for - body = loop.args[2] + body = loop.args[2] iterator = loop.args[1].args[1] - from = loop.args[1].args[2].args[2] - to = loop.args[1].args[2].args[3] - range = loop.args[1].args[2] + from = loop.args[1].args[2].args[2] + to = loop.args[1].args[2].args[3] + range = loop.args[1].args[2] ex = quote let if !isa($range, UnitRange{Int64}) error("Checkpointing.jl: Only UnitRange{Int64} is supported.") end $iterator = $from - $model = Checkpointing.checkpoint_struct_for($alg, $model, $(loop.args[1].args[2])) do $model + $model = Checkpointing.checkpoint_struct_for( + $alg, + $model, + $(loop.args[1].args[2]), + ) do $model $body $iterator += 1 nothing @@ -185,10 +189,11 @@ macro checkpoint_struct(alg, model, loop) function condition($model) $(loop.args[1]) end - $model = Checkpointing.checkpoint_struct_while($alg, $model, condition) do $model - $(loop.args[2]) - nothing - end + $model = + Checkpointing.checkpoint_struct_while($alg, $model, condition) do $model + $(loop.args[2]) + nothing + end end else error("Checkpointing.jl: Unknown loop construct.") @@ -196,7 +201,12 @@ macro checkpoint_struct(alg, model, loop) esc(ex) end -function fwd_checkpoint_struct_for(body::Function, scheme::Scheme, model, range::UnitRange{Int64}) +function fwd_checkpoint_struct_for( + body::Function, + scheme::Scheme, + model, + range::UnitRange{Int64}, +) for i in range body(model) end diff --git a/src/Rules/ChainRules.jl b/src/Rules/ChainRules.jl index f5a3c53..a84340e 100644 --- a/src/Rules/ChainRules.jl +++ b/src/Rules/ChainRules.jl @@ -11,12 +11,7 @@ function ChainRulesCore.rrule( # TODO: store checkpoints during this forward call and # start the reverse with first uturn model_input = deepcopy(model) - model = fwd_checkpoint_struct_for( - body, - alg, - model, - range, - ) + model = fwd_checkpoint_struct_for(body, alg, model, range) function checkpoint_struct_pullback(dmodel) shadowmodel = deepcopy(model_input) set_zero!(shadowmodel) @@ -33,7 +28,7 @@ function ChainRulesCore.rrule( body::Function, alg::Scheme, model::MT, - condition::Function + condition::Function, ) where {MT} model_input = deepcopy(model) while condition(model) diff --git a/src/Rules/EnzymeRules.jl b/src/Rules/EnzymeRules.jl index ccedbc5..a85a161 100644 --- a/src/Rules/EnzymeRules.jl +++ b/src/Rules/EnzymeRules.jl @@ -35,7 +35,7 @@ function reverse( alg.val, model_input, model.dval, - range.val + range.val, ) copyto!(model.val, model_final) return (nothing, nothing, nothing, nothing) @@ -51,7 +51,11 @@ function augmented_primal( condition, ) if needs_primal(config) - return AugmentedReturn(func.val(body.val, alg.val, model.val, condition.val), nothing, (model.val,)) + return AugmentedReturn( + func.val(body.val, alg.val, model.val, condition.val), + nothing, + (model.val,), + ) else return AugmentedReturn(nothing, nothing, (model.val,)) end @@ -73,7 +77,7 @@ function reverse( alg.val, model_input, model.dval, - condition.val + condition.val, ) copyto!(model.val, model_final) return (nothing, nothing, nothing, nothing) diff --git a/src/Schemes/Online_r2.jl b/src/Schemes/Online_r2.jl index 63c6b40..1016e0c 100644 --- a/src/Schemes/Online_r2.jl +++ b/src/Schemes/Online_r2.jl @@ -9,8 +9,8 @@ # TODO: Extend Online_r2 to Online_r3 mutable struct Online_r2{MT} <: Scheme where {MT} - check::Int - capo::Int + check::Int + capo::Int acp::Int numfwd::Int numcmd::Int @@ -38,107 +38,127 @@ function Online_r2{MT}( frestore::Union{Function,Nothing} = nothing; storage::AbstractStorage = ArrayStorage{MT}(checkpoints), anActionInstance::Union{Nothing,Action} = nothing, - verbose::Int = 0 + verbose::Int = 0, ) where {MT} if !isa(anActionInstance, Nothing) anActionInstance.actionflag = 0 - anActionInstance.iteration = 0 - anActionInstance.cpNum = 0 + anActionInstance.iteration = 0 + anActionInstance.cpNum = 0 end if checkpoints < 0 - @error("Online_r2: negative checkpoints") + @error("Online_r2: negative checkpoints") end - acp = checkpoints - numfwd = 0 - numcmd = 0 - numstore = 0 - oldcapo = 0 + acp = checkpoints + numfwd = 0 + numcmd = 0 + numstore = 0 + oldcapo = 0 check = -1 capo = 0 oldind = -1 ind = -1 iter = -1 incr = -1 - offset= -1 - t=-1 + offset = -1 + t = -1 ch = Vector{Int}(undef, acp) ord_ch = Vector{Int}(undef, acp) num_rep = Vector{Int}(undef, acp) - for i in 1:acp + for i = 1:acp ch[i] = -1 ord_ch[i] = -1 num_rep[i] = -1 end - revolve = Revolve{MT}(typemax(Int64), acp, fstore, frestore; verbose=verbose) - online_r2 = Online_r2{MT}(check, capo, acp, numfwd, numcmd, numstore, - oldcapo, ind, oldind, iter, incr, offset, t, - verbose, fstore, frestore, ch, ord_ch, num_rep, revolve, storage) + revolve = Revolve{MT}(typemax(Int64), acp, fstore, frestore; verbose = verbose) + online_r2 = Online_r2{MT}( + check, + capo, + acp, + numfwd, + numcmd, + numstore, + oldcapo, + ind, + oldind, + iter, + incr, + offset, + t, + verbose, + fstore, + frestore, + ch, + ord_ch, + num_rep, + revolve, + storage, + ) return online_r2 end function update_revolve(online::Online_r2{MT}, steps) where {MT} online.revolve = Revolve{MT}(steps, online.acp, online.fstore, online.frestore) - online.revolve.rwcp = online.revolve.acp-1 + online.revolve.rwcp = online.revolve.acp - 1 online.revolve.steps = steps online.revolve.acp = online.acp - online.revolve.cstart = steps-1 + online.revolve.cstart = steps - 1 online.revolve.cend = steps - online.revolve.numfwd = steps-1 - online.revolve.numinv= online.revolve.numfwd-1 - online.revolve.numstore= online.acp - online.revolve.prevcend= steps - online.revolve.firstuturned=false - online.revolve.verbose= 0 + online.revolve.numfwd = steps - 1 + online.revolve.numinv = online.revolve.numfwd - 1 + online.revolve.numstore = online.acp + online.revolve.prevcend = steps + online.revolve.firstuturned = false + online.revolve.verbose = 0 num_ch = Vector{Int}(undef, online.acp) - for i=1:online.acp + for i = 1:online.acp num_ch[i] = 1 - for j=1:online.acp + for j = 1:online.acp if (online.ch[j] < online.ch[i]) - num_ch[i] = num_ch[i]+1 + num_ch[i] = num_ch[i] + 1 end end end - for i=1:online.acp - for j=1:online.acp + for i = 1:online.acp + for j = 1:online.acp if (num_ch[j] == i) - online.ord_ch[i]=j; + online.ord_ch[i] = j end end end - for j=1:online.acp + for j = 1:online.acp online.revolve.stepof[j] = online.ch[online.ord_ch[j]] end - online.revolve.stepof[online.acp+1]=0 + online.revolve.stepof[online.acp+1] = 0 end function next_action!(online::Online_r2)::Action # Default values for next action - actionflag = none + actionflag = none if online.verbose > 0 - if(online.check !=-1) - @info(online.check+1, online.ch[online.check+1], online.capo) - for i in 1:online.acp - println("online.ch[",i,"] =", online.ch[i]) + if (online.check != -1) + @info(online.check + 1, online.ch[online.check+1], online.capo) + for i = 1:online.acp + println("online.ch[", i, "] =", online.ch[i]) end else @info(online.check, online.capo) - for i in 1:online.acp - println("online.ch[",i,"] =", online.ch[i]) + for i = 1:online.acp + println("online.ch[", i, "] =", online.ch[i]) end end end - online.numcmd+=1 + online.numcmd += 1 # We use this logic because the C++ version uses short circuiting cond2 = false if online.check != -1 - cond2 = online.ch[online.check+1] != online.capo + cond2 = online.ch[online.check+1] != online.capo end online.oldcapo = online.capo - if ((online.check == -1) || ( cond2 && (online.capo <= online.acp-1))) - # condition for takeshot for r=1 - # (If no checkpoint has been taken before OR - # If a store has not just occurred AND the iteration count is - # less than the total number of checkpoints) + if ((online.check == -1) || (cond2 && (online.capo <= online.acp - 1))) + # condition for takeshot for r=1 + # (If no checkpoint has been taken before OR + # If a store has not just occurred AND the iteration count is + # less than the total number of checkpoints) if online.verbose > 0 @info("condition for takeshot for r=1") end @@ -146,39 +166,39 @@ function next_action!(online::Online_r2)::Action online.ch[online.check+1] = online.capo online.t = 0 if (online.acp < 4) - for i in 1:online.acp - online.num_rep[i] = 2 + for i = 1:online.acp + online.num_rep[i] = 2 end online.incr = 2 online.iter = 1 - online.oldind = online.acp-1 + online.oldind = online.acp - 1 else online.iter = 1 online.incr = 1 online.oldind = 1 - for i in 1:online.acp - online.num_rep[i] = 1 - online.ord_ch[i] = i-1 + for i = 1:online.acp + online.num_rep[i] = 1 + online.ord_ch[i] = i - 1 end - online.offset = online.acp-1 + online.offset = online.acp - 1 end - if (online.capo == online.acp-1) + if (online.capo == online.acp - 1) online.ind = 2 end # Increase the number of takeshots and the corresponding checkpoint - online.numstore+=1 - return Action(store, online.capo-1, -1, online.check) - elseif (online.capo < online.acp-1) - # condition for advance for r=1 - # (the iteraton is less that the total number of checkpoints) + online.numstore += 1 + return Action(store, online.capo - 1, -1, online.check) + elseif (online.capo < online.acp - 1) + # condition for advance for r=1 + # (the iteraton is less that the total number of checkpoints) if online.verbose > 0 @info("condition for advance for r=1") end - online.capo = online.oldcapo+1 - online.numfwd+=1 + online.capo = online.oldcapo + 1 + online.numfwd += 1 return Action(forward, online.capo, online.oldcapo, -1) else - # Online_r2-Checkpointing for r=2 + # Online_r2-Checkpointing for r=2 if (online.ch[online.check+1] == online.capo) # condition for advance for r=2 # (checkpoint has just occurred) @@ -187,48 +207,53 @@ function next_action!(online::Online_r2)::Action end if (online.acp == 1) online.capo = BigInt(typemax(Int64)) - online.numfwd+=1 + online.numfwd += 1 return Action(forward, online.capo, online.oldcapo, -1) elseif (online.acp == 2) - online.capo = online.ch[1+1]+online.incr - online.numfwd+=1 + online.capo = online.ch[1+1] + online.incr + online.numfwd += 1 return Action(forward, online.capo, online.oldcapo, -1) elseif (online.acp == 3) - online.numfwd+=online.incr - if (online.iter == 0) - online.capo = online.ch[online.oldind+1] - for i=0:(online.t+1)/2 - online.capo += online.incr - online.incr = online.incr + 1 - online.iter = online.iter + 1 + online.numfwd += online.incr + if (online.iter == 0) + online.capo = online.ch[online.oldind+1] + for i = 0:(online.t+1)/2 + online.capo += online.incr + online.incr = online.incr + 1 + online.iter = online.iter + 1 end - else - online.capo = online.ch[online.ind+1]+online.incr - online.incr = online.incr + 1 - online.iter = online.iter + 1 + else + online.capo = online.ch[online.ind+1] + online.incr + online.incr = online.incr + 1 + online.iter = online.iter + 1 end actionflag = forward return Action(forward, online.capo, online.oldcapo, -1) - else + else if online.verbose > 0 - @info("Online_r2-condition for advance for r=2 online.acp-1=", online.acp-1," online.capo= ", online.capo) + @info( + "Online_r2-condition for advance for r=2 online.acp-1=", + online.acp - 1, + " online.capo= ", + online.capo + ) end - if (online.capo == online.acp-1) - online.capo = online.capo+2 - online.ind=online.acp-1 - online.numfwd+=2 + if (online.capo == online.acp - 1) + online.capo = online.capo + 2 + online.ind = online.acp - 1 + online.numfwd += 2 return Action(forward, online.capo, online.oldcapo, -1) end - if (online.t == 0) - if (online.iter < online.offset) - online.capo = online.capo+1 - online.numfwd+=1 - else - online.capo = online.capo+2 - online.numfwd+=2 + if (online.t == 0) + if (online.iter < online.offset) + online.capo = online.capo + 1 + online.numfwd += 1 + else + online.capo = online.capo + 2 + online.numfwd += 2 end - if (online.offset == 1) - online.t += 1 + if (online.offset == 1) + online.t += 1 end return Action(forward, online.capo, online.oldcapo, -1) end @@ -245,100 +270,149 @@ function next_action!(online::Online_r2)::Action end if (online.acp == 2) online.ch[1+1] = online.capo - online.incr+=1 + online.incr += 1 # Increase the number of takeshots and the corresponding checkpoint - online.numstore+=1 - return Action(store, online.capo-1, -1, 1+1) + online.numstore += 1 + return Action(store, online.capo - 1, -1, 1 + 1) elseif (online.acp == 3) online.ch[online.ind+1] = online.capo - online.check = online.ind + online.check = online.ind if online.verbose > 0 - @info(" iter ", online.iter, " online.num_rep[1] ", online.num_rep[1+1]) + @info(" iter ", online.iter, " online.num_rep[1] ", online.num_rep[1+1]) end - if (online.iter == online.num_rep[1+1]) + if (online.iter == online.num_rep[1+1]) online.iter = 0 - online.t+=1 - online.oldind = online.ind - online.num_rep[1+1]+=1 - online.ind = 2 - online.num_rep[1+1]%2 - online.incr=1 + online.t += 1 + online.oldind = online.ind + online.num_rep[1+1] += 1 + online.ind = 2 - online.num_rep[1+1] % 2 + online.incr = 1 end - # Increase the number of takeshots and the corresponding checkpoint - online.numstore+=1 - return Action(store, online.capo-1, -1, online.check) + # Increase the number of takeshots and the corresponding checkpoint + online.numstore += 1 + return Action(store, online.capo - 1, -1, online.check) else if online.verbose > 0 @info(" online.capo ", online.capo, " online.acp ", online.acp) end - if (online.capo < online.acp+2) - online.ch[online.ind+1] = online.capo - online.check = online.ind - if (online.capo == online.acp+1) + if (online.capo < online.acp + 2) + online.ch[online.ind+1] = online.capo + online.check = online.ind + if (online.capo == online.acp + 1) online.oldind = online.ord_ch[online.acp-1+1] online.ind = online.ch[online.ord_ch[online.acp-1+1]+1] if online.verbose > 0 @info(" oldind ", online.oldind, " ind ", online.ind) end - for k=online.acp:-1:3 - online.ord_ch[k]=online.ord_ch[k-1] - online.ch[online.ord_ch[k]+1] = online.ch[online.ord_ch[k-1]+1] + for k = online.acp:-1:3 + online.ord_ch[k] = online.ord_ch[k-1] + online.ch[online.ord_ch[k]+1] = online.ch[online.ord_ch[k-1]+1] end - online.ord_ch[1+1] = online.oldind - online.ch[online.ord_ch[1+1]+1] = online.ind - online.incr = 2 - online.ind = 2 + online.ord_ch[1+1] = online.oldind + online.ch[online.ord_ch[1+1]+1] = online.ind + online.incr = 2 + online.ind = 2 if online.verbose > 0 - @info(" ind ", online.ind, " incr ", online.incr, " iter ", online.iter) - for j=1:online.acp - @info(" j ", j, " ord_ch ", online.ord_ch[j], " ch ", online.ch[online.ord_ch[j]+1], " rep ", online.num_rep[online.ord_ch[j]+1]) + @info( + " ind ", + online.ind, + " incr ", + online.incr, + " iter ", + online.iter + ) + for j = 1:online.acp + @info( + " j ", + j, + " ord_ch ", + online.ord_ch[j], + " ch ", + online.ch[online.ord_ch[j]+1], + " rep ", + online.num_rep[online.ord_ch[j]+1] + ) end end end - #Increase the number of takeshots and the corresponding checkpoint - online.numstore+=1 - return Action(store, online.capo-1, -1, online.check) + #Increase the number of takeshots and the corresponding checkpoint + online.numstore += 1 + return Action(store, online.capo - 1, -1, online.check) end if (online.t == 0) if online.verbose > 0 - @info(" online.ind ", online.ind, " online.incr ", online.incr, " iter ", online.iter, " offset ", online.offset) + @info( + " online.ind ", + online.ind, + " online.incr ", + online.incr, + " iter ", + online.iter, + " offset ", + online.offset + ) end if (online.iter == online.offset) - online.offset=online.offset-1 + online.offset = online.offset - 1 online.iter = 1 online.check = online.ord_ch[online.acp-1+1] online.ch[online.ord_ch[online.acp-1+1]+1] = online.capo online.oldind = online.ord_ch[online.acp-1+1] online.ind = online.ch[online.ord_ch[online.acp-1+1]+1] if online.verbose > 0 - @info(" oldind " , online.oldind , " ind " , online.ind) + @info(" oldind ", online.oldind, " ind ", online.ind) end - for k=online.acp-1:-1:online.incr+1 - online.ord_ch[k+1]=online.ord_ch[k-1+1] - online.ch[online.ord_ch[k+1]+1] = online.ch[online.ord_ch[k-1+1]+1] + for k = online.acp-1:-1:online.incr+1 + online.ord_ch[k+1] = online.ord_ch[k-1+1] + online.ch[online.ord_ch[k+1]+1] = + online.ch[online.ord_ch[k-1+1]+1] end online.ord_ch[online.incr+1] = online.oldind online.ch[online.ord_ch[online.incr+1]+1] = online.ind - online.incr+=1 - online.ind=online.incr + online.incr += 1 + online.ind = online.incr if online.verbose > 0 - @info(" ind ", online.ind, " incr ", online.incr, " iter ", online.iter) - for j=1:online.acp - @info(" j ", j, " ord_ch ", online.ord_ch[j], " ch ", online.ch[online.ord_ch[j]+1], " rep ", online.num_rep[online.ord_ch[j]+1]) + @info( + " ind ", + online.ind, + " incr ", + online.incr, + " iter ", + online.iter + ) + for j = 1:online.acp + @info( + " j ", + j, + " ord_ch ", + online.ord_ch[j], + " ch ", + online.ch[online.ord_ch[j]+1], + " rep ", + online.num_rep[online.ord_ch[j]+1] + ) end end else online.ch[online.ord_ch[online.ind+1]+1] = online.capo online.check = online.ord_ch[online.ind+1] - online.iter+=1 - online.ind+=1 + online.iter += 1 + online.ind += 1 if online.verbose > 0 - @info(" xx ind ", online.ind, " incr ", online.incr, " iter ", online.iter) + @info( + " xx ind ", + online.ind, + " incr ", + online.incr, + " iter ", + online.iter + ) end end #Increase the number of takeshots and the corresponding checkpoint - online.numstore=online.numstore+1 - return Action(store, online.capo-1, -1, online.check) + online.numstore = online.numstore + 1 + return Action(store, online.capo - 1, -1, online.check) end end @@ -346,7 +420,9 @@ function next_action!(online::Online_r2)::Action end # This means that the end of Online_r2 Checkpointing for r=2 is reached and # another Online_r2 Checkpointing class must be started - @info("Online_r2 is optimal over the range [0,(numcheckpoints+2)*(numcheckpoints+1)/2]. Online_r3 needs to be implemented") + @info( + "Online_r2 is optimal over the range [0,(numcheckpoints+2)*(numcheckpoints+1)/2]. Online_r3 needs to be implemented" + ) return Action(err, online.capo, online.oldcapo, -1) end @@ -355,7 +431,7 @@ function rev_checkpoint_struct_while( alg::Online_r2, model_input::MT, shadowmodel::MT, - condition::Function + condition::Function, ) where {MT} model = deepcopy(model_input) model_check = alg.storage @@ -364,44 +440,44 @@ function rev_checkpoint_struct_while( storemapinv = Dict{Int32,Int32}() storemap = Dict{Int32,Int32}() check = 0 - oldcapo=0 - onlinesteps=0 + oldcapo = 0 + onlinesteps = 0 go = true while go next_action = next_action!(alg) if (next_action.actionflag == Checkpointing.store) - check=next_action.cpnum+1 - storemapinv[check]=next_action.iteration + check = next_action.cpnum + 1 + storemapinv[check] = next_action.iteration model_check[check] = deepcopy(model) elseif (next_action.actionflag == Checkpointing.forward) - for j= oldcapo:(next_action.iteration-1) + for j = oldcapo:(next_action.iteration-1) body(model) go = condition(model) - onlinesteps=onlinesteps+1 + onlinesteps = onlinesteps + 1 if !go break end end - oldcapo=next_action.iteration + oldcapo = next_action.iteration else @error("Unexpected action in online phase: ", next_action.actionflag) go = false end end for (key, value) in storemapinv - storemap[value]=key + storemap[value] = key end # Switch to offline revolve now. - update_revolve(alg, onlinesteps+1) + update_revolve(alg, onlinesteps + 1) while true next_action = next_action!(alg.revolve) if (next_action.actionflag == Checkpointing.store) - check=pop!(freeindices) - storemap[next_action.iteration-1]=check + check = pop!(freeindices) + storemap[next_action.iteration-1] = check model_check[check] = deepcopy(model) elseif (next_action.actionflag == Checkpointing.forward) - for j= next_action.startiteration:(next_action.iteration-1) + for j = next_action.startiteration:(next_action.iteration-1) body(model) end elseif (next_action.actionflag == Checkpointing.firstuturn) @@ -410,16 +486,16 @@ function rev_checkpoint_struct_while( model_final = deepcopy(model) # Enzyme.autodiff(body, Duplicated(model,shadowmodel)) elseif (next_action.actionflag == Checkpointing.uturn) - Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel)) - if haskey(storemap,next_action.iteration-1-1) + Enzyme.autodiff(Reverse, body, Duplicated(model, shadowmodel)) + if haskey(storemap, next_action.iteration - 1 - 1) push!(freeindices, storemap[next_action.iteration-1-1]) - delete!(storemap,next_action.iteration-1-1) + delete!(storemap, next_action.iteration - 1 - 1) end elseif (next_action.actionflag == Checkpointing.restore) model = deepcopy(model_check[storemap[next_action.iteration-1]]) elseif next_action.actionflag == Checkpointing.done - if haskey(storemap,next_action.iteration-1-1) - delete!(storemap,next_action.iteration-1-1) + if haskey(storemap, next_action.iteration - 1 - 1) + delete!(storemap, next_action.iteration - 1 - 1) end break end diff --git a/src/Schemes/Periodic.jl b/src/Schemes/Periodic.jl index 268f8dc..82c44e1 100644 --- a/src/Schemes/Periodic.jl +++ b/src/Schemes/Periodic.jl @@ -31,21 +31,27 @@ function Periodic{MT}( bundle_::Union{Nothing,Int} = nothing, verbose::Int = 0, gc::Bool = true, - write_checkpoints::Bool = false + write_checkpoints::Bool = false, ) where {MT} if !isa(anActionInstance, Nothing) # same as default init above anActionInstance.actionflag = 0 - anActionInstance.iteration = 0 - anActionInstance.cpNum = 0 + anActionInstance.iteration = 0 + anActionInstance.cpNum = 0 end - acp = checkpoints - period = div(steps, checkpoints) + acp = checkpoints + period = div(steps, checkpoints) periodic = Periodic{MT}( - steps, acp, period,verbose, - fstore, frestore, storage, gc, - write_checkpoints + steps, + acp, + period, + verbose, + fstore, + frestore, + storage, + gc, + write_checkpoints, ) forwardcount(periodic) @@ -58,7 +64,12 @@ function forwardcount(periodic::Periodic) elseif periodic.steps < 1 error("Periodic forwardcount: error: steps < 1") elseif mod(periodic.steps, periodic.acp) != 0 - error("Periodic forwardcount: error: steps ", periodic.steps, " not divisible by checkpoints ", periodic.acp) + error( + "Periodic forwardcount: error: steps ", + periodic.steps, + " not divisible by checkpoints ", + periodic.acp, + ) end end @@ -67,8 +78,8 @@ function rev_checkpoint_struct_for( alg::Periodic, model_input::MT, shadowmodel::MT, - range -) where{MT} + range, +) where {MT} model = deepcopy(model_input) model_final = [] model_check_outer = alg.storage @@ -78,28 +89,34 @@ function rev_checkpoint_struct_for( GC.enable(false) end if alg.write_checkpoints - prim_output = HDF5Storage{MT}(alg.steps; filename="primal_$(alg.write_checkpoints_filename).h5") - adj_output = HDF5Storage{MT}(alg.steps; filename="adjoint_$(alg.write_checkpoints_filename).h5") + prim_output = HDF5Storage{MT}( + alg.steps; + filename = "primal_$(alg.write_checkpoints_filename).h5", + ) + adj_output = HDF5Storage{MT}( + alg.steps; + filename = "adjoint_$(alg.write_checkpoints_filename).h5", + ) end for i = 1:alg.acp model_check_outer[i] = deepcopy(model) - for j= (i-1)*alg.period: (i)*alg.period-1 + for j = (i-1)*alg.period:(i)*alg.period-1 body(model) end end model_final = deepcopy(model) for i = alg.acp:-1:1 model = deepcopy(model_check_outer[i]) - for j= 1:alg.period + for j = 1:alg.period model_check_inner[j] = deepcopy(model) body(model) end - for j= alg.period:-1:1 + for j = alg.period:-1:1 if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 prim_output[j] = model end model = deepcopy(model_check_inner[j]) - Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel)) + Enzyme.autodiff(Reverse, body, Duplicated(model, shadowmodel)) if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 adj_output[j] = shadowmodel end diff --git a/src/Schemes/Revolve.jl b/src/Schemes/Revolve.jl index b7ac852..e323e29 100644 --- a/src/Schemes/Revolve.jl +++ b/src/Schemes/Revolve.jl @@ -42,13 +42,13 @@ function Revolve{MT}( gc::Bool = true, write_checkpoints::Bool = false, write_checkpoints_filename::String = "chkp.h5", - write_checkpoints_period::Int = 1 + write_checkpoints_period::Int = 1, ) where {MT} if !isa(anActionInstance, Nothing) # same as default init above anActionInstance.actionflag = 0 - anActionInstance.iteration = 0 - anActionInstance.cpNum = 0 + anActionInstance.iteration = 0 + anActionInstance.cpNum = 0 end if verbose > 0 @info "Revolve: Number of checkpoints: $checkpoints" @@ -56,14 +56,14 @@ function Revolve{MT}( end !isa(bundle_, Nothing) ? bundle = bundle_ : bundle = 1 if bundle < 1 || bundle > steps - error("Revolve: bundle parameter out of range [1,steps]") - elseif steps<0 - error("Revolve: negative steps") + error("Revolve: bundle parameter out of range [1,steps]") + elseif steps < 0 + error("Revolve: negative steps") elseif checkpoints < 0 - error("Revolve: negative checkpoints") + error("Revolve: negative checkpoints") end cstart = 0 - tail = 1 + tail = 1 if bundle > 1 tail = mod(steps, bundle) steps = steps / bundle @@ -73,22 +73,38 @@ function Revolve{MT}( tail = bundle end end - cend = steps - acp = checkpoints - numfwd = 0 - numinv = 0 - numstore = 0 - rwcp = -1 - prevcend = 0 - firstuturned = false - stepof = Vector{Int}(undef, acp+1) + cend = steps + acp = checkpoints + numfwd = 0 + numinv = 0 + numstore = 0 + rwcp = -1 + prevcend = 0 + firstuturned = false + stepof = Vector{Int}(undef, acp + 1) revolve = Revolve{MT}( - steps, bundle, tail, acp, cstart, cend, numfwd, - numinv, numstore, rwcp, prevcend, firstuturned, - stepof, verbose, fstore, frestore, storage, gc, + steps, + bundle, + tail, + acp, + cstart, + cend, + numfwd, + numinv, + numstore, + rwcp, + prevcend, + firstuturned, + stepof, + verbose, + fstore, + frestore, + storage, + gc, write_checkpoints, - write_checkpoints_filename, write_checkpoints_period + write_checkpoints_filename, + write_checkpoints_period, ) if verbose > 0 @@ -110,10 +126,10 @@ end function next_action!(revolve::Revolve)::Action # Default values for next action - actionflag = none - iteration = 0 + actionflag = none + iteration = 0 startiteration = 0 - cpnum = 0 + cpnum = 0 if revolve.numinv == 0 # first invocation for v in revolve.stepof @@ -125,10 +141,10 @@ function next_action!(revolve::Revolve)::Action revolve.numinv += 1 rwcptest = (revolve.rwcp == -1) if !rwcptest - rwcptest = revolve.stepof[revolve.rwcp+1] != revolve.cstart + rwcptest = revolve.stepof[revolve.rwcp+1] != revolve.cstart end if (revolve.cend - revolve.cstart) == 0 - # nothing in current subrange + # nothing in current subrange if (revolve.rwcp == -1) || (revolve.cstart == revolve.stepof[1]) # we are done revolve.rwcp = revolve.rwcp - 1 @@ -143,14 +159,14 @@ function next_action!(revolve::Revolve)::Action end actionflag = done else - revolve.cstart = revolve.stepof[revolve.rwcp+1] - revolve.prevcend = revolve.cend - actionflag = restore + revolve.cstart = revolve.stepof[revolve.rwcp+1] + revolve.prevcend = revolve.cend + actionflag = restore end elseif (revolve.cend - revolve.cstart) == 1 revolve.cend = revolve.cend - 1 revolve.prevcend = revolve.cend - if (revolve.rwcp >= 0) && (revolve.stepof[revolve.rwcp + 1] == revolve.cstart) + if (revolve.rwcp >= 0) && (revolve.stepof[revolve.rwcp+1] == revolve.cstart) revolve.rwcp -= 1 end if !revolve.firstuturned @@ -170,7 +186,7 @@ function next_action!(revolve::Revolve)::Action actionflag = store end elseif (revolve.prevcend < revolve.cend) && (revolve.acp == revolve.rwcp + 1) - error("Revolve: insufficient allowed checkpoints") + error("Revolve: insufficient allowed checkpoints") else availcp = revolve.acp - revolve.rwcp if availcp < 1 @@ -188,9 +204,9 @@ function next_action!(revolve::Revolve)::Action else bino2 = 1 end - if availcp==1 + if availcp == 1 bino3 = 0 - elseif availcp>2 + elseif availcp > 2 bino3 = bino2 * (availcp - 1) / (availcp + reps - 2) else bino3 = 1 @@ -199,13 +215,13 @@ function next_action!(revolve::Revolve)::Action if availcp < 3 bino5 = 0 elseif availcp > 3 - bino5 = bino3 * (availcp - 1) / reps + bino5 = bino3 * (availcp - 1) / reps else bino5 = 1 end if (revolve.cend - revolve.cstart) <= (bino1 + bino3) revolve.cstart = trunc(Int, revolve.cstart + bino4) - elseif (revolve.cend - revolve.cstart) >= (range-bino5) + elseif (revolve.cend - revolve.cstart) >= (range - bino5) revolve.cstart = trunc(Int, revolve.cstart + bino1) else revolve.cstart = trunc(Int, revolve.cend - bino2 - bino3) @@ -214,10 +230,14 @@ function next_action!(revolve::Revolve)::Action revolve.cstart = prevcstart + 1 end if revolve.cstart == revolve.steps - revolve.numfwd = (revolve.numfwd + ((revolve.cstart - 1) - prevcstart) - * revolve.bundle + revolve.tail) + revolve.numfwd = ( + revolve.numfwd + + ((revolve.cstart - 1) - prevcstart) * revolve.bundle + + revolve.tail + ) else - revolve.numfwd = revolve.numfwd + (revolve.cstart - prevcstart) * revolve.bundle + revolve.numfwd = + revolve.numfwd + (revolve.cstart - prevcstart) * revolve.bundle end actionflag = forward end @@ -245,21 +265,21 @@ function next_action!(revolve::Revolve)::Action if (revolve.verbose > 1) && (actionflag == store) @info " store input of iteration $iteration " end - cpnum=revolve.rwcp + cpnum = revolve.rwcp return Action(actionflag, iteration, startiteration, cpnum) end -function guess(revolve::Revolve; bundle::Union{Nothing, Int} = nothing)::Int - b=1 - bSteps=revolve.steps +function guess(revolve::Revolve; bundle::Union{Nothing,Int} = nothing)::Int + b = 1 + bSteps = revolve.steps if !isa(bundle, Nothing) - b=bundle + b = bundle end if revolve.steps < 1 error("Revolve: error: steps < 1") - elseif b<1 + elseif b < 1 error("Revolve: error: bundle < 1") else if b > 1 @@ -275,10 +295,10 @@ function guess(revolve::Revolve; bundle::Union{Nothing, Int} = nothing)::Int checkpoints = 1 reps = 1 s = 0 - while chkrange(revolve, checkpoints+s, reps+s) > bSteps + while chkrange(revolve, checkpoints + s, reps + s) > bSteps s -= 1 end - while chkrange(revolve, checkpoints+s, reps+s) < bSteps + while chkrange(revolve, checkpoints + s, reps + s) < bSteps s += 1 end checkpoints += s @@ -314,18 +334,18 @@ function factor(revolve::Revolve, steps, checkpoints, bundle::Union{Nothing,Int} if f == -1 error("Revolve: error returned by forwardcount") else - factor = float(f)/steps + factor = float(f) / steps end return factor end function chkrange(::Revolve, ss, tt) ret = Int(0) - res = 1. + res = 1.0 if tt < 0 || ss < 0 error("Revolve chkrange: error: negative parameter") else - for i in 1:tt + for i = 1:tt res = res * (ss + i) res = res / i if res > typemax(typeof(ret)) @@ -344,8 +364,8 @@ end function forwardcount(revolve::Revolve) checkpoints = revolve.acp - bundle = revolve.bundle - steps = revolve.steps + bundle = revolve.bundle + steps = revolve.steps if checkpoints < 0 error("Revolve forwardcount: error: checkpoints < 0") elseif steps < 1 @@ -353,12 +373,12 @@ function forwardcount(revolve::Revolve) elseif bundle < 1 error("Revolve forwardcount: error: bundle < 1") else - s=steps + s = steps if bundle > 1 - tail = mod(s,bundle) + tail = mod(s, bundle) s = s / bundle if tail > 0 - s = s + 1 + s = s + 1 end end if s == 1 @@ -369,8 +389,8 @@ function forwardcount(revolve::Revolve) reps = 0 range = 1 while range < s - reps = reps + 1 - range = range*(reps+checkpoints)/reps + reps = reps + 1 + range = range * (reps + checkpoints) / reps end ret = (reps * s - range * reps / (checkpoints + 1)) * bundle end @@ -380,7 +400,7 @@ end function reset!(revolve::Revolve) revolve.cstart = 0 - revolve.tail = 1 + revolve.tail = 1 if revolve.bundle > 1 tail = mod(steps, bundle) steps = steps / bundle @@ -390,12 +410,12 @@ function reset!(revolve::Revolve) tail = bundle end end - revolve.numfwd = 0 - revolve.numinv = 0 - revolve.numstore = 0 - revolve.rwcp = -1 - revolve.prevcend = 0 - revolve.firstuturned = false + revolve.numfwd = 0 + revolve.numinv = 0 + revolve.numstore = 0 + revolve.rwcp = -1 + revolve.prevcend = 0 + revolve.firstuturned = false return nothing end @@ -404,7 +424,7 @@ function rev_checkpoint_struct_for( alg::Revolve, model_input::MT, shadowmodel::MT, - range + range, ) where {MT} model = deepcopy(model_input) if alg.verbose > 0 @@ -418,23 +438,29 @@ function rev_checkpoint_struct_for( GC.enable(false) end if alg.write_checkpoints - prim_output = HDF5Storage{MT}(alg.steps; filename="primal_$(alg.write_checkpoints_filename).h5") - adj_output = HDF5Storage{MT}(alg.steps; filename="adjoint_$(alg.write_checkpoints_filename).h5") + prim_output = HDF5Storage{MT}( + alg.steps; + filename = "primal_$(alg.write_checkpoints_filename).h5", + ) + adj_output = HDF5Storage{MT}( + alg.steps; + filename = "adjoint_$(alg.write_checkpoints_filename).h5", + ) end step = alg.steps while true next_action = next_action!(alg) if (next_action.actionflag == Checkpointing.store) - check = check+1 - storemap[next_action.iteration-1]=check + check = check + 1 + storemap[next_action.iteration-1] = check model_check[check] = deepcopy(model) elseif (next_action.actionflag == Checkpointing.forward) - for j= next_action.startiteration:(next_action.iteration - 1) + for j = next_action.startiteration:(next_action.iteration-1) body(model) end elseif (next_action.actionflag == Checkpointing.firstuturn) body(model) - model_final = deepcopy(model) + model_final = deepcopy(model) if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 prim_output[step] = model_final end @@ -442,7 +468,7 @@ function rev_checkpoint_struct_for( @info "Revolve: First Uturn" @info "Size of total storage: $(Base.format_bytes(Base.summarysize(alg.storage)))" end - Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel)) + Enzyme.autodiff(Reverse, body, Duplicated(model, shadowmodel)) if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 adj_output[step] = shadowmodel end @@ -454,7 +480,7 @@ function rev_checkpoint_struct_for( if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 prim_output[step] = model end - Enzyme.autodiff(Reverse, body, Duplicated(model,shadowmodel)) + Enzyme.autodiff(Reverse, body, Duplicated(model, shadowmodel)) if alg.write_checkpoints && step % alg.write_checkpoints_period == 1 adj_output[step] = shadowmodel end @@ -462,16 +488,16 @@ function rev_checkpoint_struct_for( if !alg.gc GC.gc() end - if haskey(storemap,next_action.iteration-1-1) - delete!(storemap,next_action.iteration-1-1) - check=check-1 + if haskey(storemap, next_action.iteration - 1 - 1) + delete!(storemap, next_action.iteration - 1 - 1) + check = check - 1 end elseif (next_action.actionflag == Checkpointing.restore) model = deepcopy(model_check[storemap[next_action.iteration-1]]) elseif next_action.actionflag == Checkpointing.done - if haskey(storemap,next_action.iteration-1-1) - delete!(storemap,next_action.iteration-1-1) - check=check-1 + if haskey(storemap, next_action.iteration - 1 - 1) + delete!(storemap, next_action.iteration - 1 - 1) + check = check - 1 end break end diff --git a/src/Storage/HDF5Storage.jl b/src/Storage/HDF5Storage.jl index 8ab2b05..5d5161f 100644 --- a/src/Storage/HDF5Storage.jl +++ b/src/Storage/HDF5Storage.jl @@ -10,7 +10,7 @@ mutable struct HDF5Storage{MT} <: AbstractStorage where {MT} acp::Int64 end -function HDF5Storage{MT}(acp::Int; filename=tempname()) where {MT} +function HDF5Storage{MT}(acp::Int; filename = tempname()) where {MT} fid = h5open(filename, "w") storage = HDF5Storage{MT}(fid, filename, acp) function _finalizer(storage::HDF5Storage{MT}) diff --git a/test/enzyme.jl b/test/enzyme.jl index ff97053..05d49be 100644 --- a/test/enzyme.jl +++ b/test/enzyme.jl @@ -5,34 +5,34 @@ using Test function f(x) - return [(x[1]*(x[2]-2.0))^2, (x[1]-1.0)^2*x[2]^2] + return [(x[1] * (x[2] - 2.0))^2, (x[1] - 1.0)^2 * x[2]^2] end -function f2(x,res) +function f2(x, res) y = f(x) - copyto!(res,y) + copyto!(res, y) return nothing end -x = [2.0,6.0] +x = [2.0, 6.0] J = ReverseDiff.jacobian(f, x) dx = copy(x) -dx = [0.0,0.0] -y = [0.0,0.0] -dy = [0.0,1.0] -autodiff(f2, Duplicated(x,dx), Duplicated(y, dy)) +dx = [0.0, 0.0] +y = [0.0, 0.0] +dy = [0.0, 1.0] +autodiff(f2, Duplicated(x, dx), Duplicated(y, dy)) J2 = similar(J) fill!(J2, 0) -for i in 1:2 - x = [2.0,6.0] +for i = 1:2 + x = [2.0, 6.0] fill!(dx, 0) fill!(y, 0) dy[i] = 1.0 - autodiff(f2, Duplicated(x,dx), Duplicated(y, dy)) - J2[i,:] = dx[:] + autodiff(f2, Duplicated(x, dx), Duplicated(y, dy)) + J2[i, :] = dx[:] end # Is correct, but gives the memmove warning @@ -46,26 +46,27 @@ include("../examples/optcontrol.jl") global snaps = 3 global info = 1 - function store(F_H, F_C,t, i) - F_C[1,i] = F_H[1] - F_C[2,i] = F_H[2] - F_C[3,i] = t + function store(F_H, F_C, t, i) + F_C[1, i] = F_H[1] + F_C[2, i] = F_H[2] + F_C[3, i] = t return end function restore(F_C, i) - F_H = [F_C[1,i], F_C[2,i]] - t = F_C[3,i] + F_H = [F_C[1, i], F_C[2, i]] + t = F_C[3, i] return F_H, t end - revolve = Revolve(steps, snaps, store, restore; verbose=info) + revolve = Revolve(steps, snaps, store, restore; verbose = info) F_opt, F_final, L_opt, L = optcontrol(revolve, steps, ReverseDiffADTool()) # Check whether ReverseDiff works - @test isapprox(F_opt, F_final, rtol=1e-4) - @test isapprox(L_opt, L, rtol=1e-4) - revolve = Revolve(steps, snaps, store, restore; verbose=info) - F_opt_enzyme, F_final_enzyme, L_opt_enzyme, L_enzyme = optcontrol(revolve, steps, EnzymeADTool()) + @test isapprox(F_opt, F_final, rtol = 1e-4) + @test isapprox(L_opt, L, rtol = 1e-4) + revolve = Revolve(steps, snaps, store, restore; verbose = info) + F_opt_enzyme, F_final_enzyme, L_opt_enzyme, L_enzyme = + optcontrol(revolve, steps, EnzymeADTool()) @test isapprox(F_final, F_final_enzyme) # Returns wrong adjoints - @test_broken isapprox(L, L_enzyme, rtol=1e-4) -end \ No newline at end of file + @test_broken isapprox(L, L_enzyme, rtol = 1e-4) +end diff --git a/test/multilevel.jl b/test/multilevel.jl index 9b81c59..2ea3533 100644 --- a/test/multilevel.jl +++ b/test/multilevel.jl @@ -8,8 +8,8 @@ mutable struct Chkp end function loops(chkp::Chkp, scheme1::Scheme, it1::Int, it2::Int) - @checkpoint_struct scheme1 chkp for i in 1:it1 - @checkpoint_struct chkp.scheme chkp for j in 1:it2 + @checkpoint_struct scheme1 chkp for i = 1:it1 + @checkpoint_struct chkp.scheme chkp for j = 1:it2 chkp.x .= 2.0 * sqrt.(chkp.x) .* sqrt.(chkp.x) end end @@ -26,15 +26,22 @@ dx = Chkp([0.0, 0.0, 0.0], revolve) primal = loops(x, periodic, it1, it2) -peridoc = Periodic{Chkp}(it1, 1; verbose=0) -revolve = Revolve{Chkp}(it2, 2; verbose=0) +peridoc = Periodic{Chkp}(it1, 1; verbose = 0) +revolve = Revolve{Chkp}(it2, 2; verbose = 0) x = Chkp([2.0, 3.0, 4.0], revolve) dx = Chkp([0.0, 0.0, 0.0], revolve) -g = autodiff(Enzyme.ReverseWithPrimal, loops, Active, Duplicated(x, dx), Const(periodic), Const(it1), Const(it2)) +g = autodiff( + Enzyme.ReverseWithPrimal, + loops, + Active, + Duplicated(x, dx), + Const(periodic), + Const(it1), + Const(it2), +) # TODO: Primal is wrong only when multilevel checkpointing is used @test_broken g[2] == primal @test all(dx.x .== [1024.0, 1024.0, 1024.0]) - diff --git a/test/output_chkp.jl b/test/output_chkp.jl index 2762c23..b790b41 100644 --- a/test/output_chkp.jl +++ b/test/output_chkp.jl @@ -8,7 +8,7 @@ mutable struct ChkpOut end function loops(chkp::ChkpOut, scheme::Scheme, iters::Int) - @checkpoint_struct scheme chkp for i in 1:iters + @checkpoint_struct scheme chkp for i = 1:iters chkp.x .= 2.0 * sqrt.(chkp.x) .* sqrt.(chkp.x) end return reduce(+, chkp.x) @@ -18,10 +18,10 @@ iters = 10 revolve = Revolve{ChkpOut}( iters, 3; - verbose=0, - write_checkpoints=true, - write_checkpoints_filename="chkp", - write_checkpoints_period=2 + verbose = 0, + write_checkpoints = true, + write_checkpoints_filename = "chkp", + write_checkpoints_period = 2, ) x = ChkpOut([2.0, 3.0, 4.0]) diff --git a/test/runtests.jl b/test/runtests.jl index f1f0493..a933f31 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,10 +13,16 @@ adtools = [ZygoteTool(), EnzymeTool()] @testset "Checkpointing.jl" begin @testset "Enzyme..." begin @test Enzyme.EnzymeRules.has_rrule_from_sig( - Base.signature_type(Checkpointing.checkpoint_struct_for, Tuple{Any,Any,Any,Any}) + Base.signature_type( + Checkpointing.checkpoint_struct_for, + Tuple{Any,Any,Any,Any}, + ), ) @test Enzyme.EnzymeRules.has_rrule_from_sig( - Base.signature_type(Checkpointing.checkpoint_struct_while, Tuple{Any,Any,Any,Any}) + Base.signature_type( + Checkpointing.checkpoint_struct_while, + Tuple{Any,Any,Any,Any}, + ), ) include("speelpenning.jl") errf, errg = main() @@ -42,10 +48,10 @@ adtools = [ZygoteTool(), EnzymeTool()] snaps = 3 info = 0 - revolve = Revolve{Model}(steps, snaps; verbose=info) + revolve = Revolve{Model}(steps, snaps; verbose = info) F, L, F_opt, L_opt = muoptcontrol(revolve, steps, adtool) - @test isapprox(F_opt, F, rtol=1e-4) - @test isapprox(L_opt, L, rtol=1e-4) + @test isapprox(F_opt, F, rtol = 1e-4) + @test isapprox(L_opt, L, rtol = 1e-4) end @testset "Periodic..." begin @@ -53,10 +59,10 @@ adtools = [ZygoteTool(), EnzymeTool()] snaps = 4 info = 0 - periodic = Periodic{Model}(steps, snaps; verbose=info) + periodic = Periodic{Model}(steps, snaps; verbose = info) F, L, F_opt, L_opt = muoptcontrol(periodic, steps, adtool) - @test isapprox(F_opt, F, rtol=1e-4) - @test isapprox(L_opt, L, rtol=1e-4) + @test isapprox(F_opt, F, rtol = 1e-4) + @test isapprox(L_opt, L, rtol = 1e-4) end end end @@ -68,11 +74,11 @@ adtools = [ZygoteTool(), EnzymeTool()] snaps = 4 info = 0 - revolve = Revolve{Heat}(steps, snaps; verbose=info) + revolve = Revolve{Heat}(steps, snaps; verbose = info) T, dT = heat_for(revolve, steps, adtool) - @test isapprox(norm(T), 66.21987468492061, atol=1e-11) - @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) + @test isapprox(norm(T), 66.21987468492061, atol = 1e-11) + @test isapprox(norm(dT), 6.970279349365908, atol = 1e-11) end @testset "Periodic..." begin @@ -80,22 +86,22 @@ adtools = [ZygoteTool(), EnzymeTool()] snaps = 4 info = 0 - periodic = Periodic{Heat}(steps, snaps; verbose=info) + periodic = Periodic{Heat}(steps, snaps; verbose = info) T, dT = heat_for(periodic, steps, adtool) - @test isapprox(norm(T), 66.21987468492061, atol=1e-11) - @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) + @test isapprox(norm(T), 66.21987468492061, atol = 1e-11) + @test isapprox(norm(dT), 6.970279349365908, atol = 1e-11) end @testset "Online_r2..." begin steps = 500 snaps = 100 info = 0 - online = Online_r2{Heat}(snaps; verbose=info) + online = Online_r2{Heat}(snaps; verbose = info) T, dT = heat_while(online, steps, adtool) - @test isapprox(norm(T), 66.21987468492061, atol=1e-11) - @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) + @test isapprox(norm(T), 66.21987468492061, atol = 1e-11) + @test isapprox(norm(dT), 6.970279349365908, atol = 1e-11) end end @testset "Testing HDF5 storage using heat example" begin @@ -105,11 +111,16 @@ adtools = [ZygoteTool(), EnzymeTool()] snaps = 4 info = 0 - revolve = Revolve{Heat}(steps, snaps; storage=HDF5Storage{Heat}(snaps), verbose=info) + revolve = Revolve{Heat}( + steps, + snaps; + storage = HDF5Storage{Heat}(snaps), + verbose = info, + ) T, dT = heat_for(revolve, steps, adtool) - @test isapprox(norm(T), 66.21987468492061, atol=1e-11) - @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) + @test isapprox(norm(T), 66.21987468492061, atol = 1e-11) + @test isapprox(norm(dT), 6.970279349365908, atol = 1e-11) end @testset "Periodic..." begin @@ -117,22 +128,31 @@ adtools = [ZygoteTool(), EnzymeTool()] snaps = 4 info = 0 - periodic = Periodic{Heat}(steps, snaps; storage=HDF5Storage{Heat}(snaps), verbose=info) + periodic = Periodic{Heat}( + steps, + snaps; + storage = HDF5Storage{Heat}(snaps), + verbose = info, + ) T, dT = heat_for(periodic, steps, adtool) - @test isapprox(norm(T), 66.21987468492061, atol=1e-11) - @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) + @test isapprox(norm(T), 66.21987468492061, atol = 1e-11) + @test isapprox(norm(dT), 6.970279349365908, atol = 1e-11) end @testset "Online_r2..." begin steps = 500 snaps = 100 info = 0 - online = Online_r2{Heat}(snaps; storage=HDF5Storage{Heat}(snaps), verbose=info) + online = Online_r2{Heat}( + snaps; + storage = HDF5Storage{Heat}(snaps), + verbose = info, + ) T, dT = heat_while(online, steps, adtool) - @test isapprox(norm(T), 66.21987468492061, atol=1e-11) - @test isapprox(norm(dT), 6.970279349365908, atol=1e-11) + @test isapprox(norm(T), 66.21987468492061, atol = 1e-11) + @test isapprox(norm(dT), 6.970279349365908, atol = 1e-11) end end end @@ -146,7 +166,7 @@ adtools = [ZygoteTool(), EnzymeTool()] snaps = 100 info = 0 - revolve = Revolve{Box}(steps, snaps; verbose=info) + revolve = Revolve{Box}(steps, snaps; verbose = info) T, dT = box_for(revolve, steps, adtool) @test isapprox(T, 21.41890316892692) @test isapprox(dT[5], 0.00616139595759519) @@ -158,7 +178,7 @@ adtools = [ZygoteTool(), EnzymeTool()] snaps = 100 info = 0 - periodic = Periodic{Box}(steps, snaps; verbose=info) + periodic = Periodic{Box}(steps, snaps; verbose = info) T, dT = box_for(periodic, steps, adtool) @test isapprox(T, 21.41890316892692) @test isapprox(dT[5], 0.00616139595759519) @@ -168,7 +188,7 @@ adtools = [ZygoteTool(), EnzymeTool()] steps = 10000 snaps = 500 info = 0 - online = Online_r2{Box}(snaps; verbose=info) + online = Online_r2{Box}(snaps; verbose = info) T, dT = box_while(online, steps, adtool) @test isapprox(T, 21.41890316892692) @test isapprox(dT[5], 0.00616139595759519) diff --git a/test/speelpenning.jl b/test/speelpenning.jl index 40f24f4..cfe023f 100644 --- a/test/speelpenning.jl +++ b/test/speelpenning.jl @@ -8,19 +8,19 @@ function main() y = [0.0] n = 10 - x = [i/(1.0+i) for i in 1:n] - speelpenning(y,x) + x = [i / (1.0 + i) for i = 1:n] + speelpenning(y, x) dx = zeros(n) dy = [1.0] - autodiff(Reverse, speelpenning, Duplicated(y,dy), Duplicated(x,dx)) + autodiff(Reverse, speelpenning, Duplicated(y, dy), Duplicated(x, dx)) y = [0.0] - speelpenning(y,x) + speelpenning(y, x) errg = 0.0 for (i, v) in enumerate(x) - errg += abs(dx[i]-y[1]/v) + errg += abs(dx[i] - y[1] / v) end - return (y[1]-1/(1.0+n)), errg + return (y[1] - 1 / (1.0 + n)), errg end