From 394806cc691190de7d394ca8d0f2717f63d6e87b Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Tue, 24 Sep 2024 16:13:30 -0400 Subject: [PATCH 01/12] Bump compats and update tutorials for Optimization v4 --- Project.toml | 6 ++-- docs/Project.toml | 8 ++--- docs/src/examples/augmented_neural_ode.md | 34 +++++++++---------- docs/src/examples/hamiltonian_nn.md | 25 ++++++-------- docs/src/examples/mnist_conv_neural_ode.md | 14 ++++---- docs/src/examples/mnist_neural_ode.md | 34 +++++++++---------- docs/src/examples/multiple_shooting.md | 19 +++++++---- docs/src/examples/neural_ode.md | 20 ++++++----- .../examples/neural_ode_weather_forecast.md | 4 +-- docs/src/examples/neural_sde.md | 19 +++++++---- docs/src/examples/normalizing_flows.md | 4 +-- docs/src/examples/physical_constraints.md | 2 +- 12 files changed, 100 insertions(+), 89 deletions(-) diff --git a/Project.toml b/Project.toml index aee97ae05..ca68f6a54 100644 --- a/Project.toml +++ b/Project.toml @@ -54,9 +54,9 @@ LuxLib = "1.2" NNlib = "0.9.22" OneHotArrays = "0.2.5" Optimisers = "0.3" -Optimization = "3.25.0" -OptimizationOptimJL = "0.3.0" -OptimizationOptimisers = "0.2.1" +Optimization = "4" +OptimizationOptimJL = "0.4" +OptimizationOptimisers = "0.3" OrdinaryDiffEq = "6.76.0" Printf = "1.10" Random = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 09aa6590e..ee0dbb7d9 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -57,10 +57,10 @@ MLUtils = "0.4" NNlib = "0.9" OneHotArrays = "0.2" Optimisers = "0.3" -Optimization = "3.9" -OptimizationOptimJL = "0.2, 0.3" -OptimizationOptimisers = "0.2" -OptimizationPolyalgorithms = "0.2" +Optimization = "4" +OptimizationOptimJL = "0.4" +OptimizationOptimisers = "0.3" +OptimizationPolyalgorithms = "0.3" OrdinaryDiffEq = "6.31" Plots = "1.36" Printf = "1" diff --git a/docs/src/examples/augmented_neural_ode.md b/docs/src/examples/augmented_neural_ode.md index a0f9a6a0f..98261b9d5 100644 --- a/docs/src/examples/augmented_neural_ode.md +++ b/docs/src/examples/augmented_neural_ode.md @@ -69,13 +69,13 @@ function plot_contour(model, ps, st, npoints = 300) return contour(x, y, sol; fill = true, linewidth = 0.0) end -loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2) +loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2) dataloader = concentric_sphere( 2, (0.0f0, 2.0f0), (3.0f0, 4.0f0), 2000, 2000; batch_size = 256) iter = 0 -cb = function (ps, l) +cb = function (state, l) global iter iter += 1 if iter % 10 == 0 @@ -87,15 +87,15 @@ end model, ps, st = construct_model(1, 2, 64, 0) opt = OptimizationOptimisers.Adam(0.005) -loss_node(model, dataloader.data[1], dataloader.data[2], ps, st) +loss_node(model, (dataloader.data[1], dataloader.data[2]), ps, st) println("Training Neural ODE") optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 1000) plt_node = plot_contour(model, res.u, st) @@ -106,10 +106,10 @@ println() println("Training Augmented Neural ODE") optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 1000) plot_contour(model, res.u, st) ``` @@ -229,7 +229,7 @@ We use the L2 distance between the model prediction `model(x)` and the actual pr optimization objective. ```@example augneuralode -loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2) +loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2) ``` #### Dataset @@ -248,7 +248,7 @@ Additionally, we define a callback function which displays the total loss at spe ```@example augneuralode iter = 0 -cb = function (ps, l) +cb = function (state, l) global iter iter += 1 if iter % 10 == 0 @@ -276,10 +276,10 @@ for `20` epochs. model, ps, st = construct_model(1, 2, 64, 0) optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 1000) plot_contour(model, res.u, st) ``` @@ -297,10 +297,10 @@ a function which can be expressed by the neural ode. For more details and proofs model, ps, st = construct_model(1, 2, 64, 1) optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 1000) plot_contour(model, res.u, st) ``` diff --git a/docs/src/examples/hamiltonian_nn.md b/docs/src/examples/hamiltonian_nn.md index 9c1716bad..43f816bda 100644 --- a/docs/src/examples/hamiltonian_nn.md +++ b/docs/src/examples/hamiltonian_nn.md @@ -75,7 +75,7 @@ The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we ```@example hamiltonian using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random, - ComponentArrays, Optimization, OptimizationOptimisers, IterTools + ComponentArrays, Optimization, OptimizationOptimisers, MLUtils t = range(0.0f0, 1.0f0; length = 1024) π_32 = Float32(π) @@ -87,12 +87,8 @@ dpdt = -2π_32 .* q_t data = cat(q_t, p_t; dims = 1) target = cat(dqdt, dpdt; dims = 1) B = 256 -NEPOCHS = 100 -dataloader = ncycle( - ((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))), - selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) - for i in 1:(size(data, 2) ÷ B)), - NEPOCHS) +NEPOCHS = 1000 +dataloader = DataLoader((data, target); batchsize = B) ``` ### Training the HamiltonianNN @@ -103,24 +99,25 @@ We parameterize the with a small MultiLayered Perceptron. HNNs are trained by o hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote()) ps, st = Lux.setup(Xoshiro(0), hnn) ps_c = ps |> ComponentArray +hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st) opt = OptimizationOptimisers.Adam(0.01f0) -function loss_function(ps, data, target) - pred, st_ = hnn(data, ps, st) +function loss_function(ps, databatch) + (data, target) = databatch + pred = hnn_stateful(data, ps) return mean(abs2, pred .- target), pred end -function callback(ps, loss, pred) +function callback(state, loss) println("[Hamiltonian NN] Loss: ", loss) return false end -opt_func = OptimizationFunction( - (ps, _, data, target) -> loss_function(ps, data, target), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps_c) +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps_c, dataloader) -res = solve(opt_prob, opt, dataloader; callback) +res = solve(opt_prob, opt; callback, epochs = NEPOCHS) ps_trained = res.u ``` diff --git a/docs/src/examples/mnist_conv_neural_ode.md b/docs/src/examples/mnist_conv_neural_ode.md index f67262780..e1f9024b4 100644 --- a/docs/src/examples/mnist_conv_neural_ode.md +++ b/docs/src/examples/mnist_conv_neural_ode.md @@ -89,20 +89,20 @@ end # burn in accuracy accuracy(m, ((img, lab),), ps, st) -function loss_function(ps, x, y) +function loss_function(ps, data) + (x, y) = data pred, _ = m(x, ps, st) - return logitcrossentropy(pred, y), pred + return logitcrossentropy(pred, y) end # burn in loss -loss_function(ps, img, lab) +loss_function(ps, (img, lab)) opt = OptimizationOptimisers.Adam(0.005) iter = 0 -opt_func = OptimizationFunction( - (ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps); +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps, dataloader); function callback(ps, l, pred) global iter += 1 @@ -112,7 +112,7 @@ function callback(ps, l, pred) end # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt, dataloader; maxiters = 5, callback) +res = Optimization.solve(opt_prob, opt; maxiters = 5, callback) acc = accuracy(m, dataloader, res.u, st) acc # hide ``` diff --git a/docs/src/examples/mnist_neural_ode.md b/docs/src/examples/mnist_neural_ode.md index 349dbcff2..27b78f26c 100644 --- a/docs/src/examples/mnist_neural_ode.md +++ b/docs/src/examples/mnist_neural_ode.md @@ -81,9 +81,10 @@ end accuracy(m, ((x_train1, y_train1),), ps, st) # burn in accuracy -function loss_function(ps, x, y) +function loss_function(ps, data) + (x, y) = data pred, st_ = m(x, ps, st) - return logitcrossentropy(pred, y), pred + return logitcrossentropy(pred, y) end loss_function(ps, x_train1, y_train1) # burn in loss @@ -91,19 +92,18 @@ loss_function(ps, x_train1, y_train1) # burn in loss opt = OptimizationOptimisers.Adam(0.05) iter = 0 -opt_func = OptimizationFunction( - (ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps) +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps, dataloader) -function callback(ps, l, pred) +function callback(state, l) global iter += 1 iter % 10 == 0 && - @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))" + @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))" return false end # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5) +res = Optimization.solve(opt_prob, opt; callback, maxiters = 5) accuracy(m, dataloader, res.u, st) ``` @@ -285,12 +285,13 @@ final output of our model. `logitcrossentropy` takes in the prediction from our model `model(x)` and compares it to actual output `y`: ```@example mnist -function loss_function(ps, x, y) +function loss_function(ps, data) + (x, y) = data pred, st_ = m(x, ps, st) - return logitcrossentropy(pred, y), pred + return logitcrossentropy(pred, y) end -loss_function(ps, x_train1, y_train1) # burn in loss +loss_function(ps, (x_train1, y_train1)) # burn in loss ``` #### Optimizer @@ -309,14 +310,13 @@ This callback function is used to print both the training and testing accuracy a ```@example mnist iter = 0 -opt_func = OptimizationFunction( - (ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps) +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps, dataloader) -function callback(ps, l, pred) +function callback(state, l) global iter += 1 iter % 10 == 0 && - @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))" + @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))" return false end ``` @@ -329,6 +329,6 @@ for Neural ODE is given by `nn_ode.p`: ```@example mnist # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5) +res = Optimization.solve(opt_prob, opt; callback, maxiters = 5) accuracy(m, dataloader, res.u, st) ``` diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 2b20219f5..b48dd8e8f 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -48,6 +48,9 @@ ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) nn = Chain(x -> x .^ 3, Dense(2, 16, tanh), Dense(16, 2)) p_init, st = Lux.setup(rng, nn) +ps = ComponentArray(p_init) +pd, pax = getdata(ps), getaxes(ps) + neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps) prob_node = ODEProblem((u, p, t) -> nn(u, p, st)[1], u0, tspan, ComponentArray(p_init)) @@ -62,14 +65,13 @@ end anim = Plots.Animation() iter = 0 -callback = function (p, l, preds; doplot = true) +function callback(state, l; doplot = true, prob_node = prob_node) display(l) global iter iter += 1 if doplot && iter % 1 == 0 # plot the original data plt = scatter(tsteps, ode_data[1, :]; label = "Data") - # plot the different predictions for individual shoot plot_multiple_shoot(plt, preds, group_size) @@ -83,23 +85,26 @@ end group_size = 3 continuity_term = 200 +l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + Tsit5(), group_size; continuity_term) + function loss_function(data, pred) return sum(abs2, data - pred) end -ps = ComponentArray(p_init) -pd, pax = getdata(ps), getaxes(ps) - function loss_multiple_shooting(p) ps = ComponentArray(p, pax) - return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + + loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) + global preds = currpred + return loss end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype) optprob = Optimization.OptimizationProblem(optf, pd) -res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback) +res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000) gif(anim, "multiple_shooting.gif"; fps = 15) ``` diff --git a/docs/src/examples/neural_ode.md b/docs/src/examples/neural_ode.md index c1dda9248..da9174bda 100644 --- a/docs/src/examples/neural_ode.md +++ b/docs/src/examples/neural_ode.md @@ -40,15 +40,16 @@ end function loss_neuralode(p) pred = predict_neuralode(p) loss = sum(abs2, ode_data .- pred) - return loss, pred + return loss end # Do not plot by default for the documentation # Users should change doplot=true to see the plots callbacks -callback = function (p, l, pred; doplot = false) +function callback(state, l; doplot = false) println(l) # plot current prediction against data if doplot + pred = predict_neuralode(state.u) plt = scatter(tsteps, ode_data[1, :]; label = "data") scatter!(plt, tsteps, pred[1, :]; label = "prediction") display(plot(plt)) @@ -57,7 +58,7 @@ callback = function (p, l, pred; doplot = false) end pinit = ComponentArray(p) -callback(pinit, loss_neuralode(pinit)...; doplot = true) +callback((;u = pinit), loss_neuralode(pinit); doplot = true) # use Optimization.jl to solve the problem adtype = Optimization.AutoZygote() @@ -73,7 +74,7 @@ optprob2 = remake(optprob; u0 = result_neuralode.u) result_neuralode2 = Optimization.solve( optprob2, Optim.BFGS(; initial_stepnorm = 0.01); callback, allow_f_increases = false) -callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true) +callback((;u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true) ``` ![Neural ODE](https://user-images.githubusercontent.com/1814174/88589293-e8207f80-d026-11ea-86e2-8a3feb8252ca.gif) @@ -134,7 +135,7 @@ end function loss_neuralode(p) pred = predict_neuralode(p) loss = sum(abs2, ode_data .- pred) - return loss, pred + return loss end ``` @@ -143,10 +144,11 @@ it would show every step and overflow the documentation, but for your use case s ```@example neuralode # Callback function to observe training -callback = function (p, l, pred; doplot = false) +callback = function (state, l; doplot = false) println(l) # plot current prediction against data if doplot + pred = predict_neuralode(state.u) plt = scatter(tsteps, ode_data[1, :]; label = "data") scatter!(plt, tsteps, pred[1, :]; label = "prediction") display(plot(plt)) @@ -155,7 +157,7 @@ callback = function (p, l, pred; doplot = false) end pinit = ComponentArray(p) -callback(pinit, loss_neuralode(pinit)...) +callback((; u = pinit), loss_neuralode(pinit)) ``` We then train the neural network to learn the ODE. @@ -198,8 +200,8 @@ result_neuralode2 = Optimization.solve(optprob2, Optim.BFGS(; initial_stepnorm = And then we use the callback with `doplot=true` to see the final plot: ```@example neuralode -callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true) +callback((; u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true) plt = scatter(tsteps, ode_data[1, :]; label = "data") # hide -scatter!(plt, tsteps, loss_neuralode(result_neuralode2.u)[2][1, :]; label = "prediction") # hide +scatter!(plt, tsteps, predict(result_neuralode2.u)[1, :]; label = "prediction") # hide plt # hide ``` diff --git a/docs/src/examples/neural_ode_weather_forecast.md b/docs/src/examples/neural_ode_weather_forecast.md index 71553a504..8f69acfe5 100644 --- a/docs/src/examples/neural_ode_weather_forecast.md +++ b/docs/src/examples/neural_ode_weather_forecast.md @@ -122,8 +122,8 @@ function train_one_round(node, p, state, y, opt, maxiters, rng, y0 = y[:, 1]; kw end function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing; kwargs...) - log_results(ps, losses) = (p, loss) -> begin - push!(ps, copy(p.u)) + log_results(ps, losses) = (state, loss) -> begin + push!(ps, copy(state.u)) push!(losses, loss) false end diff --git a/docs/src/examples/neural_sde.md b/docs/src/examples/neural_sde.md index 88eddd091..8d895ed16 100644 --- a/docs/src/examples/neural_sde.md +++ b/docs/src/examples/neural_sde.md @@ -119,21 +119,28 @@ end function loss_neuralsde(p; n = 100) u = repeat(reshape(u0, :, 1), 1, n) samples = predict_neuralsde(p, u) - means = mean(samples; dims = 2) - vars = var(samples; dims = 2, mean = means)[:, 1, :] - means = means[:, 1, :] + currmeans = mean(samples; dims = 2) + currvars = var(samples; dims = 2, mean = means)[:, 1, :] + currmeans = currmeans[:, 1, :] loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars) - return loss, means, vars + global means = currmeans + global vars = currvars + return loss end ``` ```@example nsde list_plots = [] iter = 0 +u = repeat(reshape(u0, :, 1), 1, 100) +samples = predict_neuralsde(ps, u) +means = mean(samples; dims = 2) +vars = var(samples; dims = 2, mean = means)[:, 1, :] +means = means[:, 1, :] # Callback function to observe training -callback = function (p, loss, means, vars; doplot = false) - global list_plots, iter +callback = function (state, loss; doplot = false) + global list_plots, iter, means, vars if iter == 0 list_plots = [] diff --git a/docs/src/examples/normalizing_flows.md b/docs/src/examples/normalizing_flows.md index 6d4898ae3..aa88af6cd 100644 --- a/docs/src/examples/normalizing_flows.md +++ b/docs/src/examples/normalizing_flows.md @@ -28,7 +28,7 @@ function loss(θ) return -mean(logpx) end -function cb(p, l) +function cb(state, l) @info "FFJORD Training" loss=l return false end @@ -95,7 +95,7 @@ function loss(θ) return -mean(logpx) end -function cb(p, l) +function cb(state, l) @info "FFJORD Training" loss=loss(p) return false end diff --git a/docs/src/examples/physical_constraints.md b/docs/src/examples/physical_constraints.md index 723344619..d3e0832bc 100644 --- a/docs/src/examples/physical_constraints.md +++ b/docs/src/examples/physical_constraints.md @@ -191,7 +191,7 @@ The optimizer is `BFGS`(see below). The callback function displays the loss during training. ```@example dae2 -callback = function (state, l, pred) #callback function to observe training +callback = function (state, l) #callback function to observe training display(l) return false end From b256e1e906fde48b71e72a20606caa0f3080a442 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Tue, 24 Sep 2024 16:20:17 -0400 Subject: [PATCH 02/12] bump DataInterpolations --- Project.toml | 2 +- docs/src/examples/multiple_shooting.md | 2 +- docs/src/examples/neural_ode.md | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index ca68f6a54..b638ea65a 100644 --- a/Project.toml +++ b/Project.toml @@ -34,7 +34,7 @@ Boltz = "1" ChainRulesCore = "1" ComponentArrays = "0.15.17" ConcreteStructs = "0.2" -DataInterpolations = "5, 6" +DataInterpolations = "6.4" DelayDiffEq = "5.47.3" DiffEqCallbacks = "3.6.2" Distances = "0.10.11" diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index b48dd8e8f..2a6d3ca70 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -86,7 +86,7 @@ group_size = 3 continuity_term = 200 l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, - Tsit5(), group_size; continuity_term) + Tsit5(), group_size; continuity_term) function loss_function(data, pred) return sum(abs2, data - pred) diff --git a/docs/src/examples/neural_ode.md b/docs/src/examples/neural_ode.md index da9174bda..74b97b7d9 100644 --- a/docs/src/examples/neural_ode.md +++ b/docs/src/examples/neural_ode.md @@ -58,7 +58,7 @@ function callback(state, l; doplot = false) end pinit = ComponentArray(p) -callback((;u = pinit), loss_neuralode(pinit); doplot = true) +callback((; u = pinit), loss_neuralode(pinit); doplot = true) # use Optimization.jl to solve the problem adtype = Optimization.AutoZygote() @@ -74,7 +74,7 @@ optprob2 = remake(optprob; u0 = result_neuralode.u) result_neuralode2 = Optimization.solve( optprob2, Optim.BFGS(; initial_stepnorm = 0.01); callback, allow_f_increases = false) -callback((;u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true) +callback((; u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true) ``` ![Neural ODE](https://user-images.githubusercontent.com/1814174/88589293-e8207f80-d026-11ea-86e2-8a3feb8252ca.gif) From d20f849495dba16f4884b002326014f988acb114 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 28 Oct 2024 12:07:40 -0400 Subject: [PATCH 03/12] more callback args fixes --- docs/src/examples/mnist_conv_neural_ode.md | 4 ++-- docs/src/examples/multiple_shooting.md | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/docs/src/examples/mnist_conv_neural_ode.md b/docs/src/examples/mnist_conv_neural_ode.md index e1f9024b4..bff4ecaf9 100644 --- a/docs/src/examples/mnist_conv_neural_ode.md +++ b/docs/src/examples/mnist_conv_neural_ode.md @@ -104,10 +104,10 @@ iter = 0 opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) opt_prob = OptimizationProblem(opt_func, ps, dataloader); -function callback(ps, l, pred) +function callback(state, l) global iter += 1 iter % 10 == 0 && - @info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))" + @info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))" return false end diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 2a6d3ca70..5670b20fb 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -124,14 +124,16 @@ pd, pax = getdata(ps), getaxes(ps) function loss_single_shooting(p) ps = ComponentArray(p, pax) - return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) + global preds = currpred + return loss end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_single_shooting(x), adtype) optprob = Optimization.OptimizationProblem(optf, pd) -res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback) +res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000) gif(anim, "single_shooting.gif"; fps = 15) ``` From fd9a413238c2e4636011cbb0217b1a5e3b0aa560 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Tue, 29 Oct 2024 14:35:48 -0400 Subject: [PATCH 04/12] fix more examples and update gde link --- docs/src/examples/hamiltonian_nn.md | 24 ++++++------- docs/src/examples/mnist_conv_neural_ode.md | 2 +- docs/src/examples/mnist_neural_ode.md | 6 ++-- docs/src/examples/multiple_shooting.md | 42 +++++++++++----------- docs/src/examples/neural_gde.md | 2 +- docs/src/examples/neural_ode.md | 2 +- docs/src/examples/neural_sde.md | 2 +- docs/src/examples/physical_constraints.md | 4 +-- 8 files changed, 41 insertions(+), 43 deletions(-) diff --git a/docs/src/examples/hamiltonian_nn.md b/docs/src/examples/hamiltonian_nn.md index 43f816bda..7399816af 100644 --- a/docs/src/examples/hamiltonian_nn.md +++ b/docs/src/examples/hamiltonian_nn.md @@ -14,7 +14,7 @@ Before getting to the explanation, here's some code to start with. We will follo ```@example hamiltonian_cp using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random, - ComponentArrays, Optimization, OptimizationOptimisers, IterTools + ComponentArrays, Optimization, OptimizationOptimisers, MLUtils t = range(0.0f0, 1.0f0; length = 1024) π_32 = Float32(π) @@ -23,15 +23,11 @@ p_t = reshape(cos.(2π_32 * t), 1, :) dqdt = 2π_32 .* p_t dpdt = -2π_32 .* q_t -data = vcat(q_t, p_t) -target = vcat(dqdt, dpdt) +data = cat(q_t, p_t; dims = 1) +target = cat(dqdt, dpdt; dims = 1) B = 256 NEPOCHS = 100 -dataloader = ncycle( - ((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))), - selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) - for i in 1:(size(data, 2) ÷ B)), - NEPOCHS) +dataloader = DataLoader((data, target); batchsize = B) hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote()) ps, st = Lux.setup(Xoshiro(0), hnn) @@ -39,21 +35,21 @@ ps_c = ps |> ComponentArray opt = OptimizationOptimisers.Adam(0.01f0) -function loss_function(ps, data, target) +function loss_function(ps, databatch) + data, target = databatch pred, st_ = hnn(data, ps, st) return mean(abs2, pred .- target), pred end -function callback(ps, loss, pred) +function callback(st, loss) println("[Hamiltonian NN] Loss: ", loss) return false end -opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target), - Optimization.AutoForwardDiff()) -opt_prob = OptimizationProblem(opt_func, ps_c) +opt_func = OptimizationFunction(loss_function, Optimization.AutoForwardDiff()) +opt_prob = OptimizationProblem(opt_func, ps_c, dataloader) -res = Optimization.solve(opt_prob, opt, dataloader; callback) +res = Optimization.solve(opt_prob, opt; callback) ps_trained = res.u diff --git a/docs/src/examples/mnist_conv_neural_ode.md b/docs/src/examples/mnist_conv_neural_ode.md index bff4ecaf9..80f68bfe6 100644 --- a/docs/src/examples/mnist_conv_neural_ode.md +++ b/docs/src/examples/mnist_conv_neural_ode.md @@ -112,7 +112,7 @@ function callback(state, l) end # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt; maxiters = 5, callback) +res = Optimization.solve(opt_prob, opt; epochs = 5, callback) acc = accuracy(m, dataloader, res.u, st) acc # hide ``` diff --git a/docs/src/examples/mnist_neural_ode.md b/docs/src/examples/mnist_neural_ode.md index 27b78f26c..158cc7a67 100644 --- a/docs/src/examples/mnist_neural_ode.md +++ b/docs/src/examples/mnist_neural_ode.md @@ -87,7 +87,7 @@ function loss_function(ps, data) return logitcrossentropy(pred, y) end -loss_function(ps, x_train1, y_train1) # burn in loss +loss_function(ps, (x_train1, y_train1)) # burn in loss opt = OptimizationOptimisers.Adam(0.05) iter = 0 @@ -103,7 +103,7 @@ function callback(state, l) end # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt; callback, maxiters = 5) +res = Optimization.solve(opt_prob, opt; callback, epochs = 5) accuracy(m, dataloader, res.u, st) ``` @@ -329,6 +329,6 @@ for Neural ODE is given by `nn_ode.p`: ```@example mnist # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt; callback, maxiters = 5) +res = Optimization.solve(opt_prob, opt; callback, epochs = 5) accuracy(m, dataloader, res.u, st) ``` diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 5670b20fb..a7c5f88e4 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -54,6 +54,26 @@ pd, pax = getdata(ps), getaxes(ps) neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps) prob_node = ODEProblem((u, p, t) -> nn(u, p, st)[1], u0, tspan, ComponentArray(p_init)) +# Define parameters for Multiple Shooting +group_size = 3 +continuity_term = 200 + +function loss_function(data, pred) + return sum(abs2, data - pred) +end + +l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + Tsit5(), group_size; continuity_term) + +function loss_multiple_shooting(p) + ps = ComponentArray(p, pax) + + loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + Tsit5(), group_size; continuity_term) + global preds = currpred + return loss +end + function plot_multiple_shoot(plt, preds, group_size) step = group_size - 1 ranges = group_ranges(datasize, group_size) @@ -73,6 +93,8 @@ function callback(state, l; doplot = true, prob_node = prob_node) # plot the original data plt = scatter(tsteps, ode_data[1, :]; label = "Data") # plot the different predictions for individual shoot + l1, preds = multiple_shoot(st.u, ode_data, tsteps, prob_node, loss_function, + Tsit5(), group_size; continuity_term) plot_multiple_shoot(plt, preds, group_size) frame(anim) @@ -81,26 +103,6 @@ function callback(state, l; doplot = true, prob_node = prob_node) return false end -# Define parameters for Multiple Shooting -group_size = 3 -continuity_term = 200 - -l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, - Tsit5(), group_size; continuity_term) - -function loss_function(data, pred) - return sum(abs2, data - pred) -end - -function loss_multiple_shooting(p) - ps = ComponentArray(p, pax) - - loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, - Tsit5(), group_size; continuity_term) - global preds = currpred - return loss -end - adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype) optprob = Optimization.OptimizationProblem(optf, pd) diff --git a/docs/src/examples/neural_gde.md b/docs/src/examples/neural_gde.md index e50c70245..ce0e2c730 100644 --- a/docs/src/examples/neural_gde.md +++ b/docs/src/examples/neural_gde.md @@ -4,7 +4,7 @@ This tutorial has not been ran or updated in awhile. -This tutorial has been adapted from [here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/examples/neural_ode_cora.jl). +This tutorial has been adapted from [here](https://github.com/CarloLucibello/GraphNeuralNetworks.jl/blob/master/GraphNeuralNetworks/examples/neural_ode_cora.jl). In this tutorial, we will use Graph Differential Equations (GDEs) to perform classification on the [CORA Dataset](https://paperswithcode.com/dataset/cora). We shall be using the Graph Neural Networks primitives from the package [GraphNeuralNetworks](https://github.com/CarloLucibello/GraphNeuralNetworks.jl). diff --git a/docs/src/examples/neural_ode.md b/docs/src/examples/neural_ode.md index 74b97b7d9..699eea94f 100644 --- a/docs/src/examples/neural_ode.md +++ b/docs/src/examples/neural_ode.md @@ -202,6 +202,6 @@ And then we use the callback with `doplot=true` to see the final plot: ```@example neuralode callback((; u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true) plt = scatter(tsteps, ode_data[1, :]; label = "data") # hide -scatter!(plt, tsteps, predict(result_neuralode2.u)[1, :]; label = "prediction") # hide +scatter!(plt, tsteps, predict_neuralode(result_neuralode2.u)[1, :]; label = "prediction") # hide plt # hide ``` diff --git a/docs/src/examples/neural_sde.md b/docs/src/examples/neural_sde.md index 8d895ed16..0c67ee2f8 100644 --- a/docs/src/examples/neural_sde.md +++ b/docs/src/examples/neural_sde.md @@ -120,7 +120,7 @@ function loss_neuralsde(p; n = 100) u = repeat(reshape(u0, :, 1), 1, n) samples = predict_neuralsde(p, u) currmeans = mean(samples; dims = 2) - currvars = var(samples; dims = 2, mean = means)[:, 1, :] + currvars = var(samples; dims = 2, mean = currmeans)[:, 1, :] currmeans = currmeans[:, 1, :] loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars) global means = currmeans diff --git a/docs/src/examples/physical_constraints.md b/docs/src/examples/physical_constraints.md index d3e0832bc..430ec4e1e 100644 --- a/docs/src/examples/physical_constraints.md +++ b/docs/src/examples/physical_constraints.md @@ -50,7 +50,7 @@ end function loss_stiff_ndae(p) pred = predict_stiff_ndae(p) loss = sum(abs2, Array(sol_stiff) .- pred) - return loss, pred + return loss end # callback = function (state, l, pred) #callback function to observe training @@ -172,7 +172,7 @@ from these predictions. In this case, we use **least squares** as our loss. function loss_stiff_ndae(p) pred = predict_stiff_ndae(p) loss = sum(abs2, sol_stiff .- pred) - return loss, pred + return loss end l1 = first(loss_stiff_ndae(ComponentArray(pinit))) From 95da38ef6a170ed4a766fd6620ca0484efe05e31 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 31 Oct 2024 07:27:52 -0100 Subject: [PATCH 05/12] Update neural_ode_mm_tests.jl --- test/neural_ode_mm_tests.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/neural_ode_mm_tests.jl b/test/neural_ode_mm_tests.jl index 964ea42e9..0a1acc83a 100644 --- a/test/neural_ode_mm_tests.jl +++ b/test/neural_ode_mm_tests.jl @@ -32,10 +32,10 @@ function loss(p) pred = first(ndae(u₀, p, st)) loss = sum(abs2, Array(sol) .- pred) - return loss, pred + return loss end - cb = function (p, l, pred) + cb = function (state, l) @info "[NeuralODEMM] Loss: $l" return false end From 4426d96cd9167a6048b26fc031895ce58458470f Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 31 Oct 2024 07:37:07 -0100 Subject: [PATCH 06/12] Update multiple_shooting.md --- docs/src/examples/multiple_shooting.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index a7c5f88e4..3e63a46e6 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -93,7 +93,7 @@ function callback(state, l; doplot = true, prob_node = prob_node) # plot the original data plt = scatter(tsteps, ode_data[1, :]; label = "Data") # plot the different predictions for individual shoot - l1, preds = multiple_shoot(st.u, ode_data, tsteps, prob_node, loss_function, + l1, preds = multiple_shoot(state.u, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) plot_multiple_shoot(plt, preds, group_size) From 9de27c645b2bbb17ff8bc688314140ef44c74ed7 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 31 Oct 2024 07:41:40 -0100 Subject: [PATCH 07/12] Update hamiltonian_nn.md --- docs/src/examples/hamiltonian_nn.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/examples/hamiltonian_nn.md b/docs/src/examples/hamiltonian_nn.md index 7399816af..4ec243311 100644 --- a/docs/src/examples/hamiltonian_nn.md +++ b/docs/src/examples/hamiltonian_nn.md @@ -26,7 +26,7 @@ dpdt = -2π_32 .* q_t data = cat(q_t, p_t; dims = 1) target = cat(dqdt, dpdt; dims = 1) B = 256 -NEPOCHS = 100 +NEPOCHS = 1000 dataloader = DataLoader((data, target); batchsize = B) hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote()) @@ -41,7 +41,7 @@ function loss_function(ps, databatch) return mean(abs2, pred .- target), pred end -function callback(st, loss) +function callback(state, loss) println("[Hamiltonian NN] Loss: ", loss) return false end @@ -49,7 +49,7 @@ end opt_func = OptimizationFunction(loss_function, Optimization.AutoForwardDiff()) opt_prob = OptimizationProblem(opt_func, ps_c, dataloader) -res = Optimization.solve(opt_prob, opt; callback) +res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS) ps_trained = res.u From bf53f508b154333e943aa517fc460be589a79a03 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 31 Oct 2024 14:53:53 -0100 Subject: [PATCH 08/12] Update docs/src/examples/hamiltonian_nn.md --- docs/src/examples/hamiltonian_nn.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/examples/hamiltonian_nn.md b/docs/src/examples/hamiltonian_nn.md index 4ec243311..3819d8537 100644 --- a/docs/src/examples/hamiltonian_nn.md +++ b/docs/src/examples/hamiltonian_nn.md @@ -38,7 +38,7 @@ opt = OptimizationOptimisers.Adam(0.01f0) function loss_function(ps, databatch) data, target = databatch pred, st_ = hnn(data, ps, st) - return mean(abs2, pred .- target), pred + return mean(abs2, pred .- target) end function callback(state, loss) From 58427ab5073da72cce1ec534e92c5422be85e3b6 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Sun, 3 Nov 2024 12:07:56 -0100 Subject: [PATCH 09/12] should all be fixed --- docs/src/examples/hamiltonian_nn.md | 14 +++++++------- docs/src/examples/multiple_shooting.md | 2 +- docs/src/examples/neural_sde.md | 18 ++++++++++++------ 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/docs/src/examples/hamiltonian_nn.md b/docs/src/examples/hamiltonian_nn.md index 3819d8537..87e2750f2 100644 --- a/docs/src/examples/hamiltonian_nn.md +++ b/docs/src/examples/hamiltonian_nn.md @@ -26,10 +26,10 @@ dpdt = -2π_32 .* q_t data = cat(q_t, p_t; dims = 1) target = cat(dqdt, dpdt; dims = 1) B = 256 -NEPOCHS = 1000 +NEPOCHS = 500 dataloader = DataLoader((data, target); batchsize = B) -hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote()) +hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote()) ps, st = Lux.setup(Xoshiro(0), hnn) ps_c = ps |> ComponentArray @@ -83,7 +83,7 @@ dpdt = -2π_32 .* q_t data = cat(q_t, p_t; dims = 1) target = cat(dqdt, dpdt; dims = 1) B = 256 -NEPOCHS = 1000 +NEPOCHS = 500 dataloader = DataLoader((data, target); batchsize = B) ``` @@ -92,17 +92,17 @@ dataloader = DataLoader((data, target); batchsize = B) We parameterize the with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization. ```@example hamiltonian -hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote()) +hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote()) ps, st = Lux.setup(Xoshiro(0), hnn) ps_c = ps |> ComponentArray hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st) -opt = OptimizationOptimisers.Adam(0.01f0) +opt = OptimizationOptimisers.Adam(0.005f0) function loss_function(ps, databatch) (data, target) = databatch pred = hnn_stateful(data, ps) - return mean(abs2, pred .- target), pred + return mean(abs2, pred .- target) end function callback(state, loss) @@ -113,7 +113,7 @@ end opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) opt_prob = OptimizationProblem(opt_func, ps_c, dataloader) -res = solve(opt_prob, opt; callback, epochs = NEPOCHS) +res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS) ps_trained = res.u ``` diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 3e63a46e6..73d454173 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -93,7 +93,7 @@ function callback(state, l; doplot = true, prob_node = prob_node) # plot the original data plt = scatter(tsteps, ode_data[1, :]; label = "Data") # plot the different predictions for individual shoot - l1, preds = multiple_shoot(state.u, ode_data, tsteps, prob_node, loss_function, + l1, preds = multiple_shoot(ComponentArray(state.u, pax), ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) plot_multiple_shoot(plt, preds, group_size) diff --git a/docs/src/examples/neural_sde.md b/docs/src/examples/neural_sde.md index 0c67ee2f8..555b1453f 100644 --- a/docs/src/examples/neural_sde.md +++ b/docs/src/examples/neural_sde.md @@ -86,8 +86,8 @@ Let's see what that looks like: # Get the prediction using the correct initial condition prediction0 = neuralsde(u0, ps, st)[1] -drift_model = StatefulLuxLayer{true}(drift_dudt, nothing, st.drift) -diffusion_model = StatefulLuxLayer{true}(diffusion_dudt, nothing, st.diffusion) +drift_model = StatefulLuxLayer{true}(drift_dudt, ps.drift, st.drift) +diffusion_model = StatefulLuxLayer{true}(diffusion_dudt, ps.diffusion, st.diffusion) drift_(u, p, t) = drift_model(u, p.drift) diffusion_(u, p, t) = diffusion_model(u, p.diffusion) @@ -110,7 +110,7 @@ mean and variance from `n` runs at each time point and uses the distance from the data values: ```@example nsde -neuralsde_model = StatefulLuxLayer{true}(neuralsde, nothing, st) +neuralsde_model = StatefulLuxLayer{true}(neuralsde, ps, st) function predict_neuralsde(p, u = u0) return Array(neuralsde_model(u, p)) @@ -122,7 +122,7 @@ function loss_neuralsde(p; n = 100) currmeans = mean(samples; dims = 2) currvars = var(samples; dims = 2, mean = currmeans)[:, 1, :] currmeans = currmeans[:, 1, :] - loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars) + loss = sum(abs2, sde_data - currmeans) + sum(abs2, sde_data_vars - currvars) global means = currmeans global vars = currvars return loss @@ -181,15 +181,21 @@ We resume the training with a larger `n`. (WARNING - this step is a couple of orders of magnitude longer than the previous one). ```@example nsde +opt = OptimizationOptimisers.Adam(0.001) optf2 = Optimization.OptimizationFunction((x, p) -> loss_neuralsde(x; n = 100), adtype) optprob2 = Optimization.OptimizationProblem(optf2, result1.u) -result2 = Optimization.solve(optprob2, opt; callback, maxiters = 20) +result2 = Optimization.solve(optprob2, opt; callback, maxiters = 100) ``` And now we plot the solution to an ensemble of the trained neural SDE: ```@example nsde -_, means, vars = loss_neuralsde(result2.u; n = 1000) +n = 1000 +u = repeat(reshape(u0, :, 1), 1, n) +samples = predict_neuralsde(result2.u) +currmeans = mean(samples; dims = 2) +currvars = var(samples; dims = 2, mean = currmeans)[:, 1, :] +currmeans = currmeans[:, 1, :] plt2 = Plots.scatter(tsteps, sde_data'; yerror = sde_data_vars', label = "data", title = "Neural SDE: After Training", xlabel = "Time") From 2b6461228e4f2bef0a62cbc6d7e994abeba791f7 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 4 Nov 2024 07:51:33 -0100 Subject: [PATCH 10/12] Update multiple_shooting.md --- docs/src/examples/multiple_shooting.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 73d454173..6e3773b55 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -106,7 +106,7 @@ end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype) optprob = Optimization.OptimizationProblem(optf, pd) -res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000) +res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 300) gif(anim, "multiple_shooting.gif"; fps = 15) ``` @@ -135,7 +135,7 @@ end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_single_shooting(x), adtype) optprob = Optimization.OptimizationProblem(optf, pd) -res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000) +res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 300) gif(anim, "single_shooting.gif"; fps = 15) ``` From 08405b3036b6cd1fcc1b660942803fd80e723b18 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 4 Nov 2024 07:53:35 -0100 Subject: [PATCH 11/12] Update docs/src/examples/multiple_shooting.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/examples/multiple_shooting.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 6e3773b55..6d08f5eed 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -93,7 +93,8 @@ function callback(state, l; doplot = true, prob_node = prob_node) # plot the original data plt = scatter(tsteps, ode_data[1, :]; label = "Data") # plot the different predictions for individual shoot - l1, preds = multiple_shoot(ComponentArray(state.u, pax), ode_data, tsteps, prob_node, loss_function, + l1, preds = multiple_shoot( + ComponentArray(state.u, pax), ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) plot_multiple_shoot(plt, preds, group_size) From b8b003da3afbf7ccb3bac61fee03d02ffb9880fa Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Mon, 4 Nov 2024 08:39:35 -0100 Subject: [PATCH 12/12] Update augmented_neural_ode.md --- docs/src/examples/augmented_neural_ode.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/examples/augmented_neural_ode.md b/docs/src/examples/augmented_neural_ode.md index 98261b9d5..6096f597b 100644 --- a/docs/src/examples/augmented_neural_ode.md +++ b/docs/src/examples/augmented_neural_ode.md @@ -95,7 +95,7 @@ optfunc = OptimizationFunction( (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) -res = solve(optprob, opt; callback = cb, epochs = 1000) +res = solve(optprob, opt; callback = cb, epochs = 100) plt_node = plot_contour(model, res.u, st) @@ -109,7 +109,7 @@ optfunc = OptimizationFunction( (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) -res = solve(optprob, opt; callback = cb, epochs = 1000) +res = solve(optprob, opt; callback = cb, epochs = 100) plot_contour(model, res.u, st) ``` @@ -279,7 +279,7 @@ optfunc = OptimizationFunction( (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) -res = solve(optprob, opt; callback = cb, epochs = 1000) +res = solve(optprob, opt; callback = cb, epochs = 100) plot_contour(model, res.u, st) ``` @@ -300,7 +300,7 @@ optfunc = OptimizationFunction( (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) -res = solve(optprob, opt; callback = cb, epochs = 1000) +res = solve(optprob, opt; callback = cb, epochs = 100) plot_contour(model, res.u, st) ```