-
-
Notifications
You must be signed in to change notification settings - Fork 104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Correct gradients for vector of symbols while indexing #678
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #678 +/- ##
==========================================
- Coverage 31.72% 31.64% -0.08%
==========================================
Files 55 55
Lines 4505 4519 +14
==========================================
+ Hits 1429 1430 +1
- Misses 3076 3089 +13 ☔ View full report in Codecov by Sentry. |
Add a test? |
Yes, I wanted to open this to gauge feedback and see breakages downstream. |
No breakages downstream. |
There seem to be a lot of failures because of some rate limits being hit with codecov |
Downstream currently has a failure from the RealInput guesses change. Can you confirm the relevant downstream tests are passing locally for you? |
I do see a failure in a clean env, checking ERROR: LoadError: MethodError: no method matching _getindex(::ODESolution{…}, ::NotSymbolic, ::Vector{…})
Closest candidates are:
_getindex(::AbstractVectorOfArray{T, N}, ::NotSymbolic, ::Colon...) where {T<:Number, N}
@ RecursiveArrayTools ~/arpa/jsmo/RecursiveArrayTools.jl/src/vector_of_array.jl:307
_getindex(::AbstractVectorOfArray{T, N}, ::NotSymbolic, ::Colon...) where {T, N}
@ RecursiveArrayTools ~/arpa/jsmo/RecursiveArrayTools.jl/src/vector_of_array.jl:296
_getindex(::AbstractVectorOfArray{T, N}, ::NotSymbolic, ::AbstractArray{Bool}, Colon...) where {T, N}
@ RecursiveArrayTools ~/arpa/jsmo/RecursiveArrayTools.jl/src/vector_of_array.jl:313
... |
That was a typo, and in the process I also figured that some of the The added tests pass for me locally. |
This seems to also "just work" for julia> gxy = getu(prob, [sys.x, sys.y])
julia> gradsii = Zygote.gradient(sol) do sol
sum(sum.(gxy(sol)))
end
julia> gradidx = Zygote.gradient(sol) do sol
sum(sum.(sol[[sys.x, sys.y]]))
end
julia> gradidx[1] == gradsii[1]
true |
Might be because I'm on SciML/SymbolicIndexingInterface.jl#72 |
"Just works" for getu makes sense with these changes. Thanks for checking as well! |
@ChrisRackauckas thoughts on merging? |
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Before this PR, indexing with vector of symbols fails at the check for symbolic variables at
SciMLBase.jl/ext/SciMLBaseZygoteExt.jl
Line 111 in 9d87ca0
and returns
This PR:
Requires SciML/RecursiveArrayTools.jl#367