Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] ExpectationProblem interface #55

Merged
merged 47 commits into from
Aug 13, 2022
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
ee274b8
add test project toml
Apr 20, 2021
78dac26
initial type stable expectation prototype
Apr 20, 2021
b1d2139
type stable forwarddiff.gradient
Apr 20, 2021
87283d4
remove boxed counter hack in favor of type annotation. Make non-mutat…
Apr 20, 2021
0bcc3ca
add various zygote adjoints
Apr 23, 2021
4cf45df
seperated integrand function so callable outside DEU. Added nout and …
Apr 28, 2021
86aa308
add support for old interface
Apr 28, 2021
a612d20
general cleanup
Apr 28, 2021
e7273e2
add interface transform tests for old interface
Apr 29, 2021
2971537
add missing system dynamics
Apr 29, 2021
0c6b810
add JointPdf interface tests
Apr 29, 2021
677c5ab
add new expectation interface calling JointPdf dispatch and add new i…
Apr 29, 2021
89485ab
add missing file
Apr 29, 2021
2ca17cd
remove unneeded type information
Apr 30, 2021
2062cf9
refactor myintegrate
Apr 30, 2021
054cd7c
initial structure of ExpectationProblem
Apr 30, 2021
5038752
add convenience constructor for independent distributions for Generic…
Apr 30, 2021
66c331c
update GP rand return type
Apr 30, 2021
b7ebe8a
initial working ExpectationProblem
May 5, 2021
7365134
clean up and implementation of Monte Carlo
May 5, 2021
48999eb
remove OBE tests, add GenericDistribution and SystemMap tests
May 6, 2021
d5d5a64
add ExpectationProblem interface tests
May 6, 2021
17af0ff
reorganized tests
May 6, 2021
e741324
add nout to ExpectationProblem
May 6, 2021
6b7f279
change general distribution to use SVector instead of tuples. Start …
May 6, 2021
759aa46
Expectation tests passing for all quadrature and nout >= 1
May 6, 2021
1dbe5a0
clean up type signatures
May 7, 2021
9b28bbc
add MonteCarlo and Generic map tests
May 7, 2021
e564fea
update form of observable to g(x,p). Change general maps interface to…
May 7, 2021
cc023a7
add commented test
May 7, 2021
8bfaa89
add solve inferrence tests
May 7, 2021
03bbc40
reorganization and general cleanup. Added some comments
May 7, 2021
d7a5278
update build_integrand and tests
May 7, 2021
61e5779
move build_integrand to make Koopman() available
May 7, 2021
df42709
Merge branch 'master' into agerlach/ExpectationProblem
agerlach May 10, 2021
235b5fd
add batch mode support for DiffEq problems
May 11, 2021
50024a8
Merge branch 'agerlach/ExpectationProblem' of https://github.com/ager…
May 11, 2021
d26129d
bump versions and remove system map
ChrisRackauckas Aug 13, 2022
e553fd6
Format
ChrisRackauckas Aug 13, 2022
8b97077
bring back system map
ChrisRackauckas Aug 13, 2022
b82abdd
remove sm from tests
ChrisRackauckas Aug 13, 2022
2c990d2
add back problem type specialization
ChrisRackauckas Aug 13, 2022
a3bd13a
get tests passing again
ChrisRackauckas Aug 13, 2022
c94c730
output struct
ChrisRackauckas Aug 13, 2022
bd745e6
add differentiation tests
ChrisRackauckas Aug 13, 2022
2f5d90b
Merge branch 'master' into agerlach/ExpectationProblem
ChrisRackauckas Aug 13, 2022
8abd919
tests pass
ChrisRackauckas Aug 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
*.jl.*.cov
*.jl.mem
Manifest.toml
tests/Manifest.toml
21 changes: 6 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@ version = "1.8.0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
KernelDensity = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Quadrature = "67601950-bd08-11e9-3c89-fd23fb4432d2"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
DiffEqBase = "6"
Expand All @@ -19,18 +25,3 @@ Quadrature = "0.1, 1.0"
Reexport = "0.2, 1.0"
julia = "1"

