Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to build parameter optimization problems for Lux.jl using ModelingToolkit.jl ? #2604

Closed
chooron opened this issue Apr 3, 2024 · 2 comments
Labels
question Further information is requested

Comments

@chooron
Copy link

chooron commented Apr 3, 2024

Question❓
I have built a simple neural network model using Lux.jl, as shown below:

using ModelingToolkit
using Optimization, OptimizationOptimisers
using Lux
using Random

model = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 16, leakyrelu), Lux.Dense(16, 1, leakyrelu), name=:model)
rng = MersenneTwister()
Random.seed!(rng, 42)
ps, st = Lux.setup(rng, model)
time = 1:100
x1 = sin.(time)
x2 = cos.(time)
x3 = tan.(time)
x = hcat(x1, x2, x3)'
y = @.(2 * x1 + 3 * x2 - 5 * x3) + rand(100)
y_hat = vec(model(x, ps, st)[1])

Then I would like to translate the problem of optimizing model parameters into ModelingToolkit.jl, as follows:

@variables layer1_w[16, 3], layer1_b[16, 1]
@variables layer2_w[16, 16], layer2_b[16, 1]
@variables layer3_w[1, 16], layer3_b[1, 1]

function loss_func(p, _)
    y_hat = vec(model(x, p, st)[1])
    sum(abs(y .- y_hat))
end

sys = OptimizationSystem(loss_func, [layer1_w...; layer1_b...; layer2_w...; layer2_b...; layer3_w...; layer3_b...], [], name=:sys)
sys = complete(sys)
u0 = [
    layer1_w => ps[:layer_1][:weight],
    layer1_b => ps[:layer_1][:bias],
    layer2_w => ps[:layer_2][:weight],
    layer2_b => ps[:layer_2][:bias],
    layer3_w => ps[:layer_3][:weight],
    layer3_b => ps[:layer_3][:bias],
]
prob = OptimizationProblem(sys, u0)
solve(prob, GradientDescent())

But this approach didn't work. Are there any examples of ModelingToolkit.jl being applied to Lux.jl?

@chooron chooron added the question Further information is requested label Apr 3, 2024
@chooron
Copy link
Author

chooron commented Apr 3, 2024

In fact, what I'm more interested in is using Lux.jl models within an ODE Problem, which might align with DiffEqFlux.jl. I've constructed a test case as follows:

using ModelingToolkit
using Optimization, OptimizationOptimisers
using Lux
using Random

model = Lux.Chain(Lux.Dense(3, 16, tanh), Lux.Dense(16, 16, leakyrelu), Lux.Dense(16, 1, leakyrelu), name=:model)
rng = MersenneTwister()
Random.seed!(rng, 42)
ps, st = Lux.setup(rng, model)
time = 1:100
x1 = sin.(time)
x2 = cos.(time)
x3 = tan.(time)
x = hcat(x1, x2, x3)'
y = @.(2 * x1 + 3 * x2 - 5 * x3) + rand(100)
y_hat = vec(model([1,2,3], ps, st)[1])

D = Differential
@variables t water(t)
@parameters layer1_w[1:16, 1:3], layer1_b[1:16, 1:1]
@parameters layer2_w[1:16, 1:16], layer2_b[1:16, 1:1]
@parameters layer3_w[1:1, 1:16], layer3_b[1:1, 1:1]

p_tuple = (layer_1=(weight=layer1_w, bias=layer1_b), layer_2=(weight=layer2_w, bias=layer2_b), layer_3=(weight=layer3_w, bias=layer3_b))
eqs = [D(water) ~ (model([sin(t), cos(t), tan(t)], p_tuple, st)[1])[1]]
u0 = [water => 0.0]
p = [
    layer1_w => ps[:layer_1][:weight],
    layer1_b => ps[:layer_1][:bias],
    layer2_w => ps[:layer_2][:weight],
    layer2_b => ps[:layer_2][:bias],
    layer3_w => ps[:layer_3][:weight],
    layer3_b => ps[:layer_3][:bias],
]
sys = ODESystem(eqs, t, name=:sys)
sys = structural_simplify(sys)
prob = ODEProblem(sys, u0, (0.0, 100.0), p)

但这种方式同样存在一些问题,如system不包含任何states和parameters:

Model sys with 1 equations
Unknowns (0):
Parameters (0):

同时当调用structural_simplify(sys)时,出现错误:

ERROR: MethodError: no method matching -(::SymbolicUtils.BasicSymbolic{Real}, ::Differential)

Closest candidates are:
  -(::ChainRulesCore.ZeroTangent, ::Any)
   @ ChainRulesCore D:\Julia\Julia-1.10.0\packages\packages\ChainRulesCore\zgT0R\src\tangent_arithmetic.jl:101
  -(::Any, ::ChainRulesCore.NoTangent)
   @ ChainRulesCore D:\Julia\Julia-1.10.0\packages\packages\ChainRulesCore\zgT0R\src\tangent_arithmetic.jl:62
  -(::MutableArithmetics.Zero, ::Any)
   @ MutableArithmetics D:\Julia\Julia-1.10.0\packages\packages\MutableArithmetics\xVyia\src\rewrite.jl:63
  ...

Stacktrace:
 [1] TearingState(sys::ODESystem; quick_cancel::Bool, check::Bool)
   @ ModelingToolkit D:\Julia\Julia-1.10.0\packages\packages\ModelingToolkit\hPXD5\src\systems\systemstructure.jl:284
 [2] TearingState(sys::ODESystem)
   @ ModelingToolkit D:\Julia\Julia-1.10.0\packages\packages\ModelingToolkit\hPXD5\src\systems\systemstructure.jl:250
 [3] __structural_simplify(sys::ODESystem, io::Nothing; simplify::Bool, kwargs::@Kwargs{})
   @ ModelingToolkit D:\Julia\Julia-1.10.0\packages\packages\ModelingToolkit\hPXD5\src\systems\systems.jl:60
 [4] __structural_simplify
   @ ModelingToolkit D:\Julia\Julia-1.10.0\packages\packages\ModelingToolkit\hPXD5\src\systems\systems.jl:57 [inlined]
 [5] structural_simplify(sys::ODESystem, io::Nothing; simplify::Bool, split::Bool, kwargs::@Kwargs{})
   @ ModelingToolkit D:\Julia\Julia-1.10.0\packages\packages\ModelingToolkit\hPXD5\src\systems\systems.jl:22
 [6] structural_simplify
   @ ModelingToolkit D:\Julia\Julia-1.10.0\packages\packages\ModelingToolkit\hPXD5\src\systems\systems.jl:19 [inlined]
 [7] structural_simplify(sys::ODESystem)
   @ ModelingToolkit D:\Julia\Julia-1.10.0\packages\packages\ModelingToolkit\hPXD5\src\systems\systems.jl:19

@ChrisRackauckas
Copy link
Member

The first tutorials are coming together in SciML/ModelingToolkitNeuralNets.jl#19, let's take the conversations to there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants