From 70c4fb387f59806843558114f6abf421d2faa1f7 Mon Sep 17 00:00:00 2001 From: Vaibhav Dixit Date: Mon, 28 Oct 2024 12:07:40 -0400 Subject: [PATCH] more callback args fixes --- docs/src/examples/mnist_conv_neural_ode.md | 4 ++-- docs/src/examples/multiple_shooting.md | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/src/examples/mnist_conv_neural_ode.md b/docs/src/examples/mnist_conv_neural_ode.md index e1f9024b4..bff4ecaf9 100644 --- a/docs/src/examples/mnist_conv_neural_ode.md +++ b/docs/src/examples/mnist_conv_neural_ode.md @@ -104,10 +104,10 @@ iter = 0 opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote()) opt_prob = OptimizationProblem(opt_func, ps, dataloader); -function callback(ps, l, pred) +function callback(state, l) global iter += 1 iter % 10 == 0 && - @info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))" + @info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))" return false end diff --git a/docs/src/examples/multiple_shooting.md b/docs/src/examples/multiple_shooting.md index 2a6d3ca70..c6ca773ab 100644 --- a/docs/src/examples/multiple_shooting.md +++ b/docs/src/examples/multiple_shooting.md @@ -86,7 +86,7 @@ group_size = 3 continuity_term = 200 l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, - Tsit5(), group_size; continuity_term) + Tsit5(), group_size; continuity_term) function loss_function(data, pred) return sum(abs2, data - pred) @@ -124,14 +124,16 @@ pd, pax = getdata(ps), getaxes(ps) function loss_single_shooting(p) ps = ComponentArray(p, pax) - return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, + loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function, Tsit5(), group_size; continuity_term) + global preds = currpred + return loss end adtype = Optimization.AutoZygote() optf = Optimization.OptimizationFunction((x, p) -> loss_single_shooting(x), adtype) optprob = Optimization.OptimizationProblem(optf, pd) -res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback) +res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000) gif(anim, "single_shooting.gif"; fps = 15) ```