[extras]
Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
DiffEqGPU = "071ae1c0-96b5-11e9-1965-c90190d839ea"
DiffEqProblemLibrary = "a077e3f3-b75c-5d7f-a0c6-6bc4c8ec64a9"
DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["DiffEqProblemLibrary", "Test", "OrdinaryDiffEq", "Cubature", "Cuba", "FiniteDiff", "ForwardDiff", "Zygote", "DiffEqGPU", "DiffEqSensitivity", "LinearAlgebra"]
28 changes: 24 additions & 4 deletions src/DiffEqUncertainty.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
module DiffEqUncertainty

using DiffEqBase, Statistics, Distributions, Reexport
# LinearAlgebra
using DiffEqBase, Statistics, Reexport, RecursiveArrayTools, StaticArrays,
Distributions, KernelDensity, Zygote, LinearAlgebra, Random
using Parameters: @unpack

@reexport using Quadrature
using KernelDensity
import DiffEqBase: solve

include("expectation/system_utils.jl")
include("expectation/distribution_utils.jl")
include("expectation/problem_types.jl")
include("expectation/expectation.jl")

include("probints.jl")
include("koopman.jl")


# Type Piracy, should upstream
Base.eltype(K::UnivariateKDE) = eltype(K.density)
Base.minimum(K::UnivariateKDE) = minimum(K.x)
Base.maximum(K::UnivariateKDE) = maximum(K.x)
Base.extrema(K::UnivariateKDE) = minimum(K), maximum(K)

Base.minimum(d::AbstractMvNormal) = fill(-Inf, length(d))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines are technically type piracy right? Also, they might not be correct for any AbstractNormal that accepts singular covariance matrices, e.g., Diagonal([0, 1]) does not have infinite support for the first variable.

Copy link
Contributor Author

@agerlach agerlach May 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the UnivariateKDE parts should be upstreamed. Base.minimum(d::AbstractMvNormal) was upstreamed in Distributions in JuliaStats/Distributions.jl#1319, probably need to double check the PR discussion to make sure a bug wasn't introduced.

Base.maximum(d::AbstractMvNormal) = fill(Inf, length(d))
Base.extrema(d::AbstractMvNormal) = minimum(d), maximum(d)

Base.minimum(d::Product) = minimum.(d.v)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had these here as a stop gap until this PR when through JuliaStats/Distributions.jl#1319

Base.maximum(d::Product) = maximum.(d.v)
Base.extrema(d::Product) = minimum(d), maximum(d)

export ProbIntsUncertainty,AdaptiveProbIntsUncertainty
export expectation, centralmoment, Koopman, MonteCarlo

export Koopman, MonteCarlo, PrefusedAD, PostfusedAD, NonfusedAD
export GenericDistribution, SystemMap, ExpectationProblem, build_integrand

end
27 changes: 27 additions & 0 deletions src/expectation/distribution_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
## GenericDistribution wrapper
# Defines a generic distribution that just wraps functions for pdf function, rand and bounds.
# User can use this for define any arbitray joint pdf
struct GenericDistribution{TF, TRF, TLB, TUB}
pdf_func::TF
rand_func::TRF
lb::TLB
ub::TUB
end

# Forms joint pdf for independent univariate distributions
# included b/c Distributions.jl Product method of mixed distirbutions are type instable
function GenericDistribution(d, ds...)
dists = (d, ds...)
pdf_func(x) = exp(sum(logpdf(f,y) for (f,y) in zip(dists, x)))
rand_func() = [rand(d) for d in dists]
lb = SVector(map(minimum, dists)...)
ub = SVector(map(maximum, dists)...)

GenericDistribution(pdf_func, rand_func, lb, ub)
end

