Skip to content

Modifies rateinterval API to accept computed urate. #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/aggregators/coevolve.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Queue method. This method handles variable intensity rates.
"""
mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <:
mutable struct CoevolveJumpAggregation{T, S, F1, F2, F3, RNG, GR, PQ} <:
AbstractSSAJumpAggregator
next_jump::Int # the next jump to execute
prev_jump::Int # the previous jump that was executed
Expand All @@ -18,7 +18,7 @@ mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <:
pq::PQ # priority queue of next time
lrates::F1 # vector of rate lower bound functions
urates::F1 # vector of rate upper bound functions
rateintervals::F1 # vector of interval length functions
rateintervals::F3 # vector of interval length functions
haslratevec::Vector{Bool} # vector of whether an lrate was provided for this vrj
end

Expand Down Expand Up @@ -46,7 +46,7 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not
end

pq = MutableBinaryMinHeap{T}()
CoevolveJumpAggregation{T, S, F1, F2, RNG, typeof(dg),
CoevolveJumpAggregation{T, S, F1, F2, typeof(rateintervals), RNG, typeof(dg),
typeof(pq)}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng,
dg, pq, lrates, urates, rateintervals, haslratevec)
end
Expand All @@ -58,14 +58,17 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps,
AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{Any}}
RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),
Tuple{typeof(u), typeof(p), typeof(t)}}
RateIntervalWrapper = FunctionWrappers.FunctionWrapper{typeof(t),
Tuple{typeof(u), typeof(p),
typeof(t), typeof(t)}}

ncrjs = (constant_jumps === nothing) ? 0 : length(constant_jumps)
nvrjs = (variable_jumps === nothing) ? 0 : length(variable_jumps)
nrjs = ncrjs + nvrjs
affects! = Vector{AffectWrapper}(undef, nrjs)
rates = Vector{RateWrapper}(undef, nvrjs)
lrates = similar(rates)
rateintervals = similar(rates)
rateintervals = Vector{RateIntervalWrapper}(undef, nvrjs)
urates = Vector{RateWrapper}(undef, nrjs)
haslratevec = zeros(Bool, nvrjs)

Expand All @@ -84,7 +87,7 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps,
urates[idx] = RateWrapper(vrj.urate)
idx += 1
rates[i] = RateWrapper(vrj.rate)
rateintervals[i] = RateWrapper(vrj.rateinterval)
rateintervals[i] = RateIntervalWrapper(vrj.rateinterval)
haslratevec[i] = haslrate(vrj)
lrates[i] = haslratevec[i] ? RateWrapper(vrj.lrate) : RateWrapper(nullrate)
end
Expand Down Expand Up @@ -143,8 +146,8 @@ end
@inbounds return p.urates[uidx](u, params, t)
end

@inline function get_rateinterval(p::CoevolveJumpAggregation, lidx, u, params, t)
@inbounds return p.rateintervals[lidx](u, params, t)
@inline function get_rateinterval(p::CoevolveJumpAggregation, lidx, u, params, t, urate)
@inbounds return p.rateintervals[lidx](u, params, t, urate)
end

@inline function get_lrate(p::CoevolveJumpAggregation, lidx, u, params, t)
Expand All @@ -171,7 +174,7 @@ function next_time(p::CoevolveJumpAggregation{T}, u, params, t, i, tstop::T) whe
_t = t + s
if lidx > 0
while t < tstop
rateinterval = get_rateinterval(p, lidx, u, params, t)
rateinterval = get_rateinterval(p, lidx, u, params, t, urate)
if s > rateinterval
t = t + rateinterval
urate = get_urate(p, uidx, u, params, t)
Expand Down
12 changes: 3 additions & 9 deletions test/hawkes_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,11 @@ function hawkes_jump(i::Int, g, h; uselrate = true)
urate = rate
if uselrate
lrate(u, p, t) = p[1]
rateinterval = (u, p, t) -> begin
_lrate = lrate(u, p, t)
_urate = urate(u, p, t)
return _urate == _lrate ? typemax(t) : 1 / (2 * _urate)
end
rateinterval = (u, p, t, urate) -> begin return urate == p[1] ? typemax(t) :
1 / (2 * urate) end
else
lrate = nothing
rateinterval = (u, p, t) -> begin
_urate = urate(u, p, t)
return 1 / (2 * _urate)
end
rateinterval = (u, p, t, urate) -> begin return 1 / (2 * urate) end
end
function affect!(integrator)
push!(h[i], integrator.t)
Expand Down