Skip to content

Commit

Permalink
Merge pull request #745 from AstitvaAggarwal/Bpinn_pde
Browse files Browse the repository at this point in the history
BPINN PDE solver
  • Loading branch information
ChrisRackauckas authored Jan 7, 2024
2 parents 9f191f8 + 8008f3a commit 0687aaf
Show file tree
Hide file tree
Showing 13 changed files with 1,267 additions and 111 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ jobs:
matrix:
group:
- ODEBPINN
- PDEBPINN
- NNPDE1
- NNPDE2
- AdaptiveLoss
Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/Lotka_Volterra_BPINNs.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ plot!(solution, labels = ["true x" "true y"])
We can see the estimated ODE parameters by -

```@example bpinn
sol_pestim.estimated_ode_params
sol_pestim.estimated_de_params
```

We can see it is close to the true values of the parameters.
29 changes: 16 additions & 13 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,24 @@ end
BPINN Solution contains the original solution from AdvancedHMC.jl sampling(BPINNstats contains fields related to that)
> ensemblesol is the Probabilistic Estimate(MonteCarloMeasurements.jl Particles type) of Ensemble solution from All Neural Network's(made using all sampled parameters) output's.
> estimated_nn_params - Probabilistic Estimate of NN params from sampled weights,biases
> estimated_ode_params - Probabilistic Estimate of ODE params from sampled unknown ode paramters
> estimated_de_params - Probabilistic Estimate of DE params from sampled unknown DE paramters
"""
struct BPINNsolution{O <: BPINNstats, E,
NP <: Vector{<:MonteCarloMeasurements.Particles{<:Float64}},
OP <: Union{Vector{Nothing},
Vector{<:MonteCarloMeasurements.Particles{<:Float64}}}}

struct BPINNsolution{O <: BPINNstats, E, NP, OP, P}
original::O
ensemblesol::E
estimated_nn_params::NP
estimated_ode_params::OP

function BPINNsolution(original, ensemblesol, estimated_nn_params, estimated_ode_params)
estimated_de_params::OP
timepoints::P

function BPINNsolution(original,
ensemblesol,
estimated_nn_params,
estimated_de_params,
timepoints)
new{typeof(original), typeof(ensemblesol), typeof(estimated_nn_params),
typeof(estimated_ode_params)}(original, ensemblesol, estimated_nn_params,
estimated_ode_params)
typeof(estimated_de_params), typeof(timepoints)}(original, ensemblesol, estimated_nn_params,
estimated_de_params, timepoints)
end
end

Expand Down Expand Up @@ -260,14 +263,14 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem,
end

nnparams = length(θinit)
estimnnparams = [Particles(reduce(hcat, samples)[i, :]) for i in 1:nnparams]
estimnnparams = [Particles(reduce(hcat, samples[(end - numensemble):end])[i, :]) for i in 1:nnparams]

if ninv == 0
estimated_params = [nothing]
else
estimated_params = [Particles(reduce(hcat, samples[(end - ninv + 1):end])[i, :])
estimated_params = [Particles(reduce(hcat, samples[(end - numensemble):end])[i, :])
for i in (nnparams + 1):(nnparams + ninv)]
end

BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params)
BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params, t)
end
26 changes: 14 additions & 12 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,21 @@ include("discretize.jl")
include("neural_adapter.jl")
include("advancedHMC_MCMC.jl")
include("BPINN_ode.jl")
include("PDE_BPINN.jl")

export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
KolmogorovParamDomain, NNParamKolmogorov,
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
WeightedIntervalTraining,
build_loss_function, get_loss_function,
generate_training_sets, get_variables, get_argument, get_bounds,
get_phi, get_numeric_derivative, get_numeric_integral,
build_symbolic_equation, build_symbolic_loss_function, symbolic_discretize,
AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss,
MiniMaxAdaptiveLoss,
LogOptions, ahmc_bayesian_pinn_ode, BNNODE
KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem,
KolmogorovParamDomain, NNParamKolmogorov,
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
WeightedIntervalTraining,
build_loss_function, get_loss_function,
generate_training_sets, get_variables, get_argument, get_bounds,
get_phi, get_numeric_derivative, get_numeric_integral,
build_symbolic_equation, build_symbolic_loss_function, symbolic_discretize,
AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss,
MiniMaxAdaptiveLoss, LogOptions,
ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters,
BPINNsolution

end # module
Loading

0 comments on commit 0687aaf

Please sign in to comment.