Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Correct gradients for vector of symbols while indexing #678

Merged
merged 7 commits into from
May 2, 2024

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented Apr 24, 2024

Checklist

  • Appropriate tests were added
  • Any code changes were done in a way that does not break public API
  • All documentation related to code changes were updated
  • The new code follows the
    contributor guidelines, in particular the SciML Style Guide and
    COLPRAC.
  • Any new documentation only uses public API

Additional context

Before this PR, indexing with vector of symbols fails at the check for symbolic variables at

i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym

and returns

@parameters σ ρ β
@variables x(t) y(t) z(t) w(t)

eqs = [D(D(x)) ~ σ * (y - x),
    D(y) ~ x *- z) - y,
    D(z) ~ x * y - β * z,
    w ~ x + y + z]

@mtkbuild sys = ODESystem(eqs, t)

u0 = [D(x) => 2.0,
    x => 1.0,
    y => 0.0,
    z => 0.0]

p ==> 28.0,
    ρ => 10.0,
    β => 8 / 3]

tspan = (0.0, 100.0)
prob = ODEProblem(sys, u0, tspan, p, jac = true)
sol = solve(prob, Tsit5())
julia> gs2 = Zygote.gradient(sol) do sol
    sum(sum.(sol[[sys.x, sys.y]]))
end
(VectorOfArray{Float64, 2, Vector{Vector{Float64}}}([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]    [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),)

This PR:

julia> gs2 = Zygote.gradient(sol) do sol
    sum(sum.(sol[[sys.x, sys.y]]))
end
(VectorOfArray{Float64, 2, Vector{Vector{Float64}}}([[0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]    [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0], [0.0, 1.0, 0.0, 1.0]]),)

Requires SciML/RecursiveArrayTools.jl#367

Copy link

codecov bot commented Apr 24, 2024

Codecov Report

Attention: Patch coverage is 0% with 20 lines in your changes are missing coverage. Please review.

Project coverage is 31.64%. Comparing base (1238b2b) to head (071ad2f).
Report is 11 commits behind head on master.

Files Patch % Lines
ext/SciMLBaseZygoteExt.jl 0.00% 18 Missing ⚠️
ext/SciMLBaseChainRulesCoreExt.jl 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #678      +/-   ##
==========================================
- Coverage   31.72%   31.64%   -0.08%     
==========================================
  Files          55       55              
  Lines        4505     4519      +14     
==========================================
+ Hits         1429     1430       +1     
- Misses       3076     3089      +13     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ChrisRackauckas
Copy link
Member

Add a test?

@DhairyaLGandhi
Copy link
Member Author

Yes, I wanted to open this to gauge feedback and see breakages downstream.

@ChrisRackauckas
Copy link
Member

No breakages downstream.

@DhairyaLGandhi
Copy link
Member Author

There seem to be a lot of failures because of some rate limits being hit with codecov

@ChrisRackauckas
Copy link
Member

Downstream currently has a failure from the RealInput guesses change. Can you confirm the relevant downstream tests are passing locally for you?

@DhairyaLGandhi
Copy link
Member Author

I do see a failure in a clean env, checking

ERROR: LoadError: MethodError: no method matching _getindex(::ODESolution{…}, ::NotSymbolic, ::Vector{…})

Closest candidates are:
  _getindex(::AbstractVectorOfArray{T, N}, ::NotSymbolic, ::Colon...) where {T<:Number, N}
   @ RecursiveArrayTools ~/arpa/jsmo/RecursiveArrayTools.jl/src/vector_of_array.jl:307
  _getindex(::AbstractVectorOfArray{T, N}, ::NotSymbolic, ::Colon...) where {T, N}
   @ RecursiveArrayTools ~/arpa/jsmo/RecursiveArrayTools.jl/src/vector_of_array.jl:296
  _getindex(::AbstractVectorOfArray{T, N}, ::NotSymbolic, ::AbstractArray{Bool}, Colon...) where {T, N}
   @ RecursiveArrayTools ~/arpa/jsmo/RecursiveArrayTools.jl/src/vector_of_array.jl:313
  ...

@DhairyaLGandhi
Copy link
Member Author

That was a typo, and in the process I also figured that some of the rrules in SciMLBaseChainRulesExt were using old SII syntax, so I updated those too.

The added tests pass for me locally.

@AayushSabharwal
Copy link
Member

This seems to also "just work" for getu:

julia> gxy = getu(prob, [sys.x, sys.y])
julia> gradsii = Zygote.gradient(sol) do sol
           sum(sum.(gxy(sol)))
       end
julia> gradidx = Zygote.gradient(sol) do sol
           sum(sum.(sol[[sys.x, sys.y]]))
       end
julia> gradidx[1] == gradsii[1]
true

@AayushSabharwal
Copy link
Member

Might be because I'm on SciML/SymbolicIndexingInterface.jl#72

@DhairyaLGandhi
Copy link
Member Author

DhairyaLGandhi commented May 2, 2024

"Just works" for getu makes sense with these changes. Thanks for checking as well!

@AnasAbdelR
Copy link

@ChrisRackauckas thoughts on merging?

@ChrisRackauckas ChrisRackauckas merged commit c261a03 into SciML:master May 2, 2024
33 of 42 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants