Skip to content

Commit

Permalink
more callback args fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Oct 28, 2024
1 parent b256e1e commit 70c4fb3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions docs/src/examples/mnist_conv_neural_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ iter = 0
opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps, dataloader);
function callback(ps, l, pred)
function callback(state, l)
global iter += 1
iter % 10 == 0 &&
@info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, ps.u, st))"
@info "[MNIST Conv GPU] Accuracy: $(accuracy(m, dataloader, state.u, st))"
return false
end
Expand Down
8 changes: 5 additions & 3 deletions docs/src/examples/multiple_shooting.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ group_size = 3
continuity_term = 200
l1, preds = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
Tsit5(), group_size; continuity_term)
function loss_function(data, pred)
return sum(abs2, data - pred)
Expand Down Expand Up @@ -124,14 +124,16 @@ pd, pax = getdata(ps), getaxes(ps)
function loss_single_shooting(p)
ps = ComponentArray(p, pax)
return multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
loss, currpred = multiple_shoot(ps, ode_data, tsteps, prob_node, loss_function,
Tsit5(), group_size; continuity_term)
global preds = currpred
return loss
end
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_single_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pd)
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback)
res_ms = Optimization.solve(optprob, PolyOpt(); callback = callback, maxiters = 5000)
gif(anim, "single_shooting.gif"; fps = 15)
```

Expand Down

0 comments on commit 70c4fb3

Please sign in to comment.