From e4bd1af6146cacb37b32f4feffdec25d40d2b33e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 26 Sep 2024 03:38:26 -0400 Subject: [PATCH] fix: patch optimization tutorial --- examples/OptimizationIntegration/main.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/OptimizationIntegration/main.jl b/examples/OptimizationIntegration/main.jl index 7c617348a..3bc227963 100644 --- a/examples/OptimizationIntegration/main.jl +++ b/examples/OptimizationIntegration/main.jl @@ -117,15 +117,13 @@ function train_model(dataloader) res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, maxiters=epochs) ## Let's finetune a bit with L-BFGS - opt_prob = remake(opt_prob; u0=res_adam.u) + opt_prob = OptimizationProblem(opt_func, res_adam.u, (gdev(ode_data), TimeWrapper(t))) res_lbfgs = solve(opt_prob, LBFGS(); callback, maxiters=epochs) ## Now that we have a good fit, let's train it on the entire dataset without ## Minibatching. We need to do this since ODE solves can lead to accumulated errors if ## the model was trained on individual parts (without a data-shooting approach). - opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote()) - opt_prob = OptimizationProblem(opt_func, res_lbfgs.u, (gdev(ode_data), TimeWrapper(t))) - + opt_prob = remake(opt_prob; u0=res_lbfgs.u) res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback) return StatefulLuxLayer{true}(model, res.u, smodel.st)