Skip to content

Commit

Permalink
Merge pull request #12 from lxvm/fixes
Browse files Browse the repository at this point in the history
update tests
  • Loading branch information
ArnoStrouwen authored Feb 23, 2024
2 parents 05b61b2 + 3202674 commit 7336c22
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/expectation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ function build_integrand(prob::ExpectationProblem{F}, ::Koopman, mid, p,
batch::Integer) where {F <: SystemMap}
@unpack S, g, h, d = prob

prob_func = function (prob, i, repeat, x) # TODO is it better to make prob/output funcs outside of integrand, then call w/ closure?
u0, p = h((_make_view(x, i)), prob.u0, prob.p)
remake(prob, u0 = u0, p = p)
function prob_func(prob, i, repeat, x) # TODO is it better to make prob/output funcs outside of integrand, then call w/ closure?
u0, _p = h((_make_view(x, i)), prob.u0, prob.p)
remake(prob, u0 = u0, p = _p)
end

output_func(sol, i, x) = (g(sol, sol.prob.p) * pdf(d, (_make_view(x, i))), false)

function integrand_koopman_systemmap_batch(x, p)
function integrand_koopman_systemmap_batch(x, _)
trajectories = size(x)[end]
# TODO How to inject ensemble method in solve? currently in SystemMap, but does that make sense?
ensprob = EnsembleProblem(S.prob; output_func = (sol, i) -> output_func(sol, i, x),
Expand Down
14 changes: 7 additions & 7 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ end end
DEU.parameters)
dists = (Uniform(1, 2), Uniform(3, 4), truncated(Normal(0, 1), -5, 5))
gd = GenericDistribution(dists...)
x = [mean(d) for d in dists]
@testset "DiffEq" begin
x = [mean(d) for d in dists]
h(x, u, p) = x, p
prob = ODEProblem(eoms[1], u0s[1], tspan, ps[1])
sm = SystemMap(prob, Tsit5(); saveat = 1.0)
Expand All @@ -66,21 +66,21 @@ end end
for foo in getters
@constinferred foo(ep)
end
f = @constinferred build_integrand(ep, Koopman(), Val(false))
f = @constinferred build_integrand(ep, Koopman(), x, DEU.parameters(ep), nothing)
@constinferred f(x, DEU.parameters(ep))

fbatch = @constinferred build_integrand(ep, Koopman(), Val(true))
fbatch = #= @constinferred =# build_integrand(ep, Koopman(), x, DEU.parameters(ep), 10)
y = reshape(repeat(x, outer = 5), :, 5)
dy = similar(y[1, :])
@constinferred fbatch(dy, y, DEU.parameters(ep))

# nout > 1
g2(soln, p) = [soln[1, end], soln[2, end]]
ep = @constinferred ExpectationProblem(sm, g2, h, gd; nout = 2)
f = @constinferred build_integrand(ep, Koopman(), Val(false))
ep = @constinferred ExpectationProblem(sm, g2, h, gd)
f = @constinferred build_integrand(ep, Koopman(), x, DEU.parameters(ep), nothing)
@constinferred f(x, DEU.parameters(ep))

fbatch = @constinferred build_integrand(ep, Koopman(), Val(true))
fbatch = #= @constinferred =# build_integrand(ep, Koopman(), x, DEU.parameters(ep), 10)
y = reshape(repeat(x, outer = 5), :, 5)
dy = similar(y[1:2, :])
@constinferred fbatch(dy, y, DEU.parameters(ep))
Expand All @@ -91,7 +91,7 @@ end end
for foo in getters
@constinferred foo(ep)
end
f = @constinferred build_integrand(ep, Koopman(), Val(false))
f = @constinferred build_integrand(ep, Koopman(), x, DEU.parameters(ep), nothing)
@constinferred f([0.0, 1.0, 2.0], DEU.parameters(ep))
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/processnoise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ prob = SDEProblem(f, g, u0, (0.0, 1.0), noise = W)
sm = ProcessNoiseSystemMap(prob, 8, LambaEM(), abstol = 1e-3, reltol = 1e-3)
cov(x, u, p) = x, p
observed(sol, p) = sol[:, end]
exprob = ExpectationProblem(sm, observed, cov; nout = length(u0))
exprob = ExpectationProblem(sm, observed, cov)
sol1 = solve(exprob, Koopman(), ireltol = 1e-3, iabstol = 1e-3, batch = 64,
quadalg = CubaDivonne())
sol2 = solve(exprob, MonteCarlo(1_000_000))
Expand Down

0 comments on commit 7336c22

Please sign in to comment.