Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Robustness to RATv3 #963

Closed
wants to merge 9 commits into from
Closed

Improve Robustness to RATv3 #963

wants to merge 9 commits into from

Conversation

ChrisRackauckas
Copy link
Member

Since AbstractVectorOfArray should be handled properly given the deprecation warning

@AayushSabharwal
Copy link
Member

AayushSabharwal commented Dec 29, 2023

So the issue with this is the following (so far as I can determine):
In this line of Zygote, it maps over a VectorOfArray y∂b. The code expects that iterating over it will yield items of type eltype(y∂b), but VoA doesn't work like that and yields the subarrays instead, which breaks the code

EDIT: so fixing this and reusing the Zygote.unbroadcast(::AbstractArray, ...) dispatch for ::AbstractVectorOfArray fixes this issue, but there's another error further in

@AayushSabharwal
Copy link
Member

AayushSabharwal commented Dec 29, 2023

The "fix" I have is also not great. There's the underlying issue that sometimes Zygote turns VectorOfArrays into Refs containing a NamedTuple. There's a hack around this in RAT's ZygoteExt, but it crops up here too.

While stepping through Zygote:

1|julia> ȳ
Base.RefValue{Any}((u = [[1.0, 0.0, 4.0, 0.0], [1.0, 0.0, 4.0, 0.0]],))

This should be a VectorOfArray

@ChrisRackauckas
Copy link
Member Author

Oh that means that the constructor of a VectorOfArray is missing a chain rule.

@AayushSabharwal
Copy link
Member

Could it also be that the Array(::VectorOfArray) functions are missing an adjoint? Adding one for that fixed it for me

@ChrisRackauckas
Copy link
Member Author

Ahh that could do it.

@AayushSabharwal
Copy link
Member

I think we also need a method for this:

ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{VectorOfArray, NamedTuple{(), Tuple{}}})(::Vector{Float64})

@AayushSabharwal
Copy link
Member

I don't fully understand what Zygote.unbroadcast is supposed to do. I hacked together something that works specifically for VectorOfArray, but it doesn't work for ODESolution

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants