diff --git a/docs/Project.toml b/docs/Project.toml index b3c1ed2088..462d2f4fa2 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" @@ -6,8 +7,10 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Integrals = "de52edbc-65ea-441a-8357-d3a637375a31" IntegralsCubature = "c31f79ba-6e32-46d4-a52f-182a8ac42a54" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" +MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" NeuralPDE = "315f7962-48a3-4962-8226-d0f33b1235f0" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" @@ -20,6 +23,7 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [compat] +AdvancedHMC = "0.5" DiffEqBase = "6.106" Documenter = "1" DomainSets = "0.6" @@ -28,6 +32,7 @@ Integrals = "3.3" IntegralsCubature = "=0.2.2" Lux = "0.4, 0.5" ModelingToolkit = "8.33" +MonteCarloMeasurements = "1" NeuralPDE = "5.3" Optimization = "3.9" OptimizationOptimJL = "0.1" diff --git a/docs/pages.jl b/docs/pages.jl index 60df4edd60..8ccf2de6a9 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -4,6 +4,7 @@ pages = ["index.md", #"examples/nnrode_example.md", # currently incorrect ], "PDE PINN Tutorials" => Any["Introduction to NeuralPDE for PDEs" => "tutorials/pdesystem.md", + "Bayesian PINNs for PDEs" => "tutorials/low_level_2.md", "Using GPUs" => "tutorials/gpu.md", "Defining Systems of PDEs" => "tutorials/systems.md", "Imposing Constraints" => "tutorials/constraints.md", @@ -21,6 +22,7 @@ pages = ["index.md", "examples/nonlinear_hyperbolic.md"], "Manual" => Any["manual/ode.md", "manual/pinns.md", + "manual/bpinns.md", "manual/training_strategies.md", "manual/adaptive_losses.md", "manual/logging.md", diff --git a/docs/src/manual/bpinns.md b/docs/src/manual/bpinns.md new file mode 100644 index 0000000000..a5296938ac --- /dev/null +++ b/docs/src/manual/bpinns.md @@ -0,0 +1,22 @@ +# `BayesianPINN` Discretizer for PDESystems + +Using the Bayesian PINN solvers, we can solve general nonlinear PDEs, ODEs and also simultaneously perform parameter estimation on them. + +Note: The BPINN PDE solver also works for ODEs defined using ModelingToolkit, [ModelingToolkit.jl PDESystem documentation](https://docs.sciml.ai/ModelingToolkit/stable/systems/PDESystem/). Despite this, the ODE specific BPINN solver `BNNODE` [refer](https://docs.sciml.ai/NeuralPDE/dev/manual/ode/#NeuralPDE.BNNODE) exists and uses `NeuralPDE.ahmc_bayesian_pinn_ode` at a lower level. + +# `BayesianPINN` Discretizer for PDESystems and lower level Bayesian PINN Solver calls for PDEs and ODEs. + +```@docs +NeuralPDE.BayesianPINN +NeuralPDE.ahmc_bayesian_pinn_ode +NeuralPDE.ahmc_bayesian_pinn_pde +``` + +## `symbolic_discretize` for `BayesianPINN` and lower level interface. + +```@docs +SciMLBase.symbolic_discretize(::PDESystem, ::NeuralPDE.AbstractPINN) +NeuralPDE.BPINNstats +NeuralPDE.BPINNsolution +``` + diff --git a/docs/src/manual/pinns.md b/docs/src/manual/pinns.md index e1212e78f9..81a77fbcf9 100644 --- a/docs/src/manual/pinns.md +++ b/docs/src/manual/pinns.md @@ -29,10 +29,10 @@ NeuralPDE.Phi SciMLBase.discretize(::PDESystem, ::NeuralPDE.PhysicsInformedNN) ``` -## `symbolic_discretize` and the lower-level interface +## `symbolic_discretize` for `PhysicsInformedNN` and the lower-level interface ```@docs -SciMLBase.symbolic_discretize(::PDESystem, ::NeuralPDE.PhysicsInformedNN) +SciMLBase.symbolic_discretize(::PDESystem, ::NeuralPDE.AbstractPINN) NeuralPDE.PINNRepresentation NeuralPDE.PINNLossFunctions ``` diff --git a/docs/src/tutorials/low_level.md b/docs/src/tutorials/low_level.md index d60e3163dc..aad1347971 100644 --- a/docs/src/tutorials/low_level.md +++ b/docs/src/tutorials/low_level.md @@ -1,4 +1,4 @@ -# Investigating `symbolic_discretize` with the 1-D Burgers' Equation +# Investigating `symbolic_discretize` with the `PhysicsInformedNN` Discretizer for the 1-D Burgers' Equation Let's consider the Burgers' equation: diff --git a/docs/src/tutorials/low_level_2.md b/docs/src/tutorials/low_level_2.md new file mode 100644 index 0000000000..00929ec72f --- /dev/null +++ b/docs/src/tutorials/low_level_2.md @@ -0,0 +1,143 @@ +# Using `ahmc_bayesian_pinn_pde` with the `BayesianPINN` Discretizer for the Kuramoto–Sivashinsky equation + +Consider the Kuramoto–Sivashinsky equation: + +```math +∂_t u(x, t) + u(x, t) ∂_x u(x, t) + \alpha ∂^2_x u(x, t) + \beta ∂^3_x u(x, t) + \gamma ∂^4_x u(x, t) = 0 \, , +``` + +where $\alpha = \gamma = 1$ and $\beta = 4$. The exact solution is: + +```math +u_e(x, t) = 11 + 15 \tanh \theta - 15 \tanh^2 \theta - 15 \tanh^3 \theta \, , +``` + +where $\theta = t - x/2$ and with initial and boundary conditions: + +```math +\begin{align*} + u( x, 0) &= u_e( x, 0) \, ,\\ + u( 10, t) &= u_e( 10, t) \, ,\\ + u(-10, t) &= u_e(-10, t) \, ,\\ +∂_x u( 10, t) &= ∂_x u_e( 10, t) \, ,\\ +∂_x u(-10, t) &= ∂_x u_e(-10, t) \, . +\end{align*} +``` + +With Bayesian Physics-Informed Neural Networks, here is an example of using `BayesianPINN` discretization with `ahmc_bayesian_pinn_pde` : + +```@example low_level_2 +using NeuralPDE, Flux, Lux, ModelingToolkit, LinearAlgebra, AdvancedHMC +import ModelingToolkit: Interval, infimum, supremum, Distributions +using Plots, MonteCarloMeasurements + +@parameters x, t, α +@variables u(..) +Dt = Differential(t) +Dx = Differential(x) +Dx2 = Differential(x)^2 +Dx3 = Differential(x)^3 +Dx4 = Differential(x)^4 + +# α = 1 +β = 4 +γ = 1 +eq = Dt(u(x, t)) + u(x, t) * Dx(u(x, t)) + α * Dx2(u(x, t)) + β * Dx3(u(x, t)) + γ * Dx4(u(x, t)) ~ 0 + +u_analytic(x, t; z = -x / 2 + t) = 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3 +du(x, t; z = -x / 2 + t) = 15 / 2 * (tanh(z) + 1) * (3 * tanh(z) - 1) * sech(z)^2 + +bcs = [u(x, 0) ~ u_analytic(x, 0), + u(-10, t) ~ u_analytic(-10, t), + u(10, t) ~ u_analytic(10, t), + Dx(u(-10, t)) ~ du(-10, t), + Dx(u(10, t)) ~ du(10, t)] + +# Space and time domains +domains = [x ∈ Interval(-10.0, 10.0), + t ∈ Interval(0.0, 1.0)] + +# Discretization +dx = 0.4; +dt = 0.2; + +# Function to compute analytical solution at a specific point (x, t) +function u_analytic_point(x, t) + z = -x / 2 + t + return 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3 +end + +# Function to generate the dataset matrix +function generate_dataset_matrix(domains, dx, dt) + x_values = -10:dx:10 + t_values = 0.0:dt:1.0 + + dataset = [] + + for t in t_values + for x in x_values + u_value = u_analytic_point(x, t) + push!(dataset, [u_value, x, t]) + end + end + + return vcat([data' for data in dataset]...) +end + +datasetpde = [generate_dataset_matrix(domains, dx, dt)] + +# noise to dataset +noisydataset = deepcopy(datasetpde) +noisydataset[1][:, 1] = noisydataset[1][:, 1] .+ randn(size(noisydataset[1][:, 1])) .* 5 / 100 .* + noisydataset[1][:, 1] +``` + +Plotting dataset, added noise is set at 5%. +```@example low_level_2 +plot(datasetpde[1][:, 2], datasetpde[1][:, 1], title="Dataset from Analytical Solution") +plot!(noisydataset[1][:, 2], noisydataset[1][:, 1]) +``` + +```@example low_level_2 +# Neural network +chain = Lux.Chain(Lux.Dense(2, 8, Lux.tanh), + Lux.Dense(8, 8, Lux.tanh), + Lux.Dense(8, 1)) + +discretization = NeuralPDE.BayesianPINN([chain], + GridTraining([dx, dt]), param_estim = true, dataset = [noisydataset, nothing]) + +@named pde_system = PDESystem(eq, + bcs, + domains, + [x, t], + [u(x, t)], + [α], + defaults = Dict([α => 0.5])) + +sol1 = ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 100, Kernel = AdvancedHMC.NUTS(0.8), + bcstd = [0.2, 0.2, 0.2, 0.2, 0.2], + phystd = [1.0], l2std = [0.05], param = [Distributions.LogNormal(0.5, 2)], + priorsNNw = (0.0, 10.0), + saveats = [1 / 100.0, 1 / 100.0], progress = true) +``` + +And some analysis: + +```@example low_level_2 +phi = discretization.phi[1] +xs, ts = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, [dx / 10, dt])] +u_predict = [[first(pmean(phi([x, t], sol1.estimated_nn_params[1]))) for x in xs] + for t in ts] +u_real = [[u_analytic(x, t) for x in xs] for t in ts] +diff_u = [[abs(u_analytic(x, t) - first(pmean(phi([x, t], sol1.estimated_nn_params[1])))) + for x in xs] + for t in ts] + +p1 = plot(xs, u_predict, title = "predict") +p2 = plot(xs, u_real, title = "analytic") +p3 = plot(xs, diff_u, title = "error") +plot(p1, p2, p3) +``` diff --git a/src/BPINN_ode.jl b/src/BPINN_ode.jl index 7a013371e7..7c4be0f048 100644 --- a/src/BPINN_ode.jl +++ b/src/BPINN_ode.jl @@ -26,7 +26,7 @@ of the physics-informed neural network which is used as a solver for a standard * `Kernel`: Choice of MCMC Sampling Algorithm. Defaults to `AdvancedHMC.HMC` ## Keyword Arguments -(refer ahmc_bayesian_pinn_ode() keyword arguments.) +(refer `NeuralPDE.ahmc_bayesian_pinn_ode` keyword arguments.) ## Example diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 6f7dbaf839..7478b088df 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -67,6 +67,6 @@ export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE, AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss, MiniMaxAdaptiveLoss, LogOptions, ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters, - BPINNsolution + BPINNsolution, BayesianPINN end # module diff --git a/src/PDE_BPINN.jl b/src/PDE_BPINN.jl index 504e347eeb..b63741e9b6 100644 --- a/src/PDE_BPINN.jl +++ b/src/PDE_BPINN.jl @@ -295,7 +295,52 @@ function inference(samples, pinnrep, saveats, numensemble, ℓπ) end end -# priors: pdf for W,b + pdf for ODE params +""" +```julia +ahmc_bayesian_pinn_pde(pde_system, discretization; + draw_samples = 1000, + bcstd = [0.01], l2std = [0.05], + phystd = [0.05], priorsNNw = (0.0, 2.0), + param = [], nchains = 1, Kernel = HMC(0.1, 30), + Adaptorkwargs = (Adaptor = StanHMCAdaptor, + Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), + Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0], + numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false) +``` +## NOTES +* Dataset is required for accurate Parameter estimation + solving equations. +* Returned solution is a BPINNsolution consisting of Ensemble solution, estimated PDE and NN parameters + for chosen `saveats` grid spacing and last n = `numensemble` samples in Chain. the complete set of samples + in the MCMC chain is returned as `fullsolution`, refer `BPINNsolution` for more details. + +## Positional Arguments +* `pde_system`: ModelingToolkit defined PDE equation or system of equations. +* `discretization`: BayesianPINN discretization for the given pde_system, Neural Network and training strategy. + +## Keyword Arguments +* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are ~2/3 of draw samples) +* `bcstd`: Vector of standard deviations of BPINN prediction against Initial/Boundary Condition equations. +* `l2std`: Vector of standard deviations of BPINN prediction against L2 losses/Dataset for each dependant variable of interest. +* `phystd`: Vector of standard deviations of BPINN prediction against Chosen Underlying PDE equations. +* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of BPINN are Normal Distributions by default. +* `param`: Vector of chosen PDE's parameter's Distributions in case of Inverse problems. +* `nchains`: number of chains you want to sample + +# AdvancedHMC.jl is still developing convenience structs so might need changes on new releases. +* `Kernel`: Choice of MCMC Sampling Algorithm object HMC/NUTS/HMCDA (AdvancedHMC.jl implemenations ). +* `Adaptorkwargs`: `Adaptor`, `Metric`, `targetacceptancerate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/ + Note: Target percentage(in decimal) of iterations in which the proposals are accepted (0.8 by default) +* `Integratorkwargs`: `Integrator`, `jitter_rate`, `tempering_rate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/ +* `saveats`: Grid spacing for each independant variable for evaluation of ensemble solution, estimated parameters. +* `numensemble`: Number of last samples to take for creation of ensemble solution, estimated parameters. +* `progress`: controls whether to show the progress meter or not. +* `verbose`: controls the verbosity. (Sample call args in AHMC) + +""" + +""" +priors: pdf for W,b + pdf for PDE params +""" function ahmc_bayesian_pinn_pde(pde_system, discretization; draw_samples = 1000, bcstd = [0.01], l2std = [0.05], @@ -369,6 +414,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization; #ode parameter estimation nparameters = length(initial_θ) ninv = length(param) + # add init_params for NN params priors = [ MvNormal(priorsNNw[1] * ones(nparameters), LinearAlgebra.Diagonal(abs2.(priorsNNw[2] .* ones(nparameters)))), diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index 0a9c569cc4..1a2c47de0d 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -436,40 +436,34 @@ Incase you are only solving the Equations for solution, do not provide dataset ## Keyword Arguments * `strategy`: The training strategy used to choose the points for the evaluations. By default GridTraining is used with given physdt discretization. -* `dataset`: Vector containing Vectors of corresponding u,t values * `init_params`: intial parameter values for BPINN (ideally for multiple chains different initializations preferred) -* `nchains`: number of chains you want to sample (random initialisation of params by default) +* `nchains`: number of chains you want to sample * `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are ~2/3 of draw samples) -* `l2std`: standard deviation of BPINN predicition against L2 losses/Dataset -* `phystd`: standard deviation of BPINN predicition against Chosen Underlying ODE System -* `priorsNNw`: Vector of [mean, std] for BPINN parameter. Weights and Biases of BPINN are Normal Distributions by default +* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset +* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System +* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of BPINN are Normal Distributions by default. * `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems. * `autodiff`: Boolean Value for choice of Derivative Backend(default is numerical) * `physdt`: Timestep for approximating ODE in it's Time domain. (1/20.0 by default) # AdvancedHMC.jl is still developing convenience structs so might need changes on new releases. * `Kernel`: Choice of MCMC Sampling Algorithm (AdvancedHMC.jl implemenations HMC/NUTS/HMCDA) -* `Integratorkwargs`: A NamedTuple containing the chosen integrator and its keyword Arguments, as follows : - * `Integrator`: https://turinglang.org/AdvancedHMC.jl/stable/ - * `jitter_rate`: https://turinglang.org/AdvancedHMC.jl/stable/ - * `tempering_rate`: https://turinglang.org/AdvancedHMC.jl/stable/ -* `Adaptorkwargs`: A NamedTuple containing the chosen Adaptor, it's Metric and targetacceptancerate, as follows : - * `Adaptor`: https://turinglang.org/AdvancedHMC.jl/stable/ - * `Metric`: https://turinglang.org/AdvancedHMC.jl/stable/ - * `targetacceptancerate`: Target percentage(in decimal) of iterations in which the proposals were accepted(0.8 by default) +* `Integratorkwargs`: `Integrator`, `jitter_rate`, `tempering_rate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/ +* `Adaptorkwargs`: `Adaptor`, `Metric`, `targetacceptancerate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/ + Note: Target percentage(in decimal) of iterations in which the proposals are accepted (0.8 by default) * `MCMCargs`: A NamedTuple containing all the chosen MCMC kernel's(HMC/NUTS/HMCDA) Arguments, as follows : * `n_leapfrog`: number of leapfrog steps for HMC * `δ`: target acceptance probability for NUTS and HMCDA * `λ`: target trajectory length for HMCDA * `max_depth`: Maximum doubling tree depth (NUTS) * `Δ_max`: Maximum divergence during doubling tree (NUTS) + Refer: https://turinglang.org/AdvancedHMC.jl/stable/ * `progress`: controls whether to show the progress meter or not. * `verbose`: controls the verbosity. (Sample call args in AHMC) """ """ -dataset would be (x̂,t) priors: pdf for W,b + pdf for ODE params """ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; diff --git a/src/discretize.jl b/src/discretize.jl index 400be5d2c1..c8412b2d15 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -389,14 +389,15 @@ end """ ```julia -prob = symbolic_discretize(pde_system::PDESystem, discretization::PhysicsInformedNN) +prob = symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN) ``` `symbolic_discretize` is the lower level interface to `discretize` for inspecting internals. It transforms a symbolic description of a ModelingToolkit-defined `PDESystem` into a `PINNRepresentation` which holds the pieces required to build an `OptimizationProblem` -for [Optimization.jl](https://docs.sciml.ai/Optimization/stable) whose solution is the solution -to the PDE. +for [Optimization.jl](https://docs.sciml.ai/Optimization/stable) or a Likelihood Function +used for HMC based Posterior Sampling Algorithms [AdvancedHMC.jl](https://turinglang.org/AdvancedHMC.jl/stable/) +which is later optimized upon to give Solution or the Solution Distribution of the PDE. For more information, see `discretize` and `PINNRepresentation`. """ diff --git a/src/pinn_types.jl b/src/pinn_types.jl index ea66f725fd..48c8f46da9 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -17,7 +17,7 @@ end """This function is defined here as stubs to be overriden by the subpackage NeuralPDELogging if imported""" function logvector(logger, v::AbstractVector{R}, name::AbstractString, - step::Integer) where {R <: Real} + step::Integer) where {R <: Real} nothing end @@ -95,17 +95,17 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN kwargs::K @add_kwonly function PhysicsInformedNN(chain, - strategy; - init_params = nothing, - phi = nothing, - derivative = nothing, - param_estim = false, - additional_loss = nothing, - adaptive_loss = nothing, - logger = nothing, - log_options = LogOptions(), - iteration = nothing, - kwargs...) + strategy; + init_params = nothing, + phi = nothing, + derivative = nothing, + param_estim = false, + additional_loss = nothing, + adaptive_loss = nothing, + logger = nothing, + log_options = LogOptions(), + iteration = nothing, + kwargs...) multioutput = chain isa AbstractArray if phi === nothing @@ -134,23 +134,22 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), typeof(param_estim), typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain, - strategy, - init_params, - _phi, - _derivative, - param_estim, - additional_loss, - adaptive_loss, - logger, - log_options, - iteration, - self_increment, - multioutput, - kwargs) + strategy, + init_params, + _phi, + _derivative, + param_estim, + additional_loss, + adaptive_loss, + logger, + log_options, + iteration, + self_increment, + multioutput, + kwargs) end end - """ ```julia BayesianPINN(chain, @@ -163,10 +162,15 @@ BayesianPINN(chain, logger = nothing, log_options = LogOptions(), iteration = nothing, - dataset=nothing, + dataset = nothing, kwargs...) where {iip} ``` +A `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a +`PDESystem` into a likelihood function used for HMC based Posterior Sampling Algorithms [AdvancedHMC.jl](https://turinglang.org/AdvancedHMC.jl/stable/) +which is later optimized upon to give the Solution Distribution of the PDE, using the Physics-Informed Neural Networks (PINN) +methodology. + ## Positional Arguments * `chain`: a vector of Flux.jl or Lux.jl chains with a d-dimensional input and a @@ -177,6 +181,9 @@ BayesianPINN(chain, ## Keyword Arguments +* `Dataset`: A vector of matrix, each matrix for ith dependant + variable and first col in matrix is for dependant variables, + remaining coloumns for independant variables. Needed for inverse problem solving. * `init_params`: the initial parameters of the neural networks. This should match the specification of the chosen `chain` library. For example, if a Flux.chain is used, then `init_params` should match `Flux.destructure(chain)[1]` in shape. If `init_params` is not @@ -187,7 +194,7 @@ BayesianPINN(chain, of the neural network defining `phi`). By default, this is generated from the `chain`. This should only be used to more directly impose functional information in the training problem, for example imposing the boundary condition by the test function formulation. -* `adaptive_loss`: the choice for the adaptive loss function. See the +* `adaptive_loss`: (STILL WIP), the choice for the adaptive loss function. See the [adaptive loss page](@ref adaptive_loss) for more details. Defaults to no adaptivity. * `additional_loss`: a function `additional_loss(phi, θ, p_)` where `phi` are the neural network trial solutions, `θ` are the weights of the neural network(s), and `p_` are the diff --git a/test/BPINN_PDE_tests.jl b/test/BPINN_PDE_tests.jl index dc742a7af2..dae9072c38 100644 --- a/test/BPINN_PDE_tests.jl +++ b/test/BPINN_PDE_tests.jl @@ -171,7 +171,7 @@ chain = Lux.Chain(Lux.Dense(dim, 9, Lux.σ), Lux.Dense(9, 9, Lux.σ), Lux.Dense( # Discretization dx = 0.05 -discretization=NeuralPDE.BayesianPINN([chain], GridTraining(dx)) +discretization = NeuralPDE.BayesianPINN([chain], GridTraining(dx)) @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])