-
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0bd6e2e
commit 145db5e
Showing
10 changed files
with
260 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,4 @@ | ||
# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options | ||
style = "sciml" | ||
format_markdown = true | ||
format_docstrings = true | ||
annotate_untyped_fields_with_any = false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
/Manifest.toml | ||
/docs/Manifest.toml | ||
/docs/build/ | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,13 +3,48 @@ uuid = "f162e290-f571-43a6-83d9-22ecc16da15f" | |
authors = ["Sebastian Micluța-Câmpeanu <[email protected]> and contributors"] | ||
version = "1.0.0-DEV" | ||
|
||
[deps] | ||
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" | ||
Lux = "b2108857-7c20-44ae-9111-449ecde12c47" | ||
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" | ||
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" | ||
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739" | ||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" | ||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" | ||
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" | ||
|
||
[compat] | ||
Aqua = "0.8" | ||
ComponentArrays = "0.15" | ||
ForwardDiff = "0.10.36" | ||
JET = "0.8" | ||
Lux = "0.5.32" | ||
LuxCore = "0.1.14" | ||
ModelingToolkit = "9.9.0" | ||
ModelingToolkitStandardLibrary = "2.6" | ||
NNlib = "0.9" | ||
Optimization = "3.22" | ||
OptimizationOptimisers = "0.2" | ||
OrdinaryDiffEq = "6.74" | ||
Random = "1" | ||
SafeTestsets = "0.1" | ||
SciMLStructures = "1.1.0" | ||
SymbolicIndexingInterface = "0.3.15" | ||
Symbolics = "5.27" | ||
Test = "1" | ||
julia = "1.10" | ||
|
||
[extras] | ||
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" | ||
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | ||
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" | ||
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" | ||
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" | ||
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" | ||
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" | ||
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" | ||
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
||
[targets] | ||
test = ["Aqua", "JET", "Test"] | ||
test = ["Aqua", "JET", "Test", "OrdinaryDiffEq", "ForwardDiff", "Optimization", "OptimizationOptimisers", "SafeTestsets", "SciMLStructures", "SymbolicIndexingInterface"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,47 @@ | ||
module UDEComponents | ||
|
||
# Write your package code here. | ||
using ModelingToolkit: @parameters, @named, ODESystem, t_nounits | ||
using ModelingToolkitStandardLibrary.Blocks: RealInput, RealOutput | ||
using Symbolics: Symbolics, @register_array_symbolic, @wrapped | ||
using LuxCore: stateless_apply | ||
using Lux: Lux | ||
using Random: Xoshiro | ||
using NNlib: softplus | ||
using ComponentArrays: ComponentArray | ||
|
||
export create_ude_component, multi_layer_feed_forward | ||
|
||
include("utils.jl") | ||
include("hacks.jl") # this should be removed / upstreamed | ||
|
||
""" | ||
create_ude_component(n_input = 1, n_output = 1; | ||
chain = multi_layer_feed_forward(n_input, n_output), | ||
rng = Xoshiro(0)) | ||
Create an `ODESystem` with a neural network inside. | ||
""" | ||
function create_ude_component(n_input = 1, | ||
n_output = 1; | ||
chain = multi_layer_feed_forward(n_input, n_output), | ||
rng = Xoshiro(0)) | ||
lux_p, st = Lux.setup(rng, chain) | ||
ca = ComponentArray(lux_p) | ||
|
||
@parameters p[1:length(ca)] = Vector(ca) | ||
@parameters T::typeof(typeof(p))=typeof(p) [tunable = false] | ||
|
||
@named input = RealInput(nin = n_input) | ||
@named output = RealOutput(nout = n_output) | ||
|
||
out = stateless_apply(chain, input.u, lazyconvert(typeof(ca), p)) | ||
|
||
eqs = [output.u ~ out] | ||
|
||
@named ude_comp = ODESystem( | ||
eqs, t_nounits, [], [p, T], systems = [input, output]) | ||
return ude_comp | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
lazyconvert(x, y) = convert(x, y) | ||
lazyconvert(x, y::Symbolics.Arr) = Symbolics.array_term(convert, x, y) | ||
Symbolics.propagate_ndims(::typeof(convert), x, y) = ndims(y) | ||
Symbolics.propagate_shape(::typeof(convert), x, y) = Symbolics.shape(y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
function multi_layer_feed_forward(input_length, output_length; width::Int = 5, | ||
depth::Int = 1, activation = softplus) | ||
Lux.Chain(Lux.Dense(input_length, width, activation), | ||
[Lux.Dense(width, width, activation) for _ in 1:(depth)]..., | ||
Lux.Dense(width, output_length); disable_optimizations = true) | ||
end | ||
|
||
# Symbolics.@register_array_symbolic print_input(x) begin | ||
# size = size(x) | ||
# eltype = eltype(x) | ||
# end | ||
|
||
# function print_input(x) | ||
# @info x | ||
# x | ||
# end | ||
|
||
# function debug_component(n_input, n_output) | ||
# @named input = RealInput(nin = n_input) | ||
# @named output = RealOutput(nout = n_output) | ||
|
||
# eqs = [output.u ~ print_input(input.u)] | ||
|
||
# @named dbg_comp = ODESystem(eqs, t_nounits, [], [], systems = [input, output]) | ||
# end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
using Test | ||
using JET | ||
using UDEComponents | ||
using ModelingToolkit | ||
using ModelingToolkitStandardLibrary.Blocks | ||
using OrdinaryDiffEq | ||
using SymbolicIndexingInterface | ||
using Optimization | ||
using OptimizationOptimisers: Adam | ||
using SciMLStructures | ||
using SciMLStructures: Tunable | ||
using ForwardDiff | ||
|
||
function lotka_ude() | ||
@variables t x(t)=3.1 y(t)=1.5 | ||
@parameters α=1.3 β=0.9 γ=0.8 δ=1.8 | ||
Dt = ModelingToolkit.D_nounits | ||
@named nn_in = RealInput(nin = 2) | ||
@named nn_out = RealOutput(nout = 2) | ||
|
||
eqs = [ | ||
Dt(x) ~ α * x + nn_in.u[1], | ||
Dt(y) ~ -δ * y + nn_in.u[2], | ||
nn_out.u[1] ~ x, | ||
nn_out.u[2] ~ y | ||
] | ||
return ODESystem( | ||
eqs, ModelingToolkit.t_nounits, name = :lotka, systems = [nn_in, nn_out]) | ||
end | ||
|
||
function lotka_true() | ||
@variables t x(t)=3.1 y(t)=1.5 | ||
@parameters α=1.3 β=0.9 γ=0.8 δ=1.8 | ||
Dt = ModelingToolkit.D_nounits | ||
|
||
eqs = [ | ||
Dt(x) ~ α * x - β * x * y, | ||
Dt(y) ~ -δ * y + δ * x * y | ||
] | ||
return ODESystem(eqs, ModelingToolkit.t_nounits, name = :lotka_true) | ||
end | ||
|
||
model = lotka_ude() | ||
nn = create_ude_component(2, 2) | ||
|
||
eqs = [ | ||
connect(model.nn_in, nn.output) | ||
connect(model.nn_out, nn.input) | ||
] | ||
|
||
ude_sys = complete(ODESystem( | ||
eqs, ModelingToolkit.t_nounits, systems = [model, nn], name = :ude_sys)) | ||
|
||
sys = structural_simplify(ude_sys) | ||
|
||
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys, [], (0, 1.0), []) | ||
|
||
model_true = structural_simplify(lotka_true()) | ||
prob_true = ODEProblem{true, SciMLBase.FullSpecialize}(model_true, [], (0, 1.0), []) | ||
sol_ref = solve(prob_true, Rodas4()) | ||
|
||
x0 = reduce(vcat, getindex.((default_values(sys),), tunable_parameters(sys))) | ||
|
||
get_vars = getu(sys, [sys.lotka.x, sys.lotka.y]) | ||
get_refs = getu(model_true, [model_true.x, model_true.y]) | ||
|
||
function loss(x, (prob, sol_ref, get_vars, get_refs)) | ||
new_p = SciMLStructures.replace(Tunable(), prob.p, x) | ||
new_prob = remake(prob, p = new_p) | ||
ts = sol_ref.t | ||
new_sol = solve(new_prob, Rodas4(), saveat = ts) | ||
|
||
loss = zero(eltype(x)) | ||
|
||
for i in eachindex(new_sol.u) | ||
loss += sum(sqrt.(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i)))) | ||
end | ||
|
||
if SciMLBase.successful_retcode(new_sol) | ||
loss | ||
else | ||
Inf | ||
end | ||
end | ||
|
||
|
||
of = OptimizationFunction{true}(loss, AutoForwardDiff()) | ||
|
||
ps = (prob, sol_ref, get_vars, get_refs); | ||
|
||
@test_call target_modules=(UDEComponents,) loss(x0, ps) | ||
@test_opt target_modules=(UDEComponents,) loss(x0, ps) | ||
|
||
@test all(.!isnan.(ForwardDiff.gradient(Base.Fix2(of, ps), x0))) | ||
|
||
op = OptimizationProblem(of, x0, (prob, sol_ref, get_vars, get_refs)) | ||
|
||
|
||
# using Plots | ||
|
||
# oh = [] | ||
|
||
plot_cb = (opt_state, loss) -> begin | ||
@info "step $(opt_state.iter), loss: $loss" | ||
# push!(oh, opt_state) | ||
# new_p = SciMLStructures.replace(Tunable(), prob.p, opt_state.u) | ||
# new_prob = remake(prob, p = new_p) | ||
# sol = solve(new_prob, Rodas4()) | ||
# display(plot(sol)) | ||
false | ||
end | ||
|
||
res = solve(op, Adam(), maxiters = 2000)#, callback = plot_cb) | ||
|
||
@test res.objective < 1 | ||
|
||
res_p = SciMLStructures.replace(Tunable(), prob.p, res) | ||
res_prob = remake(prob, p = res_p) | ||
res_sol = solve(res_prob, Rodas4()) | ||
|
||
@test SciMLBase.successful_retcode(res_sol) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
using Test | ||
using UDEComponents | ||
using Aqua | ||
using JET | ||
|
||
@testset verbose = true "Code quality (Aqua.jl)" begin | ||
Aqua.find_persistent_tasks_deps(UDEComponents) | ||
Aqua.test_ambiguities(UDEComponents, recursive = false) | ||
Aqua.test_deps_compat(UDEComponents) | ||
# TODO: fix type piracy in propagate_ndims and propagate_shape | ||
Aqua.test_piracies(UDEComponents, broken=true) | ||
Aqua.test_project_extras(UDEComponents) | ||
Aqua.test_stale_deps(UDEComponents, ignore = Symbol[]) | ||
Aqua.test_unbound_args(UDEComponents) | ||
Aqua.test_undefined_exports(UDEComponents) | ||
end | ||
|
||
@testset "Code linting (JET.jl)" begin | ||
JET.test_package(UDEComponents; target_defined_modules = true) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,8 @@ | ||
using UDEComponents | ||
using Test | ||
using Aqua | ||
using JET | ||
using SafeTestsets | ||
|
||
@testset "UDEComponents.jl" begin | ||
@testset "Code quality (Aqua.jl)" begin | ||
Aqua.test_all(UDEComponents) | ||
end | ||
@testset "Code linting (JET.jl)" begin | ||
JET.test_package(UDEComponents; target_defined_modules = true) | ||
end | ||
# Write your tests here. | ||
@testset verbose=true "UDEComponents.jl" begin | ||
@safetestset "QA" include("qa.jl") | ||
@safetestset "Basic" include("lotka_volterra.jl") | ||
end |