Skip to content

Commit

Permalink
Merge pull request #231 from lxvm/inftypes
Browse files Browse the repository at this point in the history
Refactor infinity substitution
  • Loading branch information
ChrisRackauckas authored Feb 14, 2024
2 parents 10d9e90 + cd069d8 commit d97ff6f
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 173 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Distributions = "0.25.87"
FastGaussQuadrature = "0.5,1"
FiniteDiff = "2.12"
ForwardDiff = "0.10.36"
HCubature = "1.5"
HCubature = "1.5.2"
LinearAlgebra = "1.10"
MCIntegration = "0.4.2"
MonteCarloIntegration = "0.2"
Expand Down
36 changes: 10 additions & 26 deletions ext/IntegralsCubatureExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,42 +66,42 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
end
end
elseif y isa AbstractArray
fsize = size(y)[begin:(end - 1)]
fdim = prod(fsize)
bfsize = size(y)[begin:(end - 1)]
bfdim = prod(bfsize)
if isinplace(prob)
# dx is a Matrix, but to provide a buffer of the same type as y, we make
# would like to make views of a larger buffer, but CubatureJL doesn't set
# a hard limit for max_batch, so we allocate a new buffer with the needed size
f = (x, dx) -> begin
dy = similar(y, fsize..., size(dx, 2))
dy = similar(y, bfsize..., size(dx, 2))
prob.f(dy, x, p)
dx .= reshape(dy, fdim, size(dx, 2))
dx .= reshape(dy, bfdim, size(dx, 2))
end
else
f = (x, dx) -> (dx .= reshape(prob.f(x, p), fdim, size(dx, 2)))
f = (x, dx) -> (dx .= reshape(prob.f(x, p), bfdim, size(dx, 2)))
end
if mid isa Number
if alg isa CubatureJLh
val_, err = Cubature.hquadrature_v(fdim, f, lb, ub;
val_, err = Cubature.hquadrature_v(bfdim, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
else
val_, err = Cubature.pquadrature_v(fdim, f, lb, ub;
val_, err = Cubature.pquadrature_v(bfdim, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
end
else
if alg isa CubatureJLh
val_, err = Cubature.hcubature_v(fdim, f, lb, ub;
val_, err = Cubature.hcubature_v(bfdim, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
else
val_, err = Cubature.pcubature_v(fdim, f, lb, ub;
val_, err = Cubature.pcubature_v(bfdim, f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
end
end
val = reshape(val_, fsize...)
val = reshape(val_, bfsize...)
else
error("BatchIntegralFunction integrands must be arrays for Cubature.jl")
end
Expand Down Expand Up @@ -165,22 +165,6 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
error("IntegralFunctions must be scalars or arrays for Cubature.jl")
end
end

#=
nout = prob.nout
if nout == 1
# the output of prob.f could be either scalar or a vector of length 1, however
# the behavior of the output of the integration routine is undefined (could differ
# across algorithms)
# Cubature will output a real number in when called without nout/fdim
if prob.batch == 0
if isinplace(prob)
dx = zeros(eltype(lb), prob.nout)
@@ -181,6 +334,7 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
end
end
end
=#
SciMLBase.build_solution(prob, alg, val, err, retcode = ReturnCode.Success)
end

Expand Down
14 changes: 7 additions & 7 deletions src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, domain, p
if isinplace(prob)
# quadgk only works with vector buffers. If the buffer is an array, we have to
# turn it into a vector of arrays
u = prob.f.integrand_prototype
f = if u isa AbstractVector
BatchIntegrand((y, x) -> prob.f(y, x, p), similar(u))
bu = prob.f.integrand_prototype
f = if bu isa AbstractVector
BatchIntegrand((y, x) -> prob.f(y, x, p), similar(bu))
else
fsize = size(u)[begin:(end - 1)]
BatchIntegrand{Array{eltype(u),ndims(u)-1}}() do y, x
y_ = similar(u, fsize..., length(y))
fsize = size(bu)[begin:(end - 1)]
BatchIntegrand{Array{eltype(bu),ndims(bu)-1}}() do y, x
y_ = similar(bu, fsize..., length(y))
prob.f(y_, x, p)
map!(collect, y, eachslice(y_; dims=ndims(u)))
map!(collect, y, eachslice(y_; dims=ndims(bu)))
return nothing
end
end
Expand Down
158 changes: 80 additions & 78 deletions src/infinity_handling.jl
Original file line number Diff line number Diff line change
@@ -1,113 +1,116 @@
_oftype(x, y) = oftype(x, y)
_oftype(::SubArray, y) = y

function substitute_bounds(lb, ub)
mid = (lb + ub) / 2 # floating-point promotion
if isinf(lb) && isinf(ub)
lb < 0 || error("Positive infinite lower bound not supported.")
ub > 0 || error("Negative infinite lower bound not supported.")
lb_sub = -one(lb)
ub_sub = one(lb)
lb_sub = flipsign(one(mid), lb)
ub_sub = flipsign(one(mid), ub)
elseif isinf(lb)
lb < 0 || error("Positive infinite lower bound not supported.")
lb_sub = -one(lb)
ub_sub = zero(lb)
lb_sub = flipsign(one(mid), lb)
ub_sub = zero(one(mid))
elseif isinf(ub)
ub > 0 || error("Positive infinite lower bound not supported.")
lb_sub = zero(lb)
ub_sub = one(lb)
lb_sub = zero(one(mid))
ub_sub = flipsign(one(mid), ub)
else
lb_sub = lb
ub_sub = ub
lb_sub = -one(mid)
ub_sub = one(mid)
end
return lb_sub, ub_sub
return lb_sub, ub_sub # unitless
end
function substitute_f(t, p, f, lb::Number, ub::Number)

function substitute_t(t::Number, lb::Number, ub::Number)
u = oneunit(eltype(lb))
# apply correct units
if isinf(lb) && isinf(ub)
return f(t / (1 - t^2), p) * (1 + t^2) / (1 - t^2)^2
den = inv(1 - t^2)
t * den * u, (1 + t^2) * den^2 * u
elseif isinf(lb)
return f(ub + (t / (1 + t)), p) * 1 / ((1 + t)^2)
den = inv(1 - flipsign(t, lb))
ub + t * den * u, den^2 * u
elseif isinf(ub)
return f(lb + (t / (1 - t)), p) * 1 / ((1 - t)^2)
den = inv(1 - flipsign(t, ub))
lb + t * den * u, den^2 * u
else
return f(t, p)
den = (ub - lb) * oftype(t, 0.5)
lb - t * den, den
end
end
function substitute_f_iip(dt, dy, t, p, f, lb::Number, ub::Number)
if isinf(lb) && isinf(ub)
f(dt, t / (1 - t^2), p)
dt .= dy .* ((1 + t^2) / (1 - t^2)^2)
elseif isinf(lb)
return f(ub + (t / (1 + t)), p) * 1 / ((1 + t)^2)
elseif isinf(ub)
return f(lb + (t / (1 - t)), p) * 1 / ((1 - t)^2)
else
return f(t, p)
function substitute_t(t::AbstractVector, lb::AbstractVector, ub::AbstractVector)
x = similar(t, typeof(one(eltype(t))*(first(lb)+first(ub))))
jac = one(eltype(t))
for i in eachindex(lb)
x[i], dj = substitute_t(t[i], lb[i], ub[i])
jac *= dj
end
return _oftype(t, x), jac
end
function substitute_f(t, p, f, lb::AbstractVector, ub::AbstractVector)
x = similar(t)
jac_diag = similar(t)
for i in eachindex(lb)
if isinf(lb[i]) && isinf(ub[i])
x[i] = t[i] / (1 - t[i]^2)
jac_diag[i] = (1 + t[i]^2) / (1 - t[i]^2)^2
elseif isinf(lb[i])
x[i] = ub[i] + (t[i] / (1 + t[i]))
jac_diag[i] = 1 / ((1 + t[i])^2)
elseif isinf(ub[i])
x[i] = lb[i] + (t[i] / (1 - t[i]))
jac_diag[i] = 1 / ((1 - t[i])^2)
else
x[i] = t[i]
jac_diag[i] = one(lb[i])
end

function substitute_f(f::F, t, p, lb, ub) where {F}
x, jac = substitute_t(t, lb, ub)
return f(x, p) * jac
end
function substitute_f(f::F, dt, t, p, lb, ub) where {F}
x, jac = substitute_t(t, lb, ub)
f(dt, x, p)
dt .*= jac
return
end

function substitute_t(t::AbstractVector, lb::Number, ub::Number)
x = similar(t, typeof(one(eltype(t))*(lb+ub)))
jac = similar(x)
for (i, ti) in enumerate(t)
x[i], jac[i] = substitute_t(ti, lb, ub)
end
f(x, p) * prod(jac_diag)
return x, jac
end
function substitute_f_iip(dt, t, p, f, lb, ub)
x = similar(t)
jac_diag = similar(t)
for i in eachindex(lb)
if isinf(lb[i]) && isinf(ub[i])
x[i] = t[i] / (1 - t[i]^2)
jac_diag[i] = (1 + t[i]^2) / (1 - t[i]^2)^2
elseif isinf(lb[i])
x[i] = ub[i] + (t[i] / (1 + t[i]))
jac_diag[i] = 1 / ((1 + t[i])^2)
elseif isinf(ub[i])
x[i] = lb[i] + (t[i] / (1 - t[i]))
jac_diag[i] = 1 / ((1 - t[i])^2)
else
x[i] = t[i]
jac_diag[i] = one(lb[i])
end
function substitute_t(t::AbstractArray, lb::AbstractVector, ub::AbstractVector)
x = similar(t, typeof(one(eltype(t))*(first(lb)+first(ub))))
jac = similar(x, size(t, ndims(t)))
for (i, it) in enumerate(axes(t)[end])
x[axes(x)[begin:end-1]..., i], jac[i] = substitute_t(t[axes(t)[begin:end-1]..., it], lb, ub)
end
return x, jac
end

function substitute_batchf(f::F, t, p, lb, ub) where {F}
x, jac = substitute_t(t, lb, ub)
r = f(x, p)
return r .* reshape(jac, ntuple(d -> d==ndims(r) ? length(jac) : 1, ndims(r)))
end
function substitute_batchf(f::F, dt, t, p, lb, ub) where {F}
x, jac = substitute_t(t, lb, ub)
f(dt, x, p)
dt .*= prod(jac_diag)
for (i, j) in zip(axes(dt)[end], jac)
for idt in CartesianIndices(axes(dt)[begin:end-1])
dt[idt, i] *= j
end
end
return
end

function transformation_if_inf(prob, ::Val{true})
lb, ub = prob.domain
lb, ub = promote(prob.domain...)
f = prob.f
if lb isa Number
lb_sub, ub_sub = substitute_bounds(lb, ub)
else
bounds = substitute_bounds.(lb, ub)
lb_sub = first.(bounds)
ub_sub = last.(bounds)
end
bounds = map(substitute_bounds, lb, ub)
lb_sub = lb isa Number ? first(bounds) : map(first, bounds)
ub_sub = ub isa Number ? last(bounds) : map(last, bounds)
f_sub = if isinplace(prob)
if f isa BatchIntegralFunction
BatchIntegralFunction{true}((dt, t, p) -> substitute_f_iip(dt, t, p, f, lb, ub),
BatchIntegralFunction{true}(let f=f.f; (dt, t, p) -> substitute_batchf(f, dt, t, p, lb, ub); end,
f.integrand_prototype,
max_batch = f.max_batch)
else
IntegralFunction{true}((dt, t, p) -> substitute_f_iip(dt, t, p, f, lb, ub),
IntegralFunction{true}(let f=f.f; (dt, t, p) -> substitute_f(f, dt, t, p, lb, ub); end,
f.integrand_prototype)
end
else
if f isa BatchIntegralFunction
BatchIntegralFunction{false}((t, p) -> substitute_f(t, p, f, lb, ub),
BatchIntegralFunction{false}(let f=f.f; (t, p) -> substitute_batchf(f, t, p, lb, ub); end,
f.integrand_prototype)
else
IntegralFunction{false}((t, p) -> substitute_f(t, p, f, lb, ub),
IntegralFunction{false}(let f=f.f; (t, p) -> substitute_f(f, t, p, lb, ub); end,
f.integrand_prototype)
end
end
Expand All @@ -116,8 +119,7 @@ end

function transformation_if_inf(prob, ::Nothing)
lb, ub = prob.domain
if (lb isa Number && ub isa Number && (ub == Inf || lb == -Inf)) ||
-Inf in lb || Inf in ub
if any(isinf, lb) || any(isinf, ub)
return transformation_if_inf(prob, Val(true))
else
return transformation_if_inf(prob, Val(false))
Expand Down
Loading

0 comments on commit d97ff6f

Please sign in to comment.