From acd0f02c908cbb48ea093715bf640ea35d523f01 Mon Sep 17 00:00:00 2001 From: Vaibhav Kumar Dixit Date: Wed, 27 Dec 2023 15:55:53 -0500 Subject: [PATCH] Update solve.jl --- src/solve.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/solve.jl b/src/solve.jl index f9c98c32f..a725f646b 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -51,19 +51,19 @@ The callback function `callback` is a function which is called after every optim step. Its signature is: ```julia -callback = (u, loss_val, other_args) -> false +callback = (u, loss_val, other_args...) -> false ``` where `u` and `loss_val` are the current optimization variables and loss/objective value in the optimization loop and `other_args` can be the extra things returned from -the optimization `f`. This allows for saving values from the optimization and -using them for plotting and display without recalculating. The callback should -return a Boolean value, and the default should be `false`, such that the optimization -gets stopped if it returns `true`. +the optimization `f` or the state of the optimizer and iteration count. This allows +for saving values from the optimization and using them for plotting and display without +recalculating. The callback should return a Boolean value, and the default should be `false`, +such that the optimization gets stopped if it returns `true`. ### Callback Example -Here we show an example a callback function that plots the prediction at the current value of the optimization variables. +Here we show an example of a callback function that plots the prediction at the current value of the optimization variables. The loss function here returns the loss and the prediction i.e. the solution of the `ODEProblem` `prob`, so we can use the prediction in the callback. ```julia @@ -76,7 +76,7 @@ function loss(u, p) sum(abs2, batch .- pred), pred end -callback = function (p, l, pred; doplot = false) #callback function to observe training +callback = function (p, l, pred, args...; doplot = false) #callback function to observe training display(l) # plot current prediction against data if doplot