Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make AD operators scimloperators #212

Merged
merged 32 commits into from
Feb 19, 2023
Merged

Conversation

vpuri3
Copy link
Contributor

@vpuri3 vpuri3 commented Feb 4, 2023

@codecov-commenter
Copy link

codecov-commenter commented Feb 6, 2023

Codecov Report

Base: 0.22% // Head: 82.92% // Increases project coverage by +82.70% 🎉

Coverage data is based on head (f322ec7) compared to base (a91d1da).
Patch coverage: 84.84% of modified lines in pull request are covered.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files
@@             Coverage Diff             @@
##           master     #212       +/-   ##
===========================================
+ Coverage    0.22%   82.92%   +82.70%     
===========================================
  Files          15       15               
  Lines         887      949       +62     
===========================================
+ Hits            2      787      +785     
+ Misses        885      162      -723     
Impacted Files Coverage Δ
src/SparseDiffTools.jl 100.00% <ø> (+60.00%) ⬆️
src/differentiation/vecjac_products.jl 62.74% <73.07%> (+62.74%) ⬆️
src/differentiation/jaches_products_zygote.jl 93.18% <83.33%> (+93.18%) ⬆️
src/differentiation/jaches_products.jl 95.62% <90.74%> (+95.62%) ⬆️
src/differentiation/vecjac_products_zygote.jl 57.14% <100.00%> (+57.14%) ⬆️
... and 9 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@vpuri3
Copy link
Contributor Author

vpuri3 commented Feb 6, 2023

@ChrisRackauckas @YingboMa can you take a look? Maybe suggest some tests for VecJacProd operator

@vpuri3
Copy link
Contributor Author

vpuri3 commented Feb 6, 2023

TODO - add 5 arg mul! in test, and add tests for VecJacProd.

@vpuri3
Copy link
Contributor Author

vpuri3 commented Feb 7, 2023

TODO - write 5-arg mul! tests after SciML/SciMLOperators.jl#150

@vpuri3 vpuri3 changed the title [WIP] make AD operators scimloperators make AD operators scimloperators Feb 7, 2023
@vpuri3
Copy link
Contributor Author

vpuri3 commented Feb 7, 2023

@ChrisRackauckas this should be good. Please take a look

test/runtests.jl Outdated Show resolved Hide resolved
@vpuri3
Copy link
Contributor Author

vpuri3 commented Feb 8, 2023

@ChrisRackauckas

@ChrisRackauckas
Copy link
Member

That looks fine in general, but why is it not just replacing the old ones? Delete the old ones and use their name?

@vpuri3
Copy link
Contributor Author

vpuri3 commented Feb 8, 2023

OrdinaryDiffEq uses JecVac in src/derivative_utils.jl. So deleting the old ones would break stuff there. Once we're done updating OrdinaryDiffEq, we can delete the old ones

@ChrisRackauckas
Copy link
Member

This should be a drop-in replacement though?

@vpuri3
Copy link
Contributor Author

vpuri3 commented Feb 8, 2023

No. ODE.jl uses the type JacVec. Eg

if W.J !== nothing && !(W.J isa SparseDiffTools.JacVec) 

I don't have the type JacVec anymore. My functions JacVecProd, etc form a FunctionOperator

@ChrisRackauckas
Copy link
Member

Then we can make this a breaking release v2, with this change and adding the ArrayInterface packages from the README as extension packages.

Comment on lines 91 to 105
#=
ff1 = ODEFunction(lorenz, jac_prototype = JacVec{Float64}(lorenz, u0))
ff2 = ODEFunction(lorenz, jac_prototype = JacVec{Float64}(lorenz, u0, autodiff=false))

for ff in [ff1, ff2]
prob = ODEProblem(ff, u0, tspan)
@test solve(prob, TRBDF2()).retcode == :Success
@test solve(prob, TRBDF2(linsolve = KrylovJL_GMRES())).retcode == :Success
@test solve(prob, Exprb32()).retcode == :Success
@test solve(prob, Rosenbrock23()).retcode == :Success
@test solve(prob, Rosenbrock23(linsolve = KrylovJL_GMRES())).retcode == :Success
end
=#