Distributions.pdf(d::GenericDistribution, x) = d.pdf_func(x)
Base.minimum(d::GenericDistribution) = d.lb
Base.maximum(d::GenericDistribution) = d.ub
Base.extrema(d::GenericDistribution) = minimum(d), maximum(d)
Base.rand(d::GenericDistribution) = d.rand_func()
212 changes: 212 additions & 0 deletions src/expectation/expectation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
abstract type AbstractExpectationADAlgorithm end
struct NonfusedAD <: AbstractExpectationADAlgorithm end
struct PrefusedAD <: AbstractExpectationADAlgorithm
norm_partials::Bool
end
PrefusedAD() = PrefusedAD(true)
struct PostfusedAD <: AbstractExpectationADAlgorithm
norm_partials::Bool
end
PostfusedAD() = PostfusedAD(true)

abstract type AbstractExpectationAlgorithm <: DiffEqBase.DEAlgorithm end
struct Koopman{TS} <:AbstractExpectationAlgorithm where TS<:AbstractExpectationADAlgorithm
sensealg::TS
end
Koopman() = Koopman(NonfusedAD())
struct MonteCarlo <: AbstractExpectationAlgorithm
trajectories::Int
end

# Builds integrand for arbitrary functions
# TODO return in and out of place functions as tuple?
function build_integrand(prob::ExpectationProblem, ::Koopman, ::Val{false})
@unpack g, d = prob
function(x,p)
g(x,p)*pdf(d,x)
end
end

# Builds integrand for DEProblems
function build_integrand(prob::ExpectationProblem{F}, ::Koopman, ::Val{false}) where F<:SystemMap
@unpack S, g, h, d = prob
function(x,p)
uΜ„, pΜ„ = h(x, p.x[1], p.x[2])
g(S(uΜ„,pΜ„), pΜ„)*pdf(d,x)
end
end

function _make_view(x::Union{Vector{T}, Adjoint{T, Vector{T}}}, i) where T
@view x[i]
end

function _make_view(x, i)
@view x[:,i]
end

function build_integrand(prob::ExpectationProblem{F}, ::Koopman, ::Val{true}) where F<:SystemMap
@unpack S, g, h, d = prob

if prob.nout == 1 # TODO fix upstream in quadrature, expected sizes depend on quadrature method is requires different copying based on nout > 1
set_result! = @inline function(dx, sol)
dx[:] .= sol[:]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about the semantics of sol[:], but it looks like there are unnecessary allocations from the [:] here.
Since they are marked @inline, I guess performance matters here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_result! ends up formating the solution of an EnsembleSolve to the format required by Quadrature.jl. This is getting called once per batch integrand evaluation. I'm not sure how to clean up the allocation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also highlights an interface issue with Quadrature (which I am probably guilty of creating), i.e. we shouldn't really need separate set_result! for scalar and vector valued integrands.

end
else
set_result! = @inline function (dx, sol)
dx .= reshape(sol[:,:], size(dx))
end
end

prob_func = function (prob, i, repeat, x) # TODO is it better to make prob/output funcs outside of integrand, then call w/ closure?
u0, p = h((_make_view(x,i)), prob.u0, prob.p)
remake(prob, u0=u0, p=p)
end

output_func(sol, i, x) = (g(sol,sol.prob.p)*pdf(d, (_make_view(x,i))), false)

function(dx, x, p) where T
trajectories = size(x,2)
# TODO How to inject ensemble method in solve? currently in SystemMap, but does that make sense?
ensprob = EnsembleProblem(S.prob; output_func= (sol, i) -> output_func(sol, i, x),
prob_func= (prob, i, repeat) -> prob_func(prob, i, repeat, x))
sol = solve(ensprob, S.args...; trajectories = trajectories, S.kwargs...)
set_result!(dx, sol)
nothing
end
end

#TODO need common return type for koopman() and MonteCarlo() solves
# solve expectation problem of generic callable functions via MonteCarlo
function DiffEqBase.solve(exprob::ExpectationProblem, expalg::MonteCarlo)
params = parameters(exprob)
dist = distribution(exprob)
g = observable(exprob)
mean(g(rand(dist), params) for _ ∈ 1:expalg.trajectories)
end

