diff --git a/src/solve.jl b/src/solve.jl index f9c98c32f..0f3660bdc 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -51,15 +51,17 @@ The callback function `callback` is a function which is called after every optim step. Its signature is: ```julia -callback = (u, loss_val, other_args) -> false +callback = (state, loss_val, other_args) -> false ``` -where `u` and `loss_val` are the current optimization variables and loss/objective value -in the optimization loop and `other_args` can be the extra things returned from -the optimization `f`. This allows for saving values from the optimization and +where `state` is a `OptimizationState` and stores information for the current +iteration of the solver and `loss_val` is loss/objective value. For more +information about the fields of the `state` look at the `OptimizationState` +documentation. The `other_args` can be the extra things returned from the +optimization `f`. This allows for saving values from the optimization and using them for plotting and display without recalculating. The callback should -return a Boolean value, and the default should be `false`, such that the optimization -gets stopped if it returns `true`. +return a Boolean value, and the default should be `false`, such that the +optimization gets stopped if it returns `true`. ### Callback Example @@ -76,7 +78,7 @@ function loss(u, p) sum(abs2, batch .- pred), pred end -callback = function (p, l, pred; doplot = false) #callback function to observe training +callback = function (state, l, pred; doplot = false) #callback function to observe training display(l) # plot current prediction against data if doplot