Skip to content

Commit

Permalink
Merge pull request #706 from AayushSabharwal/as/hotfix-hotfix
Browse files Browse the repository at this point in the history
test: fix several downstream test failures
  • Loading branch information
ChrisRackauckas authored Jun 7, 2024
2 parents 6b220f9 + 5112e27 commit ff83ff0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 4 deletions.
45 changes: 45 additions & 0 deletions ext/SciMLBaseZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,32 @@ end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function Base.getindex(VA::ODESolution, sym, ::Colon)
function ODESolution_getindex_pullback(Δ)
i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
if is_observed(VA, sym)
f = observed(VA, sym)
p = parameter_values(VA)
tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
u = state_values(VA)
t = current_time(VA)
y, back = Zygote.pullback(u, tunables) do u, tunables
f.(u, Ref(tunables), t)
end
gs = back(Δ)
(gs[1], nothing)
elseif i === nothing
throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
else
@show i
Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
for (x, j) in zip(VA.u, 1:length(VA))]
(Δ′, nothing)
end
end
VA[sym, :], ODESolution_getindex_pullback
end

function obs_grads(VA, sym, obs_idx, Δ)
y, back = Zygote.pullback(VA) do sol
getindex.(Ref(sol), sym[obs_idx])
Expand Down Expand Up @@ -172,6 +198,25 @@ end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function Base.getindex(
VA::ODESolution{T}, sym::Union{Tuple, AbstractVector}, ::Colon) where {T}
function ODESolution_getindex_pullback(Δ)
sym = sym isa Tuple ? collect(sym) : sym
i = map(x -> symbolic_type(x) != NotSymbolic() ? variable_index(VA, x) : x, sym)

obs_idx = findall(s -> is_observed(VA, s), sym)
not_obs_idx = setdiff(1:length(sym), obs_idx)

gs_obs = obs_grads(VA, sym, isempty(obs_idx) ? nothing : obs_idx, Δ)
gs_not_obs = not_obs_grads(VA, sym, not_obs_idx, i, Δ)

a = Zygote.accum(gs_obs[1], gs_not_obs)

(a, nothing)
end
VA[sym], ODESolution_getindex_pullback
end

@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12
}(u,
args...) where {T1, T2, T3, T4, T5, T6, T7, T8,
Expand Down
7 changes: 3 additions & 4 deletions test/downstream/modelingtoolkit_remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ k = ShiftIndex(t)
# Roundabout method to avoid having to specify values for previous timestep
fn = DiscreteFunction(discsys)
ps = ModelingToolkit.MTKParameters(discsys, p)
discu0 = Dict([u0..., x(k-1) => 0.0, y(k-1) => 0.0, z(k-1) => 0.0])
push!(syss, discsys)
push!(probs, DiscreteProblem(fn, [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], (0, 10), ps))
push!(probs, DiscreteProblem(fn, getindex.((discu0,), unknowns(discsys)), (0, 10), ps))

for (sys, prob) in zip(syss, probs)
@test parameter_values(prob) isa ModelingToolkit.MTKParameters
Expand Down Expand Up @@ -147,9 +148,7 @@ for (sys, prob) in zip(syss, probs)
prob2 = @inferred baseType remake(prob; u0 = [sys.x => 0.5σ + 1], p = [sys.β => 0.5x + 1])
@test ugetter(prob2) [15.0, 0.0, 0.0]
@test pgetter(prob2) [28.0, 8.5, 10.0]
prob2 = @inferred baseType remake(prob; u0 = [:x => 0.5σ + 1], p = [ => 0.5x + 1])
@test ugetter(prob2) [15.0, 0.0, 0.0]
@test pgetter(prob2) [28.0, 8.5, 10.0]
# Not testing `Symbol => expr` since nested substitution doesn't work with that
end

@variables ud(t) xd(t) yd(t)
Expand Down

0 comments on commit ff83ff0

Please sign in to comment.