# solve expectation over DEProblem via MonteCarlo
function DiffEqBase.solve(exprob::ExpectationProblem{F}, expalg::MonteCarlo) where F<:SystemMap
d = distribution(exprob)
cov = input_cov(exprob)
S = mapping(exprob)
g = observable(exprob)

prob_func = function (prob, i, repeat)
u0, p = cov(rand(d), prob.u0, prob.p)
remake(prob, u0=u0, p=p)
end

output_func(sol, i) = (g(sol,sol.prob.p), false)

monte_prob = EnsembleProblem(S.prob;
output_func=output_func,
prob_func=prob_func)
sol = solve(monte_prob, S.args...;trajectories=expalg.trajectories,S.kwargs...)
mean(sol.u)
end

# Solve Koopman expectation
function DiffEqBase.solve(prob::ExpectationProblem, expalg::Koopman, args...;
maxiters=1000000,
batch=0,
quadalg=HCubatureJL(),
ireltol=1e-2, iabstol=1e-2,
kwargs...) where {A<:AbstractExpectationADAlgorithm}

integrand = build_integrand(prob, expalg, Val(batch > 1))
lb, ub = extrema(prob.d)

sol = integrate(quadalg, expalg.sensealg, integrand, lb, ub, prob.params;
reltol=ireltol, abstol=iabstol, maxiters=maxiters,
nout = prob.nout, batch = batch,
kwargs...)

return sol
end


# Integrate function to test new Adjoints, will need to roll up to Quadrature.jl
function integrate(quadalg, adalg::AbstractExpectationADAlgorithm, f, lb::TB, ub::TB, p;
nout = 1, batch = 0,
kwargs...) where {TB}
#TODO check batch iip type stability w/ QuadratureProblem{XXXX}
prob = QuadratureProblem{batch > 1}(f,lb,ub,p; nout = nout, batch = batch)
solve(prob, quadalg; kwargs...)
end

# defines adjoint via βˆ«βˆ‚/βˆ‚p f(x,p) dx
Zygote.@adjoint function integrate(quadalg, adalg::NonfusedAD, f::F, lb::T, ub::T, params::P;
nout = 1, batch = 0, norm = norm,
kwargs...) where {F,T,P}

primal = integrate(quadalg, adalg, f, lb, ub, params;
norm = norm, nout = nout, batch = batch,
kwargs...)

function integrate_pullbacks(Ξ”)
function dfdp(x,params)
_,back = Zygote.pullback(p->f(x,p),params)
back(Ξ”)[1]
end
βˆ‚p = integrate(quadalg, adalg, dfdp, lb, ub, params;
norm = norm, nout = nout*length(params), batch = batch,
kwargs...)
# βˆ‚lb = -f(lb,params) #needs correct for dim > 1
# βˆ‚ub = f(ub,params)
return nothing, nothing, nothing, nothing, nothing, βˆ‚p
end
primal, integrate_pullbacks
end

# defines adjoint via ∫[f(x,p; βˆ‚/βˆ‚p f(x,p)] dx, ie it fuses the primal, post the primal calculation
# has flag to only compute quad norm with respect to only the primal in the pull-back. Gives same quadrature points as doing forwarddiff
Zygote.@adjoint function integrate(quadalg, adalg::PostfusedAD, f::F, lb::T, ub::T, params::P;
nout = 1, batch = 0, norm = norm,
kwargs...) where {F,T,P}

primal = integrate(quadalg, adalg, f, lb, ub, params;
norm = norm, nout = nout, batch = batch,
kwargs...)

_norm = adalg.norm_partials ? norm : primalnorm(nout, norm)

function integrate_pullbacks(Ξ”)
function dfdp(x,params)
y, back = Zygote.pullback(p->f(x,p),params)
[y; back(Ξ”)[1]] #TODO need to match proper arrray type? promote_type???
end
βˆ‚p = integrate(quadalg, adalg, dfdp, lb, ub, params;
norm = _norm, nout = nout + nout*length(params), batch = batch,
kwargs...)
return nothing, nothing, nothing, nothing, nothing, @view βˆ‚p[(nout+1):end]
end
primal, integrate_pullbacks
end