# HesVec
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OrdinaryDiffEq fails to precompile with this version of SparseDiffTools because of the changes we have discussed above. So I'm going to remove these tests and put them back in once we make the requisite changes in OrdianryDiffEq (in OrdinaryDiffEq.jl/src/derivative_utils.jl). Making a note of that here: SciML/SciMLOperators.jl#142

Comment on lines 97 to 105
isinplace = static_hasmethod(f, typeof((u, p, t)))
outofplace = static_hasmethod(f, typeof((u, u, p, t)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u, p, t), or f(du, u, p, t)")
end

L = RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = autodiff,
isinplace = isinplace, outofplace = outofplace)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other function operators need this treatment too

Comment on lines +94 to +95
vecprod = autodiff ? auto_vecjac : num_vecjac
vecprod! = autodiff ? auto_vecjac! : num_vecjac!
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also have a choice for Zygote. For stability, JvFormZygote() and such? Use the same one the Jv operators.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do this as a follow up, it should be included in the breaking update.

@vpuri3
Copy link
Contributor Author

vpuri3 commented Feb 18, 2023

currently, JacVec, HesVec, HesVecGrad do not use p, t. They accept functions like f(x), f(y, x).

function (L::FwdModeAutoDiffVecProd)(v, p, t)
    L.vecprod(L.f, L.u, v)
end

As compared to VecJac that accepts f(u, p, t), f(du, u, p, t)

# Interpret the call as df/du' * u
function (L::RevModeAutoDiffVecProd)(v, p, t)
    L.vecprod(_u -> L.f(_u, p, t), L.u, v)
end

Do we want to enforce the same (u,p,t) signature in the former?

@ChrisRackauckas
Copy link
Member

No, these don't have p and t.

@vpuri3
Copy link
Contributor Author

vpuri3 commented Feb 19, 2023

ok, then this is complete. needs to rerun ci once array interface stuff is cleared out

@@ -7,3 +7,6 @@ function auto_vecjac(f, x, v)
vv, back = Zygote.pullback(f, x)
return vec(back(reshape(v, size(vv)))[1])
end

const ZygoteVecJac = VecJac
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this for?

@ChrisRackauckas
Copy link
Member

Just two small questions are left. After this merges, then we should change the autodiff choices to be non-boolean logic and Zygote an extension package in order to release.

@vpuri3
Copy link
Contributor Author

vpuri3 commented Feb 19, 2023

@ChrisRackauckas about the zygote stuff...

You mentioned above that we need Zygote operators for stability. I wrote functions ZygoteVecJac and ZygoteHesVec that use num/auto_vecjac(!), num/autoback_hesvec(!).

The difference between VecJac and ZygoteVecJac is that the former defaults to autodiff = false and the later true. That way VecJac won't error with default kwargs if zygote isn't loaded. Since VecJac uses auto_vecjac(!) when autodiff = true, I included a warning in there to make sure Zygote is loaded:

    if autodiff 
        @assert isdefined(SparseDiffTools, :auto_vecjac) "Please load Zygote with `using Zygote`, or `import Zygote` to use VecJac with `autodiff = true`."
    end

@ChrisRackauckas
Copy link
Member

You mentioned above that we need Zygote operators for stability. I wrote functions ZygoteVecJac and ZygoteHesVec that use num/auto_vecjac(!), num/autoback_hesvec(!).

No, I meant that autodiff shouldn't be boolean logic but multi-valued logic through types choices, like AutoZygote().

@ChrisRackauckas
Copy link
Member

We can follow up with the extension package stuff.

@ChrisRackauckas ChrisRackauckas merged commit d79a00f into JuliaDiff:master Feb 19, 2023
@vpuri3 vpuri3 deleted the scimlops branch February 19, 2023 21:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants