-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
make AD operators scimloperators #212
Conversation
Codecov ReportBase: 0.22% // Head: 82.92% // Increases project coverage by
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## master #212 +/- ##
===========================================
+ Coverage 0.22% 82.92% +82.70%
===========================================
Files 15 15
Lines 887 949 +62
===========================================
+ Hits 2 787 +785
+ Misses 885 162 -723
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
@ChrisRackauckas @YingboMa can you take a look? Maybe suggest some tests for |
TODO - add 5 arg |
TODO - write 5-arg mul! tests after SciML/SciMLOperators.jl#150 |
@ChrisRackauckas this should be good. Please take a look |
That looks fine in general, but why is it not just replacing the old ones? Delete the old ones and use their name? |
OrdinaryDiffEq uses |
This should be a drop-in replacement though? |
No. ODE.jl uses the type if W.J !== nothing && !(W.J isa SparseDiffTools.JacVec) I don't have the type JacVec anymore. My functions |
Then we can make this a breaking release v2, with this change and adding the ArrayInterface packages from the README as extension packages. |
test/test_jaches_products.jl
Outdated
#= | ||
ff1 = ODEFunction(lorenz, jac_prototype = JacVec{Float64}(lorenz, u0)) | ||
ff2 = ODEFunction(lorenz, jac_prototype = JacVec{Float64}(lorenz, u0, autodiff=false)) | ||
|
||
for ff in [ff1, ff2] | ||
prob = ODEProblem(ff, u0, tspan) | ||
@test solve(prob, TRBDF2()).retcode == :Success | ||
@test solve(prob, TRBDF2(linsolve = KrylovJL_GMRES())).retcode == :Success | ||
@test solve(prob, Exprb32()).retcode == :Success | ||
@test solve(prob, Rosenbrock23()).retcode == :Success | ||
@test solve(prob, Rosenbrock23(linsolve = KrylovJL_GMRES())).retcode == :Success | ||
end | ||
=# | ||
|
||
# HesVec |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OrdinaryDiffEq fails to precompile with this version of SparseDiffTools because of the changes we have discussed above. So I'm going to remove these tests and put them back in once we make the requisite changes in OrdianryDiffEq (in OrdinaryDiffEq.jl/src/derivative_utils.jl
). Making a note of that here: SciML/SciMLOperators.jl#142
isinplace = static_hasmethod(f, typeof((u, p, t))) | ||
outofplace = static_hasmethod(f, typeof((u, u, p, t))) | ||
|
||
if !(isinplace) & !(outofplace) | ||
error("$f must have signature f(u, p, t), or f(du, u, p, t)") | ||
end | ||
|
||
L = RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = autodiff, | ||
isinplace = isinplace, outofplace = outofplace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other function operators need this treatment too
vecprod = autodiff ? auto_vecjac : num_vecjac | ||
vecprod! = autodiff ? auto_vecjac! : num_vecjac! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also have a choice for Zygote. For stability, JvFormZygote() and such? Use the same one the Jv operators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll do this as a follow up, it should be included in the breaking update.
currently, function (L::FwdModeAutoDiffVecProd)(v, p, t)
L.vecprod(L.f, L.u, v)
end As compared to # Interpret the call as df/du' * u
function (L::RevModeAutoDiffVecProd)(v, p, t)
L.vecprod(_u -> L.f(_u, p, t), L.u, v)
end Do we want to enforce the same |
No, these don't have p and t. |
ok, then this is complete. needs to rerun ci once array interface stuff is cleared out |
15b01e5
to
ce0febe
Compare
@@ -7,3 +7,6 @@ function auto_vecjac(f, x, v) | |||
vv, back = Zygote.pullback(f, x) | |||
return vec(back(reshape(v, size(vv)))[1]) | |||
end | |||
|
|||
const ZygoteVecJac = VecJac |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this for?
Just two small questions are left. After this merges, then we should change the autodiff choices to be non-boolean logic and Zygote an extension package in order to release. |
@ChrisRackauckas about the zygote stuff... You mentioned above that we need Zygote operators for stability. I wrote functions The difference between if autodiff
@assert isdefined(SparseDiffTools, :auto_vecjac) "Please load Zygote with `using Zygote`, or `import Zygote` to use VecJac with `autodiff = true`."
end |
No, I meant that |
We can follow up with the extension package stuff. |
Plan: SciML/SciMLOperators.jl#142