Skip to content

Commit

Permalink
refactor: remove Flux support
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jan 23, 2024
1 parent 9321766 commit a8492f7
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 345 deletions.
21 changes: 8 additions & 13 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ of the physics-informed neural network which is used as a solver for a standard
## Positional Arguments
* `chain`: A neural network architecture, defined as either a `Flux.Chain` or a `Lux.AbstractExplicitLayer`.
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer`.
* `Kernel`: Choice of MCMC Sampling Algorithm. Defaults to `AdvancedHMC.HMC`
## Keyword Arguments
Expand All @@ -46,18 +46,18 @@ dataset = [x̂, time]
chainlux = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 6, tanh), Lux.Dense(6, 1))
alg = NeuralPDE.BNNODE(chainlux, draw_samples = 2000,
alg = BNNODE(chainlux, draw_samples = 2000,
l2std = [0.05], phystd = [0.05],
priorsNNw = (0.0, 3.0), progress = true)
sol_lux = solve(prob, alg)
# with parameter estimation
alg = NeuralPDE.BNNODE(chainlux,dataset = dataset,
draw_samples = 2000,l2std = [0.05],
phystd = [0.05],priorsNNw = (0.0, 10.0),
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
progress = true)
alg = BNNODE(chainlux,dataset = dataset,
draw_samples = 2000,l2std = [0.05],
phystd = [0.05],priorsNNw = (0.0, 10.0),
param = [Normal(6.5, 0.5), Normal(-3, 0.5)],
progress = true)
sol_lux_pestim = solve(prob, alg)
```
Expand Down Expand Up @@ -222,13 +222,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
luxar = [chain(t', θ[i], st)[1] for i in 1:numensemble]
# only need for size
θinit = collect(ComponentArrays.ComponentArray(θinit))
elseif chain isa Flux.Chain
θinit, re1 = Flux.destructure(chain)
out = re1.([samples[i][1:(end - ninv)]
for i in (draw_samples - numensemble):draw_samples])
luxar = collect(out[i](t') for i in eachindex(out))
else
throw(error("Only Lux.AbstractExplicitLayer and Flux.Chain neural networks are supported"))
throw(error("Only Lux.AbstractExplicitLayer neural networks are supported"))
end

# contructing ensemble predictions
Expand Down
5 changes: 3 additions & 2 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Reexport, Statistics
using Zygote, ForwardDiff, Random, Distributions
using Adapt, DiffEqNoiseProcess, StochasticDiffEq
using Optimization
using OptimizationOptimisers
using Integrals, Cubature
using QuasiMonteCarlo
using RuntimeGeneratedFunctions
Expand All @@ -29,7 +30,7 @@ import DomainSets: Domain, ClosedInterval
import ModelingToolkit: Interval, infimum, supremum #,Ball
import SciMLBase: @add_kwonly, parameterless_type
import UnPack: @unpack
import ChainRulesCore, Flux, Lux, ComponentArrays
import ChainRulesCore, Lux, ComponentArrays
import ChainRulesCore: @non_differentiable

RuntimeGeneratedFunctions.init(@__MODULE__)
Expand All @@ -43,7 +44,7 @@ include("symbolic_utilities.jl")
include("training_strategies.jl")
include("adaptive_losses.jl")
include("ode_solve.jl")
include("rode_solve.jl")
# include("rode_solve.jl")
include("transform_inf_integral.jl")
include("discretize.jl")
include("neural_adapter.jl")
Expand Down
169 changes: 52 additions & 117 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,36 +79,22 @@ function setparameters(Tar::PDELogTargetDensity, θ)
ps_new = θ[1:(end - Tar.extraparams)]
ps = Tar.init_params

if (ps[names[1]] isa ComponentArrays.ComponentVector)
# multioutput case for Lux chains, for each depvar ps would contain Lux ComponentVectors
# which we use for mapping current ahmc sampled vector of parameters onto NNs
# multioutput case for Lux chains, for each depvar ps would contain Lux ComponentVectors
# which we use for mapping current ahmc sampled vector of parameters onto NNs
i = 0
Luxparams = [vector_to_parameters(ps_new[((i += length(ps[x])) - length(ps[x]) + 1):i],
ps[x]) for x in names]

i = 0
Luxparams = [vector_to_parameters(ps_new[((i += length(ps[x])) - length(ps[x]) + 1):i],
ps[x]) for x in names]
a = ComponentArrays.ComponentArray(NamedTuple{Tar.names}(i for i in Luxparams))

if Tar.extraparams > 0
b = θ[(end - Tar.extraparams + 1):end]
return ComponentArrays.ComponentArray(;
depvar = a,
p = b)
else
# multioutput Flux
Luxparams = θ
end

if (Luxparams isa AbstractVector) && (Luxparams[1] isa ComponentArrays.ComponentVector)
# multioutput Lux
a = ComponentArrays.ComponentArray(NamedTuple{Tar.names}(i for i in Luxparams))

if Tar.extraparams > 0
b = θ[(end - Tar.extraparams + 1):end]

return ComponentArrays.ComponentArray(;
depvar = a,
p = b)
else
return ComponentArrays.ComponentArray(;
depvar = a)
end
else
# multioutput fLux case
return vector_to_parameters(Luxparams, ps)
return ComponentArrays.ComponentArray(;
depvar = a)
end
end

Expand Down Expand Up @@ -138,33 +124,18 @@ function L2LossData(Tar::PDELogTargetDensity, θ)
# dataset[i][:, 1] -> depvar col of depvar's dataset

if Tar.extraparams > 0
if Tar.init_params isa ComponentArrays.ComponentVector
for i in eachindex(Φ)
sumt += logpdf(MvNormal(Φ[i](dataset[i][:, 2:end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
init_params)[Tar.names[i]])[1,
:],
LinearAlgebra.Diagonal(abs2.(ones(size(dataset[i])[1]) .*
L2stds[i]))),
dataset[i][:, 1])
end
sumt
else
# Flux case needs subindexing wrt Tar.names indices(hence stored in Tar.names)
for i in eachindex(Φ)
sumt += logpdf(MvNormal(Φ[i](dataset[i][:, 2:end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
init_params)[Tar.names[2][i]])[1,
:],
LinearAlgebra.Diagonal(abs2.(ones(size(dataset[i])[1]) .*
L2stds[i]))),
dataset[i][:, 1])
end
sumt
for i in eachindex(Φ)
sumt += logpdf(MvNormal(Φ[i](dataset[i][:, 2:end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
init_params)[Tar.names[i]])[1,
:],
LinearAlgebra.Diagonal(abs2.(ones(size(dataset[i])[1]) .*
L2stds[i]))),
dataset[i][:, 1])
end
else
return 0
return sumt
end
return 0
end

# priors for NN parameters + ODE constants
Expand All @@ -182,10 +153,9 @@ function priorlogpdf(Tar::PDELogTargetDensity, θ)

return (invlogpdf
+
logpdf(nnwparams, θ[1:(length(θ) - Tar.extraparams)]))
else
return logpdf(nnwparams, θ)
logpdf(nnwparams, θ[1:(length(θ) - Tar.extraparams)]))
end
return logpdf(nnwparams, θ)
end

function integratorchoice(Integratorkwargs, initial_ϵ)
Expand Down Expand Up @@ -243,56 +213,34 @@ function inference(samples, pinnrep, saveats, numensemble, ℓπ)
for i in (nnparams + 1):(nnparams + ninv)]
end

# names is an indicator of type of chain
if names[1] != 1
# getting parameter ranges in case of Lux chains
Luxparams = []
i = 0
for x in names
len = length(initial_nnθ[x])
push!(Luxparams, (i + 1):(i + len))
i += len
end

# convert to format directly usable by lux
estimatedLuxparams = [vector_to_parameters(estimnnparams[Luxparams[i]],
initial_nnθ[names[i]]) for i in eachindex(phi)]

# infer predictions(preds) each row - NN, each col - ith sample
samplesn = reduce(hcat, samples)
preds = []
for j in eachindex(phi)
push!(preds,
[phi[j](timepoints[j],
vector_to_parameters(samplesn[:, i][Luxparams[j]],
initial_nnθ[names[j]])) for i in 1:numensemble])
end

# note here no of samples referse to numensemble and points is the no of points in each dep_vars discretization
# each phi will give output in single domain of depvar(so we have each row as a vector of vector outputs)
# so we get after reduce a single matrix of n rows(samples), and j cols(points)
ensemblecurves = [Particles(reduce(vcat, preds[i])) for i in eachindex(phi)]

return ensemblecurves, estimatedLuxparams, estimated_params, timepoints
else
# get intervals for parameters corresponding to flux chains
Fluxparams = names[2]

# convert to format directly usable by Flux
estimatedFluxparams = [estimnnparams[Fluxparams[i]] for i in eachindex(phi)]

# infer predictions(preds) each row - NN, each col - ith sample
samplesn = reduce(hcat, samples)
preds = []
for j in eachindex(phi)
push!(preds,
[phi[j](timepoints[j], samplesn[:, i][Fluxparams[j]]) for i in 1:numensemble])
end

ensemblecurves = [Particles(reduce(vcat, preds[i])) for i in eachindex(phi)]
# getting parameter ranges in case of Lux chains
Luxparams = []
i = 0
for x in names
len = length(initial_nnθ[x])
push!(Luxparams, (i + 1):(i + len))
i += len
end

return ensemblecurves, estimatedFluxparams, estimated_params, timepoints
# convert to format directly usable by lux
estimatedLuxparams = [vector_to_parameters(estimnnparams[Luxparams[i]],
initial_nnθ[names[i]]) for i in eachindex(phi)]

# infer predictions(preds) each row - NN, each col - ith sample
samplesn = reduce(hcat, samples)
preds = []
for j in eachindex(phi)
push!(preds,
[phi[j](timepoints[j],
vector_to_parameters(samplesn[:, i][Luxparams[j]],
initial_nnθ[names[j]])) for i in 1:numensemble])
end

# note here no of samples referse to numensemble and points is the no of points in each dep_vars discretization
# each phi will give output in single domain of depvar(so we have each row as a vector of vector outputs)
# so we get after reduce a single matrix of n rows(samples), and j cols(points)
ensemblecurves = [Particles(reduce(vcat, preds[i])) for i in eachindex(phi)]
return ensemblecurves, estimatedLuxparams, estimated_params, timepoints
end

"""
Expand Down Expand Up @@ -396,20 +344,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
# contains only NN parameters
initial_nnθ = pinnrep.init_params

if (discretization.multioutput && chain[1] isa Lux.AbstractExplicitLayer)
# converting vector of parameters to ComponentArray for runtimegenerated functions
names = ntuple(i -> pinnrep.depvars[i], length(chain))
else
# Flux multioutput
i = 0
temp = []
for j in eachindex(initial_nnθ)
len = length(initial_nnθ[j])
push!(temp, (i + 1):(i + len))
i += len
end
names = tuple(1, temp)
end
names = ntuple(i -> pinnrep.depvars[i], length(chain))

#ode parameter estimation
nparameters = length(initial_θ)
Expand Down
32 changes: 14 additions & 18 deletions src/adaptive_losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ function vectorify(x, t::Type{T}) where {T <: Real}
end

# Dispatches

"""
```julia
NonAdaptiveLoss{T}(; pde_loss_weights = 1,
Expand Down Expand Up @@ -159,8 +158,8 @@ end
"""
```julia
function MiniMaxAdaptiveLoss(reweight_every;
pde_max_optimiser = Flux.ADAM(1e-4),
bc_max_optimiser = Flux.ADAM(0.5),
pde_max_optimiser = OptimizationOptimisers.Adam(1e-4),
bc_max_optimiser = OptimizationOptimisers.Adam(0.5),
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1)
Expand All @@ -178,9 +177,9 @@ where loss functions that have not been satisfied get a greater weight,
## Keyword Arguments
* `pde_max_optimiser`: a Flux.Optimise.AbstractOptimiser that is used internally to
* `pde_max_optimiser`: a OptimizationOptimisers optimiser that is used internally to
maximize the weights of the PDE loss functions.
* `bc_max_optimiser`: a Flux.Optimise.AbstractOptimiser that is used internally to maximize
* `bc_max_optimiser`: a OptimizationOptimisers optimiser that is used internally to maximize
the weights of the BC loss functions.
## References
Expand All @@ -190,8 +189,8 @@ Levi McClenny, Ulisses Braga-Neto
https://arxiv.org/abs/2009.04544
"""
mutable struct MiniMaxAdaptiveLoss{T <: Real,
PDE_OPT <: Flux.Optimise.AbstractOptimiser,
BC_OPT <: Flux.Optimise.AbstractOptimiser} <:
PDE_OPT,
BC_OPT} <:
AbstractAdaptiveLoss
reweight_every::Int64
pde_max_optimiser::PDE_OPT
Expand All @@ -201,17 +200,15 @@ mutable struct MiniMaxAdaptiveLoss{T <: Real,
additional_loss_weights::Vector{T}
SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss{T,
PDE_OPT, BC_OPT}(reweight_every;
pde_max_optimiser = Flux.ADAM(1e-4),
bc_max_optimiser = Flux.ADAM(0.5),
pde_max_optimiser = OptimizationOptimisers.Adam(1e-4),
bc_max_optimiser = OptimizationOptimisers.Adam(0.5),
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1) where {
T <:
Real,
PDE_OPT <:
Flux.Optimise.AbstractOptimiser,
BC_OPT <:
Flux.Optimise.AbstractOptimiser
PDE_OPT,
BC_OPT
}
new(convert(Int64, reweight_every), convert(PDE_OPT, pde_max_optimiser),
convert(BC_OPT, bc_max_optimiser),
Expand All @@ -222,8 +219,8 @@ end

# default to Float64, ADAM, ADAM
SciMLBase.@add_kwonly function MiniMaxAdaptiveLoss(reweight_every;
pde_max_optimiser = Flux.ADAM(1e-4),
bc_max_optimiser = Flux.ADAM(0.5),
pde_max_optimiser = OptimizationOptimisers.Adam(1e-4),
bc_max_optimiser = OptimizationOptimisers.Adam(0.5),
pde_loss_weights = 1,
bc_loss_weights = 1,
additional_loss_weights = 1)
Expand All @@ -245,9 +242,8 @@ function generate_adaptive_loss_function(pinnrep::PINNRepresentation,

function run_minimax_adaptive_loss(θ, pde_losses, bc_losses)
if iteration[1] % adaloss.reweight_every == 0
Flux.Optimise.update!(pde_max_optimiser, adaloss.pde_loss_weights,
-pde_losses)
Flux.Optimise.update!(bc_max_optimiser, adaloss.bc_loss_weights, -bc_losses)
OptimizationOptimisers.Optimisers.update(pde_max_optimiser, adaloss.pde_loss_weights, -pde_losses)
OptimizationOptimisers.Optimisers.update(bc_max_optimiser, adaloss.bc_loss_weights, -bc_losses)
logvector(pinnrep.logger, adaloss.pde_loss_weights,
"adaptive_loss/pde_loss_weights", iteration[1])
logvector(pinnrep.logger, adaloss.bc_loss_weights,
Expand Down
Loading

0 comments on commit a8492f7

Please sign in to comment.