Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update docs remove extra returns from loss and extra args from callback #249

Merged
merged 7 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 12 additions & 17 deletions docs/src/getting_started/fit_simulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ function loss(newp)
newprob = remake(prob, p = newp)
sol = solve(newprob, saveat = 1)
loss = sum(abs2, sol .- xy_data)
return loss, sol
return loss
end

# Define a callback function to monitor optimization progress
function callback(p, l, sol)
function callback(state, l)
display(l)
newprob = remake(prob, p = state.u)
sol = solve(newprob, saveat = 1)
Comment on lines +108 to +109
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't recompute

plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
display(plt)
Expand Down Expand Up @@ -278,37 +280,28 @@ function loss(newp)
newprob = remake(prob, p = newp)
sol = solve(newprob, saveat = 1)
l = sum(abs2, sol .- xy_data)
return l, sol
return l
end
```

Notice that our loss function returns the loss value as the first return,
but returns extra information (the ODE solution with the new parameters)
as an extra return argument.
We will explain why this extra return information is helpful in the next section.

### Step 5: Solve the Optimization Problem

This step will look very similar to [the first optimization tutorial](@ref first_opt),
except now we have a new loss function `loss` which returns both the loss value
and the associated ODE solution.
(In the previous tutorial, `L` only returned the loss value.)
The `Optimization.solve` function can accept an optional callback function
to monitor the optimization process using extra arguments returned from `loss`.

The callback syntax is always:

```
callback(
optimization variables,
state,
the current loss value,
other arguments returned from the loss function, ...
)
```

In this case, we will provide the callback the arguments `(p, l, sol)`,
since it always takes the current state of the optimization first (`p`)
then the returns from the loss function (`l, sol`).
In this case, we will provide the callback the arguments `(state, l)`,
since it always takes the current state of the optimization first (`state`)
then the current loss value (`l`).
The return value of the callback function should default to `false`.
`Optimization.solve` will halt if/when the callback function returns `true` instead.
Typically the `return` statement would monitor the loss value
Expand All @@ -318,8 +311,10 @@ More details about callbacks in Optimization.jl can be found
[here](https://docs.sciml.ai/Optimization/stable/API/solve/).

```@example odefit
function callback(p, l, sol)
function callback(state, l)
display(l)
newprob = remake(prob, p = state.u)
sol = solve(newprob, saveat = 1)
Comment on lines +316 to +317
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't recompute

plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
display(plt)
Expand Down
9 changes: 4 additions & 5 deletions docs/src/showcase/blackhole.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,8 @@ function loss(NN_params)
prob_nn, RK4(), u0 = u0, p = NN_params, saveat = tsteps, dt = dt, adaptive = false))
pred_waveform = compute_waveform(dt_data, pred, mass_ratio, model_params)[1]
loss = (sum(abs2,
view(waveform, obs_to_use_for_training) .-
view(pred_waveform, obs_to_use_for_training)))
return loss, pred_waveform
loss = ( sum(abs2, view(waveform,obs_to_use_for_training) .- view(pred_waveform,obs_to_use_for_training) ) )
return loss
end
```

Expand All @@ -508,10 +506,11 @@ We'll use the following callback to save the history of the loss values.
```@example ude
losses = []
callback(θ, l, pred_waveform; doplot = true) = begin
callback(state, l; doplot = true) = begin
push!(losses, l)
#= Disable plotting as it trains since in docs
display(l)
waveform = compute_waveform(dt_data, soln, mass_ratio, model_params)[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't recompute

# plot current prediction against data
plt = plot(tsteps, waveform,
markershape=:circle, markeralpha = 0.25,
Expand Down
2 changes: 1 addition & 1 deletion docs/src/showcase/missing_physics.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ current loss:
```@example ude
losses = Float64[]
callback = function (p, l)
callback = function (state, l)
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
Expand Down
2 changes: 1 addition & 1 deletion docs/src/showcase/pinngpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ prob = discretize(pde_system, discretization)
## Step 6: Solve the Optimization Problem

```@example pinn
callback = function (p, l)
callback = function (state, l)
println("Current loss is: $l")
return false
end
Expand Down