diff --git a/Project.toml b/Project.toml index aee97ae05..ca68f6a54 100644 --- a/Project.toml +++ b/Project.toml @@ -54,9 +54,9 @@ LuxLib = "1.2" NNlib = "0.9.22" OneHotArrays = "0.2.5" Optimisers = "0.3" -Optimization = "3.25.0" -OptimizationOptimJL = "0.3.0" -OptimizationOptimisers = "0.2.1" +Optimization = "4" +OptimizationOptimJL = "0.4" +OptimizationOptimisers = "0.3" OrdinaryDiffEq = "6.76.0" Printf = "1.10" Random = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 09aa6590e..ee0dbb7d9 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -57,10 +57,10 @@ MLUtils = "0.4" NNlib = "0.9" OneHotArrays = "0.2" Optimisers = "0.3" -Optimization = "3.9" -OptimizationOptimJL = "0.2, 0.3" -OptimizationOptimisers = "0.2" -OptimizationPolyalgorithms = "0.2" +Optimization = "4" +OptimizationOptimJL = "0.4" +OptimizationOptimisers = "0.3" +OptimizationPolyalgorithms = "0.3" OrdinaryDiffEq = "6.31" Plots = "1.36" Printf = "1" diff --git a/docs/src/examples/augmented_neural_ode.md b/docs/src/examples/augmented_neural_ode.md index a0f9a6a0f..98261b9d5 100644 --- a/docs/src/examples/augmented_neural_ode.md +++ b/docs/src/examples/augmented_neural_ode.md @@ -69,13 +69,13 @@ function plot_contour(model, ps, st, npoints = 300) return contour(x, y, sol; fill = true, linewidth = 0.0) end -loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2) +loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2) dataloader = concentric_sphere( 2, (0.0f0, 2.0f0), (3.0f0, 4.0f0), 2000, 2000; batch_size = 256) iter = 0 -cb = function (ps, l) +cb = function (state, l) global iter iter += 1 if iter % 10 == 0 @@ -87,15 +87,15 @@ end model, ps, st = construct_model(1, 2, 64, 0) opt = OptimizationOptimisers.Adam(0.005) -loss_node(model, dataloader.data[1], dataloader.data[2], ps, st) +loss_node(model, (dataloader.data[1], dataloader.data[2]), ps, st) println("Training Neural ODE") optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 1000) plt_node = plot_contour(model, res.u, st) @@ -106,10 +106,10 @@ println() println("Training Augmented Neural ODE") optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 1000) plot_contour(model, res.u, st) ``` @@ -229,7 +229,7 @@ We use the L2 distance between the model prediction `model(x)` and the actual pr optimization objective. ```@example augneuralode -loss_node(model, x, y, ps, st) = mean((first(model(x, ps, st)) .- y) .^ 2) +loss_node(model, data, ps, st) = mean((first(model(data[1], ps, st)) .- data[2]) .^ 2) ``` #### Dataset @@ -248,7 +248,7 @@ Additionally, we define a callback function which displays the total loss at spe ```@example augneuralode iter = 0 -cb = function (ps, l) +cb = function (state, l) global iter iter += 1 if iter % 10 == 0 @@ -276,10 +276,10 @@ for `20` epochs. model, ps, st = construct_model(1, 2, 64, 0) optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 1000) plot_contour(model, res.u, st) ``` @@ -297,10 +297,10 @@ a function which can be expressed by the neural ode. For more details and proofs model, ps, st = construct_model(1, 2, 64, 1) optfunc = OptimizationFunction( - (x, p, data, target) -> loss_node(model, data, target, x, st), + (x, data) -> loss_node(model, data, x, st), Optimization.AutoZygote()) -optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev) -res = solve(optprob, opt, IterTools.ncycle(dataloader, 5); callback = cb) +optprob = OptimizationProblem(optfunc, ComponentArray(ps |> cdev) |> gdev, dataloader) +res = solve(optprob, opt; callback = cb, epochs = 1000) plot_contour(model, res.u, st) ``` diff --git a/docs/src/examples/hamiltonian_nn.md b/docs/src/examples/hamiltonian_nn.md index 9c1716bad..43f816bda 100644 --- a/docs/src/examples/hamiltonian_nn.md +++ b/docs/src/examples/hamiltonian_nn.md @@ -75,7 +75,7 @@ The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we ```@example hamiltonian using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random, - ComponentArrays, Optimization, OptimizationOptimisers, IterTools + ComponentArrays, Optimization, OptimizationOptimisers, MLUtils t = range(0.0f0, 1.0f0; length = 1024) π_32 = Float32(π) @@ -87,12 +87,8 @@ dpdt = -2π_32 .* q_t data = cat(q_t, p_t; dims = 1) target = cat(dqdt, dpdt; dims = 1) B = 256 -NEPOCHS = 100 -dataloader = ncycle( - ((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))), - selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2))))) - for i in 1:(size(data, 2) ÷ B)), - NEPOCHS) +NEPOCHS = 1000 +dataloader = DataLoader((data, target); batchsize = B) ``` ### Training the HamiltonianNN @@ -103,24 +99,25 @@ We parameterize the with a small MultiLayered Perceptron. HNNs are trained by o hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote()) ps, st = Lux.setup(Xoshiro(0), hnn) ps_c = ps |> ComponentArray +hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st) opt = OptimizationOptimisers.Adam(0.01f0) -function loss_function(ps, data, target) - pred, st_ = hnn(data, ps, st) +function loss_function(ps, databatch) + (data, target) = databatch + pred = hnn_stateful(data, ps) return mean(abs2, pred .- target), pred end -function callback(ps, loss, pred) +function callback(state, loss) println("[Hamiltonian NN] Loss: ", loss) return false end -opt_func = OptimizationFunction( - (ps, _, data, target) -> loss_function(ps, data, target), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps_c) +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps_c, dataloader) -res = solve(opt_prob, opt, dataloader; callback) +res = solve(opt_prob, opt; callback, epochs = NEPOCHS) ps_trained = res.u ``` diff --git a/docs/src/examples/mnist_conv_neural_ode.md b/docs/src/examples/mnist_conv_neural_ode.md index f67262780..e1f9024b4 100644 --- a/docs/src/examples/mnist_conv_neural_ode.md +++ b/docs/src/examples/mnist_conv_neural_ode.md @@ -89,20 +89,20 @@ end # burn in accuracy accuracy(m, ((img, lab),), ps, st) -function loss_function(ps, x, y) +function loss_function(ps, data) + (x, y) = data pred, _ = m(x, ps, st) - return logitcrossentropy(pred, y), pred + return logitcrossentropy(pred, y) end # burn in loss -loss_function(ps, img, lab) +loss_function(ps, (img, lab)) opt = OptimizationOptimisers.Adam(0.005) iter = 0 -opt_func = OptimizationFunction( - (ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps); +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps, dataloader); function callback(ps, l, pred) global iter += 1 @@ -112,7 +112,7 @@ function callback(ps, l, pred) end # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt, dataloader; maxiters = 5, callback) +res = Optimization.solve(opt_prob, opt; maxiters = 5, callback) acc = accuracy(m, dataloader, res.u, st) acc # hide ``` diff --git a/docs/src/examples/mnist_neural_ode.md b/docs/src/examples/mnist_neural_ode.md index 349dbcff2..27b78f26c 100644 --- a/docs/src/examples/mnist_neural_ode.md +++ b/docs/src/examples/mnist_neural_ode.md @@ -81,9 +81,10 @@ end accuracy(m, ((x_train1, y_train1),), ps, st) # burn in accuracy -function loss_function(ps, x, y) +function loss_function(ps, data) + (x, y) = data pred, st_ = m(x, ps, st) - return logitcrossentropy(pred, y), pred + return logitcrossentropy(pred, y) end loss_function(ps, x_train1, y_train1) # burn in loss @@ -91,19 +92,18 @@ loss_function(ps, x_train1, y_train1) # burn in loss opt = OptimizationOptimisers.Adam(0.05) iter = 0 -opt_func = OptimizationFunction( - (ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps) +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps, dataloader) -function callback(ps, l, pred) +function callback(state, l) global iter += 1 iter % 10 == 0 && - @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))" + @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))" return false end # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5) +res = Optimization.solve(opt_prob, opt; callback, maxiters = 5) accuracy(m, dataloader, res.u, st) ``` @@ -285,12 +285,13 @@ final output of our model. `logitcrossentropy` takes in the prediction from our model `model(x)` and compares it to actual output `y`: ```@example mnist -function loss_function(ps, x, y) +function loss_function(ps, data) + (x, y) = data pred, st_ = m(x, ps, st) - return logitcrossentropy(pred, y), pred + return logitcrossentropy(pred, y) end -loss_function(ps, x_train1, y_train1) # burn in loss +loss_function(ps, (x_train1, y_train1)) # burn in loss ``` #### Optimizer @@ -309,14 +310,13 @@ This callback function is used to print both the training and testing accuracy a ```@example mnist iter = 0 -opt_func = OptimizationFunction( - (ps, _, x, y) -> loss_function(ps, x, y), Optimization.AutoZygote()) -opt_prob = OptimizationProblem(opt_func, ps) +opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) +opt_prob = OptimizationProblem(opt_func, ps, dataloader) -function callback(ps, l, pred) +function callback(state, l) global iter += 1 iter % 10 == 0 && - @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))" + @info "[MNIST GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))" return false end ``` @@ -329,6 +329,6 @@ for Neural ODE is given by `nn_ode.p`: ```@example mnist # Train the NN-ODE and monitor the loss and weights. -res = Optimization.solve(opt_prob, opt, dataloader; callback, maxiters = 5) +res = Optimization.solve(opt_prob, opt; callback, maxiters = 5) accuracy(m, dataloader, res.u, st) ``` diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 2b20219f5..b48dd8e8f 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -48,6 +48,9 @@ ode_data = Array(solve(prob_trueode, Tsit5(); saveat = tsteps)) nn = Chain(x -> x .^ 3, Dense(2, 16, tanh), Dense(16, 2)) p_init, st = Lux.setup(rng, nn) +ps = ComponentArray(p_init) +pd, pax = getdata(ps), getaxes(ps) + neuralode = NeuralODE(nn, tspan, Tsit5(); saveat = tsteps) prob_node = ODEProblem((u, p, t) -> nn(u, p, st)[1], u0, tspan, ComponentArray(p_init)) @@ -62,14 +65,13 @@ end anim = Plots.Animation() iter = 0 -callback = function (p, l, preds; doplot = true) +function callback(state, l; doplot = true, prob_node = prob_node) display(l) global iter iter += 1 if doplot && iter % 1 == 0 # plot the original data plt = scatter(tsteps, ode_data[1, :]; label = "Data") - # plot the different predictions for individual shoot plot_multiple_shoot(plt, preds, group_size) @@ -83,23 +85,26 @@ end group_size = 3 continuity_term = 200 +l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + Tsit5(), group_size; continuity_term) + function loss_function(data, pred) return sum(abs2, data - pred) end -ps = ComponentArray(p_init) -pd, pax = getdata(ps), getaxes(ps) - function loss_multiple_shooting(p) ps = ComponentArray(p, pax) - return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + + loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) + global preds = currpred + return loss end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_multiple_shooting(x), adtype) optprob = Optimization.OptimizationProblem(optf, pd) -res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback) +res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000) gif(anim, "multiple_shooting.gif"; fps = 15) ``` diff --git a/docs/src/examples/neural_ode.md b/docs/src/examples/neural_ode.md index c1dda9248..da9174bda 100644 --- a/docs/src/examples/neural_ode.md +++ b/docs/src/examples/neural_ode.md @@ -40,15 +40,16 @@ end function loss_neuralode(p) pred = predict_neuralode(p) loss = sum(abs2, ode_data .- pred) - return loss, pred + return loss end # Do not plot by default for the documentation # Users should change doplot=true to see the plots callbacks -callback = function (p, l, pred; doplot = false) +function callback(state, l; doplot = false) println(l) # plot current prediction against data if doplot + pred = predict_neuralode(state.u) plt = scatter(tsteps, ode_data[1, :]; label = "data") scatter!(plt, tsteps, pred[1, :]; label = "prediction") display(plot(plt)) @@ -57,7 +58,7 @@ callback = function (p, l, pred; doplot = false) end pinit = ComponentArray(p) -callback(pinit, loss_neuralode(pinit)...; doplot = true) +callback((;u = pinit), loss_neuralode(pinit); doplot = true) # use Optimization.jl to solve the problem adtype = Optimization.AutoZygote() @@ -73,7 +74,7 @@ optprob2 = remake(optprob; u0 = result_neuralode.u) result_neuralode2 = Optimization.solve( optprob2, Optim.BFGS(; initial_stepnorm = 0.01); callback, allow_f_increases = false) -callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true) +callback((;u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true) ``` ![Neural ODE](https://user-images.githubusercontent.com/1814174/88589293-e8207f80-d026-11ea-86e2-8a3feb8252ca.gif) @@ -134,7 +135,7 @@ end function loss_neuralode(p) pred = predict_neuralode(p) loss = sum(abs2, ode_data .- pred) - return loss, pred + return loss end ``` @@ -143,10 +144,11 @@ it would show every step and overflow the documentation, but for your use case s ```@example neuralode # Callback function to observe training -callback = function (p, l, pred; doplot = false) +callback = function (state, l; doplot = false) println(l) # plot current prediction against data if doplot + pred = predict_neuralode(state.u) plt = scatter(tsteps, ode_data[1, :]; label = "data") scatter!(plt, tsteps, pred[1, :]; label = "prediction") display(plot(plt)) @@ -155,7 +157,7 @@ callback = function (p, l, pred; doplot = false) end pinit = ComponentArray(p) -callback(pinit, loss_neuralode(pinit)...) +callback((; u = pinit), loss_neuralode(pinit)) ``` We then train the neural network to learn the ODE. @@ -198,8 +200,8 @@ result_neuralode2 = Optimization.solve(optprob2, Optim.BFGS(; initial_stepnorm = And then we use the callback with `doplot=true` to see the final plot: ```@example neuralode -callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot = true) +callback((; u = result_neuralode2.u), loss_neuralode(result_neuralode2.u); doplot = true) plt = scatter(tsteps, ode_data[1, :]; label = "data") # hide -scatter!(plt, tsteps, loss_neuralode(result_neuralode2.u)[2][1, :]; label = "prediction") # hide +scatter!(plt, tsteps, predict(result_neuralode2.u)[1, :]; label = "prediction") # hide plt # hide ``` diff --git a/docs/src/examples/neural_ode_weather_forecast.md b/docs/src/examples/neural_ode_weather_forecast.md index 71553a504..8f69acfe5 100644 --- a/docs/src/examples/neural_ode_weather_forecast.md +++ b/docs/src/examples/neural_ode_weather_forecast.md @@ -122,8 +122,8 @@ function train_one_round(node, p, state, y, opt, maxiters, rng, y0 = y[:, 1]; kw end function train(t, y, obs_grid, maxiters, lr, rng, p = nothing, state = nothing; kwargs...) - log_results(ps, losses) = (p, loss) -> begin - push!(ps, copy(p.u)) + log_results(ps, losses) = (state, loss) -> begin + push!(ps, copy(state.u)) push!(losses, loss) false end diff --git a/docs/src/examples/neural_sde.md b/docs/src/examples/neural_sde.md index 88eddd091..8d895ed16 100644 --- a/docs/src/examples/neural_sde.md +++ b/docs/src/examples/neural_sde.md @@ -119,21 +119,28 @@ end function loss_neuralsde(p; n = 100) u = repeat(reshape(u0, :, 1), 1, n) samples = predict_neuralsde(p, u) - means = mean(samples; dims = 2) - vars = var(samples; dims = 2, mean = means)[:, 1, :] - means = means[:, 1, :] + currmeans = mean(samples; dims = 2) + currvars = var(samples; dims = 2, mean = means)[:, 1, :] + currmeans = currmeans[:, 1, :] loss = sum(abs2, sde_data - means) + sum(abs2, sde_data_vars - vars) - return loss, means, vars + global means = currmeans + global vars = currvars + return loss end ``` ```@example nsde list_plots = [] iter = 0 +u = repeat(reshape(u0, :, 1), 1, 100) +samples = predict_neuralsde(ps, u) +means = mean(samples; dims = 2) +vars = var(samples; dims = 2, mean = means)[:, 1, :] +means = means[:, 1, :] # Callback function to observe training -callback = function (p, loss, means, vars; doplot = false) - global list_plots, iter +callback = function (state, loss; doplot = false) + global list_plots, iter, means, vars if iter == 0 list_plots = [] diff --git a/docs/src/examples/normalizing_flows.md b/docs/src/examples/normalizing_flows.md index 6d4898ae3..aa88af6cd 100644 --- a/docs/src/examples/normalizing_flows.md +++ b/docs/src/examples/normalizing_flows.md @@ -28,7 +28,7 @@ function loss(θ) return -mean(logpx) end -function cb(p, l) +function cb(state, l) @info "FFJORD Training" loss=l return false end @@ -95,7 +95,7 @@ function loss(θ) return -mean(logpx) end -function cb(p, l) +function cb(state, l) @info "FFJORD Training" loss=loss(p) return false end diff --git a/docs/src/examples/physical_constraints.md b/docs/src/examples/physical_constraints.md index 723344619..d3e0832bc 100644 --- a/docs/src/examples/physical_constraints.md +++ b/docs/src/examples/physical_constraints.md @@ -191,7 +191,7 @@ The optimizer is `BFGS`(see below). The callback function displays the loss during training. ```@example dae2 -callback = function (state, l, pred) #callback function to observe training +callback = function (state, l) #callback function to observe training display(l) return false end