Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training in batches and building gradient as mean of individual gradients #481

Closed
mayor-slash opened this issue Jan 26, 2024 · 1 comment

Comments

@mayor-slash
Copy link

I am building a Universal differential equation model. This is supossed to train on multiple input conditions at the same time. I could obviously iterate through the different conditions in my loss function. Add them all up and compute the gradient of that loss. But since the computation of that gradient is rather time consuming (stiff ODE), I would like to compute a gradient for each set. And then take the mean gradient and apply it to my parameters.
Since the whole NODE part seems irrelevant for this problem, I have constructed a MWE which computes the gradients of individual batches.

using Lux
using ComponentArrays
using Random
using Zygote
using LinearAlgebra
using Optimization, OptimizationOptimisers

rng = Random.default_rng()
Random.seed!(rng,42)
Random.TaskLocalRNG()

nx = 5
ny = 2
model = Lux.Chain(
    Lux.Dense(nx,10,Lux.tanh),
    Lux.Dense(10,ny))
ps,st = Lux.setup(rng,model)
ps = ComponentArray(ps)

n_batches = 10
per_batch = 50
X_batches = rand(rng,n_batches,per_batch,nx)
Y_batches = rand(rng,n_batches,per_batch,ny)

function batch_loss(p,x,y)
    y_pred = hcat([Lux.apply(model,e,p,st)[1] for e in eachrow(x)]...)
    loss = sum(abs.(y_pred' .- y))
    return loss
end

function single_gradient(p,x,y)
    f(p)=batch_loss(p,x,y)
    return Zygote.gradient(f,p)
end

function gather_gradients(p)
    gradients = []
    for i in 1:n_batches
        push!(gradients, single_gradient(p,X_batches[i,:,:],Y_batches[i,:,:]))
    end
    # return mean gradient of all batches
end

opt = Adam(0.01)

Could someone help me finish this example on how to efficentily compute the mean gradient? And how would I pass this gradient to an Optimizer such as Adam?

I am still pretty new to Julia so any advice on how to optimise my approach and code is very welcome.
Thank you!

@avik-pal
Copy link
Member

avik-pal commented Feb 5, 2024

the gradient should be a component array so you can just mean the gradients array. See https://lux.csail.mit.edu/dev/tutorials/beginner/2_PolynomialFitting#training for training.

@avik-pal avik-pal closed this as completed Feb 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants