-
-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improving Zygote autodiff perfmance #291
Conversation
Trying the example in the issue I get this error and I can't find out why:
|
At least now trying to call |
I finally got there for using DataInterpolations
using Zygote
using Plots
using Random
using ColorSchemes
Random.seed!(3)
pl = plot()
n = 10
t = collect(1.0:10)
u_exact = rand(n)
A_exact = LinearInterpolation(u_exact, t)
t_data = 1 .+ 9rand(10n)
u_data = A_exact.(t_data) .+ (rand(10n) .- 0.5) / 8
scatter!(t_data, u_data; label = "perturbed data")
u_fit = rand(n)
function loss(u_fit)
A = LinearInterpolation(u_fit, t)
values = A.(t_data)
return sum((values - u_data) .^ 2)
end
lr = 1e-2
N = 10
for (i, color) in enumerate(cgrad(:jet, range(0, 1, length = N)))
∇loss = only(gradient(loss, u_fit))
u_fit .-= lr * ∇loss
loss_it = loss(u_fit, u_data, t_data, t)
plot!(LinearInterpolation(u_fit, t); color, label = "Iteration $i, loss = $(round(loss_it, digits = 3))")
end
pl |
@ChrisRackauckas @marcobonici My findings so far: Trying to calculate the gradient w.r.t. Apart from some minor refactors, the main thing I had to do to make the code compatible with |
Also got it to work for using DataInterpolations
using Zygote
using ForwardDiff
using Plots
using Random
using ColorSchemes
Random.seed!(3)
pl = plot(legendfontsize = 5)
method = QuadraticSpline
n = 10
t = collect(1.0:10)
u_exact = rand(n)
A_exact = method(u_exact, t)
t_data = 1 .+ 9rand(10n)
u_data = A_exact.(t_data) .+ (rand(10n) .- 0.5) / 10
scatter!(t_data, u_data; label = "perturbed data")
u_fit = rand(n)
function loss(u_fit)
A = method(u_fit, t)
values = A.(t_data)
return sum((values - u_data) .^ 2)
end
lr = 1e-3
N = 200
plot_update = 10
for (i, color) in enumerate(cgrad(:jet, range(0, 1, length = N)))
∇loss = only(Zygote.gradient(loss, u_fit))
∇loss_fd = ForwardDiff.gradient(loss, u_fit)
u_fit .-= lr * ∇loss
loss_it = loss(u_fit)
if i % plot_update == 1
plot!(
method(u_fit, t);
color,
label = "Iteration $i, loss = $(round(loss_it, digits = 3))",
)
end
@assert ∇loss ≈ ∇loss_fd
end
pl |
@marcobonici this is already an order of magnitude better than what you reported for the gradient in your example:
However, it leaves much to be desired. Most time is spent on calculating the cached parameters in the constructor, which it seems |
Have some more speedup:
|
Closed in favor of #315 |
Fixes #289.
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.