Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Flux Support #789

Merged
merged 17 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ DiffEqNoiseProcess = "77a26b50-5914-5dd7-bc55-306e6241c503"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
Expand All @@ -26,8 +25,8 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -53,7 +52,7 @@ DiffEqNoiseProcess = "5.1"
Distributions = "0.23, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
DomainSets = "0.6, 0.7"
Flux = "0.13, 0.14"
Flux = "0.14"
ForwardDiff = "0.10"
Functors = "0.4"
Integrals = "4"
Expand All @@ -64,8 +63,8 @@ MCMCChains = "6"
ModelingToolkit = "8"
MonteCarloMeasurements = "1"
Optim = "1.7.8"
Optimisers = "0.2, 0.3"
Optimization = "3"
OptimizationOptimisers = "0.1"
QuasiMonteCarlo = "0.3.2"
Reexport = "1.0"
RuntimeGeneratedFunctions = "0.5"
Expand All @@ -80,15 +79,15 @@ julia = "1.6"

[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[targets]
test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimisers", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "cuDNN", "LuxCUDA"]
test = ["Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "cuDNN", "LuxCUDA", "Flux"]
47 changes: 20 additions & 27 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# HIGH level API for BPINN ODE solver

"""
```julia
BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric),
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
progress = false, verbose = false)
```
BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
priorsNNw = (0.0, 2.0), param = [nothing], l2std = [0.05],
phystd = [0.05], dataset = [nothing], physdt = 1 / 20.0,
MCMCargs = (n_leapfrog=30), nchains = 1, init_params = nothing,
Adaptorkwargs = (Adaptor = StanHMCAdaptor, targetacceptancerate = 0.8, Metric = DiagEuclideanMetric),
Integratorkwargs = (Integrator = Leapfrog,), autodiff = false,
progress = false, verbose = false)

Algorithm for solving ordinary differential equations using a Bayesian neural network. This is a specialization
of the physics-informed neural network which is used as a solver for a standard `ODEProblem`.
Expand All @@ -22,7 +20,7 @@ of the physics-informed neural network which is used as a solver for a standard

## Positional Arguments

* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a `Lux.AbstractExplicitLayer`.
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer`.
* `Kernel`: Choice of MCMC Sampling Algorithm. Defaults to `AdvancedHMC.HMC`

## Keyword Arguments
Expand All @@ -46,18 +44,18 @@ dataset = [x̂, time]

chainlux = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1))

alg = NeuralPDE.BNNODE(chainlux, draw_samples = 2000,
alg = BNNODE(chainlux, draw_samples = 2000,
l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 3.0), progress = true)

sol_lux = solve(prob, alg)

# with parameter estimation
alg = NeuralPDE.BNNODE(chainlux,dataset = dataset,
draw_samples = 2000,l2std = [0.05],
phystd = [0.05],priorsNNw = (0.0, 10.0),
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
progress = true)
alg = BNNODE(chainlux,dataset = dataset,
draw_samples = 2000,l2std = [0.05],
phystd = [0.05],priorsNNw = (0.0, 10.0),
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
progress = true)

sol_lux_pestim = solve(prob, alg)
```
Expand All @@ -74,11 +72,10 @@ is an accurate interpolation (up to the neural network training result). In addi
## References

Liu Yanga, Xuhui Menga, George Em Karniadakis. "B-PINNs: Bayesian Physics-Informed Neural Networks for
Forward and Inverse PDE Problems with Noisy Data"

Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, Ellen Kuhl.
"Bayesian Physics Informed Neural Networks for real-world nonlinear dynamical systems"
Forward and Inverse PDE Problems with Noisy Data".

Kevin Linka, Amelie Schäfer, Xuhui Meng, Zongren Zou, George Em Karniadakis, Ellen Kuhl
"Bayesian Physics Informed Neural Networks for real-world nonlinear dynamical systems".
"""
struct BNNODE{C, K, IT <: NamedTuple,
A <: NamedTuple, H <: NamedTuple,
Expand Down Expand Up @@ -116,6 +113,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
BNNODE(chain, Kernel, strategy,
draw_samples, priorsNNw, param, l2std,
phystd, dataset, physdt, MCMCkwargs,
Expand Down Expand Up @@ -222,13 +220,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
elseif chain isa Flux.Chain
θinit, re1 = Flux.destructure(chain)
out = re1.([samples[i][1:(end - ninv)]
for i in (draw_samples - numensemble):draw_samples])
luxar = collect(out[i](t') for i in eachindex(out))
else
throw(error("Only Lux.AbstractExplicitLayer and Flux.Chain neural networks are supported"))
throw(error("Only Lux.AbstractExplicitLayer neural networks are supported"))
end

# contructing ensemble predictions
Expand Down Expand Up @@ -272,4 +265,4 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
end

BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params, t)
end
end
7 changes: 3 additions & 4 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Reexport, Statistics
using Zygote, ForwardDiff, Random, Distributions
using Adapt, DiffEqNoiseProcess, StochasticDiffEq
using Optimization
using OptimizationOptimisers
using Integrals, Cubature
using QuasiMonteCarlo
using RuntimeGeneratedFunctions
Expand All @@ -24,14 +25,12 @@ using Symbolics: wrap, unwrap, arguments, operation
using SymbolicUtils
using AdvancedHMC, LogDensityProblems, LinearAlgebra, Functors, MCMCChains
using MonteCarloMeasurements

import ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives
import DomainSets: Domain, ClosedInterval
import ModelingToolkit: Interval, infimum, supremum #,Ball
import SciMLBase: @add_kwonly, parameterless_type
import Optimisers
import UnPack: @unpack
import ChainRulesCore, Flux, Lux, ComponentArrays
import ChainRulesCore, Lux, ComponentArrays
import ChainRulesCore: @non_differentiable

RuntimeGeneratedFunctions.init(@__MODULE__)
Expand All @@ -45,7 +44,7 @@ include("symbolic_utilities.jl")
include("training_strategies.jl")
include("adaptive_losses.jl")
include("ode_solve.jl")
include("rode_solve.jl")
# include("rode_solve.jl")
include("dae_solve.jl")
include("transform_inf_integral.jl")
include("discretize.jl")
Expand Down
Loading
Loading