# Fuses primal and partials prior to pullback, I doubt this will stick around based on required system evals.
Zygote.@adjoint function integrate(quadalg, adalg::PrefusedAD, f::F, lb::T, ub::T, params::P;
nout = 1, batch = 0, norm = norm,
kwargs...) where {F,T,P}
# from Seth Axen via Slack
# Does not work w/ ArrayPartition unless with following hack
# Base.similar(A::ArrayPartition, ::Type{T}, dims::NTuple{N,Int}) where {T,N} = similar(Array(A), T, dims)
# TODO add ArrayPartition similar fix upstream, see https://github.com/SciML/RecursiveArrayTools.jl/issues/135
βˆ‚f_βˆ‚params(x, params) = only(Zygote.jacobian(p -> f(x, p), params))
f_augmented(x, params) = [f(x, params); βˆ‚f_βˆ‚params(x, params)...] #TODO need to match proper arrray type? promote_type???
_norm = adalg.norm_partials ? norm : primalnorm(nout, norm)

res = integrate(quadalg, adalg, f_augmented, lb, ub, params;
norm = _norm, nout = nout + nout*length(params), batch = batch,
kwargs...)
primal = first(res)
function integrate_pullback(Ξ”y)
βˆ‚params = Ξ”y .* conj.(@view(res[(nout+1):end]))
return nothing, nothing, nothing, nothing, nothing, βˆ‚params
end
primal, integrate_pullback
end

# define norm function based only on primal part of fused integrand
function primalnorm(nout, fnorm)
x->fnorm(@view x[1:nout])
end
41 changes: 41 additions & 0 deletions src/expectation/problem_types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

abstract type AbstractUncertaintyProblem end

struct ExpectationProblem{TS, TG, TH, TF, TP} <: AbstractUncertaintyProblem
# defines ∫ g(S(h(x,u0,p)))*f(x)dx
# 𝕏 = uncertainty space, π•Œ = Initial condition space, β„™ = model parameter space,
S::TS # mapping, S: π•Œ Γ— β„™ β†’ π•Œ
g::TG # observable(output_func), g: π•Œ Γ— β„™ β†’ β„βΏα΅’α΅˜α΅—
h::TH # cov(input_func), h: 𝕏 Γ— π•Œ Γ— β„™ β†’ π•Œ Γ— β„™
d::TF # distribution, pdf(d,x): 𝕏 β†’ ℝ
params::TP
nout::Int
end

# Constructor for general maps/functions
function ExpectationProblem(g, pdist, params; nout = 1)
h(x,u,p) = x, p
S(x,p) = x
ExpectationProblem(S, g, h, pdist, params, nout)
end

# Constructor for DEProblems
function ExpectationProblem(sm::SystemMap, g, h, d; nout = 1)
ExpectationProblem(sm, g, h, d,
ArrayPartition(deepcopy(sm.prob.u0),deepcopy(sm.prob.p)),
nout)
end

distribution(prob::ExpectationProblem) = prob.d
mapping(prob::ExpectationProblem) = prob.S
observable(prob::ExpectationProblem) = prob.g
input_cov(prob::ExpectationProblem) = prob.h
parameters(prob::ExpectationProblem) = prob.params

##
# struct CentralMomentProblem
# ns::NTuple{Int,N}
# altype::Union{NestedExpectation, BinomialExpansion} #Should rely be in solve
# exp_prob::ExpectationProblem
# end

14 changes: 14 additions & 0 deletions src/expectation/system_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#Callable wrapper for DE solves. Enables seperation of args/kwargs...
struct SystemMap{DT<:DiffEqBase.DEProblem,A,K}
prob::DT
args::A
kwargs::K
end
SystemMap(prob, args...; kwargs...) = SystemMap(prob, args, kwargs)

function (sm::SystemMap{DT})(u0,p) where DT
prob::DT = remake(sm.prob,
u0 = convert(typeof(sm.prob.u0),u0),
p = convert(typeof(sm.prob.p), p))
solve(prob, sm.args...; sm.kwargs...)
end
Loading