You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am building a Universal differential equation model. This is supossed to train on multiple input conditions at the same time. I could obviously iterate through the different conditions in my loss function. Add them all up and compute the gradient of that loss. But since the computation of that gradient is rather time consuming (stiff ODE), I would like to compute a gradient for each set. And then take the mean gradient and apply it to my parameters.
Since the whole NODE part seems irrelevant for this problem, I have constructed a MWE which computes the gradients of individual batches.
using Lux
using ComponentArrays
using Random
using Zygote
using LinearAlgebra
using Optimization, OptimizationOptimisers
rng = Random.default_rng()
Random.seed!(rng,42)
Random.TaskLocalRNG()
nx = 5
ny = 2
model = Lux.Chain(
Lux.Dense(nx,10,Lux.tanh),
Lux.Dense(10,ny))
ps,st = Lux.setup(rng,model)
ps = ComponentArray(ps)
n_batches = 10
per_batch = 50
X_batches = rand(rng,n_batches,per_batch,nx)
Y_batches = rand(rng,n_batches,per_batch,ny)
function batch_loss(p,x,y)
y_pred = hcat([Lux.apply(model,e,p,st)[1] for e in eachrow(x)]...)
loss = sum(abs.(y_pred' .- y))
return loss
end
function single_gradient(p,x,y)
f(p)=batch_loss(p,x,y)
return Zygote.gradient(f,p)
end
function gather_gradients(p)
gradients = []
for i in 1:n_batches
push!(gradients, single_gradient(p,X_batches[i,:,:],Y_batches[i,:,:]))
end
# return mean gradient of all batches
end
opt = Adam(0.01)
Could someone help me finish this example on how to efficentily compute the mean gradient? And how would I pass this gradient to an Optimizer such as Adam?
I am still pretty new to Julia so any advice on how to optimise my approach and code is very welcome.
Thank you!
The text was updated successfully, but these errors were encountered:
I am building a Universal differential equation model. This is supossed to train on multiple input conditions at the same time. I could obviously iterate through the different conditions in my loss function. Add them all up and compute the gradient of that loss. But since the computation of that gradient is rather time consuming (stiff ODE), I would like to compute a gradient for each set. And then take the mean gradient and apply it to my parameters.
Since the whole NODE part seems irrelevant for this problem, I have constructed a MWE which computes the gradients of individual batches.
Could someone help me finish this example on how to efficentily compute the mean gradient? And how would I pass this gradient to an Optimizer such as Adam?
I am still pretty new to Julia so any advice on how to optimise my approach and code is very welcome.
Thank you!
The text was updated successfully, but these errors were encountered: