Skip to content

Commit

Permalink
chore: clean up shadowing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi committed Jun 25, 2024
1 parent 939634f commit 6239ae1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/lss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ function shadow_forward(prob::ForwardLSSProblem, sensealg::ForwardLSS,

# windowing (cos)
ts = current_time(sol)
@. window = (ts - first(ts)) * convert(eltype(Δt), 2 * pi / Δt)
@. window = (ts - ts[1]) * convert(eltype(Δt), 2 * pi / Δt)
@. window = one(eltype(window)) - cos(window)
window ./= sum(window)

Expand Down Expand Up @@ -442,7 +442,7 @@ function shadow_forward(prob::ForwardLSSProblem, sensealg::ForwardLSS,

# windowing cos2
ts = current_time(sol)
@. window = (ts - first(ts)) * convert(eltype(Δt), 2 * pi / Δt)
@. window = (ts - ts[1]) * convert(eltype(Δt), 2 * pi / Δt)
@. window = (one(eltype(window)) - cos(window))^2
window ./= sum(window)

Expand Down
7 changes: 4 additions & 3 deletions src/nilsas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ function NILSASProblem(sol, sensealg::NILSAS, alg;
t = nothing, dgdu_discrete = nothing, dgdp_discrete = nothing,
dgdu_continuous = nothing, dgdp_continuous = nothing, g = sensealg.g,
kwargs...)

@unpack tspan, f = sol.prob
p = parameter_values(sol)
u0 = state_values(sol)
p = parameter_values(sol.prob)
u0 = state_values(sol.prob)
tunables, repack, aliases = canonicalize(Tunable(), p)
@unpack nseg, nstep, rng, adjoint_sensealg, M = sensealg #number of segments on time interval, number of steps saved on each segment

Expand Down Expand Up @@ -303,7 +304,7 @@ function adjoint_sense(prob::NILSASProblem, nilsas::NILSAS, alg; kwargs...)
checkpoints = checkpoints, z0 = z0, M = M, nilss = nilss,
tspan = (t1, t2), kwargs...)
_sol = solve(_prob, alg; save_everystep = false, save_start = false,
saveat = eltype(state_values(sol, 1))[],
saveat = eltype(state_values(sol.prob))[],
dt = dtsave,
tstops = checkpoints,
callback = cb,
Expand Down

0 comments on commit 6239ae1

Please sign in to comment.