-
-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: adjoints through observable functions #689
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #689 +/- ##
==========================================
- Coverage 31.79% 29.16% -2.64%
==========================================
Files 55 55
Lines 4535 4574 +39
==========================================
- Hits 1442 1334 -108
- Misses 3093 3240 +147 ☔ View full report in Codecov by Sentry. |
ext/SciMLBaseZygoteExt.jl
Outdated
@adjoint function literal_getproperty(sol::AbstractTimeseriesSolution, | ||
::Val{:u}) | ||
function solu_adjoint(Δ) | ||
zerou = zero(sol.prob.u0) | ||
_Δ = @. ifelse(Δ === nothing, (zerou,), Δ) | ||
(build_solution(sol.prob, sol.alg, sol.t, _Δ),) | ||
end | ||
sol.u, solu_adjoint | ||
end | ||
# @adjoint function literal_getproperty(sol::AbstractTimeseriesSolution, | ||
# ::Val{:u}) | ||
# function solu_adjoint(Δ) | ||
# zerou = zero(sol.prob.u0) | ||
# _Δ = @. ifelse(Δ === nothing, (zerou,), Δ) | ||
# (build_solution(sol.prob, sol.alg, sol.t, _Δ),) | ||
# end | ||
# sol.u, solu_adjoint | ||
# end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was returning the ODESolution
as the adjoint. It is also an issue because it shortcuts the gradients through parameters and instead replaces it with the sol.prob
, whereas we need to accumulate the gradients here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a unit test in the downstream set which shows this is fine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to. In fact, that's why I asked if anything was relying on this behavior previously. Could you suggest what kind of test you have in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be the root cause of many of the test failures? So that means it's caught by the tests already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is what the error is referring to. I am missing a branch https://github.com/DhairyaLGandhi/RecursiveArrayTools.jl/tree/dg/noproj which removes an extra projection rule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does refer to projecting to a VectorOfArray
, and that rule wasn't defined for Tangent
. Removing it gets us the expected results. If we want to project back to a VectorOfArray
type, then that needs to be handled elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now that we have restored the adjoint, I believe this can be resolved
Co-authored-by: Christopher Rackauckas <[email protected]>
Add your unit tests as a new downstream testset. |
Note that with SciMLSensitivity.jl#dg/ss (and SciML/SciMLStructures.jl#18) https://github.com/SciML/SciMLSensitivity.jl/blob/32f5ae7529a1957661b153f0ca9eff7e4caf0c5a/test/reversediff_output_types.jl#L14 looks like: julia> gs = gradient(u0 -> loss(u0), u0)
([-0.7779831009550049, 0.40028226620020263],) |
I've added a DAE example in the tests, but switched it off until we get SciMLSensitivity updated as well. The DC motor example fails to initialize currently. If there's a different test case, I can also hook that in. |
Project.toml
Outdated
@@ -68,6 +68,7 @@ Logging = "1.10" | |||
Makie = "0.20" | |||
Markdown = "1.10" | |||
ModelingToolkit = "8.75, 9" | |||
ModelingToolkitStandardLibrary = "2.7" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be gated int Downstream
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added bounds to test/downstream/Project.toml
in 940ea78, should I remove anything from the regular test environment or do i need to declare these in both places?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove it from the regular
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
d061ce4 does that
@ChrisRackauckas SciMLSensitivity test pass with d061ce4 (latest commit), but the Core (Downstream) tests get cancelled before anything runs. Is that because the Core (Python) tests fail for unrelated reasons? |
So what happens here is:
|
Both CI/ Python and CI/ Downgrade seem to be failing on master as well. |
The problem I mentioned has not been fixed. It's not a problem with ADTypes per se, it's a problem with environment stacking |
Is there anything left to be done in this PR? |
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Currently, ADing through observables errors, however this allows us to AD through the observable function via symbolic indexing and accumulate and return grads against
sol
This needs handling as part of when the observable symbol is in a collection (vector/ tuple/ ...), and also for various ADs like ReverseDiff and Enzyme.
Add any other context about the problem here.
Ideally, this would be handled by removing all the adjoints related to
getindex
and let AD do the heavy lifting for us. But this is faster to implement in its current form.